Giter VIP home page Giter VIP logo

spectral_graph_convnets's Introduction

Graph ConvNets in PyTorch

October 15, 2017

Xavier Bresson

http://www.ntu.edu.sg/home/xbresson
https://github.com/xbresson
https://twitter.com/xbresson

Description

Prototype implementation in PyTorch of the NIPS'16 paper:
Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering
M Defferrard, X Bresson, P Vandergheynst
Advances in Neural Information Processing Systems, 3844-3852, 2016
ArXiv preprint: arXiv:1606.09375

Code objective

The code provides a simple example of graph ConvNets for the MNIST classification task.
The graph is a 8-nearest neighbor graph of a 2D grid.
The signals on graph are the MNIST images vectorized as $28^2 \times 1$ vectors.

Installation

git clone https://github.com/xbresson/graph_convnets_pytorch.git
cd graph_convnets_pytorch
pip install -r requirements.txt # installation for python 3.6.2
python check_install.py
jupyter notebook # run the 2 notebooks

Results

GPU Quadro M4000

  • Standard ConvNets: 01_standard_convnet_lenet5_mnist_pytorch.ipynb, accuracy= 99.31, speed= 6.9 sec/epoch.
  • Graph ConvNets: 02_graph_convnet_lenet5_mnist_pytorch.ipynb, accuracy= 99.19, speed= 100.8 sec/epoch

Note

PyTorch has not yet implemented function torch.mm(sparse, dense) for variables: pytorch/pytorch#2389. It will be certainly implemented but in the meantime, I defined a new autograd function for sparse variables, called "my_sparse_mm", by subclassing torch.autograd.function and implementing the forward and backward passes.

class my_sparse_mm(torch.autograd.Function):
    """
    Implementation of a new autograd function for sparse variables, 
    called "my_sparse_mm", by subclassing torch.autograd.Function 
    and implementing the forward and backward passes.
    """
    
    def forward(self, W, x):  # W is SPARSE
        self.save_for_backward(W, x)
        y = torch.mm(W, x)
        return y
    
    def backward(self, grad_output):
        W, x = self.saved_tensors 
        grad_input = grad_output.clone()
        grad_input_dL_dW = torch.mm(grad_input, x.t()) 
        grad_input_dL_dx = torch.mm(W.t(), grad_input )
        return grad_input_dL_dW, grad_input_dL_dx

When to use this algorithm?

Any problem that can be cast as analyzing a set of signals on a fixed graph, and you want to use ConvNets for this analysis.




spectral_graph_convnets's People

Contributors

xbresson avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

spectral_graph_convnets's Issues

Implementation ambiguity

Hello @xbresson ,

Let me first thank for this work.

l went through your implementation. Hence , l have some question :

1- You set coarsening_levels = 4, however only L[0] and L[2] are used

x = self.graph_conv_cheby(x, self.cl1, L[0], lmax[0], self.CL1_F, self.CL1_K) and
x = self.graph_conv_cheby(x, self.cl2, L[2], lmax[2], self.CL2_F, self.CL2_K)
What about L[1],L[3] and L[4] ?

2- Why lmax[0] and lmax[2] are recalculated in graph_conv_cheby . Even if lmax is a parameter input of this function.

3- Do you mind explaining why your rescale Laplacian eigenvalues to [-1,1] ?
rescale_L(L, lmax)

4- concat(x,x_) is not ued at all in graph_conv_cheby() :

        def concat(x, x_):
            x_ = x_.unsqueeze(0)  # 1 x V x Fin*B
            return torch.cat((x, x_), 0)  # K x V x Fin*B

5- What if K=1 in graph_conv_cheby() ?
l noticed that at least K should equal 2.
How can l use that just for K=1 ?

Thank you for your consideration

What is the need of the linear layer after Chbyshev multiplications ?

Hello @xbresson ,

Thank you for your work and for making available pytorch version of GCN.

I have question relate to graph_conv_cheby(self, x, cl, L, lmax, Fout, K) function (https://github.com/xbresson/spectral_graph_convnets/blob/master/02_graph_convnet_lenet5_mnist_pytorch.ipynb)

  1. What is the need of linear layer x = cl(x) after the Chebyshev multiplications ?
  2. What if we don't use the linear layer and keep only the result of chebyshev multiplication as follow

if K > 1: 
            x1 = my_sparse_mm()(L,x0)              # V x Fin*B
            x = torch.cat((x, x1.unsqueeze(0)),0)  # 2 x V x Fin*B
        for k in range(2, K):
            x2 = 2 * my_sparse_mm()(L,x1) - x0  
            x = torch.cat((x, x2.unsqueeze(0)),0)  # M x Fin*B
            x0, x1 = x1, x2  
        
        x = x.view([K, V, Fin, B])           # K x V x Fin x B     
        x = x.permute(3,1,2,0).contiguous()  # B x V x Fin x K       
        x = x.view([B*V, Fin*K])    

?
Thank you a lot for your answer

loss_train=loss.data[0] ERROR

I am getting the following error. Can you please help?
File "full_gcn.py", line 365, in
loss_train = loss.data[0]
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

something wrong with rescale_L

To rescale Laplacian eigenvalues to [-1, 1], the formula should be $L' = 2L/l_max - I$.
I think the 33-th line in corsening.py should be like follows:

L /= lmax / 2

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.