Giter VIP home page Giter VIP logo

gradientgating's Introduction

Gradient Gating for Deep Multi-Rate Learning on Graphs

This repository contains the implementation to reproduce the numerical experiments of the ICLR 2023 paper Gradient Gating for Deep Multi-Rate Learning on Graphs

PWC PWC PWC

Requirements

Main dependencies (with python >= 3.7):
torch==1.9.0
torch-cluster==1.5.9
torch-geometric==2.0.3
torch-scatter==2.0.9
torch-sparse==0.6.12
torch-spline-conv==1.2.1

Commands to install all the dependencies in a new conda environment
(python 3.7 and cuda 10.2 -- for other cuda versions change accordingly)

conda create --name gradientgating python=3.7
conda activate gradientgating

pip install torch==1.9.0

pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
pip install torch-geometric
pip install scipy
pip install numpy

Citation

If you found our work useful in your research, please cite our paper at:

@inproceedings{rusch2022gradient,
  title={Gradient Gating for Deep Multi-Rate Learning on Graphs},
  author={Rusch, T Konstantin and Chamberlain, Benjamin P and Mahoney, Michael W and Bronstein, Michael M and Mishra, Siddhartha},
  booktitle={International Conference on Learning Representations},
  year={2023}
}

(Also consider starring the project on GitHub.)

gradientgating's People

Contributors

tk-rusch avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

gradientgating's Issues

hyperparameter settings

Dear authors,

Thanks for the awesome work and sharing the code. I noticed that the hyperparameters were set based on random search. Can you also kindly share the hyperparameter setting used in 'run_GNN.py' including:
--GNN, --nhid, --nlayers, --lr, --drop_in, --drop, --weight_decay, --G2_exp
for the following datasets:
Texas, Wisconsin, Cornell, Cora, Citeseer, Pubmed.

Thank you so much!

I cannot find the .npz file

def get_data(hom_level=0, graph=1):
  seed = 123456
  dataset = Planetoid('../data', 'Cora')

  if(hom_level==0):
    loader = np.load('data/h0.0' + str(round(hom_level, 3)) + '-r' + str(graph) + '.npz')
  else:
    loader = np.load('data/h0.'+str(round(hom_level,3))+'-r'+str(graph)+'.npz')

I was trying to find the .npz file but there is no explanation about the .npz file.
And also i try to understand about the hom_level, but there is no usage of hom_level.
Can you tell me how to deal with these problem?

Detailed settings for reproducing the results in paper.

Hi Rusch,

This is a really exciting work! I tried to reproducing the results (on heterophilic graphs) in paper, but I failed and found that the results are much worse than the results in paper. Can you provide the detailed settings for reproducing the results?

hyperparameters

Dear authors,

Thanks for the awesome work and sharing the code. I noticed that the hyperparameters were set based on random search. Can you also kindly share the hyperparameter setting used in 'run_GNN.py' including:
--GNN, --nhid, --nlayers, --lr, --drop_in, --drop, --weight_decay, --G2_exp
for the following datasets:
Texas, Wisconsin, Cornell, Cora, Citeseer, Pubmed.

Thank you so much!

Code Question

First of all, thank you for your paper and codes, it really helps me a lot.
Actually, I have a question about your code. This is the code of your model.

class G2_GNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, nlayers, conv_type='GCN', p=2., drop_in=0, drop=0, use_gg_conv=True):
        super(G2_GNN, self).__init__()
        self.conv_type = conv_type
        self.enc = nn.Linear(nfeat, nhid)
        self.dec = nn.Linear(nhid, nclass)
        self.drop_in = drop_in
        self.drop = drop
        self.nlayers = nlayers
        if conv_type == 'GCN':
            self.conv = GCNConv(nhid, nhid)
            if use_gg_conv == True:
                self.conv_gg = GCNConv(nhid, nhid)
        elif conv_type == 'GAT':
            self.conv = GATConv(nhid,nhid,heads=4,concat=True)
            if use_gg_conv == True:
                self.conv_gg = GATConv(nhid,nhid,heads=4,concat=True)
        else:
            print('specified graph conv not implemented')

        if use_gg_conv == True:
            self.G2 = G2(self.conv_gg,p,conv_type,activation=nn.ReLU())
        else:
            self.G2 = G2(self.conv,p,conv_type,activation=nn.ReLU())

    def forward(self, data):
        X = data.x
        n_nodes = X.size(0)
        edge_index = data.edge_index
        X = F.dropout(X, self.drop_in, training=self.training)
        X = torch.relu(self.enc(X))

        for i in range(self.nlayers):
            if self.conv_type == 'GAT':
                X_ = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
            else:
                X_ = torch.relu(self.conv(X, edge_index))
            tau = self.G2(X, edge_index)
            X = (1 - tau) * X + tau * X_
        X = F.dropout(X, self.drop, training=self.training)

        return self.dec(X)

I thought n-layers was the number of layers. But when I looked at the code, I realized that it means that one layer is used as nlayers. I think this means that whether I insert the number of layer such as 16 or 32, the number of layer is always one.
May I ask why you implement the code like this? or Did I misunderstand?

And I also want to ask, if I want to check the model performance, what is the order of the Model, do I just pile the layer G2 up?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.