Giter VIP home page Giter VIP logo

evoprompt's Introduction

EvoPrompt

The official PyTorch implementation of our AAAI 2024 (Oral) paper:

Evolving Parameterized Prompt Memory for Continual Learning

Muhammad Rifki Kurniawan, Xiang Song, Zhiheng Ma, Yuhang He, Yihong Gong, Qi Yang, Xing Wei.

GitHub maintainer: Muhammad Rifki Kurniawan

Highlight

🔖Brief Introduction

Recent studies have demonstrated the potency of leveraging prompts in Transformers for continual learning (CL). Nevertheless, employing a discrete key-prompt bottleneck can lead to selection mismatches and inappropriate prompt associations during testing. Furthermore, this approach hinders adaptive prompting due to the lack of shareability among nearly identical instances at more granular level. To address these challenges, we introduce the Evolving Parameterized Prompt Memory (EvoPrompt), a novel method involving adaptive and continuous prompting attached to pre-trained Vision Transformer (ViT), conditioned on specific instance. We formulate a continuous prompt function as a neural bottleneck and encode the collection of prompts on network weights. We establish a paired prompt memory system consisting of a stable reference and a flexible working prompt memory. Inspired by linear mode connectivity, we progressively fuse the working prompt memory and reference prompt memory during inter-task periods, resulting in continually evolved prompt memory. This fusion involves aligning functionally equivalent prompts using optimal transport and aggregating them in parameter space with an adjustable bias based on prompt node attribution. Additionally, to enhance backward compatibility, we propose compositional classifier initialization, which leverages prior prototypes from pre-trained models to guide the initialization of new classifiers in a subspace-aware manner.

Preparing Environment

Create Anaconda environment and installing the dependencies and library (we use CUDA 12.2):

# conda environment
conda env create -f environment.yml
conda activate cl

# install inclearn library
pip install -e .

Training EvoPrompt

Run the following script to reproduce our experiments on Split CIFAR100, ImageNet-R, and CORe50 on 5, 10, and 20 tasks.

  • Split CIFAR100
# 5 tasks
bash scripts/evoprompt/train-cifar100-5_tasks.sh

# 10 tasks
bash scripts/evoprompt/train-cifar100-10_tasks.sh

# 20 tasks
bash scripts/evoprompt/train-cifar100-20_tasks.sh
  • Split ImageNet-R
# 5 tasks
bash scripts/evoprompt/train-imagenetr-5_tasks.sh

# 10 tasks
bash scripts/evoprompt/train-imagenetr-10_tasks.sh

# 20 tasks
bash scripts/evoprompt/train-imagenetr-20_tasks.sh
  • CORe50
bash scripts/evoprompt/train-core50.sh

You can also reproduce our experiments on following simple baseline using these scripts:

Acknowledgement

Our continual trainer is built on top of the following projects and sincerely appreciate great open-source code:

Citation

If you find our paper or code useful for your research, we'd be thrilled if you could citing our work.

@article{kurniawan2024evoprompt,
  title={Evolving Parameterized Prompt Memory for Continual Learning},
  volume={38},
  url={https://ojs.aaai.org/index.php/AAAI/article/view/29231},
  DOI={10.1609/aaai.v38i12.29231},
  number={12},
  journal={Proceedings of the AAAI Conference on Artificial Intelligence},
  author={Kurniawan, Muhammad Rifki and Song, Xiang and Ma, Zhiheng and He, Yuhang and Gong, Yihong and Qi, Yang and Wei, Xing},
  year={2024},
  month={Mar.},
  pages={13301-13309} }

evoprompt's People

Contributors

mrifkikurniawan avatar miv-xjtu avatar

Stargazers

Qy avatar Qiwei Li avatar Xupeng (Tony) Tong avatar Fangwen Wu avatar Yifan Bai avatar Kosuke Akizuki avatar Randy Pangestu avatar

Watchers

 avatar

evoprompt's Issues

Forgetting metric calculation issue

Congratulations on the great work. While going through the released code, I encounter a small issue that I thought might be worth bringing to your attention.

def forgetting(accuracies):

def forgetting(accuracies):
    if len(accuracies) == 1:
        return 0.0
    last_accuracies = accuracies[-1]
    usable_tasks = last_accuracies.keys()
    forgetting = 0.0
    for task in usable_tasks:
        if task == "total" or task == "average_accuracy":
            continue
        max_task = 0.0
        for task_accuracies in accuracies[:-1]:
            if task in task_accuracies:
                max_task = max(max_task, task_accuracies[task])
        forgetting += max_task - last_accuracies[task]
    return forgetting / len(usable_tasks)

In the forgetting metric calculation in metric.py. For a experiment with T = 10 tasks, in the 10th (last) task, the variable last_accuracies is a Dict with keys {'total': , '00-19': , '20-39':, '40-59': , '60-79':, '80-99': , '100-119': , '120-139': , '140-159': , '160-179': , '180-199': , 'average_accuracy': }, hence the len(usable_tasks) = 12. In the last line of the code, the forgetting value is divided by the length of usable_tasks, which is 12 in this particular case. However, according to the definition of the forgetting metric, the forgetting should be divided by T-1 instead. Please correct me if I made any mistakes. Thank you for your attention.

When will the code be released?

The paper said that the code will be released, but it has been a long time. I hope to give a definite time to experience the effect of the paper. I have been paying attention to the task.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.