Giter VIP home page Giter VIP logo

fisher_information_embedding's Introduction

Fisher information embedding for node and graph learning

This repository implements the Fisher information embedding (FIE) described in the following paper

Dexiong Chen *, Paolo Pellizzoni *, and Karsten Borgwardt. Fisher Information Embedding for Node and Graph Learning. ICML 2023.
* Equal contribution

TL;DR: a class of node embeddings with an information geometry interpretation, available with both unsupervised and supervised algorithms for learning the embeddings.

Citation

Please use the following to cite our work:

@inproceedings{chen23fie,
    author = {Dexiong Chen and Paolo Pellizzoni and Karsten Borgwardt},
    title = {Fisher Information Embedding for Node and Graph Learning},
    year = {2023},
    booktitle = {International Conference on Machine Learning~(ICML)},
    series = {Proceedings of Machine Learning Research}
}

A short description of FIE

In this work, we propose a novel attention-based node embedding framework for graphs. Our framework builds upon a hierarchical kernel for multisets of subgraphs around nodes (e.g. neighborhoods) and each kernel leverages the geometry of a smooth statistical manifold to compare pairs of multisets, by “projecting” the multisets onto the manifold. By explicitly computing node embeddings with a manifold of Gaussian mixtures, our method leads to a new attention mechanism for neighborhood aggregation.

An illustration of the Fisher Information Embedding for nodes. (a) Multisets $h(\mathcal{S_G}(\cdot))$ of node features are obtained from the neighborhoods of each node. (b) Multisets are transformed to parametric distributions, e.g. $p_\theta$ and $p_{\theta'}$, via maximum likelihood estimation. (c) The node embeddings are obtained by estimating the parameter of each distribution using the EM algorithm at an anchor distribution $p_{\theta_0}$ as the starting point. The last panel shows a representation of the parametric distribution manifold $\mathcal{M}$ and its tangent space $T_{\theta_0}\mathcal{M}$ at the anchor point $\theta_0$. The points $p_{\theta}$ and $p_{\theta'}$ represent probability distributions on $\mathcal{M}$ and the gray dashed line between them their geodesic distance. The red dashed lines represent the retraction mapping $R_{\theta_0}^{-1}$.

Quickstart

Click to see the example
from torch_geometric import datasets
from torch_geometric.loader import DataLoader

# Construct data loader
dataset = datasets.Planetoid('./datasets/citation', name='Cora', split='public')
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
input_size = dataset.num_node_features

# Build FIE model
model = FIENet(
    input_size,
    num_layers=2,
    hidden_size=16,
    num_mixtures=8,
    pooling=None,
    concat=True
)

# Train model parameters using k-means
model.unsup_train(data_loader)

# Compute node embeddings
X = model.predict(data_loader)

Installation

The dependencies are managed by miniconda. Run the following to install the dependencies

# For CPU only
conda env create -f env.yml
# Or if you have a GPU
conda env create -f env_cuda.yml
# Then activate the environment
conda activate fie

Then, install our fisher_information_embedding package:

pip install -e .

Training models

Please see Table 3 and 4 in our paper to find the search grids for each hyperparameter. Note that we use very minimal hyperparameter tuning in our paper.

Training models on semi-supervised learning tasks using citation networks

  • Unsupervised node embedding mode with logistic classifier:
    python train_citation.py --dataset Cora --hidden-size 512 --num-mixtures 8 --num-layers 4
  • Supervised node embedding mode:
    python train_citation_sup.py --dataset Cora --hidden-size 64 --num-mixtures 8 --num-layers 4

Training models on supervised learning tasks using large OGB datasets

  • Unsupervised node embedding mode with FLAML:
    python train_ogb_node.py --save-memory --dataset ogbn-arxiv --hidden-size 256 --num-mixtures 8 --num-layers 5
  • Supervised node embedding mode:
    python train_ogb_node_sup_ns.py --dataset ogbn-arxiv --hidden-size 256 --num-mixtures 4 --num-layers 3

fisher_information_embedding's People

Contributors

claying avatar paolopellizzoni avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.