Giter VIP home page Giter VIP logo

harryjo97 / gdss Goto Github PK

View Code? Open in Web Editor NEW
131.0 2.0 22.0 47.06 MB

Official Code Repository for the paper "Score-based Generative Modeling of Graphs via the System of Stochastic Differential Equations" (ICML 2022)

Home Page: https://arxiv.org/abs/2202.02514

Python 63.43% C++ 36.52% Shell 0.04%
graph-generation score-based-generative-modeling diffusion-models stochastic-differential-equations molecule-generation

gdss's People

Contributors

harryjo97 avatar seullee05 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

gdss's Issues

Question on the data format for generated samples of ENZYMES dataset

Dear Authors, this project really helps a lot, I have a question that how can we transform the samples in the pkl file to dgl graph with node features for some classifiers using standard PyG enzymes or dgl. It seems that no node features in the networkx graph we generated, thanks.

Variance of the training/testing results

Hi there,

Thanks for sharing the code for your wonderful project. I have a question about the variance of the sampling results. I ran the training on the grid dataset using the default config file (with an arbitrary seed).

The testing-time performance metrics I got are:
MMD_full {'degree': 0.460601, 'cluster': 0.008495, 'orbit': 0.126024, 'spectral': 0.681714}

On the other hand, The MMD results claimed on the paper are: deg: 0.111, clus: 0.005, orbit: 0.070 for GDSS and deg: 0.171, clus: 0.011, orbit: 0.223 for GDSS-seq.

The MMD results of the samples generated by the provided checkpoint model are:
MMD_full {'degree': 0.093013, 'cluster': 0.00718, 'orbit': 0.101709, 'spectral': 0.793645}

I understand that the random seed could affect the sampling results. But this variance is a bit large in my perspective (especially the network I trained myself). Do you have any insights about this? The previous EDP-GNN baseline seems to have a large variance when the number of generated samples is small. Do you think it could be attributed to the intrinsics of the score-based model?

Best,
Qi

Including node coordinates as node feature

Thank you for the wonderful and inspiring work. I have a question regarding generating graphs with 3D node coordinate locations. If I want to generate graphs that should also learn orientation information (in 3D space) of the underlying distribution (by means of generating raw coordinates for the nodes along with other node features), should I

  1. Use the node coordinates as additional node features (these would be continuous and would be concatenated with one-hot node degree features. I imagine this might be sub-optimal, but I am not exactly sure).
  2. Have it as a third component in the SDE, i.e. imagine it as the third component in a graph structure and have the SDE working with three terms.

Your insights would be extremely helpful.

Thanks,
Chinmay

Question about permutation equivariance

Thank you for the wonderful and inspiring work. I have a question about the permutation equivariance of GNN score function.
In the last paragraph of Section 3.2, it writes:

Note that since the message-passing operations of GNNs and the attention function used in GMH are permutation equivariant.

I am wondering what is the meaning of permutation equivariant here. To my understanding, a GNN is permutation equivariant means:
$$PH = GNN(PX, PAP^T)$$
where H is the node embedding.

However, for score function with respect to $A_t$, the output of $s(X_t, A_t, t)$ has the same dimensionality as $A_t$, but we treat the output as "node embedding" with dimension $n$, and the permutation equivariance means:
$$s(PX_t, PA_tP^T, t) = Ps(X_t, A_t, t)$$
not
$$s(PX_t, PA_tP^T, t) = Ps(X_t, A_t, t)P^T$$
right?

Looking forward to your reply. Thanks a lot.
Xikun

Questions about mol dataset transforming

Hi, nice work and thanks for sharing the codes. I have some questions on the transforming of molecular datasets. Specifically, in utils/data_loader_mol.py line 40 function get_transform_fn, you create x_ with the last place for virtual nodes. However, this place is then removed by the following x = x[:, :-1], which makes me a little bit confused. Another question is that you do not consider aromatic bond as excluded by adj[:3]. Although I notice that moflow also did so, I wonder why and further in the generation process how to form aromatic bonds. Thanks in advance.

Having Error 'c10::HIPError'

Hi! I have encountered an error when trying to train the model. I tried different datasets, but the error remains. Thank you for helping me!
My env: Python 3.7.0 and Pytorch 1.10.1

The complete error message goes as follows:

terminate called after throwing an instance of 'c10::HIPError' | 0/500 [00:00<?, ?it/s]
what(): HIP error: hipErrorNoDevice
HIP kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing HIP_LAUNCH_BLOCKING=1.
Exception raised from deviceCount at /pytorch/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h:102 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f9eff1e7212 in /nethome/hsun409/anaconda3/envs/GDSS_env/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: + 0x5618da (0x7f9f14f4a8da in /nethome/hsun409/anaconda3/envs/GDSS_env/lib/python3.7/site-packages/torch/lib/libtorch_hip.so)
frame #2: torch::autograd::Engine::start_device_threads() + 0x21a (0x7f9f4b4d82ca in /nethome/hsun409/anaconda3/envs/GDSS_env/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #3: + 0xf907 (0x7f9f61c86907 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #4: torch::autograd::Engine::initialize_device_threads_pool() + 0xcd (0x7f9f4b4d70bd in /nethome/hsun409/anaconda3/envs/GDSS_env/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #5: torch::autograd::Engine::execute_with_graph_task(std::shared_ptrtorch::autograd::GraphTask const&, std::shared_ptrtorch::autograd::Node, torch::autograd::InputBuffer&&) + 0x28 (0x7f9f4b4ddf78 in /nethome/hsun409/anaconda3/envs/GDSS_env/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #6: torch::autograd::python::PythonEngine::execute_with_graph_task(std::shared_ptrtorch::autograd::GraphTask const&, std::shared_ptrtorch::autograd::Node, torch::autograd::InputBuffer&&) + 0x3c (0x7f9f5ed7ba3c in /nethome/hsun409/anaconda3/envs/GDSS_env/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #7: torch::autograd::Engine::execute(std::vector<torch::autograd::Edge, std::allocatortorch::autograd::Edge > const&, std::vector<at::Tensor, std::allocatorat::Tensor > const&, bool, bool, bool, std::vector<torch::autograd::Edge, std::allocatortorch::autograd::Edge > const&) + 0x900 (0x7f9f4b4dc1a0 in /nethome/hsun409/anaconda3/envs/GDSS_env/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #8: torch::autograd::python::PythonEngine::execute(std::vector<torch::autograd::Edge, std::allocatortorch::autograd::Edge > const&, std::vector<at::Tensor, std::allocatorat::Tensor > const&, bool, bool, bool, std::vector<torch::autograd::Edge, std::allocatortorch::autograd::Edge > const&) + 0x56 (0x7f9f5ed7b996 in /nethome/hsun409/anaconda3/envs/GDSS_env/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #9: THPEngine_run_backward(_object*, _object*, _object*) + 0x9d4 (0x7f9f5ed7c4a4 in /nethome/hsun409/anaconda3/envs/GDSS_env/lib/python3.7/site-packages/torch/lib/libtorch_python.so)

frame #31: __libc_start_main + 0xe7 (0x7f9f618a7c87 in /lib/x86_64-linux-gnu/libc.so.6)

File missing when sampling

Hello, congratulations on your nice work and thanks for sharing the codes. When I sampled graphs with command "CUDA_VISIBLE_DEVICES=${gpu_ids} python main.py --type sample --config sample" as stated in README, I found that in sampler.py LINE 125:
"with open(f'data/{self.configt.data.data.lower()}_test_nx.pkl', 'rb') as f:"
the files "*_test_nx.pkl" are missing in data folder. I wonder how to fix it, thanks in advance.

Size mismatching error occurs when sampling.

RunCUDA_VISIBLE_DEVICES=0 python main.py --type sample --config sample.
The output is as the following, an error about size mismatching occurs:

./checkpoints/ZINC250k/gdss_zinc250k.pth loaded
----------------------------------------------------------------------------------------------------
Make Directory ZINC250k/test in Logs
----------------------------------------------------------------------------------------------------
[ZINC250k]   init=atom (9)   seed=42   batch_size=1024
----------------------------------------------------------------------------------------------------
lr=0.005 schedule=True ema=0.999 epochs=1000 reduce=False eps=1e-05
(ScoreNetworkX)+(ScoreNetworkA=GCN,4)   : depth=2 adim=16 nhid=16 layers=6 linears=3 c=(2 8 4)
(x:VP)=(0.10, 1.00) N=1000 (adj:VE)=(0.20, 1.00) N=1000
----------------------------------------------------------------------------------------------------
(Reverse)+(Langevin): eps=0.0001 denoise=True ema=False || snr=0.2 seps=0.9 n_steps=1
----------------------------------------------------------------------------------------------------
GEN SEED: 42


Loading file ./data/zinc250k_kekulized.npz
Number of training mols: 224568 | Number of test mols: 24887
Traceback (most recent call last):
  File "main.py", line 39, in <module>
    main(work_type_parser.parse_known_args()[0])
  File "main.py", line 30, in main
    sampler.sample()
  File "/root/projects/GDSS/sampler.py", line 129, in sample
    x, adj, _ = self.sampling_fn(self.model_x, self.model_adj, self.init_flags)
  File "/root/projects/GDSS/solver.py", line 177, in pc_sampler
    x = mask_x(x, flags)
  File "/root/projects/GDSS/utils/graph_utils.py", line 12, in mask_x
    return x * flags[:,:,None]
RuntimeError: The size of tensor a (3000) must match the size of tensor b (10000) at non-singleton dimension 0

Possible mistake in validity w/o correction

Hi,

Thanks for sharing your code. I have a question concerning your code to compute validity without correction.

This is your procedure to compute this metric

gen_mols, num_mols_wo_correction = gen_mol(x, adj, self.configt.data.data)
num_mols = len(gen_mols)
...
logger.log(f'validity w/o correction: {num_mols_wo_correction / num_mols}')

and here is the code for the gen_mol function :

def gen_mol(x, adj, dataset, largest_connected_comp=True):    
    # x: 32, 9, 5; adj: 32, 4, 9, 9
    x = x.detach().cpu().numpy()
    adj = adj.detach().cpu().numpy()

    if dataset == 'QM9':
        atomic_num_list = [6, 7, 8, 9, 0]
    else:
        atomic_num_list = [6, 7, 8, 9, 15, 16, 17, 35, 53, 0]
    mols, num_no_correct = [], 0
    for x_elem, adj_elem in zip(x, adj):
        mol = construct_mol(x_elem, adj_elem, atomic_num_list)
        cmol, no_correct = correct_mol(mol)
        if no_correct: num_no_correct += 1
        vcmol = valid_mol_can_with_seg(cmol, largest_connected_comp=largest_connected_comp)
        mols.append(vcmol)
    mols = [mol for mol in mols if mol is not None]
    return mols, num_no_correct

While num_mols_wo_correction (num_no_correct in the function) is the actual number of valid molecules before correction, it seems that num_mols is not the number of generated molecules. That is because of the way you build mols in the function, which becomes gen_mols in your evaluation procedure. mols seems to exclude the molecules that are not valid and can't be corrected, leading to a smaller total number of molecules in your validity computation. The metric you are actually computing is the number of valid molecules before correction over the final number of valid molecules.

If I'm correct, your function should rather output something like mols, num_no_correct, num_generated, where num_generated is len(x), and your validity should be num_no_correct / num_generated.

Is that a mistake from the MoFlow paper or is there something I haven't understood ?

num_workers > 0

Hi, and thanks for the code!

I suggest using num_workers > 0 in the dataloaders. It can speed up training significantly. Using 4 workers makes training 30% faster in my case.

Computing likelihood

Hello. Is there any option to compute the likelihood for the given graph in this model?

code associated with GDSS-seq missing?

Hi Jo, thank you for sharing such an excellent work!
I noticed that you've introduced both GDSS and GDSS-seq algorithms in your paper, I'm quite interesting about the performance difference between the two models. However I didn't find the code associated with GDSS-seq in this repo.(Maybe because I didn't check the code very carefully). Could you please kindly points out the code inplementing GDSS-seq?
Thank you for your reply! Best Regards!

bug with orca file

Faced an issue while evaluating the model by the orca file when I runned the code on Enzymes dataset. It also didn't get into further rounds for evaluation.
image

How to select DDPM or Denoised Score-based model?

First, thanks for your great work~
I noticed you used score-based model in this paper, while many people used DDPM (Diffusion Model). I wonder have you ever considered using DDPM to implement this mthod? And why you finally choosed to use score-based model.

macrocycles

Hello,
I have been trying out GDSS, I think it works well - great job!
One observation, I notice the sampled 10K compounds usually contain many macrocycles and large ringed molecules even though the training set has few. Depending on the desired use, this could be good or not so good.
Do you have an idea on how to minimize the numbers of macrocycles sampled, eg: more training epochs, etc ?
Cheers,

Simple Conditional Generation

Dear authors,

Thank you for the great work. In my case, I want to do a simple conditional generation task, where each graph is associated with a label from {0,1}. During the sampling process, I would like to control the generated graph with the label. My initial though is to add a new parameter for label in two score models(GNNs) during training. Could you tell me how you will realize this task?

Best,
Allen Wnag

error in generating QM9 samples

Thanks for your code first!
My problem is : when I run the command python main.py --type sample --config sample_qm9, I got the following error:
image
I'm using mini-moses repository, but when I switched to use moses repository, the same error occurred.
I appreciate it if you will help me with it : )

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.