Giter VIP home page Giter VIP logo

geo-gcn's Introduction

Spatial Graph Convolutional Networks

This repository contains an implementation of Spatial Graph Convolutional Neural Networks (SGCN).

Dependencies

  • PyTorch >= 1.1
  • PyTorch geometric >= 1.1.2

Running the code

To run geo-GCN on MNISTSuperpixels with default parameters, go to src and use the command:

python train_models.py MNISTSuperpixels

To use chemical data:

from torch_geometric.data import DataLoader
from chem import load_dataset

batch_size = 64
dataset_name = ...  # 'freesolv' / 'esol' / 'bbbp'

train_dataset = load_dataset(dataset_name, 'train')
val_dataset = load_dataset(dataset_name, 'val')
test_dataset = load_dataset(dataset_name, 'test')

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# training loop
... 

Other options

The code allows to manipulate some of the parameters (for example using other versions of the model, changing learning rate values or optimizer types). For more information, see the list of available arguments in src/train_models.py file.

Reference

If you make use of our results or code in your research, please cite the following:

@InProceedings{
10.1007/978-3-030-63823-8_76,
author="Danel, Tomasz
and Spurek, Przemys{\l}aw
and Tabor, Jacek
and {\'{S}}mieja, Marek
and Struski, {\L}ukasz
and S{\l}owik, Agnieszka
and Maziarka, {\L}ukasz",
editor="Yang, Haiqin
and Pasupa, Kitsuchart
and Leung, Andrew Chi-Sing
and Kwok, James T.
and Chan, Jonathan H.
and King, Irwin",
title="Spatial Graph Convolutional Networks",
booktitle="Neural Information Processing",
year="2020",
publisher="Springer International Publishing",
address="Cham",
pages="668--675",
abstract="Graph Convolutional Networks (GCNs) have recently become the primary choice for learning from graph-structured data, superseding hash fingerprints in representing chemical compounds. However, GCNs lack the ability to take into account the ordering of node neighbors, even when there is a geometric interpretation of the graph vertices that provides an order based on their spatial positions. To remedy this issue, we propose Spatial Graph Convolutional Network (SGCN) which uses spatial features to efficiently learn from graphs that can be naturally located in space. Our contribution is threefold: we propose a GCN-inspired architecture which (i) leverages node positions, (ii) is a proper generalization of both GCNs and Convolutional Neural Networks (CNNs), (iii) benefits from augmentation which further improves the performance and assures invariance with respect to the desired properties. Empirically, SGCN outperforms state-of-the-art graph-based methods on image classification and chemical tasks.",
isbn="978-3-030-63823-8"
}

geo-gcn's People

Contributors

mokosaur avatar slowika avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

geo-gcn's Issues

Issue with chemical dataset

Hello!

I tried to run the code with chemical datasets. However there are always some dimension problems with the SpatialGraphConv layer. For instance I ran the 'freesolv' dataset and in the line 'aggr_out = self.lin_out(aggr_out' there is the problem.
RuntimeError: mat1 and mat2 shapes cannot be multiplied (597x1600 and 64x64)
Did you come across similar problem and do you have any idea how to solve it? Note that I used default parameters only and changed the train_model with the lines you suggested in Readme file.

By the way, I also wish to put edge_attr feature in the conv_layer. Could you point out where I can do so?

Thanks a lot. I'm trying to apply this to my current work :)

index_select() received an invalid combination of arguments

Hello!

I'm trying to run the MNIST example, but running into the following error. Any insights?

`---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
in
82
83 for epoch in range(1, 20):
---> 84 loss = train(epoch)
85 train_acc = test(train_loader)
86 test_acc = test(test_loader)

in train(epoch)
59 # data = rotation_2(data)
60 print(data.edge_index)
---> 61 output = model(data)
62 loss = F.nll_loss(output, data.y)
63 loss.backward()

/opt/tljh/user/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)

in forward(self, data)
24 def forward(self, data):
25 for i in range(self.layers_num):
---> 26 data.x = self.conv_layers[i](data.x, data.pos, data.edge_index)
27
28 if self.use_cluster_pooling:

/opt/tljh/user/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)

in forward(self, x, pos, edge_index)
22 edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) # num_edges = num_edges + num_nodes
23
---> 24 return self.propagate(edge_index=edge_index, x=x, pos=pos, aggr='add') # [N, out_channels, label_dim]
25
26 def message(self, pos_i, pos_j, x_j):

~/.local/lib/python3.6/site-packages/torch_geometric/nn/conv/message_passing.py in propagate(self, edge_index, size, **kwargs)
165 assert len(size) == 2
166
--> 167 kwargs = self.collect(edge_index, size, kwargs)
168
169 msg_kwargs = self.distribute(self.msg_params, kwargs)

~/.local/lib/python3.6/site-packages/torch_geometric/nn/conv/message_passing.py in collect(self, edge_index, size, kwargs)
113
114 self.set_size(size, idx, data)
--> 115 out[arg] = data.index_select(self.node_dim, edge_index[idx])
116
117 size[0] = size[1] if size[0] is None else size[0]

TypeError: index_select() received an invalid combination of arguments - got (int, NoneType), but expected one of:

  • (name dim, Tensor index)
    didn't match because some of the arguments have invalid types: (!int!, !NoneType!)
  • (int dim, Tensor index)
    didn't match because some of the arguments have invalid types: (int, !NoneType!)`

Regression code missing

Hi, I am really interested in your method. I am attempting to apply it to some chemical data of my own.

In your paper, you mentioned that you used this method on regression problems for the ESOL and FreeSolv datasets. Do you have the code for that posted somewhere? There are multiple places in your repository that assume a classification problem.

Thank you!

example with a chemical dataset

In order to reproduce the work on Chemicals:

Can you share a example of data preparation ?

  • with dummy 3D coordinate or rdkit 3D coordinate
  • for one molecule

how look like your data preparation dataloader.

BR
Guillaume

Couldn't reproduce the MNISTSuperPixel results

I tried to reproduce your accuracies from the paper, but running the vanilla model straight from the repo( without any tweaks), doesn't look like its learning:

Epoch: 001, Loss: 6.97393, Train Acc: 0.10218, Test Acc: 0.10100
Epoch: 002, Loss: 2.30191, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 003, Loss: 2.30144, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 004, Loss: 2.30132, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 005, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 006, Loss: 2.30123, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 007, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 008, Loss: 2.30123, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 009, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 010, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 011, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 012, Loss: 2.30125, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 013, Loss: 2.30821, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 014, Loss: 2.30259, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 015, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 016, Loss: 2.30351, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 017, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 018, Loss: 2.30123, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 019, Loss: 2.30125, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 020, Loss: 2.30123, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 021, Loss: 2.30123, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 022, Loss: 2.30125, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 023, Loss: 2.30119, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 024, Loss: 2.30123, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 025, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 026, Loss: 2.30123, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 027, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 028, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 029, Loss: 2.30232, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 030, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 031, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 032, Loss: 2.30123, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 033, Loss: 2.30123, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 034, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 035, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 036, Loss: 2.30123, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 037, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350

I tried with and without data augmentations and varied learning rate as well, but no success in fixing it.

Running time of MNISTSuperpixels

Hi, thank you very much for the package.

Could you comment on the running times? I cloned the repo and ran the train_models MNIST, it loads the dataset and starts training, but (1) training takes forever (30 cores CPU -- no GPU available) and (2) it doesn't seem to improve after a few epochs:

$ python train_models.py MNISTSuperpixels
True
Epoch: 001, Loss: 7.64400, Train Acc: 0.10218, Test Acc: 0.10100
Epoch: 002, Loss: 2.30219, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 003, Loss: 2.30138, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 004, Loss: 2.30126, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 005, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 006, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 007, Loss: 2.30123, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 008, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 009, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350
Epoch: 010, Loss: 2.30124, Train Acc: 0.11237, Test Acc: 0.11350

It took ~2days for training 10 epochs. Any help is much appreciated.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.