Giter VIP home page Giter VIP logo

tdsm's Introduction

Label-Noise Robust Diffusion Models (TDSM) (ICLR 2024)

| openreview | arXiv | datasets | checkpoints |


This repo contains an official PyTorch implementation for the paper "Label-Noise Robust Diffusion Models" in ICLR 2024.

Byeonghu Na, Yeongmin Kim, HeeSun Bae, Jung Hyun Lee, Se Jung Kwon, Wanmo Kang, and Il-Chul Moon


This paper proposes Transition-aware weighted Denoising Score Matching (TDSM) objective for training conditional diffusion models with noisy labels.

(a) Examples of noisy labeled datasets of MNIST (top) and CIFAR-10 (bottom), and (b-c) the randomly generated images of baseline and our models, trained with the noisy labeled datasets.

The training procedure of the proposed approach. The solid black arrows indicate the forward propagation, and the dashed red arrows represent the gradient signal flow. The filled circle operation denotes the dot product operation, and the dashed operation represents the L2 loss. The noisy-label classifier $\tilde{\mathbf{h}}_{\boldsymbol{\phi}^*}$ can be obtained by the cross-entropy loss on the noisy labeled dataset $\tilde{D}$.

Requirements

The requirements for this code are the same as those outlined for EDM.

In our experiment, we utilized 8 NVIDIA Tesla P40 GPUs, employing CUDA 11.4 and PyTorch 1.12 for training.

Datasets

Datasets follow the same format used in StyleGAN and EDM, where are stored as uncompressed ZIP archives containing uncompressed PNG files, accompanied by a metadata file dataset.json for label information.

Noisy Labeled Dataset

For the benchmark datasets, we add arguments to adjust the noise type and noise rate. You can change --noise_type ('sym', 'asym') and --noise_rate (0 to 1).

For example, the script to contruct the CIFAR-10 dataset under 40% symmetric noise is:

python dataset_tool.py --source=downloads/cifar10/cifar-10-python.tar.gz \
    --dest=datasets/cifar10_sym_40-32x32.zip --noise_type=sym --noise_rate=0.4

Additionally, we provide the noisy labeled datasets that we used by this link.

First download each dataset ZIP archive, then replace dataset.json file in the ZIP archive with the corresponding json file.

Training

Classifiers

You can train new classifiers using train_classifier.py. For example:

torchrun --standalone --nproc_per_node=1 train_classifier.py --outdir=classifier-runs \
    --data=datasets/cifar10_sym_40-32x32.zip --cond=1 --arch=ddpmpp --batch 1024

Label-Noise Robust Diffusion Models

You can train the diffusion models with the TDSM objective using train_noise.py. For example:

torchrun --standalone --nproc_per_node=8 train_noise.py --outdir=noise-runs \
    --data=datasets/cifar10_sym_40-32x32.zip --cond=1 --arch=ddpmpp \
    --cls=/path/to/classifier/network-snapshot-200000.pkl \
    --noise_type=sym --noise_rate=0.4

Pre-trained Models

We provide the pre-trained models for classifiers, baselines, and our models on noisy labeled datasets by this link.

Generate Samples

You can generate samples using generate.py. For example:

python generate.py --seeds=0-63 --steps=18 --class=0 --outdir=/path/to/output \
    --network=/path/to/score

Acknowledgements

This work is heavily built upon the code from:

Citation

@inproceedings{
na2024labelnoise,
title={Label-Noise Robust Diffusion Models},
author={Byeonghu Na and Yeongmin Kim and HeeSun Bae and Jung Hyun Lee and Se Jung Kwon and Wanmo Kang and Il-Chul Moon},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=HXWTXXtHNl}
}

tdsm's People

Stargazers

 avatar  avatar NH2 avatar Donghyeok Shin avatar  avatar Yeongmin Kim avatar Byeonghu Na avatar  avatar

Watchers

 avatar Kostas Georgiou avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.