Giter VIP home page Giter VIP logo

Comments (10)

zhengkaitu avatar zhengkaitu commented on June 15, 2024 1

Thanks for circling back and glad that helps. And the answer to the sharing question is yes! The updated version with codes for distributed training will be released as part of the bigger project we're working on (ASKCOS version 2), very soon.

from graph2smiles.

zhengkaitu avatar zhengkaitu commented on June 15, 2024

This is semi-consistent with our original implementation (in this repo). Graph2SMILES is slower as it is somewhat equivalent to MPNN encoder + Molecular Transformer in terms of architecture, and with that heavy computation for all-pair graph distance. However, our runs on V100 were only 50% slower (e.g., 70s per 100 steps -> 100s per 100 steps, as far as I can remember). 5 times slower doesn't look typical.

Also, we've actually reimplemented the distance calculator in C++ which would speed things up quite a bit (like to within 10% time of MT). Other codes have undergone significant changes so I'm a bit hesitant to update the whole repo, but the C-calculator should be a plug-and-play. Just drop the precompiled c_calculate.so under ./utils, and update collate_graph_distances() accordingly in data_utils.py to use it.

See https://github.com/coleygroup/Graph2SMILES/tree/c-calculate/utils

from graph2smiles.

AslantheAslan avatar AslantheAslan commented on June 15, 2024

Thanks for the support to make graph distance calculation faster. However, I could not actually understand how data_utils.py can import utils.ctypes_calculator, since there is no such module. I am doing the alterations you mentioned but it returns "ModuleNotFoundError". To fix this, when I try to import c_calculate instead of utils.ctypes_calculator, I get ImportError: dynamic module does not define module export function (PyInit_c_calculate).

from graph2smiles.

zhengkaitu avatar zhengkaitu commented on June 15, 2024

Apologies but I forgot to include ctypes_calculator.py. It's now in the branch. This file essentially defines a DistanceCalculator class as a wrapper around the C++ module in the .so, with the help of numpy.ctypeslib so that things can be easily passed into and out of the C++ module as numpy buffers.

from graph2smiles.

AslantheAslan avatar AslantheAslan commented on June 15, 2024

Again, thank for your very rapid reply but there is still another import error from train_utils when data_utils tries to import log_rank_0 from utils.train_utils. I think log_rank_0 is not given in train_utils and I am unable to write that function from scratch since I don't exactly know how it works.

from graph2smiles.

zhengkaitu avatar zhengkaitu commented on June 15, 2024

I think it's just another helper function that we added as we experimented with distributed training (in our private version). Adding the definition into train_utils should hopefully be a quick fix.

def log_tensor(tensor, tensor_name: str):
    logging.info(f"--------------------------{tensor_name}--------------------------")
    logging.info(tensor)
    if isinstance(tensor, torch.Tensor):
        logging.info(tensor.shape)
    elif isinstance(tensor, np.ndarray):
        logging.info(tensor.shape)
    elif isinstance(tensor, list):
        try:
            for item in tensor:
                logging.info(item.shape)
        except Exception as e:
            logging.info(f"Error: {e}")
            logging.info("List items are not tensors, skip shape logging.")

from graph2smiles.

AslantheAslan avatar AslantheAslan commented on June 15, 2024

I don't think that this is the solution since the problem is not with log_tensor, instead log_rank_0. There supposed to be log_rank_0 in train_utils since your data_utils tries to import that. When I try to replace all log_rank_0 functions with log_tensor, it does not work. I think I should explicitly add the definition of log_rank_0 into train_utils, that would be really helpful if you can provide that function. I guess you also updated train_utils in your local but forgot to update it on remote git.

(Note: Even after commenting lines that includes log_rank_0, I am having troubles with args.mask_rel_chirality since mask_rel_chirality is not defined as an argument. I think there you specified another argument in train script but it is not updated.)

(Note 2: After commenting the if statement that check args.mask_rel_chirality==1, the training runs without problem. But the elapsed time for 100 steps decreased from 106 sec to 94 only.

from graph2smiles.

zhengkaitu avatar zhengkaitu commented on June 15, 2024

Here you go.

def log_rank_0(message):
    if torch.distributed.is_initialized():
        if torch.distributed.get_rank() == 0:
            logging.info(message)
            sys.stdout.flush()
    else:
        logging.info(message)
        sys.stdout.flush()

from graph2smiles.

zhengkaitu avatar zhengkaitu commented on June 15, 2024

(Note 2: After commenting the if statement that check args.mask_rel_chirality==1, the training runs without problem. But the elapsed time for 100 steps decreased from 106 sec to 94 only.

This I have little idea. There are many possibilities like torch/cuda versions or even the devices themselves. We've seen some performance variation among different GPUs but not anything like a 5X difference.

from graph2smiles.

AslantheAslan avatar AslantheAslan commented on June 15, 2024

Thanks for your reply, again. Applying your distance calculator reduced the training time from 100 sec to 90 sec. Considering the fact that it took around 70 sec in your setup, I think the speed I have is kind of fair right now. 5X difference is because we measure the training speed of Molecular Transformer differently. In my setup, 1k iterations take around 200 sec for MolTransformer. I am using the same torch and CUDA versions for both models. In conclusion, I think the issue can be closed since I should find a way to utilize GPU more efficiently in my setting.

Also, I am guessing that the distributed version will not be shared publicly, but do you consider sharing that in the future?

from graph2smiles.

Related Issues (8)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.