Giter VIP home page Giter VIP logo

dg's Introduction

Refining Generative Process with Discriminator Guidance in Score-based Diffusion Models (DG) (ICML 2023 Oral)
Official PyTorch implementation of the Discriminator Guidance

Dongjun Kim *, Yeongmin Kim *, Se Jung Kwon, Wanmo Kang, and Il-Chul Moon
* Equal contribution

| paper |

Overview

Teaser image

Step-by-Step Running of Discriminator Guidance

1) Prepare a pre-trained score network

  • Download edm-cifar10-32x32-uncond-vp.pkl at EDM for unconditional model.
  • Download edm-cifar10-32x32-cond-vp.pkl at EDM for conditional model.
  • Place EDM checkpoint at the directory specified below.
${project_page}/DG/
├── checkpoints
│   ├── pretrained_score/edm-cifar10-32x32-uncond-vp.pkl
│   ├── pretrained_score/edm-cifar10-32x32-cond-vp.pkl
├── ...

2) Generate fake samples

  • To draw 50k unconditional samples, run:
python3 generate.py --network checkpoints/pretrained_score/edm-cifar10-32x32-uncond-vp.pkl --outdir=samples/cifar_uncond_vanilla --dg_weight_1st_order=0
  • To draw 50k conditional samples, run:
python3 generate.py --network checkpoints/pretrained_score/edm-cifar10-32x32-cond-vp.pkl --outdir=samples/cifar_cond_vanilla --dg_weight_1st_order=0

3) Prepare real data

${project_page}/DG/
├── data
│   ├── true_data.npz
│   ├── true_data_label.npz
├── ...

4) Prepare a pre-trained classifier

${project_page}/DG/
├── checkpoints
│   ├── ADM_classifier/32x32_classifier.pt
├── ...

5) Train a discriminator

${project_page}/DG/
├── checkpoints/discriminator
│   ├── cifar_uncond/discriminator_60.pt
│   ├── cifar_cond/discriminator_250.pt
├── ...
  • To train the unconditional discriminator from scratch, run:
python3 train.py
  • To train the conditional discriminator from scratch, run:
python3 train.py --savedir=/checkpoints/discriminator/cifar_cond --gendir=/samples/cifar_cond_vanilla --datadir=/data/true_data_label.npz --cond=1 

6) Generate discriminator-guided samples

  • To generate unconditional discriminator-guided 50k samples, run:
python3 generate.py --network checkpoints/pretrained_score/edm-cifar10-32x32-uncond-vp.pkl --outdir=samples/cifar_uncond
  • To generate conditional discriminator-guided 50k samples, run:
python3 generate.py --network checkpoints/pretrained_score/edm-cifar10-32x32-cond-vp.pkl --outdir=samples/cifar_cond --dg_weight_1st_order=1 --cond=1 --discriminator_ckpt=/checkpoints/discriminator/cifar_cond/discriminator_250.pt --boosting=1

7) Evaluate FID

${project_page}/DG/
├── stats
│   ├── cifar10-32x32.npz
├── ...
  • Run:
python3 fid_npzs.py --ref=/stats/cifar10-32x32.npz --num_samples=50000 --images=/samples/cifar_uncond/
python3 fid_npzs.py --ref=/stats/cifar10-32x32.npz --num_samples=50000 --images=/samples/cifar_cond/

Experimental Results

EDM-G++

FID-50k Cifar-10 Cifar-10(conditional) FFHQ64
EDM 2.03 1.82 2.39
EDM-G++ 1.77 1.64 1.98

Other backbones

FID-50k Cifar-10 CelebA64
Backbone 2.10 1.90
Backbone-G++ 1.94 1.34

Note that we use LSGM for Cifar-10 backbone, and Soft-Truncation for CelebA64 backbone.
See alsdudrla10/DG_imagenet for the results and released code on ImageNet256.

Samples from unconditional Cifar-10

Teaser image

Samples from conditional Cifar-10

Teaser image

Reference

If you find the code useful for your research, please consider citing

@article{kim2022refining,
  title={Refining Generative Process with Discriminator Guidance in Score-based Diffusion Models},
  author={Kim, Dongjun and Kim, Yeongmin and Kang, Wanmo and Moon, Il-Chul},
  journal={arXiv preprint arXiv:2211.17091},
  year={2022}
}

This work is heavily built upon the code from

  • Karras, T., Aittala, M., Aila, T., & Laine, S. (2022). Elucidating the design space of diffusion-based generative models. arXiv preprint arXiv:2206.00364.
  • Dhariwal, P., & Nichol, A. (2021). Diffusion models beat gans on image synthesis. Advances in Neural Information Processing Systems, 34, 8780-8794.
  • Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456.

dg's People

Contributors

aailabkaist avatar

Stargazers

 avatar Xinyu Wang avatar  avatar  avatar  avatar  avatar Byeonghu Na avatar NH2 avatar Donghyeok Shin avatar  avatar Yeongmin Kim avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.