Giter VIP home page Giter VIP logo

neurai-lab / cls-er Goto Github PK

View Code? Open in Web Editor NEW
43.0 2.0 8.0 86 KB

The official PyTorch code for ICLR'22 Paper "Learning Fast, Learning Slow: A General Continual Learning Method based on Complementary Learning System""

License: MIT License

Python 100.00%
continual-learning deep-learning deep-neural-networks incremental-learning online-continual-learning lifelong-machine-learning brain-inspired

cls-er's Introduction

Learning Fast, Learning Slow

Official Repository for ICLR'22 Paper "Learning Fast, Learning Slow: A General Continual Learning Method based on Complementary Learning System"

Screenshot 2023-08-07 at 10 47 11

We extended the Mammoth framework with our method (CLS-ER) and GCIL-CIFAR-100 dataset

Additional Results

For a more extensive evaluation of our our method and benchmarking, we evaluated CLS-ER on S-CIFAR100 with 5 Tasks and also provide the Task-IL results for all the settings. Note that similar to DER, Task-IL results merely use logit masking at inference.

S-MNIST S-CIFAR-10 S-CIFAR-100 S-TinyImg
Buffer Size Class-IL Task-IL Class-IL Task-IL Class-IL Task-IL Class-IL Task-IL
200 89.54±0.21 97.97±0.17 66.19±0.75 93.90±0.60 43.80±1.89 73.49±1.04 23.47±0.80 49.60±0.72
500 92.05±0.32 98.95±0.10 75.22±0.71 94.94±0.53 51.40±1.00 78.12±0.24 31.03±0.56 60.41±0.50
5120 95.73±0.11 99.40±0.04 86.78±0.17 97.08±0.09 65.77±0.49 84.46±0.45 46.74±0.31 75.81±0.35

Setup

  • Use python main.py to run experiments.

  • Use argument --load_best_args to use the best hyperparameters for each of the evaluation setting from the paper.

  • To reproduce the results in the paper run the following

    python main.py --dataset <dataset> --model <model> --buffer_size <buffer_size> --load_best_args

Examples:

python main.py --dataset seq-mnist --model clser --buffer_size 500 --load_best_args

python main.py --dataset seq-cifar10 --model clser --buffer_size 500 --load_best_args

python main.py --dataset seq-tinyimg --model clser --buffer_size 500 --load_best_args

python main.py --dataset perm-mnist --model clser --buffer_size 500 --load_best_args

python main.py --dataset rot-mnist --model clser --buffer_size 500 --load_best_args

python main.py --dataset mnist-360 --model clser --buffer_size 500 --load_best_args
  • For GCIL-CIFAR-100 Experiments

    python main.py --dataset <dataset> --weight_dist <weight_dist> --model <model> --buffer_size <buffer_size> --load_best_args

Example:

python main.py --dataset gcil-cifar100 --weight_dist unif --model clser --buffer_size 500 --load_best_args

python main.py --dataset gcil-cifar100 --weight_dist longtail --model clser --buffer_size 500 --load_best_args

Requirements

  • torch==1.7.0

  • torchvision==0.9.0

  • quadprog==0.1.7

Cite Our Work

If you find the code useful in your research, please consider citing our paper:

@inproceedings{
  arani2022learning,
  title={Learning Fast, Learning Slow: A General Continual Learning Method based on Complementary Learning System},
  author={Elahe Arani and Fahad Sarfraz and Bahram Zonooz},
  booktitle={International Conference on Learning Representations},
  year={2022},
  url={https://openreview.net/forum?id=uxxFrDwrE7Y}
}

cls-er's People

Contributors

elahearani avatar fahad92virgo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

cls-er's Issues

Questions about reproducing

Hi! Thanks for the awesome work! I am trying to reproduce your results but I have some questions about the details. Thanks for your answers in advance.

  1. I found that, within each task, you utilize the model of the last epoch rather than the best one. Are there any special considerations?

  2. I reproduced your work on a single RTX 3090 GPU, with the default seed 1993, but I am not sure whether I chose the right value as the final accuracy. I chose the accuracy of the last task, e.g., in seq-mnist experiment, the value of task 5's accuracy, which is also the final accuracy shown in the terminal. Am I right?
    aa954fbec7ff74b2858dcc753b02cce

  3. Could you please let us know the exact seeds you were using? So that we might be able to get closer results.

  4. We found that, during the training of most continual learning approaches, your code will have an accuracy first before starting training, as the following image. I guess it is pre-train, but I did not find the pre-trained setup. Could you please indicate how these values are obtained?
    1651084312(1)

Thank you very much again!

Problems for iCaRL

thanks for sharing!
But when I was debugging the model "icarl", I found that the function forward in class ICarL is not used at all.
I wonder whether it is normal or there's sth wrong with it.

Dataset issue

Hello, I try to reproduce your work in scope of 'ML Reproducibility Challenge 2022' but I faced with few problem about code base.
- There is a problem about compressed SEQ-TINYIMG file, unzip tools returns 'tiny-imagenet-processed.zip not an archive' error. I tried different OS but unable to export files in zip file.
- For gcil-cifar100 dataset given --weight_dist parameter is not recognized by main.py so it is not reproducible

Could you help me about this issues ?

Questions about Experience Replay

Hello, your paper has greatly inspired me, and I have some questions about it. In general class-incremental learning scenarios, reservoir sampling is performed after training each batch. So, during the training of the first task, the data within the task will also be replayed during the training process?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.