Giter VIP home page Giter VIP logo

mvpool's Introduction

MVPool

Hierarchical Multi-View Graph Pooling with Structure Learning (paper).

This is a PyTorch implementation of the MVPool algorithm, which is accepted by TKDE. The proposed MVPool conducts pooling operation via mulit-view information. Then, a structure learning layer is stacked on the pooling operation, which aims to learn a refined graph structure that can best preserve the essential topological information. It's a general operator that can be used in various architectures, including node-level representation learning and graph-level representation learning.

Requirements

  • python3.6
  • pytorch==1.3.0
  • torch-scatter==1.4.0
  • torch-sparse==0.4.3
  • torch-cluster==1.4.5
  • torch-geometric==1.3.2

Note: An older version of torch-sparse is needed, lower than 0.4.4. This code repository is heavily built on pytorch_geometric, which is a Geometric Deep Learning Extension Library for PyTorch. Please refer here for how to install and utilize the library.

Node Classification Datasets

The input contains:

  • x, the feature vectors of the labeled training instances
  • y, the one-hot labels of the labeled training instances
  • allx, the feature vectors of both labeled and unlabeled training instances (a superset of x)
  • graph, a dict in the format {index: [index_of_neighbor_nodes]}.

Let n be the number of both labeled and unlabeled training instances. These n instances should be indexed from 0 to n - 1 in graph with the same order as in allx.

In addition to x, y, allx, and graph as described above, the preprocessed datasets also include:

  • tx, the feature vectors of the test instances
  • ty, the one-hot labels of the test instances
  • test.index, the indices of test instances in graph, for the inductive setting
  • ally, the labels for instances in allx.

The indices of test instances in graph for the transductive setting are from #x to #x + #tx - 1, with the same order as in tx.

You can use cPickle.load(open(filename)) to load the numpy/scipy objects x, y, tx, ty, allx, ally, and graph. test.index is stored as a text file. More details can be found at here.

Node Classification

Just execuate the following command for node classification task:

python main_node_classification.py

Parameter settings for node classification

Datasets lr weight_decay batch_size pool_ratio lambda net_layers
Cora 0.01 0.01 Full 0.5/0.5/0.8/0.5 0.9 4
Citeseer 0.005 0.1 Full 0.7 0.0 1
Pubmed 0.01 0.001 Full 0.05/0.6/0.5/0.9 1.0 4
CS 0.01 0.01 Full 0.05/0.5/0.5/0.5 0.0 4
Physics 0.01 0.01 Full 0.05/0.8/0.8/0.8 0.0 4

Graph Classification Datasets

Graph classification benchmarks are publicly available at here.

This folder contains the following comma separated text files (replace DS by the name of the dataset):

n = total number of nodes

m = total number of edges

N = number of graphs

(1) DS_A.txt (m lines)

sparse (block diagonal) adjacency matrix for all graphs, each line corresponds to (row, col) resp. (node_id, node_id)

(2) DS_graph_indicator.txt (n lines)

column vector of graph identifiers for all nodes of all graphs, the value in the i-th line is the graph_id of the node with node_id i

(3) DS_graph_labels.txt (N lines)

class labels for all graphs in the dataset, the value in the i-th line is the class label of the graph with graph_id i

(4) DS_node_labels.txt (n lines)

column vector of node labels, the value in the i-th line corresponds to the node with node_id i

There are OPTIONAL files if the respective information is available:

(5) DS_edge_labels.txt (m lines; same size as DS_A_sparse.txt)

labels for the edges in DS_A_sparse.txt

(6) DS_edge_attributes.txt (m lines; same size as DS_A.txt)

attributes for the edges in DS_A.txt

(7) DS_node_attributes.txt (n lines)

matrix of node attributes, the comma seperated values in the i-th line is the attribute vector of the node with node_id i

(8) DS_graph_attributes.txt (N lines)

regression values for all graphs in the dataset, the value in the i-th line is the attribute of the graph with graph_id i

Run Graph Classification

Just execuate the following command for graph classification task:

python main_graph_classification.py

Citing

If you find MVPool useful for your research, please consider citing the following paper:

@article{zhang2021hierarchical,
  title={Hierarchical Multi-View Graph Pooling with Structure Learning},
  author={Zhang, Zhen and Bu, Jiajun and Ester, Martin and Zhang, Jianfeng and Li, Zhao and Yao, Chengwei and Huifen, Dai and Yu, Zhi and Wang, Can},
  journal={IEEE Transactions on Knowledge and Data Engineering},
  year={2021},
  publisher={IEEE}
}

mvpool's People

Contributors

cszhangzhen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

mvpool's Issues

error in sparse_softmax.py

Hello, how to solve this problem?

File "/home/adminm/gm2/MVPool/sparse_softmax.py", line 29, in scatter_sort
    batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.

Can you give the superparameters for each node classification dataset?

Hi, this paper seems good, but I cannot reproduce the reported performance based on given code.
Especially on some node classification datasets, the difference between the results we reproduce and those given in the paper is huge.
Is there anything that needs extra attention? For example, the number of pooled layers, the pooled ratio of each layer, or other important settings.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.