Giter VIP home page Giter VIP logo

alda's Introduction

Adversarial-Learned Loss for Domain Adaptation

By Minghao Chen, Shuai Zhao, Haifeng Liu, Deng Cai.

Introduction

A PyTorch implementation for our AAAI 2020 paper "Adversarial-Learned Loss for Domain Adaptation" (ALDA). In ALDA, we use a domain discriminator to correct the noise in the pseudo-label. ALDA outperforms state-of-the-art approaches in four standard unsupervised domain adaptation datasets.

pic1

Requirements

The code is implemented with Python(3.6) and Pytorch(1.0.0).

Install the newest Pytorch from https://pytorch.org/.

To install the required python packages, run

pip install -r requirements.txt

Setup

Digits:

Download SVHN dataset and unzip it at data/svhn2mnist.

Office-31

Download Office-31 dataset and unzip it at data/office.

Office-Home

Download Office-Home dataset and unzip it at data/office-home.

VisDA-2017

Download VisDA-2017 dataset

Training

Digits:

SVHN->MNIST
python train_svhnmnist.py ALDA --gpu_id 0 --epochs 50 --loss_type all --start_epoch 2 --threshold 0.6

USPS->MNIST
python train_uspsmnist.py ALDA --gpu_id 0 --epochs 50 --task USPS2MNIST --loss_type all --start_epoch 2 --threshold 0.6

MNIST->USPS
python train_uspsmnist.py ALDA --gpu_id 0 --epochs 50 --task MNIST2USPS --loss_type all --start_epoch 2 --threshold 0.6

Office-31:

Amazon->Webcam
python  train.py ALDA --gpu_id 0 --net ResNet50 --dset office --test_interval 500 --s_dset_path ./data/office/amazon_list.txt --t_dset_path ./data/office/webcam_list.txt --batch_size 36 --trade_off 1 --output_dir "A2W_ALDA_all_thresh=0.9_test" --loss_type all --threshold 0.9

We provide a shell file to train all six adaptation tasks at once.

sh train.sh

Office-Home

Train all twelve adaptation tasks at once:

sh train_home.sh

VisDA-2017

The code of VisDA-2017 dataset is still processing.

Results

The code is tested on GTX 1080 with cuda-9.0.

The results presented in the paper:

pic2

pic3

Citation

If you use this code in your research, please cite:

@article{chen2020adversariallearned,
    title={Adversarial-Learned Loss for Domain Adaptation},
    author={Minghao Chen and Shuai Zhao and Haifeng Liu and Deng Cai},
    journal={arXiv},
    year={2020},
    volume={abs/2001.01046}
}

Acknowledgment

The structure of this code is largely based on CDAN. We are very grateful for their open source.

alda's People

Contributors

minghchen avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.