Giter VIP home page Giter VIP logo

dkt's Introduction

DKT

The official PyTorch implementation of our CVPR 2023 poster paper:

DKT: Diverse Knowledge Transfer Transformer for Class Incremental Learning

GitHub maintainer: Xinyuan Gao

Requirement

We use the
python == 3.9
torch == 1.11.0
torchvision == 0.12.0
timm == 0.5.4
continuum == 1.2.3

Accuracy

We provide the accuracy of every phase in different settings in the following table. You can also get them in the logs. (We run the official code again, it may be slightly different from the paper).

CIFAR 20—20 1 2 3 4 5 AVG
% 88.3 80.2 76.92 71.95 67.17 76.91
CIFAR 10—10 1 2 3 4 5 6 7 8 9 10 AVG
% 94.2 86.95 83.0 77.53 74.12 74.05 70.53 67.9 65.12 63.45 75.69
CIFAR 5—5 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 AVG
% 97.8 94.0 90.27 87.3 84.16 81.67 78.54 75.38 73.91 72.42 70.36 70.42 67.82 66.46 65.45 64.8 63.96 62.48 61.03 59.2 74.37
ImageNet100 10—10 1 2 3 4 5 6 7 8 9 10 AVG
% 91.6 85.8 81.53 79.35 77.28 76.57 73.49 71.6 70.2 68.74 77.62
ImageNet1000 100—100 1 2 3 4 5 6 7 8 9 10 AVG
% 85.02 80.12 76.5 73.7 70.26 68.36 66.35 64.1 61.81 58.93 70.52

Notice

If you want to run our experiment on different numbers of GPUs, you should set the Batch_size * GPUs == 512. For example, one GPU, the Batch size 512 and two GPUs, the Batch size 256 (CIFAR-100 and ImageNet100). If you want to change it, please try to change the hyperparameters. \

For CIFAR-100, you can use a single GPU with bs 512 or two GPUs with bs 256. (The accuracy is in the logs)
For ImageNet-100, we use two GPUs with bs 256
For ImageNet-1000, we use four GPUs with bs 256

Due to the rush in organizing time, if you encounter any situation, please contact my email [[email protected]]. Thanks

Acknowledgement

Our code is heavily based on the great codebase of Dytox, thanks for its wonderful code frame.

Also, a part of our code is inspired by the CSCCT, thanks for its code.

Trainer

You can use the following command to run the code like the Dytox:

bash train.sh 0,1 
    --options options/data/cifar100_10-10.yaml options/data/cifar100_order1.yaml options/model/cifar_DKT.yaml 
    --name DKT 
    --data-path MY_PATH_TO_DATASET 
    --output-basedir PATH_TO_SAVE_CHECKPOINTS 
    --memory-size 2000

Citation

If any parts of our paper and code help your research, please consider citing us and giving a star to our repository.

@InProceedings{Gao_2023_CVPR, 
    author    = {Gao, Xinyuan and He, Yuhang and Dong, Songlin and Cheng, Jie and Wei, Xing and Gong, Yihong}, 
    title     = {DKT: Diverse Knowledge Transfer Transformer for Class Incremental Learning}, 
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 
    month     = {June}, 
    year      = {2023}, 
    pages     = {24236-24245} 
}

dkt's People

Contributors

misaka-mio avatar miv-xjtu avatar

Stargazers

Xiwen Liang avatar  avatar yuzhang avatar Qiwei Li avatar  avatar  avatar  avatar  avatar  avatar Dynasty avatar Haitao avatar  avatar Yifan Bai avatar

Watchers

Yifan Bai avatar  avatar

dkt's Issues

关于分布式训练启动的问题

作者您好:
我用另外一种分布式训练启动您代码CIFAR100-10-10进行复现,结果非常接近,但是两张V100,每个epoch训练要51s左右,使用github上面的命令启动train.sh出现了一下问题:
给您邮箱发邮件时,显示address rejected。您空闲的时候看看就好,您的研究对我来说非常有帮助
root@cd98f4d76410:/code/DKT_source# bash train.sh 0,1 --options /code/DKT_source/options/data/cifar100_10-10.yaml /code/DKT_source/options/data/cifar100_order1.yaml /code/DKT_source/options/model/cifar_DKT.yaml --name DKT --data-path /data/Logic888/CIFAR100/cifar100 --output-basedir /data/Logic888/CIFAR100/DKT/save_checkpoints --memory-size 2000
Launching exp on 0,1...
/opt/conda/lib/python3.10/site-packages/torch/distributed/launch.py:181: FutureWarning: The module torch.distributed.launch is deprecated
and will be removed in future. Use torchrun.
Note that --use-env is set by default in torchrun.
If your script expects --local-rank argument to be set, please
change it to read from os.environ['LOCAL_RANK'] instead. See
https://pytorch.org/docs/stable/distributed.html#launch-utility for
further instructions

warnings.warn(
WARNING:torch.distributed.run:


Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.


/code/DKT_source/continual/robust_models_ImageNet.py:391: UserWarning: Overwriting rvt_tiny in registry with continual.robust_models_ImageNet.rvt_tiny. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
def rvt_tiny(pretrained, **kwargs):
/code/DKT_source/continual/robust_models_ImageNet.py:411: UserWarning: Overwriting rvt_tiny_plus in registry with continual.robust_models_ImageNet.rvt_tiny_plus. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
def rvt_tiny_plus(pretrained, **kwargs):
/code/DKT_source/continual/robust_models_ImageNet.py:433: UserWarning: Overwriting rvt_small in registry with continual.robust_models_ImageNet.rvt_small. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
def rvt_small(pretrained, **kwargs):
/code/DKT_source/continual/robust_models_ImageNet.py:453: UserWarning: Overwriting rvt_small_plus in registry with continual.robust_models_ImageNet.rvt_small_plus. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
def rvt_small_plus(pretrained, **kwargs):
/code/DKT_source/continual/robust_models_ImageNet.py:475: UserWarning: Overwriting rvt_base in registry with continual.robust_models_ImageNet.rvt_base. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
def rvt_base(pretrained, **kwargs):
/code/DKT_source/continual/robust_models_ImageNet.py:495: UserWarning: Overwriting rvt_base_plus in registry with continual.robust_models_ImageNet.rvt_base_plus. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
def rvt_base_plus(pretrained, **kwargs):
usage: DKT training and evaluation script [-h] [--batch-size BATCH_SIZE] [--incremental-batch-size INCREMENTAL_BATCH_SIZE] [--epochs EPOCHS]
[--base-epochs BASE_EPOCHS] [--no-amp] [--model MODEL] [--input-size INPUT_SIZE] [--patch-size PATCH_SIZE]
[--embed-dim EMBED_DIM] [--depth DEPTH] [--num-heads NUM_HEADS] [--drop PCT] [--drop-path PCT]
[--norm {layer,scale}] [--opt OPTIMIZER] [--opt-eps EPSILON] [--opt-betas BETA [BETA ...]] [--clip-grad NORM]
[--momentum M] [--weight-decay WEIGHT_DECAY] [--sched SCHEDULER] [--lr LR] [--incremental-lr INCREMENTAL_LR]
[--lr-noise pct, pct [pct, pct ...]] [--lr-noise-pct PERCENT] [--lr-noise-std STDDEV] [--warmup-lr LR]
[--incremental-warmup-lr LR] [--min-lr LR] [--decay-epochs N] [--warmup-epochs N] [--cooldown-epochs N]
[--patience-epochs N] [--decay-rate RATE] [--color-jitter PCT] [--aa NAME] [--smoothing SMOOTHING]
[--train-interpolation TRAIN_INTERPOLATION] [--repeated-aug] [--no-repeated-aug] [--reprob PCT]
[--remode REMODE] [--recount RECOUNT] [--resplit] [--auto-kd] [--kd KD] [--distillation-tau DISTILLATION_TAU]
[--resnet] [--data-path DATA_PATH] [--data-set {CIFAR,IMNET,INAT,INAT19}]
[--data-path-subTrain DATA_PATH_SUBTRAIN] [--data-path-subVal DATA_PATH_SUBVAL]
[--inat-category {kingdom,phylum,class,order,supercategory,family,genus,name}] [--output-dir OUTPUT_DIR]
[--output-basedir OUTPUT_BASEDIR] [--device DEVICE] [--seed SEED] [--start_epoch N] [--eval] [--dist-eval]
[--num_workers NUM_WORKERS] [--pin-mem] [--no-pin-mem] [--initial-increment INITIAL_INCREMENT]
[--increment INCREMENT] [--class-order CLASS_ORDER [CLASS_ORDER ...]] [--eval-every EVAL_EVERY] [--debug]
[--retrain-scratch] [--max-task MAX_TASK] [--name NAME] [--options [OPTIONS ...]] [--DKT]
[--duplex-clf DUPLEX_CLF] [--memory-size MEMORY_SIZE] [--distributed-memory] [--global-memory]
[--oversample-memory OVERSAMPLE_MEMORY] [--oversample-memory-ft OVERSAMPLE_MEMORY_FT] [--rehearsal-test-trsf]
[--rehearsal-modes REHEARSAL_MODES] [--fixed-memory]
[--rehearsal {random,closest_token,closest_all,icarl_token,icarl_all,furthest_token,furthest_all}]
[--sep-memory] [--replay-memory REPLAY_MEMORY] [--finetuning {balanced}] [--finetuning-mode FINETUNING_MODE]
[--finetuning-lr FINETUNING_LR] [--finetuning-teacher] [--finetuning-resetclf] [--only-ft] [--ft-no-sampling]
[--freeze-task [FREEZE_TASK ...]] [--freeze-ft [FREEZE_FT ...]] [--freeze-eval] [--log-path LOG_PATH]
[--log-category LOG_CATEGORY] [--bce-loss] [--local_rank LOCAL_RANK] [--world_size WORLD_SIZE]
[--dist_url DIST_URL] [--resume RESUME] [--start-task START_TASK] [--start-epoch START_EPOCH]
[--save-every-epoch SAVE_EVERY_EPOCH] [--validation VALIDATION]
DKT training and evaluation script: error: unrecognized arguments: --local-rank=1
/code/DKT_source/continual/robust_models_ImageNet.py:391: UserWarning: Overwriting rvt_tiny in registry with continual.robust_models_ImageNet.rvt_tiny. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
def rvt_tiny(pretrained, **kwargs):
/code/DKT_source/continual/robust_models_ImageNet.py:411: UserWarning: Overwriting rvt_tiny_plus in registry with continual.robust_models_ImageNet.rvt_tiny_plus. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
def rvt_tiny_plus(pretrained, **kwargs):
/code/DKT_source/continual/robust_models_ImageNet.py:433: UserWarning: Overwriting rvt_small in registry with continual.robust_models_ImageNet.rvt_small. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
def rvt_small(pretrained, **kwargs):
/code/DKT_source/continual/robust_models_ImageNet.py:453: UserWarning: Overwriting rvt_small_plus in registry with continual.robust_models_ImageNet.rvt_small_plus. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
def rvt_small_plus(pretrained, **kwargs):
/code/DKT_source/continual/robust_models_ImageNet.py:475: UserWarning: Overwriting rvt_base in registry with continual.robust_models_ImageNet.rvt_base. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
def rvt_base(pretrained, **kwargs):
/code/DKT_source/continual/robust_models_ImageNet.py:495: UserWarning: Overwriting rvt_base_plus in registry with continual.robust_models_ImageNet.rvt_base_plus. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
def rvt_base_plus(pretrained, **kwargs):
usage: DKT training and evaluation script [-h] [--batch-size BATCH_SIZE] [--incremental-batch-size INCREMENTAL_BATCH_SIZE] [--epochs EPOCHS]
[--base-epochs BASE_EPOCHS] [--no-amp] [--model MODEL] [--input-size INPUT_SIZE] [--patch-size PATCH_SIZE]
[--embed-dim EMBED_DIM] [--depth DEPTH] [--num-heads NUM_HEADS] [--drop PCT] [--drop-path PCT]
[--norm {layer,scale}] [--opt OPTIMIZER] [--opt-eps EPSILON] [--opt-betas BETA [BETA ...]] [--clip-grad NORM]
[--momentum M] [--weight-decay WEIGHT_DECAY] [--sched SCHEDULER] [--lr LR] [--incremental-lr INCREMENTAL_LR]
[--lr-noise pct, pct [pct, pct ...]] [--lr-noise-pct PERCENT] [--lr-noise-std STDDEV] [--warmup-lr LR]
[--incremental-warmup-lr LR] [--min-lr LR] [--decay-epochs N] [--warmup-epochs N] [--cooldown-epochs N]
[--patience-epochs N] [--decay-rate RATE] [--color-jitter PCT] [--aa NAME] [--smoothing SMOOTHING]
[--train-interpolation TRAIN_INTERPOLATION] [--repeated-aug] [--no-repeated-aug] [--reprob PCT]
[--remode REMODE] [--recount RECOUNT] [--resplit] [--auto-kd] [--kd KD] [--distillation-tau DISTILLATION_TAU]
[--resnet] [--data-path DATA_PATH] [--data-set {CIFAR,IMNET,INAT,INAT19}]
[--data-path-subTrain DATA_PATH_SUBTRAIN] [--data-path-subVal DATA_PATH_SUBVAL]
[--inat-category {kingdom,phylum,class,order,supercategory,family,genus,name}] [--output-dir OUTPUT_DIR]
[--output-basedir OUTPUT_BASEDIR] [--device DEVICE] [--seed SEED] [--start_epoch N] [--eval] [--dist-eval]
[--num_workers NUM_WORKERS] [--pin-mem] [--no-pin-mem] [--initial-increment INITIAL_INCREMENT]
[--increment INCREMENT] [--class-order CLASS_ORDER [CLASS_ORDER ...]] [--eval-every EVAL_EVERY] [--debug]
[--retrain-scratch] [--max-task MAX_TASK] [--name NAME] [--options [OPTIONS ...]] [--DKT]
[--duplex-clf DUPLEX_CLF] [--memory-size MEMORY_SIZE] [--distributed-memory] [--global-memory]
[--oversample-memory OVERSAMPLE_MEMORY] [--oversample-memory-ft OVERSAMPLE_MEMORY_FT] [--rehearsal-test-trsf]
[--rehearsal-modes REHEARSAL_MODES] [--fixed-memory]
[--rehearsal {random,closest_token,closest_all,icarl_token,icarl_all,furthest_token,furthest_all}]
[--sep-memory] [--replay-memory REPLAY_MEMORY] [--finetuning {balanced}] [--finetuning-mode FINETUNING_MODE]
[--finetuning-lr FINETUNING_LR] [--finetuning-teacher] [--finetuning-resetclf] [--only-ft] [--ft-no-sampling]
[--freeze-task [FREEZE_TASK ...]] [--freeze-ft [FREEZE_FT ...]] [--freeze-eval] [--log-path LOG_PATH]
[--log-category LOG_CATEGORY] [--bce-loss] [--local_rank LOCAL_RANK] [--world_size WORLD_SIZE]
[--dist_url DIST_URL] [--resume RESUME] [--start-task START_TASK] [--start-epoch START_EPOCH]
[--save-every-epoch SAVE_EVERY_EPOCH] [--validation VALIDATION]
DKT training and evaluation script: error: unrecognized arguments: --local-rank=0
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 2) local_rank: 0 (pid: 633) of binary: /opt/conda/bin/python
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launch.py", line 196, in
main()
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launch.py", line 192, in main
launch(args)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launch.py", line 177, in launch
run(args)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 785, in run
elastic_launch(
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

/code/DKT_source/main.py FAILED

Failures:
[1]:
time : 2023-12-05_07:36:20
host : cd98f4d76410
rank : 1 (local_rank: 1)
exitcode : 2 (pid: 634)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Root Cause (first observed failure):
[0]:
time : 2023-12-05_07:36:20
host : cd98f4d76410
rank : 0 (local_rank: 0)
exitcode : 2 (pid: 633)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.