Comments (10)
Thanks for circling back and glad that helps. And the answer to the sharing question is yes! The updated version with codes for distributed training will be released as part of the bigger project we're working on (ASKCOS version 2), very soon.
from graph2smiles.
This is semi-consistent with our original implementation (in this repo). Graph2SMILES is slower as it is somewhat equivalent to MPNN encoder + Molecular Transformer in terms of architecture, and with that heavy computation for all-pair graph distance. However, our runs on V100 were only 50% slower (e.g., 70s per 100 steps -> 100s per 100 steps, as far as I can remember). 5 times slower doesn't look typical.
Also, we've actually reimplemented the distance calculator in C++ which would speed things up quite a bit (like to within 10% time of MT). Other codes have undergone significant changes so I'm a bit hesitant to update the whole repo, but the C-calculator should be a plug-and-play. Just drop the precompiled c_calculate.so under ./utils, and update collate_graph_distances() accordingly in data_utils.py to use it.
See https://github.com/coleygroup/Graph2SMILES/tree/c-calculate/utils
from graph2smiles.
Thanks for the support to make graph distance calculation faster. However, I could not actually understand how data_utils.py can import utils.ctypes_calculator, since there is no such module. I am doing the alterations you mentioned but it returns "ModuleNotFoundError". To fix this, when I try to import c_calculate instead of utils.ctypes_calculator, I get ImportError: dynamic module does not define module export function (PyInit_c_calculate).
from graph2smiles.
Apologies but I forgot to include ctypes_calculator.py
. It's now in the branch. This file essentially defines a DistanceCalculator
class as a wrapper around the C++ module in the .so, with the help of numpy.ctypeslib so that things can be easily passed into and out of the C++ module as numpy buffers.
from graph2smiles.
Again, thank for your very rapid reply but there is still another import error from train_utils when data_utils tries to import log_rank_0 from utils.train_utils. I think log_rank_0 is not given in train_utils and I am unable to write that function from scratch since I don't exactly know how it works.
from graph2smiles.
I think it's just another helper function that we added as we experimented with distributed training (in our private version). Adding the definition into train_utils should hopefully be a quick fix.
def log_tensor(tensor, tensor_name: str):
logging.info(f"--------------------------{tensor_name}--------------------------")
logging.info(tensor)
if isinstance(tensor, torch.Tensor):
logging.info(tensor.shape)
elif isinstance(tensor, np.ndarray):
logging.info(tensor.shape)
elif isinstance(tensor, list):
try:
for item in tensor:
logging.info(item.shape)
except Exception as e:
logging.info(f"Error: {e}")
logging.info("List items are not tensors, skip shape logging.")
from graph2smiles.
I don't think that this is the solution since the problem is not with log_tensor, instead log_rank_0. There supposed to be log_rank_0 in train_utils since your data_utils tries to import that. When I try to replace all log_rank_0 functions with log_tensor, it does not work. I think I should explicitly add the definition of log_rank_0 into train_utils, that would be really helpful if you can provide that function. I guess you also updated train_utils in your local but forgot to update it on remote git.
(Note: Even after commenting lines that includes log_rank_0, I am having troubles with args.mask_rel_chirality since mask_rel_chirality is not defined as an argument. I think there you specified another argument in train script but it is not updated.)
(Note 2: After commenting the if statement that check args.mask_rel_chirality==1, the training runs without problem. But the elapsed time for 100 steps decreased from 106 sec to 94 only.
from graph2smiles.
Here you go.
def log_rank_0(message):
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
logging.info(message)
sys.stdout.flush()
else:
logging.info(message)
sys.stdout.flush()
from graph2smiles.
(Note 2: After commenting the if statement that check args.mask_rel_chirality==1, the training runs without problem. But the elapsed time for 100 steps decreased from 106 sec to 94 only.
This I have little idea. There are many possibilities like torch/cuda versions or even the devices themselves. We've seen some performance variation among different GPUs but not anything like a 5X difference.
from graph2smiles.
Thanks for your reply, again. Applying your distance calculator reduced the training time from 100 sec to 90 sec. Considering the fact that it took around 70 sec in your setup, I think the speed I have is kind of fair right now. 5X difference is because we measure the training speed of Molecular Transformer differently. In my setup, 1k iterations take around 200 sec for MolTransformer. I am using the same torch and CUDA versions for both models. In conclusion, I think the issue can be closed since I should find a way to utilize GPU more efficiently in my setting.
Also, I am guessing that the distributed version will not be shared publicly, but do you consider sharing that in the future?
from graph2smiles.
Related Issues (8)
- inconsistent results in the paper HOT 1
- reaction prediction HOT 3
- Why pad `a_graph` and `b_graph` to length 11? HOT 1
- I encountered a problem while training the model. HOT 1
- Large data, configuration? HOT 1
- Raw data (how clean and token) HOT 3
- Pretrained model arguments mismatch the dataset name and expected output size HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from graph2smiles.