Giter VIP home page Giter VIP logo

nasty-teacher's Introduction

Undistillable: Making A Nasty Teacher That CANNOT teach students

License: MIT

"Undistillable: Making A Nasty Teacher That CANNOT teach students"

Haoyu Ma, Tianlong Chen, Ting-Kuei Hu, Chenyu You, Xiaohui Xie, Zhangyang Wang
In ICLR 2021 Spotlight Oral

Overview

  • We propose the concept of Nasty Teacher, a defensive approach to prevent knowledge leaking and unauthorized model cloning through KD without sacrificing performance.
  • We propose a simple yet efficient algorithm, called self-undermining knowledge distillation, to directly build a nasty teacher through self-training, requiring no additional dataset nor auxiliary network.

Prerequisite

We use Pytorch 1.4.0, and CUDA 10.1. You can install them with

conda install pytorch=1.4.0 torchvision=0.5.0 cudatoolkit=10.1 -c pytorch

It should also be applicable to other Pytorch and CUDA versions.

Then install other packages by

pip install -r requirements.txt

Usage

Teacher networks

Step 1: Train a normal teacher network
python train_scratch.py --save_path [XXX]

Here, [XXX] specifies the directory of params.json, which contains all hyperparameters to train a network. We already include all hyperparameters in experiments to reproduce the results in our paper.

For example, normally train a ResNet18 on CIFAR-10

python train_scratch.py --save_path experiments/CIFAR10/baseline/resnet18

After finishing training, you will get training.log, best_model.tar in that directory.

The normal teacher network will serve as the adversarial network for the training of the nasty teacher.

Step 2: Train a nasty teacher network
python train_nasty.py --save_path [XXX]

Again, [XXX] specifies the directory of params.json, which contains the information of adversarial networks and hyperparameters for training.
You need to specify the architecture of adversarial network and its checkpoint in this file.

For example, train a nasty ResNet18

python train_nasty.py --save_path experiments/CIFAR10/kd_nasty_resnet18/nasty_resnet18

Knowledge Distillation for Student networks

You can train a student distilling from normal or nasty teachers by

python train_kd.py --save_path [XXX]

Again, [XXX] specifies the directory of params.json, which contains the information of student networks and teacher networks

For example,

  • train a plain CNN distilling from a nasty ResNet18
python train_kd.py --save_path experiments/CIFAR10/kd_nasty_resnet18/cnn
  • Train a plain CNN distilling from a normal ResNet18
python train_kd.py --save_path experiments/CIFAR10/kd_normal_resnet18/cnn

Citation

@inproceedings{
ma2021undistillable,
title={Undistillable: Making A Nasty Teacher That {\{}CANNOT{\}} teach students},
author={Haoyu Ma and Tianlong Chen and Ting-Kuei Hu and Chenyu You and Xiaohui Xie and Zhangyang Wang},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=0zvfm-nZqQs}
}

Acknowledgement

nasty-teacher's People

Contributors

howiema avatar tianlong-chen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

nasty-teacher's Issues

The asymmetry of KL divergence.

Hi, I notice that in the 'train_nasty.py', when the KL divergence is computed, normal teacher's output (output_stu) is regarded as input and nasty teacher's output (output_tch) is regarded as target. However, in general KD, the fixed model (teacher) is usually regarded as the target and the model that needs update is regarded as the input.

I wonder why you adopt an opposite order in KL loss function. Is there any point here? Thanks!

The accuracy of resnet family

Hi, thanks for your excellent work~
Training a baseline model(train_scratch.py) on cifar100, resnet18 achieved a higher accuracy than resnet34 and resnrt50. I have nothing changed.
Do you know what is wrong? Looking for your reply!

About the setting.

Hi there, thanks for your great work!!!
I have a question about the setting of this work. Does this work assume that the 'stealer' can access and only access the logits output of the teacher model? And the model resources like model structure and parameters are invisible?

The accuracy of the nasty resnext29 on CIFAR-100.

Thanks for your excellent work!
When I use the nasty resnext29 as the teacher, the accuracy of the distilled resnet18, shufflenetV2, and resnext 29 is 73.83, 64.82, and 79.87 respectively. I did not change the json file. I wonder if some parameters need additional settings. Looking for your reply!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.