Giter VIP home page Giter VIP logo

ampl-main's Introduction

AMPL

This is the PyTorch implementation of "AMPL: An Adaptive Meta-Prompt Learner for Few-Shot Image Classification".

  • The code will be updated soon.

Installation

Python 3.8, Pytorch 1.11, CUDA 11.3. The code is tested on Ubuntu 20.04.

We have prepared a conda YAML file which contains all the python dependencies.

conda env create -f environment.yml

To activate this conda environment,

conda activate ampl

We use wandb to log the training stats (optional).

Datasets

  • π’Žπ’Šπ’π’ŠImageNet

    The π‘šπ‘–π‘›π‘–ImageNet dataset was proposed by Vinyals et al. for few-shot learning evaluation. Its complexity is high due to the use of ImageNet images but requires fewer resources and infrastructure than running on the full ImageNet dataset. In total, there are 100 classes with 600 samples of color images per class. These 100 classes are divided into 64, 16, and 20 classes respectively for sampling tasks for meta-training, meta-validation, and meta-test. To generate this dataset from ImageNet, you may use the repository π‘šπ‘–π‘›π‘–ImageNet tools.

  • π’•π’Šπ’†π’“π’†π’…ImageNet

    The π‘‘π‘–π‘’π‘Ÿπ‘’π‘‘ImageNet dataset is a larger subset of ILSVRC-12 with 608 classes (779,165 images) grouped into 34 higher-level nodes in the ImageNet human-curated hierarchy. To generate this dataset from ImageNet, you may use the repository π‘‘π‘–π‘’π‘Ÿπ‘’π‘‘ImageNet dataset: π‘‘π‘–π‘’π‘Ÿπ‘’π‘‘ImageNet tools.

  • CIFAR-FS and FC100

    CIFAR-FS and FC100 can be download using the scripts from DeepEMD.

Training

Our model are trained on 8 RTX3090 GPUs by default (24GB memory). You can specify the argument --nproc_per_node in the following command file as the number of GPUs available in your server, and increase/decrease the argument --batch_size_per_gpu if your GPU has more/less memory.

  • Pre-training (self-supervised)

    In this phase, we pretrain our model using the self-supervised learning method iBOT and SMKD (use the ). All models are trained for a maximum of 1600 epochs. We evaluate our model on the validation set after training for every 50 epochs, and report the best. 1-shot and 5-shot evaluation results with Prototype method is given in the following table. We also provide full checkpoints and test-set features for pretrained models, and command to replicate the results.

    --data_path: need to be set as the location of the training set of dataset XXX (e.g. miniImageNet). --output_dir: location where the phase1 checkpoints and evaluation files to be stored.

    Dataset 1-shot 5-shot Download
    π’Žπ’Šπ’π’ŠImageNet 60.93% 80.38% checkpoint features command
    π’•π’Šπ’†π’“π’†π’…ImageNet 71.36% 83.28% checkpoint features command
    CIFAR-FS 65.70% 83.45% checkpoint features command
    FC100 44.20% 61.64% checkpoint features command
  • Aaptive Meta-Prompt Learner (supervised)

    In this second phase, we start from the checkpoint in phase 1 and further train the model using the supervised knowledge distillation method proposed in our paper. All models are trained for a maximum of 150 epochs. We evaluate our model on the validation set after training for every 5 epochs, and report the best. Similarly, 1-shot and 5-shot evaluation results with Prototype method is given in the following table. We also provide checkpoints and features for pretrained models.

    --pretrained_dino_path: should be set as the same location as --output_dir in phase1. --pretrained_dino_file: which checkpoint file to resume from (e.g. checkpoint1250.pth). --output_dir: location where the phase2 checkpoints and evaluation files to be stored.

    Dataset 1-shot 5-shot Download
    π’Žπ’Šπ’π’ŠImageNet 74.82% 88.47% checkpoint features command
    π’•π’Šπ’†π’“π’†π’…ImageNet 78.98% 91.61% checkpoint features command
    CIFAR-FS 78.69% 90.68% checkpoint features command
    FC100 58.34% 72.25% checkpoint features command

Evaluation

We use eval_ampl.py to evaluate a trained model. Before running the evaluation code, we need to specify the image data path in server_dict of this python file.

For example, we can use the following code to do 5-way 5-shot evaluation on the model trained in Pre-training on mini-ImageNet:

  • prototype:
python eval_ampl.py --server mini --num_shots 5 --ckp_path /root/autodl-nas/FSVIT_results/MINI480_phase2 --ckpt_filename checkpoint0040.pth --output_dir /root/autodl-nas/FSVIT_results/MINI480_prototype --evaluation_method cosine --iter_num 10000
  • classifier:
python eval_ampl.py --server mini --num_shots 5 --ckp_path /root/autodl-nas/FSVIT_results/MINI480_phase2 --ckpt_filename checkpoint0040.pth --output_dir /root/autodl-nas/FSVIT_results/MINI480_classifier --evaluation_method classifier --iter_num 1000

Citation

@inproceedings{,
      title={AMPL: An Adaptive Meta-Prompt Learner for Few-Shot Image Classification}, 
      author={},
      booktitle={},
      year={},
      pages={},
}

ampl-main's People

Contributors

woodszp avatar

Watchers

Kostas Georgiou avatar Howard H. Tang avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.