Giter VIP home page Giter VIP logo

megvii-research / mdistiller Goto Github PK

View Code? Open in Web Editor NEW
776.0 8.0 118.0 1.05 MB

The official implementation of [CVPR2022] Decoupled Knowledge Distillation https://arxiv.org/abs/2203.08679 and [ICCV2023] DOT: A Distillation-Oriented Trainer https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf

Python 100.00%
pytorch knowledge-distillation computer-vision deep-learning cifar coco cvpr2022 imagenet iccv2023

mdistiller's Introduction

This repo is

(1) a PyTorch library that provides classical knowledge distillation algorithms on mainstream CV benchmarks,

(2) the official implementation of the CVPR-2022 paper: Decoupled Knowledge Distillation.

(3) the official implementation of the ICCV-2023 paper: DOT: A Distillation-Oriented Trainer.

DOT: A Distillation-Oriented Trainer

Framework

Main Benchmark Results

On CIFAR-100:

Teacher
Student
ResNet32x4
ResNet8x4
VGG13
VGG8
ResNet32x4
ShuffleNet-V2
KD 73.33 72.98 74.45
KD+DOT 75.12 73.77 75.55

On Tiny-ImageNet:

Teacher
Student
ResNet18
MobileNet-V2
ResNet18
ShuffleNet-V2
KD 58.35 62.26
KD+DOT 64.01 65.75

On ImageNet:

Teacher
Student
ResNet34
ResNet18
ResNet50
MobileNet-V1
KD 71.03 70.50
KD+DOT 71.72 73.09

Decoupled Knowledge Distillation

Framework & Performance

Main Benchmark Results

On CIFAR-100:

Teacher
Student
ResNet56
ResNet20
ResNet110
ResNet32
ResNet32x4
ResNet8x4
WRN-40-2
WRN-16-2
WRN-40-2
WRN-40-1
VGG13
VGG8
KD 70.66 73.08 73.33 74.92 73.54 72.98
DKD 71.97 74.11 76.32 76.23 74.81 74.68
Teacher
Student
ResNet32x4
ShuffleNet-V1
WRN-40-2
ShuffleNet-V1
VGG13
MobileNet-V2
ResNet50
MobileNet-V2
ResNet32x4
MobileNet-V2
KD 74.07 74.83 67.37 67.35 74.45
DKD 76.45 76.70 69.71 70.35 77.07

On ImageNet:

Teacher
Student
ResNet34
ResNet18
ResNet50
MobileNet-V1
KD 71.03 70.50
DKD 71.70 72.05

MDistiller

Introduction

MDistiller supports the following distillation methods on CIFAR-100, ImageNet and MS-COCO:

Method Paper Link CIFAR-100 ImageNet MS-COCO
KD https://arxiv.org/abs/1503.02531
FitNet https://arxiv.org/abs/1412.6550
AT https://arxiv.org/abs/1612.03928
NST https://arxiv.org/abs/1707.01219
PKT https://arxiv.org/abs/1803.10837
KDSVD https://arxiv.org/abs/1807.06819
OFD https://arxiv.org/abs/1904.01866
RKD https://arxiv.org/abs/1904.05068
VID https://arxiv.org/abs/1904.05835
SP https://arxiv.org/abs/1907.09682
CRD https://arxiv.org/abs/1910.10699
ReviewKD https://arxiv.org/abs/2104.09044
DKD https://arxiv.org/abs/2203.08679

Installation

Environments:

  • Python 3.6
  • PyTorch 1.9.0
  • torchvision 0.10.0

Install the package:

sudo pip3 install -r requirements.txt
sudo python3 setup.py develop

Getting started

  1. Wandb as the logger
  • The registeration: https://wandb.ai/home.
  • If you don't want wandb as your logger, set CFG.LOG.WANDB as False at mdistiller/engine/cfg.py.
  1. Evaluation
  • You can evaluate the performance of our models or models trained by yourself.

  • Our models are at https://github.com/megvii-research/mdistiller/releases/tag/checkpoints, please download the checkpoints to ./download_ckpts

  • If test the models on ImageNet, please download the dataset at https://image-net.org/ and put them to ./data/imagenet

    # evaluate teachers
    python3 tools/eval.py -m resnet32x4 # resnet32x4 on cifar100
    python3 tools/eval.py -m ResNet34 -d imagenet # ResNet34 on imagenet
    
    # evaluate students
    python3 tools/eval.p -m resnet8x4 -c download_ckpts/dkd_resnet8x4 # dkd-resnet8x4 on cifar100
    python3 tools/eval.p -m MobileNetV1 -c download_ckpts/imgnet_dkd_mv1 -d imagenet # dkd-mv1 on imagenet
    python3 tools/eval.p -m model_name -c output/your_exp/student_best # your checkpoints
  1. Training on CIFAR-100
  • Download the cifar_teachers.tar at https://github.com/megvii-research/mdistiller/releases/tag/checkpoints and untar it to ./download_ckpts via tar xvf cifar_teachers.tar.

    # for instance, our DKD method.
    python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml
    
    # you can also change settings at command line
    python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml SOLVER.BATCH_SIZE 128 SOLVER.LR 0.1
  1. Training on ImageNet
  • Download the dataset at https://image-net.org/ and put them to ./data/imagenet

    # for instance, our DKD method.
    python3 tools/train.py --cfg configs/imagenet/r34_r18/dkd.yaml
  1. Training on MS-COCO
  1. Extension: Visualizations

Custom Distillation Method

  1. create a python file at mdistiller/distillers/ and define the distiller
from ._base import Distiller

class MyDistiller(Distiller):
    def __init__(self, student, teacher, cfg):
        super(MyDistiller, self).__init__(student, teacher)
        self.hyper1 = cfg.MyDistiller.hyper1
        ...

    def forward_train(self, image, target, **kwargs):
        # return the output logits and a Dict of losses
        ...
    # rewrite the get_learnable_parameters function if there are more nn modules for distillation.
    # rewrite the get_extra_parameters if you want to obtain the extra cost.
  ...
  1. regist the distiller in distiller_dict at mdistiller/distillers/__init__.py

  2. regist the corresponding hyper-parameters at mdistiller/engines/cfg.py

  3. create a new config file and test it.

Citation

If this repo is helpful for your research, please consider citing the paper:

@article{zhao2022dkd,
  title={Decoupled Knowledge Distillation},
  author={Zhao, Borui and Cui, Quan and Song, Renjie and Qiu, Yiyu and Liang, Jiajun},
  journal={arXiv preprint arXiv:2203.08679},
  year={2022}
}
@article{zhao2023dot,
  title={DOT: A Distillation-Oriented Trainer},
  author={Zhao, Borui and Cui, Quan and Song, Renjie and Liang, Jiajun},
  journal={arXiv preprint arXiv:2307.08436},
  year={2023}
}

License

MDistiller is released under the MIT license. See LICENSE for details.

Acknowledgement

  • Thanks for CRD and ReviewKD. We build this library based on the CRD's codebase and the ReviewKD's codebase.

  • Thanks Yiyu Qiu and Yi Shi for the code contribution during their internship in MEGVII Technology.

  • Thanks Xin Jin for the discussion about DKD.

mdistiller's People

Contributors

zzzzz1 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mdistiller's Issues

About the coefficient 1000.0 in DKD.py

in line 21 - 26 in DKD.py

pred_teacher_part2 = F.softmax(
    logits_teacher / temperature - 1000.0 * gt_mask, dim=1
)
log_pred_student_part2 = F.log_softmax(
    logits_student / temperature - 1000.0 * gt_mask, dim=1
)

What dose the coefficient 1000.0 represent? Is it the number of classes ?

Can't reproduce the accuracy from wrn_40_2 to wrn_16_2

Hello, I have reproduced the experimental results on the imagenet and some results on cifra100, such as, from VGG13 to MobileNet-V2.

However, the best result from wrn_40_2 to the wrn_16_2 is only 75.15, which is significantly smaller than the 76.24 in the paper (Tabel 11).

All the hyper-parameters used are from the config folder of this repository. And the platform is 1080ti on ubuntu 18.04.

Is there anything I have overlooked?

Or could you share the log for this experiment?

On the issue of code reproduction

Hello, I am around 0.2~0.7 percentage points lower than the results given in the paper on all models. May I know the possible reason? I used one GPU for the analysis. Thank you for your answer.

code problem

Hello, when reproducing your code, the results printed out include the top-1 and top-5 accuracies for each epoch, is this the accuracy of the student network or the teacher network or the distilled student network? At the end, a best_acc is also given, whose best_acc is this result?I would be grateful for your reply.

detection 复现

作者您好!我尝试在mmdetection的框架下复现您在fasterRCNN的实验,但结果一直不同,能求您分享一下FasterRCNNr101-r50的log吗?
同时我有一些小的细节问题。我看论文里说是对proposal的cls结果进行kd,但代码中是对rpn head 再过roi_head里的bbox_head得到的结果。第二是这里有用inherit吗?
谢谢!!!

关于图像分割

作者,您好,请问DKD能否用于图像分割蒸馏的任务呢

Is Distributed Training Supported?

Hi,

Thank you for the repository, it makes really easy for the user to try out different approaches. One question, does this repo support distributed training? If so, how do I run it in DDP mode?

Thanks.

About the implementation of FitNets

Hello, your work on knowledge distillation is great!
However, I have some problems about the code of FitNets.
I found you just use sum of losses to get backward, specifically, the loss_feat and loss_ce are passed together to the trainer directly. But I think that it is supposed to train initial weights of intermediate layers using feature loss then train the whole student model with ce loss, according to original paper. I wonder if I get something wrong about this or I misunderstand the process? Look forward to ur reply.

Different Dataset on DKD transforms.Resize error

Hi, I was trying to use a different dataset with two classes and image size 1024 in this repo, however, I get an error RuntimeError: Function AddmmBackward returned an invalid gradient at index 1 - got [2, 256] but expected shape compatible with [2, 4096] when I use transforms.Resize(128) in the dataloader transforms function. However if I use transforms.Resize(32), the pipeline works. Could you help me out here.
Screenshot 2022-08-04 182334

关于mask的一些问题

尊敬的作者,您好,请教您一个问题:
在复现kdk的代码中,我发现了关于mask的三个函数

def _get_gt_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
    return mask


def _get_other_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
    return mask


def cat_mask(t, mask1, mask2):
    t1 = (t * mask1).sum(dim=1, keepdims=True)
    t2 = (t * mask2).sum(1, keepdims=True)
    rt = torch.cat([t1, t2], dim=1)
    return rt

而这三个函数其实在paper中的伪代码中是没有的,请问这三个函数有何特殊意义呢?

code question

Is there a caching mechanism? Why is the code of the error line reported when running the code has been modified and still reports the previous error?

Can't reproduce the accuracy.

I have tried many times, but I couldn't reproduce the experimental results in the paper(76.32%)(res32x4_res8x4)
The results of five rounds of experiments are:76.20,75.80,75.89,75.85,76.15
The environment is the same as that in readme.md: python 3.6, torch1.9.0, torchvision0.10.0
The GPU of my device is 2080Ti and the operating system is Ubuntu 20.04
Is there anything I have overlooked?

Training ResNet weights on custom dataset for further distilling

Given a custom object detection dataset in coco format, I would like to re-train a ResNet101 as a teacher model, and a ResNet18 as a student model for further distilling. But I don't see how to turn off the distillation in this repo.

Any methods for only training detection models?

关于DKD和CE的权重问题

论文附录的A.4 basic setting中提到了KD和CE两个损失的权重都设置为1.0的问题。我理解最终的损失函数是这个样子的:
LOSS = 1.0 * (alpha * TCKD + beta * NCKD) + 1.0 * CE
如果alpha=1, beta=8,那么KD损失其实是放大了很多倍的,如果CE本身很小,那么CE几乎就不起作用了?
所以想了解下文中实验中,CE和KD损失的量级是否有比较大的差别?
谢谢!

Question about the alpha weight in ImageNet training

First, thanks for the great work. We find the alpha is set to 0.5 in the paper for ImageNet, but the config files in the code: configs/imagenet/r34_r18/dkd.yaml, configs/imagenet/r50_mv1/dkd.yaml seem to use alpha=1 (the default value in cfg.py). What is the real alpha value used in the report?

Example for custom models

The framework show really good results, is there any examples for custom models distilation process.

关于数据增强

请问是否使用过数据增强去提升模型精度呢?例如mixup等?
另外关于使用 timm 的 mixup ,target shape 会变化,3140001 >> 3141000 这样的话这个 mask 该怎么去计算呢?
谢谢~

麻烦检查一下paper中的ResNet50-MobileNetV2 setting

@Zzzzz1 Hi,您好,麻烦检查一下你们paper中的ResNet50-MobileNetV2的setting是否对比正确,表格里的其余方法实质上都是使用的MobileNetV1 model,这一点可以从以下几篇paper中验证:

ReviewKD:估计是paper里cite错了,虽然cite的是mbv2,但是实际上开源代码里用的是mbv1,参见:https://github.com/dvlab-research/ReviewKD/blob/master/ImageNet/models/mobilenet.py
OFD:paper里报的结果和cite的paper都是mbv1

如果是你们的student模型用错了,麻烦对表格进行一下更正吧。最近已经发现几篇最新的paper错误使用了mbv2的模型来对比之前mbv1的方法了。不胜感激!

Originally posted by @hunto in #4 (comment)

请问DOT的代码在哪里呢

您好,麻烦请问《DOT: A Distillation-Oriented Trainer》涉及的修改在哪个目录下呢,我看了代码没有找到,麻烦您能告诉我一下吗,谢谢!

关于MobileNet-v2在ImageNet的top-1 acc问题

您好,您在论文“Decoupled Knowledge Distillation”的Table 9中提到MobileNet-v2在ImageNet的top-1 acc为68.87%,这个指标是您训练出的结果么?因为MobileNet-v2的官方提供的指标为72.0%,您为何不基于72.0%的精度继续做kd的相关论证呢?希望您可以解答我的疑惑。

About validation

Hello, I downloaded your code, want to see the effect of the code, I downloaded the pre-training weight file to the specified folder as required, running
python3 tools/eval.py -m resnet32x4 # resnet32x4 on cifar100
The result of the occurrence is
XP_~~0{2CKKK_USNTAK}6SR
And when I run
python3 tools/eval.p -m resnet8x4 -c download_ckpts/dkd_resnet8x4
No corresponding file is displayed
I don't know why this is happening, and I want to know what kind of results should be obtained and look forward to your reply

Validation Splits

Hi,

Thanks for the great repo!

I was wondering if for CIFAR-100 and Imagenet there's also a "validation" (as a heldout subset of training set) is being created somewhere in the code? Because as I looked for it, it seems that we only have a train_dataset and test_dataset.

All the best.
Thanks again for the code :)

Cannot execute ReviewKD with wrn_40_2 & wrn_16_2

I tried running the code for the ReviewKD method with wrn_40_2 as the teacher and wrn_16_2 but I get the following.

RuntimeError: Given groups=1, weight of size [256, 256, 1, 1], expected input[64, 128, 1, 1] to have 256 channels, but got 128 channels instead

Is the WRN architecture set up correctly? Does anyone else faced this issue?

OS: Ubuntu 20.04
GPU: GTX 1060 6GB

how to understand nckd loss ?

捕获
I think the nckd loss is kl loss between teacher and student prediction among nckd output probability, whose shape should be (n,c-1),like Algorithm 1, the Pseudo code of DKD in your paper.
But, why you compute like this , is this equivalent? Could you give a further explanation, thanks!

Training on CIFAR-100

when i training on CIFAR-100,it will post " RuntimeError: Numpy is not available",why?

DKD可否用于其他领域

DKD看起来是对全连接层输出的logits进行蒸馏的,请问是否可以应用在没有全连接层的网络中,但看起来需要通道对齐操作是吗

code question

请问在训练过程中,代码的训练流程是先训练教师网络,然后进行蒸馏训练么?如果教师网络是预训练模型就直接进行蒸馏训练吗?

论文中TCKD中的target的定义

对于多分类任务,论文中关于target的定义是指对于每一类,都有一个对应的target类,其他类为该类下的NCKD吗?例如一个“猫,狗,人,车”的四分类任务中,是有四个target类,然后对应有四个NCKD吗?还是说选一个类为target,其他三类为non- target?

Reproduce Issue for WRN-40-2 / WRN40-1 on cifar-100

Hi, I tried to reproduce WRN40-2/WRN40-1 result on cifar-100 but I can only got up to 73.3, which is 1.5 lower than the reported result on paper. I used the original yaml file and didn't change any hyper-parameter. The other experiments I tried on cifar100 were always within 0.5 of the reported number, which looks fine for me. But 1.6 on WRN40-2/WRN40-1 seems a bit too large.

Train a single model on cifar 100

Hi, is there a way to train a single model like resnet32x4 on cifar 100 on this repo? Want to train the models from scratch without using the pretrained models.

Cannot reproduce ReviewKD ACC with ResNet50 to MobileNetV2 on CIFAR100

I used this yaml and get 66.82 ACC@1 on CIFAR100(69.89 on paper), could you give me some suggestion about yaml or other noteworthy details?

EXPERIMENT:
NAME: ""
TAG: "reviewkd,res50,mv2"
PROJECT: "cifar100_baselines"
DISTILLER:
TYPE: "REVIEWKD"
TEACHER: "ResNet50"
STUDENT: "MobileNetV2"
REVIEWKD:
REVIEWKD_WEIGHT: 7.0
SHAPES: [1, 2, 4, 8, 16]
OUT_SHAPES: [1, 4, 8, 16, 32]
IN_CHANNELS: [12, 16, 48, 160, 1280]
OUT_CHANNELS: [256, 512, 1024, 2048, 2048]
SOLVER:
BATCH_SIZE: 128
EPOCHS: 240
LR: 0.05
LR_DECAY_STAGES: [150, 180, 210]
LR_DECAY_RATE: 0.1
WEIGHT_DECAY: 0.0005
MOMENTUM: 0.9
TYPE: "SGD"

OS: Ubuntu 20.04
GPU: V100 32G

DKD代码问题

你好作者:我想问一下DKD部分的代码具体在哪里那个位置,可否给指出来。 可以详细一点多少行开始到多少行结束。
祝好!

训练不同的模型

请问一下作者的模型是如何训练的呢,我训练了自己的resnet网络然后用来替代我从链接里下载的网络,但是出现了错误,请问可以将模型训练的py文件分享一下嘛

code question

When I loaded my own training set for distillation training, the normal training was performed during the first training, the loss decreased normally, and the model could gradually fit. But when I train for the second time, if I change the number of iterations or the decline stage of the loss function, it will fail to fit. The verification TOP-1 is always 0.98, and the loss cannot be reduced normally. Can you solve this problem? Thank you so much

ReviewKD with VGG architecture

@Zzzzz1
I tried running the code for the ReviewKD method with vgg_13 as the teacher and vgg_8 but I get the following.

RuntimeError: Given groups=1, weight of size [256, 256, 1, 1], expected input[64, 512, 1, 1] to have 256 channels, but got 512 channels instead

Is the VGG architecture set up correctly? Does anyone else faced this issue?

OS: Ubuntu 20.04
GPU: GTX 1060 6GB

Which are the Shapes to be used for it to run properly ? Can you upload them for all the different architectures?

Thank you in advance!

关于_get_gt_mask()

作者您好,我不太理解您的您的_get_gt_mask()这个方法,希望您可以帮我解疑答惑!
您的代码中:
def get_gt_mask(logits, target):
target = target.reshape(-1)
mask = torch.zeros_like(logits).scatter
(1, target.unsqueeze(1), 1).bool()
return mask
但是 def scatter_(self, dim, index, src, reduce=None): scatter_的src是一个tensor,为什么您这里只写了一个1,而且我运行到这里的时候也会报错,还有为什么要target = target.reshape(-1).unsqueeze(1),这样的话target和logits的shape不就不一样了吗?还请您解答。

ignore index in DKD

Hi!Thanks for the great work.
But I found something wrong in dkd_loss codes.
when the dataset has ignore index, RuntimeError: cuda runtime error (59) : device-side triggered at ... problem will occur in the _get_gt_mask and _get_other_mask functions.
i wonder how to fix it? thanks.

How is the acc reported calculated?

Hello, I want to know how the acc in paper calculated?
In paper, I found All results are the average over 5/3 trials on CIFAR100/ImageNet, and I want to know whether the acc is the best acc or something else.
Many thanks!

Question about the hyper-parameters used in other KD methods on different cases

First of all, thank you for the excellent work. We are currently attempting to reproduce the performance of various KD methods, including FitNet, RKD, CRD, ReviewKD, and others, as detailed in the DKD paper. We have a question regarding the hyperparameters used in CIFAR-100 for different KD methods. Specifically, we are curious the values used across different teachers and students for these KD methods (except DKD). Would you mind posting these hyperparameters🥰?

Hyperparameter on Resnet50 - wrn16x2

Hi, in table 12 of your paper, the ACC from ResNet50 - wrn16x2 with DKD is 76.60. However, I use the Resnet50 ckpt provided in this repo and hyperparameters as follows:

SOLVER:
  BATCH_SIZE: 64
  EPOCHS: 240
  LR: 0.05
  LR_DECAY_STAGES: [150, 180, 210]
  LR_DECAY_RATE: 0.1
  WEIGHT_DECAY: 0.0005
  MOMENTUM: 0.9
  TYPE: "SGD"

The best ACC I got is only 75.52. Could you please share your config?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.