Giter VIP home page Giter VIP logo

tpt's Introduction

Test-Time Prompt Tuning (TPT) for zero-shot generalization in Vision-Language Models

This repository provides the official PyTorch implementation of our NeurIPS 2022 paper:

Test-Time Prompt Tuning for Zero-shot Generalization in Vision-Language Models
Authors: Manli Shu, Weili Nie, De-An Huang, Tom Goldstein, Anima Anandkumar, Chaowei Xiao

For more details, please check out our project page and paper.

Overview

This repository contains the implementation of TPT for image classification with a pre-trained CLIP. We consider 3 different initializations for test-time prompt tuning:

  • Using a hand-crafted prompt as initialization (e.g., "a photo of a ___")
  • Using a learned soft prompt (CoOp) as initialization.
  • Using the output of a trained conditional prompt learner (CoCoOp) as initialization.

Prerequisites

Hardware

This implementation is for the single-GPU configuration.

To evaluate on ImageNet, ImageNet-V2, and ImageNet-Sketch (which has 1000 classes), you will need a GPU with more than (not including) 16GB memory. This codebase is tested on a GPU with 24GB memory. To evaluate other datasets (with less than a few hundred classes), a GPU with 16GB memory will work fine.

Environment

The code is tested on PyTorch 1.7.1.

Datasets

We suggest downloading all datasets to a root directory (${data_root}), and renaming the directory of each dataset as suggested in ${ID_to_DIRNAME} in ./data/datautils.py. This would allow you to evaluate multiple datasets within the same run.
If this is not feasible, you could evaluate different datasets separately, and change the ${data_root} accordingly in the bash script.

For out-of-distribution generalization, we consider 5 datasets:

For cross-datasets generalization, we consider 10 datasets:

For cross-dataset generalization, we adopt the same train/val/test splits as CoOp. Please refer to this page, and look for download links of split_zhou_${dataset_name}.json, and put the json files under ./data/data_splits/.

Run TPT

We provide three bash scripts under ./scripts. You can modify the paths and other args in the scripts.

An example to run TPT with CoOp initialization on out-of-distribution datasets:

bash ./scripts/test_coop.sh I/A/V/R/K.

The command line arg ${testsets} can be multiple test datasets split by "/" (, which are stored under the same root dir ${data_root}).
Note that for simplicity, we use set_id to denote different datasets. A complete list of set_id can be found in ${ID_to_DIRNAME} in ./data/datautils.py.

Main Results

Out-of-Distribution Generalization

Method ImageNet(IN) IN-A IN-V2 IN-R IN-Sketch Average OOD Average
CLIP-RN50 58.16 21.83 51.41 56.15 33.37 44.18 40.69
Ensembled prompt 59.81 23.24 52.91 60.72 35.48 46.43 43.09
CoOp 63.33 23.06 55.40 56.60 34.67 46.61 42.43
CoCoOp 62.81 23.32 55.72 57.74 34.48 46.81 42.82
TPT (ours) 60.74 26.67 54.7 59.11 35.09 47.26 43.89
TPT + CoOp 64.73 30.32 57.83 58.99 35.86 49.55 45.75
TPT + CoCoOp 62.93 27.40 56.60 59.88 35.43 48.45 44.83

Cross-Dataset Generalization

In each matrix $A$, $A_{i, j}$ is the normalized relative improvement on the $j_{th}$ dataset of using the prompt tuned on the $i$-th dataset. The value $A_{i, j}$ stands for how well a method trained on a source dataset $i$ performs on a target dataset $j$, in comparison with a zero-shot CLIP baseline (using a hand-crafted prompt). Thus, the higher, the better. The last row is the performance of TPT, which is not tuned on any source dataset. The last column summarizes the average improvement over 10 datasets, measuring the overall generalization ability across the 10 datasets.

Cross-dataset improvement normalized by the zero-shot baseline performance.

Citation

If you find our code useful or our work relevant, please consider citing:

@inproceedings{shu2022tpt,
  author    = {Manli, Shu and Weili, Nie and De-An, Huang and Zhiding, Yu and Tom, Goldstein and Anima, Anandkumar and Chaowei, Xiao},
  title     = {Test-Time Prompt Tuning for Zero-shot Generalization in Vision-Language Models},
  booktitle = {NeurIPS},
  year      = {2022},
}

Acknowledgements

We thank the authors of CoOp/CoCoOp for their open-source implementation and instructions on data preparation.

tpt's People

Contributors

azshue avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.