Giter VIP home page Giter VIP logo

tpl's Introduction

Class Incremental Learning via Likelihood Ratio Based Task Prediction

This repository contains the code for our ICLR2024 paper Class Incremental Learning via Likelihood Ratio Based Task Prediction by Haowei Lin, Yijia Shao, Weinan Qian, Ningxin Pan, Yiduo Guo, and Bing Liu.

Update [2024.2.10]: Now we support DER++, Non-CL, and more pre-trained visual encoders!

Quick Links

Overview

Requirements

First, install PyTorch by following the instructions from the official website. We run the experiments on Pytorch 2.0.1, and PyTorch version higher than 1.6.0 should also work. For example, if you use Linux and CUDA11 (how to check CUDA version), install PyTorch by the following command,

pip install torch==1.6.0+cu110 -f https://download.pytorch.org/whl/torch_stable.html

If you instead use CUDA <11 or CPU, install PyTorch by the following command,

pip install torch==1.6.0

Then run the following script to install the remaining dependencies,

pip install -r requirements.txt

Attention: Our model is based on timm==0.4.12. Using them from other versions may cause some unexpected bugs.

Training

In the following section, we describe how to train the TPL model by using our code.

Data

Before training and evaluation, please download the datasets (CIFAR-10, CIFAR-100, TinyImageNet). The default working directory is set as ~/data in our code. You can modify it according to your need.

Pre-train Model

We use the pre-train DeiT model provided by MORE. Please download it and save the file as ./ckpt/pretrained/deit_small_patch16_224_in661.pth. If you would like to test other pre-trained visual encoders, also download to the same place (you can find the pre-trained weights in timm or huggingface). We provide the scripts for Dino, MAE, CILP, ViT (small, tiny), DeiT (small, tiny).

Training scripts

We provide the examplar training and evaluation script as deit_small_in661.sh. Just run the following command and you will get the results:

bash scripts/deit_small_in661.sh

This script performs both training and testing. The default training will train TPL for 5 random seeds. In training, the results will be logged in ckpt and the training results are $HAT_{CIL}$ without using TPLR inference techniques. After running evaluation, it will be replaced with new results. If you find you get a bad results, try to check if you run the eval.py accurately. The results for the first run with seed=2023 will be saved in ./ckpt/seq0/seed2023/progressive_main_2023.

For the results in the paper, we use Nvidia A100 GPUs with CUDA 11.7. Using different types of devices or different versions of CUDA/other software may lead to slightly different performance.

Extension

Our repo also supports running baselines like DER++. If you are interested in other baselines, just follow the same way of DER++ to integrate your new code. Also, if you want to test TIL+OOD methods, you can just modify the inference code and include the OOD score computation in baseline.py. Our code base is vey extensible.

Bugs or questions?

If you have any questions related to the code or the paper, feel free to email Haowei. If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker!

Acknowledgements

We thank PyContinual for providing an extensible framework for continual learning. We use their code structure as a reference when developing this code base.

Citation

Please cite our paper if you use this code or part of it in your work:

@inproceedings{lin2024class,
      title={Class Incremental Learning via Likelihood Ratio Based Task Prediction}, 
      author={Haowei Lin and Yijia Shao and Weinan Qian and Ningxin Pan and Yiduo Guo and Bing Liu},
      year={2024},
      booktitle={International Conference on Learning Representations}
}

tpl's People

Contributors

linhaowei1 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

tpl's Issues

About equation (9) in the paper.

I find that the score is taken as the opposite number in the code. Why is the final score calculated in this way?
The equation (9) in the paper is:
屏幕截图 2024-07-18 143313
The following is the code implementation:

composition = -torch.logsumexp(torch.stack((-e1, -e2), dim=0), dim=0)

ImageNet-380 list

Hello, I was wondering if you could provide the list of classes (611) used to pretrain the network, and the ones used for CIL (380).
Thanks

Unable to reproduce the results in orginal paper.

I am trying to reproduce CIFAR100-10T, but get much lower accuracy:
The results of first two task.
{0: {'tp_acc': 0.681, 'acc': 0.3}, 1: {'tp_acc': 0.651, 'acc': 0.259}, 'auroc': 0.57520375, 'fpr@95': 0.9095, 'aupr': 0.5692886335071314}

I found that vit_hat cannot learn each task correctly (til acc is low). What's the problem?

Non-pretrained results

Hello, I'm trying to reproduce your results without using a pre-trained network (i.e., using ResNet-18).
I've followed https://github.com/k-gyuhak/WPTP implementation to use HAT + TPL's adapter, but I'm unable to reproduce the results reported on your paper (Table 8, Appendix).
Could you upload the training code for such results?

Thanks

Request for Open-Sourcing Experimental Code for Comparison Methods

Hello, I am very interested in your research work and greatly appreciate you sharing the open-source code. I am currently conducting related research and would like to explore further. Could you please also share the code for the comparison experiments mentioned in your paper? It would greatly assist the progress of my research. Thank you very much for your help and support.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.