Giter VIP home page Giter VIP logo

mlwork's Introduction

MLwork

This is the Group Work for Machine Learning Course.

SimpleHGN[KDD 2021]

Basic Idea

  • The model extend the original graph attention mechanism in GAT by including edge type information into attention calculation.
  • At each layer, we calculate the coefficient:

$$ \alpha_{ij} = \frac{exp(LeakyReLU(a^T[Wh_i||Wh_j||W_r r_{\psi(<i,j>)}]))}{\Sigma_{k\in\mathcal{E}}{exp(LeakyReLU(a^T[Wh_i||Wh_k||W_r r_{\psi(<i,k>)}]))}} $$

  • Residual connection including Node residual

$$ h_i^{(l)} = \sigma(\Sigma_{j\in \mathcal{N}i} {\alpha{ij}^{(l)}W^{(l)}h_j^{(l-1)}} + h_i^{(l-1)}) $$

  • where $h_i$ and $h_j$ is the features of the source and the target node. $r_{\psi(e)}$ is a $d$-dimension embedding for each edge type $\psi(e) \in T_e$.

  • and Edge residual:

$$ \alpha_{ij}^{(l)} = (1-\beta)\alpha_{ij}^{(l)}+\beta\alpha_{ij}^{(l-1)} $$

  • Finally, a multi-head attention is used.

Dataset information

author paper Subject Paper-Author Paper-Subject Features Train Val Test
acm4GTN 5,912 3,025 57 9,936 3,025 1,902 600 300 2,125
author conference paper author-paper conference-paper Features Train Val Test
dblp4GTN 4057 20 14328 19645 14328 334 800 400 2857

Accuracy(%)

acm4GTN dblp4GTN
Model valid test valid test
GTN(paper) - 92.68 - 94.18
RGCN 95.67 95.15 94.50 93.91
SimpleHGN 98.67 98.21 95.75 95.90

Requirements

  • Python >= 3.6
  • Pytorch >= 1.9.0
  • DGL >= 0.8.0

How to run

python trainer.py --model SimpleHGN --dataset acm4GTN --n_epoch 200 --num_heads 4 --in_dim 256 --edge_dim 64 --hidden_dim 128 --out_dim 64 --num_layers 2 --feat_drop 0.2 --negative_slope 0.2 --beta 0.2 --clip 1.0 --max_lr 1e-3

mlwork's People

Contributors

dddg617 avatar lazishu2000 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.