Giter VIP home page Giter VIP logo

geomol's People

Contributors

hannesstark avatar pattanaikl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

geomol's Issues

A question about the direction of alpha angle?

Dear Authors thanks for your wonderful job.

In the paper, the alpha angle is the sum of many different torsion angles along with X-Y rotatable, but when you use such an alpha angle to rotate fragments of the molecule you will counter a question, whether the alpha angle value rotated by X->Y direction or Y->X direction?

The figure below describes the problem, I rotate each of the fragments (LS of X) by using X<-Y direction, left bottom part is OK, but the left top part is wrong, if you use X->Y direction to rotate it(left top part) will become correct again. (which means the alpha angle has two directions somehow) I would not figure it out for a while.

In this example, X will be the larger ID than Y.
image

Code Problem in permutations for symmetric hydrogens

Hi, thanks for the insight of this great work and for releasing the code!
But when reproducing training, I have encountered the following errors:

  1. in model/model.py, in GeoMol, assign_neighobours, i got

File "/home/dgxtest/3D-pretrain/GeoMol-main/model/model.py", line 180, in assign_neighborhoods
RuntimeError: "mul_cuda" not implemented for 'Bool'
self.leaf_hydrogens[a] = self.leaf_hydrogens[a] * True if self.leaf_hydrogens[a].sum() > 1 else self.leaf_hydrogens[a] * False

I can see that this code is intended to executing a XNOR operation (not so convincing now due to error2), so I changed the logic into the following and fix the error

self.leaf_hydrogens[a] = ~(self.leaf_hydrogens[a] ^ True) if self.leaf_hydrogens[a].sum() > 1 else ~(self.leaf_hydrogens[a] ^ False)

  1. But the following error ensues

File "/home/dgxtest/3D-pretrain/GeoMol-main/model/model.py", line 332, in ground_truth_local_stats
n_perms[0:len(perms), self.leaf_hydrogens[a]] = perms
'RuntimeError: shape mismatch: value tensor of shape [24, 4] cannot be broadcast to indexing result of shape [6, 4]'

in this case, self.leaf_hydrogens[a] is [True, True, True, True], thus leading to a permutation of length 24 in "perms" while "n_perms" is hardcoded in shape [6, 4]
I am not sure whether my modification in error1 leads to a wrong self.leaf_hydrogens in error2, would you please help me point it out? very much appreciated.

btw, I am using torch1.7.0+cu110 and torch-geometric 1.6.3 as metioned in issue #2.

OS Error with torch-sparse

I was trying the run this repository with the QM9 dataset. First I ran into the issue that was reported in issue #2 and #4.

Based on that I tried downgrading the torch version to 1.7.0 and torch-geometric to both 1.6.3 and 1.7.2. However I was unable to get past the below error. I tried looking for other solutions for the below error but was not able to find many resources apart from this one here.

Perhaps if a requirement file could be shared from the owner of this repository, I would be able to create an environment where this code can run.

Let me know if more info is needed from my side.

:~/Code/geo_mol/GeoMol$ python train.py --data_dir /home/vishwesh/Code/geo_mol/GeoMol/data/QM9/qm9 --split_path /home/vishwesh/Code/geo_mol/GeoMol/data/QM9/splits/split0.npy --log_dir ./test_run --n_epochs 250 --dataset qm9
Traceback (most recent call last):
  File "train.py", line 9, in <module>
    from model.model import GeoMol
  File "/home/vishwesh/Code/geo_mol/GeoMol/model/model.py", line 5, in <module>
    import torch_geometric as tg
  File "/home/vishwesh/anaconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/__init__.py", line 5, in <module>
    import torch_geometric.data
  File "/home/vishwesh/anaconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/__init__.py", line 1, in <module>
    from .data import Data
  File "/home/vishwesh/anaconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/data.py", line 8, in <module>
    from torch_sparse import coalesce, SparseTensor
  File "/home/vishwesh/anaconda3/envs/GeoMol/lib/python3.7/site-packages/torch_sparse/__init__.py", line 19, in <module>
    torch.ops.load_library(spec.origin)
  File "/home/vishwesh/anaconda3/envs/GeoMol/lib/python3.7/site-packages/torch/_ops.py", line 105, in load_library
    ctypes.CDLL(path)
  File "/home/vishwesh/anaconda3/envs/GeoMol/lib/python3.7/ctypes/__init__.py", line 364, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /home/vishwesh/anaconda3/envs/GeoMol/lib/python3.7/site-packages/torch_sparse/_version_cpu.so: undefined symbol: _ZN3c106detail12infer_schema20make_function_schemaENS_8ArrayRefINS1_11ArgumentDefEEES4_

Expecting all tensors to be on same device, but found two device cuda:0 and cpu, when running the generate_confs.py

Hello,

I am facing an issue when trying to run the generate_confs.py using the given pretrained models. However I am running into the error shared below, please share your insights, if there is a preference between GPU and CPU when trying to run the inference.

I also tried switching between cpu and gpu for the model, but no luck so far.

  0%|          | 0/1000 [02:14<?, ?it/s]
Traceback (most recent call last):
  File "/home/vishwesh/Software/pycharm-community-2021.1.1/plugins/python-ce/helpers/pydev/pydevd.py", line 1483, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/home/vishwesh/Software/pycharm-community-2021.1.1/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/vishwesh/Code/geo_mol/GeoMol/generate_confs.py", line 63, in <module>
    model(data, inference=True, n_model_confs=n_confs*2)
  File "/home/vishwesh/anaconda3/envs/geomol_v2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/vishwesh/Code/geo_mol/GeoMol/model/model.py", line 81, in forward
    self.generate_model_prediction(data.x, data.edge_index, data.edge_attr, data.batch, data.chiral_tag)
  File "/home/vishwesh/Code/geo_mol/GeoMol/model/model.py", line 686, in generate_model_prediction
    x1, x2, h_mol = self.embed(x, edge_index, edge_attr, batch)
  File "/home/vishwesh/Code/geo_mol/GeoMol/model/model.py", line 228, in embed
    x1, _ = self.gnn(x, edge_index, edge_attr)
  File "/home/vishwesh/anaconda3/envs/geomol_v2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/vishwesh/Code/geo_mol/GeoMol/model/GNN.py", line 126, in forward
    x = self.node_init(x)
  File "/home/vishwesh/anaconda3/envs/geomol_v2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/vishwesh/Code/geo_mol/GeoMol/model/GNN.py", line 40, in forward
    x = self.layers[i](x)
  File "/home/vishwesh/anaconda3/envs/geomol_v2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/vishwesh/anaconda3/envs/geomol_v2/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 96, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/vishwesh/anaconda3/envs/geomol_v2/lib/python3.8/site-packages/torch/nn/functional.py", line 1847, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking arugment for argument mat2 in method wrapper_mm)

Process finished with exit code 1

Questions about stereoisomer issues in the evaluation of GeoMol

def clean_confs(smi, confs):
good_ids = []
smi = Chem.MolToSmiles(Chem.MolFromSmiles(smi), isomericSmiles=False)
for i, c in enumerate(confs):
conf_smi = Chem.MolToSmiles(Chem.RemoveHs(c), isomericSmiles=False)
if conf_smi == smi:
good_ids.append(i)
return [confs[i] for i in good_ids]

  • This function used to filter out conformers with inconsistent smiles relative to the given smiles (in this script this is corrected_smi). In my reproduction, most cases that the inconsistency exists are molecules with a Z/E-double bond. These cases will not be filtered out if isomericSmiles=False, which makes me confused and I'm not sure if this is a mistake.
  • For example, now conformers with smiles Cc1cc(C(=O)c2cnc(/N=C/N(C)C)s2)c(F)cc1Cl and Cc1cc(C(=O)c2cnc(/N=C\N(C)C)s2)c(F)cc1Cl in reference data will all be saved for comparison although GeoMol was used to only generate conformers with Cc1cc(C(=O)c2cnc(/N=C\N(C)C)s2)c(F)cc1Cl.

if conf_canonical_smi != canonical_smi:
continue

  • Compared with that, the code in model/featurization.py filtered out the conformers with inconsistent smiles relative to the smiles in the dataset.
  • So actually, if I used compare_confs.py to calculate the performance with isomericSmiles=False, the conformers with different isomeric SMILES will not be filtered out and the performance was the same as or even worse than before (since that GeoMol was used to generate only one stereoisomer based on the given SMILES).
  • The performance comparison between GeoMol prediction and reference data (before using clean_confs; using clean_confs; change isomericSmiles=True:
**Before**
Recall Coverage: Mean = 74.78, Median = 85.00
Recall AMR: Mean = 0.9471, Median = 0.9176
Precision Coverage: Mean = 71.84, Median = 87.50
Precision AMR: Mean = 1.0035, Median = 0.9649

**After (with clean_confs, more confs are included than before)**
Recall Coverage: Mean = 74.30, Median = 90.00
Recall AMR: Mean = 0.9489, Median = 0.8797
Precision Coverage: Mean = 65.50, Median = 81.80
Precision AMR: Mean = 1.1044, Median = 1.0041

**isomericSmiles=True**
Recall Coverage: Mean = 83.38, Median = 100.00
Recall AMR: Mean = 0.8233, Median = 0.8079
Precision Coverage: Mean = 72.73, Median = 87.50
Precision AMR: Mean = 0.9833, Median = 0.8895

As you can see, if isomericSmiles=True, the performance in GeoMol paper's result can be reproduced.


When I tried to walk further related to this issue, I found another weird thing that GeoMol will generate the conformers close in 3D geometry though with different stereoisomerism in SMILES as input. And the conformers close in 3D geometry are different stereoisomers in their SMILES. This issue does not exist in RDKit ETKDG and I am not sure if it will affect GeoMol's performance on these molecules. Here I give two examples on that,

SMILES GeoMol (trans) GeoMol (cis) ETKDG (trans) ETKDG (cis)
O=S(=O)(_N=C(_c1ccccc1)N1CCOCC1)c1ccc(Br)cc1 image image image image
Cc1cc(C(=O)c2cnc(_N=C_N(C)C)s2)c(F)cc1Cl image image image image

RuntimeError: Cannot re-initialize CUDA in forked subprocess (solved)

Just in case anyone else has the same issue, I received the following error when during training.

Starting training...
  0%|                                                                         | 0/625 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/grads/e/ethanycx/workspace/GeoMol/train.py", line 73, in <module>
    train_loss = train(model, train_loader, optimizer, device, scheduler, logger if args.verbose else None, epoch, writer)
  File "/home/grads/e/ethanycx/workspace/GeoMol/model/training.py", line 18, in train
    for i, data in tqdm(enumerate(loader), total=len(loader)):
  File "/home/grads/e/ethanycx/miniconda3/envs/torch/lib/python3.9/site-packages/tqdm/std.py", line 1178, in __iter__
    for obj in iterable:
  File "/home/grads/e/ethanycx/miniconda3/envs/torch/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
    data = self._next_data()
  File "/home/grads/e/ethanycx/miniconda3/envs/torch/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1085, in _next_data
    return self._process_data(data)
  File "/home/grads/e/ethanycx/miniconda3/envs/torch/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1111, in _process_data
    data.reraise()
  File "/home/grads/e/ethanycx/miniconda3/envs/torch/lib/python3.9/site-packages/torch/_utils.py", line 428, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/grads/e/ethanycx/miniconda3/envs/torch/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/grads/e/ethanycx/miniconda3/envs/torch/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/grads/e/ethanycx/miniconda3/envs/torch/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/grads/e/ethanycx/miniconda3/envs/torch/lib/python3.9/site-packages/torch_geometric/data/dataset.py", line 187, in __getitem__
    data = self.get(self.indices()[idx])
  File "/home/grads/e/ethanycx/workspace/GeoMol/model/featurization.py", line 74, in get
    data.edge_index_dihedral_pairs = get_dihedral_pairs(data.edge_index, data=data)
  File "/home/grads/e/ethanycx/workspace/GeoMol/model/utils.py", line 122, in get_dihedral_pairs
    keep = [t.to(device) for t in keep]
  File "/home/grads/e/ethanycx/workspace/GeoMol/model/utils.py", line 122, in <listcomp>
    keep = [t.to(device) for t in keep]
  File "/home/grads/e/ethanycx/miniconda3/envs/torch/lib/python3.9/site-packages/torch/cuda/__init__.py", line 163, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Versions: torch==1.7.1, torch_geometric==1.7.0

This seems to be a Pytorch issue with the DataLoader. I fixed the issue by inserting the following lines at line 18 in train.py (and indent later lines accordingly):

if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn', force=True)

and changing line 240 in featurization.py to num_workers=1,.

Runtime Error when enumerating train_loader during training

Hi! I really appreciate your fantastic work and code. And I've reproduced your work through the guidance in README.md
However, I've received this error when executing the training process with train.py.

Describe the error

Starting training...
  0%|                                                                                                                                                       | 0/625 [00:00<?, ?it/s][11:18:30] Explicit valence for atom # 0 N, 4, is greater than permitted
  0%|                                                                                                                                                       | 0/625 [22:56<?, ?it/s]
Traceback (most recent call last):
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/pubhome/qcxia02/.vscode-server/extensions/ms-python.python-2021.11.1422169775/pythonFiles/lib/python/debugpy/__main__.py", line 45, in <module>
    cli.main()
  File "/pubhome/qcxia02/.vscode-server/extensions/ms-python.python-2021.11.1422169775/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 444, in main
    run()
  File "/pubhome/qcxia02/.vscode-server/extensions/ms-python.python-2021.11.1422169775/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 285, in run_file
    runpy.run_path(target_as_str, run_name=compat.force_str("__main__"))
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/pubhome/qcxia02/git-repo/AI-CONF/GeoMol/train.py", line 74, in <module>
    train_loss = train(model, train_loader, optimizer, device, scheduler, logger if args.verbose else None, epoch, writer)
  File "/pubhome/qcxia02/git-repo/AI-CONF/GeoMol/model/training.py", line 18, in train
    for i, data in tqdm(enumerate(loader), total=len(loader)):
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/tqdm/std.py", line 1178, in __iter__
    for obj in iterable:
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
    return self._process_data(data)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
    data.reraise()
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/_utils.py", line 434, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/loader/dataloader.py", line 39, in __call__
    return self.collate(batch)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/loader/dataloader.py", line 20, in collate
    self.exclude_keys)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/batch.py", line 75, in from_data_list
    exclude_keys=exclude_keys,
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 86, in collate
    increment)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 142, in _collate
    data_list, stores, increment)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 162, in _collate
    value = torch.cat(values, dim=cat_dim or 0)
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 19 but got size 21 for tensor number 1 in the list.

To Reproduce

`python train.py --data_dir data/QM9/qm9/ --split_path data/QM9/splits/split0.npy --log_dir ./test_run --n_epochs 250 --dataset qm9`

Expected behavior

Training completed smoothly without error

Environments:

The environments are based on the given environment.yml file, the version of torch are listed below:
- OS: CentOS Linux release 8.4.2105
- Package Version:

  • python=3.7.10
  • pytorch=1.10.0=py3.7_cpu_0
  • torchaudio=0.10.0=py37_cpu
  • torchvision=0.11.1=py37_cpu
  • pytorch-cluster=1.5.9=py37_torch_1.10.0_cpu
  • pytorch-mutex=1.0=cpu
  • pytorch-scatter=2.0.9=py37_torch_1.10.0_cpu
  • pytorch-sparse=0.6.12=py37_torch_1.10.0_cpu
  • pytorch-spline-conv=1.2.1=py37_torch_1.10.0_cpu
  • torch-geometric=2.0.2

Additional context:

This error was raised while dataloader enumeration was called during training, i.e. for i, data in tqdm(enumerate(loader), total=len(loader)):. The Expected size 19 but got size 21 error during torch.cat comes from that it tried to cat tensor B (2nd molecule) with shape 10x21x3 to tensor A (1st molecule) with shape 10x19x3 at dimension 0 (10), which needs that the other dimension (19/21) should be the same. I'm not sure if this occurrence is normal to you and not sure where to make the modifications (if needed).

Looking forward to your reply :)

getting errors in training and while inferencing the model

  1. I have created a new environment using your .sh file and running the training script with the same datasets. But I am getting this error.
    RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 21 but got size 19 for tensor number 1 in the list.

  2. While running your generate_confs script I am getting this error. in this line data = Batch.from_data_list(data_list=[tg_data])
    TypeError: argument of type 'int' is not iterable

  3. if I am passing data directly to the model bypassing above line then in line
    model(tg_data, inference=True, n_model_confs=n_confs*2) I am getting this error
    AttributeError: 'GlobalStorage' object has no attribute 'bincount'
    NOTE : While passing the data directly to the model i changed n_atoms_per_mol = data.batch.bincount() TO
    n_atoms_per_mol = data.bincount() in get_neighbor_ids function of model.utils script. If i am not changing this line then the error is like NoneType attribute has no attribute bincount()

Screenshot from 2021-12-31 15-24-12
Screenshot from 2021-12-31 15-24-17
Screenshot from 2021-12-31 15-24-33

about the true angle computation

Hi, I have a question about why there is a [6, 6] matrix for true angle computation. For each central atom X, there are a maximum of 6 permutations, such as T1-X-T2, T1-X-T3, T2-X-T1, T2-X-T3, T3-X-T1, and T3-X-T2. I understand the significance of the first '6', but I'm unsure why there is a second '6'. Could you please explain?
` def ground_truth_local_stats(self, pos):
"""
Compute true one-hop, two-hop, and angle local stats. Note that the second dimension of the local coordinates
is 6 to account for possible symmetric hydrogens. The max number of symmetric leaf hydrogens is 3, which leads
to a max of 6 permutations (our model doesn't work for methane). This dimension captures these symmetric
hydrogen permutations.

    :param pos: coordinates (n_atoms, n_true_confs, 3)
    :return: tuple of true stats (one-hop, two-hop, and angles)
        true_one_hop (n_neighborhoods, 6, 4, n_true_confs)
        true_two_hop (n_neighborhoods, 6, 4, 4, n_true_confs)
        true_angles (n_neighborhoods, 6, 6, n_true_confs)
    """

    n_neighborhoods = len(self.neighbors)
    self.true_local_coords = torch.zeros(n_neighborhoods, 6, 4, self.n_true_confs, 3).to(self.device)

    for i, (a, n) in enumerate(self.neighbors.items()):

        # permutations for symmetric hydrogens
        n_perms = n.unsqueeze(0).repeat(6, 1)
        perms = torch.tensor(list(permutations(n[self.leaf_hydrogens[a]]))).to(self.device)
        if perms.size(1) != 0:
            n_perms[0:len(perms), self.leaf_hydrogens[a]] = perms

        # keep it local
        self.true_local_coords[i, :, 0:len(n)] = pos[n_perms] - pos[a]

    # calculate true local stats
    true_one_hop, true_two_hop, true_angles = batch_local_stats_from_coords(self.true_local_coords, self.neighbor_mask)

    return true_one_hop, true_two_hop, true_angles`

Problems about loss computation

Hi, Great Work! Could you please tell me the reason of subtracting the angle loss and the dihedral loss (at the bottom of code)? Thank U~
` def batch_molecule_loss(self, true_stats, model_stats, ignore_neighbors):
"""
Compute loss for one pair of model/true molecules

    :param true_stats: tuple of masked true stat tensors (len 5)
    :param model_stats: tuple of masked model stat tensors (len 5)
        one-hop: (n_neighborhoods, 4)
        two-hop: (n_neighborhoods, 4, 4)
        angle: (n_neighborhoods, 6)
        dihedral: (2, n_dihedral_pairs, 9)
        three-hop: (n_dihedral_pairs, 9)
    :return: molecular loss for the batch (n_batch)
    """

    # unpack stats
    model_one_hop, model_two_hop, model_angles, model_dihedrals, model_three_hop = model_stats
    true_one_hop, true_two_hop, true_angles, true_dihedrals, true_three_hop = true_stats

    # calculate losses
    one_hop_loss, two_hop_loss, angle_loss = self.local_loss(true_one_hop, true_two_hop, true_angles,
                                                             model_one_hop, model_two_hop, model_angles)
    dihedral_loss, three_hop_loss = self.pair_loss(true_dihedrals, model_dihedrals, true_three_hop, model_three_hop)

    # writing
    self.one_hop_loss.append(one_hop_loss)
    self.two_hop_loss.append(two_hop_loss)
    self.angle_loss.append(angle_loss)
    self.dihedral_loss.append(dihedral_loss)
    self.three_hop_loss.append(three_hop_loss)

    if ignore_neighbors:
        return one_hop_loss + two_hop_loss - angle_loss
    else:
        return one_hop_loss + two_hop_loss - angle_loss + three_hop_loss - dihedral_loss`

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.