Giter VIP home page Giter VIP logo

trpn's Introduction

TRPN

Introduction

A pytorch implementation of the IJCAI2020 paper "Transductive Relation-Propagation Network for Few-shot Learning". The code is based on Edge-labeling Graph Neural Network for Few-shot Learning

Author: Yuqing Ma, Shihao Bai, Shan An, Wei Liu, Aishan Liu, Xiantong Zhen and Xianglong Liu

Abstract: Few-shot learning, aiming to learn novel concepts from few labeled examples, is an interesting and very challenging problem with many practical advantages. To accomplish this task, one should concentrate on revealing the accurate relations of the support-query pairs. We propose a transductive relation-propagation graph neural network (TRPN) to explicitly model and propagate such relations across support-query pairs. Our TRPN treats the relation of each support-query pair as a graph node, named relational node, and resorts to the known relations between support samples, including both intra-class commonality and inter-class uniqueness, to guide the relation propagation in the graph, generating the discriminative relation embeddings for support-query pairs. A pseudo relational node is further introduced to propagate the query characteristics, and a fast, yet effective transductive learning strategy is devised to fully exploit the relation information among different queries. To the best of our knowledge, this is the first work that explicitly takes the relations of support-query pairs into consideration in few-shot learning, which might offer a new way to solve the few-shot learning problem. Extensive experiments conducted on several benchmark datasets demonstrate that our method can significantly outperform a variety of state-of-the-art few-shot learning methods.

Requirements

  • Python 3
  • Python packages
    • pytorch 1.0.0
    • torchvision 0.2.2
    • matplotlib
    • numpy
    • pillow
    • tensorboardX

An NVIDIA GPU and CUDA 9.0 or higher.

Getting started

mini-ImageNet

You can download miniImagenet dataset from here.

tiered-ImageNet

You can download tieredImagenet dataset from here.

Because WRN has a large amount of parameters. You can save the extracted feature before the classifaction layer to increase train or test speed. Here we provide the features extracted by WRN:

You also can use our pretrained WRN model to generate features for mini or tiered by yourself

Training

# ************************** miniImagenet, 5way 1shot  *****************************
$ python3 conv4_train.py --dataset mini --num_ways 5 --num_shots 1 
$ python3 WRN_train.py --dataset mini --num_ways 5 --num_shots 1 

# ************************** miniImagenet, 5way 5shot *****************************
$ python3 conv4_train.py --dataset mini --num_ways 5 --num_shots 5 
$ python3 WRN_train.py --dataset mini --num_ways 5 --num_shots 5 

# ************************** tieredImagenet, 5way 1shot *****************************
$ python3 conv4_train.py --dataset tiered --num_ways 5 --num_shots 1 
$ python3 WRN_train.py --dataset tiered --num_ways 5 --num_shots 1 

# ************************** tieredImagenet, 5way 5shot *****************************
$ python3 conv4_train.py --dataset tiered --num_ways 5 --num_shots 5 
$ python3 WRN_train.py --dataset tiered --num_ways 5 --num_shots 5 

# **************** miniImagenet, 5way 5shot, 20% labeled (semi) *********************
$ python3 conv4_train.py --dataset mini --num_ways 5 --num_shots 5 --num_unlabeled 4

You can download our pretrained model from here to reproduce the results of the paper.

Testing

# ************************** miniImagenet, Cway Kshot *****************************
$ python3 conv4_eval.py --test_model your_path --dataset mini --num_ways C --num_shots K 
$ python3 WRN_eval.py --test_model your_path --dataset mini --num_ways C --num_shots K 


trpn's People

Contributors

shihaobai avatar vickyfox avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.