Giter VIP home page Giter VIP logo

meta-transfer-learning-tensorflow's Introduction

๐Œ๐ž๐ญ๐š-๐“๐ซ๐š๐ง๐ฌ๐Ÿ๐ž๐ซ ๐‹๐ž๐š๐ซ๐ง๐ข๐ง๐  ๐“๐ž๐ง๐ฌ๐จ๐ซ๐…๐ฅ๐จ๐ฐ

LICENSE

This repository contains the TensorFlow implementation for CVPR 2019 Paper "Meta-Transfer Learning for Few-Shot Learning" by Qianru Sun*, Yaoyao Liu*, Tat-Seng Chua and Bernt Schiele.

If you have any problems when running this repository, feel free to send me an email or open an issue. I will reply to you as soon as I see them. (Email: liuyaoyao at tju.edu.cn)

๐’๐ฎ๐ฆ๐ฆ๐š๐ซ๐ฒ

๐ˆ๐ง๐ญ๐ซ๐จ๐๐ฎ๐œ๐ญ๐ข๐จ๐ง

Meta-learning has been proposed as a framework to address the challenging few-shot learning setting. The key idea is to leverage a large number of similar few-shot tasks in order to learn how to adapt a base-learner to a new task for which only a few labeled samples are available. As deep neural networks (DNNs) tend to overfit using a few samples only, meta-learning typically uses shallow neural networks (SNNs), thus limiting its effectiveness. In this paper we propose a novel few-shot learning method called meta-transfer learning (MTL) which learns to adapt a deep NN for few shot learning tasks. Specifically, meta refers to training multiple tasks, and transfer is achieved by learning scaling and shifting functions of DNN weights for each task. In addition, we introduce the hard task (HT) meta-batch scheme as an effective learning curriculum for MTL. We conduct experiments using (5-class, 1-shot) and (5-class, 5-shot) recognition tasks on two challenging few-shot learning benchmarks: miniImageNet and Fewshot-CIFAR100. Extensive comparisons to related works validate that our meta-transfer learning approach trained with the proposed HT meta-batch scheme achieves top performance. An ablation study also shows that both components contribute to fast convergence and high accuracy.

Figure: Meta-Transfer Learning. (a) Parameter-level fine-tuning (FT) is a conventional meta-training operation, e.g. in MAML. Its update works for all neuron parameters, ๐‘Š and ๐‘. (b) Our neuron-level scaling and shifting (SS) operations in meta-transfer learning. They reduce the number of learning parameters and avoid overfitting problems. In addition, they keep large-scale trained parameters (in yellow) frozen, preventing โ€œcatastrophic forgettingโ€.

๐ˆ๐ง๐ฌ๐ญ๐š๐ฅ๐ฅ๐š๐ญ๐ข๐จ๐ง

In order to run this repository, we advise you to install python 2.7 and TensorFlow 1.3.0 with Anaconda.

You may download Anaconda and read the installation instruction on their official website: https://www.anaconda.com/download/

Create a new environment and install tensorflow on it:

conda create --name mtl python=2.7
conda activate mtl
conda install tensorflow-gpu==1.3.0

Clone this repository:

git clone https://github.com/y2l/meta-transfer-learning-tensorflow.git 
cd meta-transfer-learning-tensorflow

Install other requirements:

pip install scipy
pip install tqdm
pip install opencv-python

๐ƒ๐š๐ญ๐š๐ฌ๐ž๐ญ๐ฌ

๐’Ž๐’Š๐’๐’Š๐ˆ๐ฆ๐š๐ ๐ž๐๐ž๐ญ

The miniImageNet dataset was proposed by Vinyals et al. for few-shot learning evaluation. Its complexity is high due to the use of ImageNet images but requires fewer resources and infrastructure than running on the full ImageNet dataset. In total, there are 100 classes with 600 samples of 84ร—84 color images per class. These 100 classes are divided into 64, 16, and 20 classes respectively for sampling tasks for meta-training, meta-validation, and meta-test.

To generate this dataset from ImageNet, you may use the repository miniImageNet tools. You may also directly download processed images. [Download Page]

๐…๐ž๐ฐ๐ฌ๐ก๐จ๐ญ-๐‚๐ˆ๐…๐€๐‘๐Ÿ๐ŸŽ๐ŸŽ

Fewshot-CIFAR100 (FC100) is based on the popular object classification dataset CIFAR100. The splits were proposed by TADAM. It offers a more challenging scenario with lower image resolution and more challenging meta-training/test splits that are separated according to object super-classes. It contains 100 object classes and each class has 600 samples of 32 ร— 32 color images. The 100 classes belong to 20 super-classes. Meta-training data are from 60 classes belonging to 12 super-classes. Meta-validation and meta-test sets contain 20 classes belonging to 4 super-classes, respectively.

You may directly download processed images. [Download Page]

๐’•๐’Š๐’†๐’“๐’†๐’…๐ˆ๐ฆ๐š๐ ๐ž๐๐ž๐ญ

The tieredImageNet dataset is a larger subset of ILSVRC-12 with 608 classes (779,165 images) grouped into 34 higher-level nodes in the ImageNet human-curated hierarchy.

To generate this dataset from ImageNet, you may use the repository tieredImageNet dataset: tieredImageNet tools. You may also directly download processed images. [Download Page]

๐‘๐ž๐ฉ๐จ ๐€๐ซ๐œ๐ก๐ข๐ญ๐ž๐œ๐ญ๐ฎ๐ซ๐ž

.
โ”œโ”€โ”€ data_generator              # dataset generator 
|   โ”œโ”€โ”€ pre_data_generator.py   # data genertor for pre-train phase
|   โ””โ”€โ”€ meta_data_generator.py  # data genertor for meta-train phase
โ”œโ”€โ”€ models                      # tensorflow model files 
|   โ”œโ”€โ”€ models.py               # basic model class
|   โ”œโ”€โ”€ pre_model.py.py         # pre-train model class
|   โ””โ”€โ”€ meta_model.py           # meta-train model class
โ”œโ”€โ”€ trainer                     # tensorflow trianer files  
|   โ”œโ”€โ”€ pre.py                  # pre-train trainer class
|   โ””โ”€โ”€ meta.py                 # meta-train trainer class
โ”œโ”€โ”€ utils                       # a series of tools used in this repo
|   โ””โ”€โ”€ misc.py                 # miscellaneous tool functions
โ”œโ”€โ”€ main.py                     # the python file with main function and parameter settings
โ””โ”€โ”€ run_experiment.py           # the script to run the whole experiment

๐”๐ฌ๐š๐ ๐ž

To run the experiments:

python run_experiment.py

You may edit the run_experiment.py file to change the hyperparameters and options.

  • LOG_DIR Name of the folder to save the log files
  • GPU_ID GPU device id
  • PRE_TRA_LABEL Additional label for pre-train model
  • PRE_TRA_ITER_MAX Iteration number for the pre-train phase
  • PRE_TRA_DROP Dropout keep rate for the pre-train phase
  • PRE_DROP_STEP Iteration number for the pre-train learning rate reducing
  • PRE_LR Pre-train learning rate
  • SHOT_NUM Sample number for each class
  • WAY_NUM Class number for the few-shot tasks
  • MAX_MAX_ITER Iteration number for meta-train phase
  • META_BATCH_SIZE Meta batch size
  • PRE_ITER Iteration number for the pre-train model used in the meta-train phase
  • UPDATE_NUM Epoch number for the base learning
  • SAVE_STEP Iteration number to save the meta model
  • META_LR Meta learning rate
  • META_LR_MIN Meta learning rate min value
  • LR_DROP_STEP Iteration number for the meta learning rate reducing
  • BASE_LR Base learning rate
  • PRE_TRA_DIR Directory for the pre-train phase images
  • META_TRA_DIR Directory for the meta-train images
  • META_VAL_DIR Directory for the meta-validation images
  • META_TES_DIR Directory for the meta-test images

The file run_experiment.py is just a script to generate commands for main.py. If you want to change other settings, please see the comments and descriptions in main.py.

In the default setting, if you run python run_experiment.py, the pretrain process will be conducted before the meta-train phase starts. If you want to use the model pretrained by us, you may download the model by the following link then replace the pretrain model loading directory in trainer/meta.py.

Download Pretain Model (miniImageNet): [Google Drive] [็™พๅบฆ็ฝ‘็›˜] (ๆๅ–็ : efsv)

We will release more pre-trained models later.

๐“๐จ๐๐จ

  • ๐‡๐š๐ซ๐ ๐ญ๐š๐ฌ๐ค ๐ฆ๐ž๐ญ๐š-๐›๐š๐ญ๐œ๐ก. The implementation of hard task meta-batch is not included in the published code. I still need time to rewrite the hard task meta batch code for the current framework.
  • ๐Œ๐จ๐ซ๐ž ๐ง๐ž๐ญ๐ฐ๐จ๐ซ๐ค ๐š๐ซ๐œ๐ก๐ข๐ญ๐ž๐œ๐ญ๐ฎ๐ซ๐ž๐ฌ. We will add new backbones to the framework like ResNet18 and ResNet34.
  • ๐๐ฒ๐“๐จ๐ซ๐œ๐ก ๐ฏ๐ž๐ซ๐ฌ๐ข๐จ๐ง. We will release the code for MTL on pytorch. It may takes several months to be completed.

๐‚๐ข๐ญ๐š๐ญ๐ข๐จ๐ง

Please cite our paper if it is helpful to your work:

@inproceedings{sun2019mtl,
  title={Meta-Transfer Learning for Few-Shot Learning},
  author={Qianru Sun and Yaoyao Liu and Tat{-}Seng Chua and Bernt Schiele},
  booktitle={CVPR},
  year={2019}
}

๐€๐œ๐ค๐ง๐จ๐ฐ๐ฅ๐ž๐๐ ๐ž๐ฆ๐ž๐ง๐ญ๐ฌ

Our implementation uses the source code from the following repositories:

Model-Agnostic Meta-Learning

Optimization as a Model for Few-Shot Learning

meta-transfer-learning-tensorflow's People

Contributors

yaoyao-liu avatar

Watchers

 avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.