Giter VIP home page Giter VIP logo

ppap's Introduction

Towards Practical Plug-and-Play Diffusion Models (CVPR 2023)

[Arxiv] [Open Access] [BibTex]

Official Pytorch Implementation of the paper "Towards Practical Plug-and-Play Diffusion Models". This repository contains the code for guidance with 1) Finetuned models on forward diffused data 2) Multi-Expert strategy 3) PPAP, which are used in the paper.

This repository is based on following repositories with some modifications:

Plan

  • Release code.
  • Make checkpoints available.
  • Make PPAP data available.

Requirements

For distributed training, MPICH should be installed with following commands.

apt install mpich
pip install git+https://github.com/openai/CLIP.git --no-deps

For installing required python packages, use this commands.

pip install -r requirements.txt 

Imagenet Class Guidance for ADM

A. Prepare pre-trained diffusion models.

For the pre-trained diffusion model, we use ADM which trained on imagenet 256x256 dataset. Checkpoint of this model is available at 256x256_diffusion_uncond.pt.

Download it and save on the path [diffusion_path].

B. Train

Our code supports training 1) finetuned model 2) multi-experts 3) PPAP. Here is commands for these.

  1. Finetune off-the-shelf models on forward diffused data.
    export PYTHONPATH=$PYTHONPATH:$(pwd)
    MODEL_FLAGS="--iterations 300000 --anneal_lr True --batch_size 256 --lr 1e-4 --weight_decay 0.05 --save_interval 10000"
    CLASSIFIER_FLAGS="--image_size 256 --classifier_name [classifier name: ResNet18, ResNet50, ResNet152, DEIT]"
    python python_scripts/classifier_train.py --log_path [directory for logging] --data_dir [ImageNet1k training dataset path] --method "finetune" $MODEL_FLAGS $CLASSIFIER_FLAGS --gpus 0
    
  2. Multi-Experts that are supervisely trained.
    export PYTHONPATH=$PYTHONPATH:$(pwd)
    MODEL_FLAGS="--iterations 300000 --anneal_lr True --batch_size 256 --lr 1e-4 --weight_decay 0.05 --save_interval 10000"
    CLASSIFIER_FLAGS="--image_size 256 --classifier_name [classifier name: ResNet18, ResNet50, ResNet152, DEIT]"
    python python_scripts/classifier_train.py --log_path [directory for logging] --data_dir [ImageNet1k training dataset path] $MODEL_FLAGS $CLASSIFIER_FLAGS --gpus 0 --n_experts [Number of experts] --method "multi_experts"
    
  3. PPAP.
    • For finetune off-the-shelf models with PPAP framework, we should generate synthetic images from unconditional diffusion models.
    • The following command will generate these data from ADM unconditional 256x256 diffusion model:
      SAMPLE_FLAGS="--batch_size 100 --num_samples 500000  --timestep_respacing ddim25 --use_ddim True"
      MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
      mpiexec -n [number of gpus] python python_scripts/generate_dataset.py --log_path [path for saving dataset] $MODEL_FLAGS $SAMPLE_FLAGS --gpus [Gpu ids] --model_path [diffusion_path]
      
    • Instead of this, you can download generated data from ADM unconditional 256x256 diffusion model from link
    • Then, this command will train PPAP with synthetic data.
      export PYTHONPATH=$PYTHONPATH:$(pwd)
      MODEL_FLAGS="--iterations 300000 --anneal_lr True --batch_size 256 --lr 1e-4 --weight_decay 0.05 --save_interval 10000"
      CLASSIFIER_FLAGS="--image_size 256 --classifier_name [classifier name: ResNet18, ResNet50, ResNet152, DEIT] --lora_alpha 8 --gamma 16"
      python python_scripts/classifier_train.py --log_path [directory for logging] --data_dir [Synthetic data path] $MODEL_FLAGS $CLASSIFIER_FLAGS  --gpus 0 --n_experts [Number of experts] --method "ppap"
      

B.1 Enabling DDP for training

If mpich is installed, distributed data parallel (DDP) can be enabled for training. For DDP with k gpus, --batch_size should be divided by k, mpiexec -n k should be specified in front of python execution command, and --gpu option should be set by gpu ids that will be used.

For example, above finetuning off-the-shelf models with DDP on 0, 1, 2, 3 gpus can be executed with following commands:

 export PYTHONPATH=$PYTHONPATH:$(pwd)
 MODEL_FLAGS="--iterations 300000 --anneal_lr True --batch_size 64 --lr 1e-4 --weight_decay 0.05 --save_interval 10000"
 CLASSIFIER_FLAGS="--image_size 256 --classifier_name [classifier name: ResNet18, ResNet50, ResNet152, DEIT]"
 mpiexec -n 4 python python_scripts/classifier_train.py --log_path [directory for logging] --data_dir [ImageNet1k training dataset path] --method "finetune" $MODEL_FLAGS $CLASSIFIER_FLAGS --gpus 0 1 2 3

C Trained checkpoint

Model Finetune Multi-experts-5 PPAP-5
ResNet50 Model experts [0, 200], [200, 400], [400, 600] [600, 800] [800, 1000] experts [0, 200], [200, 400], [400, 600] [600, 800] [800, 1000]
DeiT-S Model experts [0, 200], [200, 400], [400, 600] [600, 800] [800, 1000] experts [0, 200], [200, 400], [400, 600] [600, 800] [800, 1000]

D. Sampling with classifier guidance

Our code supports sampling with guidance from 1) finetuned model 2) multi-experts 3) PPAP.

  1. Finetune
    export PYTHONPATH=$PYTHONPATH:$(pwd)
    SAMPLE_FLAGS="--batch_size 100 --num_samples 10000  --timestep_respacing ddim25 --use_ddim True"
    MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
    MODEL_PATH_FLAGS="--model_path [diffusion_path] --classifier_path [ckpt_path]"
    python python_scripts/classifier_sample.py --log_path [sampling_path] $MODEL_FLAGS $SAMPLE_FLAGS $MODEL_PATH_FLAGS --method "finetune" --gpus 0
    
  2. Multi-experts
    export PYTHONPATH=$PYTHONPATH:$(pwd)
    SAMPLE_FLAGS="--batch_size 100 --num_samples 10000  --timestep_respacing ddim25 --use_ddim True"
    MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
    MODEL_PATH_FLAGS="--model_path [diffusion_path] --classifier_path [ckpt_path_0] [ckpt_path_1] ... [ckpt_path_N]"
    python python_scripts/classifier_sample.py --log_path [sampling_path] $MODEL_FLAGS $SAMPLE_FLAGS $MODEL_PATH_FLAGS --method "multi_experts" --gpus 0
    
  3. PPAP
    export PYTHONPATH=$PYTHONPATH:$(pwd)
    SAMPLE_FLAGS="--batch_size 100 --num_samples 10000  --timestep_respacing ddim25 --use_ddim True"
    MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
    MODEL_PATH_FLAGS="--model_path [diffusion_path] --classifier_path [ckpt_path_0] [ckpt_path_1] ... [ckpt_path_N]"
    python python_scripts/classifier_sample.py --log_path [sampling_path] $MODEL_FLAGS $SAMPLE_FLAGS $MODEL_PATH_FLAGS --method "ppap" --gpus 0
    

D.1 Sampling configuration.

  1. DDIM: To sample by DDIM with t steps, set --timestep_respacing as ddimt.
  2. DDPM: DDPM with t steps is enabled when --timestep_respacing is set as t.

D.2 DDP for sampling.

Because of slow sampling speed, we recommend to use DDP for sampling. For using DDP with k gpus, please add command mpiexec -n k in front of python execution command, and set --gpu option to gpu ids that will be used.

E. Evaluation

Check evaluations/Readme.md.

PPAP with various models for DeepFloyd-IF.

Baseimage Depth map PPAP Depth Guided Image PPAP Depth Guided + "Dog" prompt Image
depth.png uncond_depth_guidance.png text_to_img_depth_guidance_dog.png

We provide the codes for depth guidance with Midas for DeepFloyd-IF. Deepfloyd-IF is similar to GLIDE model which is used in our paper, but can create 1024x1024 higher quality images than GLIDE. From this reason, we change the target diffusion model as DeepFloyd-IF in released code for offering high quality images.

A. Prepare pre-trained model weight of DeepFloyd-IF.

First step is preparing pretrained checkpoint of DeepFloyd-IF. Please refer the repository of DeepFloyd-IF (Link) and get the access token of hugging face.

Then, set hf_token argument of following python command as your access token.

B. Generate unconditional image dataset for PPAP.

For finetune off-the-shelf models with PPAP framework, we should generate synthetic images from unconditional diffusion models. The following command will generate these data from deepfloyd-IF:

export PYTHONPATH=$PYTHONPATH:$(pwd)
mpiexec -n [number_of_gpus] python python_scripts/generate_dataset_deepfloyd.py --stage 2 --num_samples 500000 --gpus ["gpu_ids"] --log_path ["Directory for saving the dataset"] --batch_size [batch_size] --hf_token ["your token"]

C. PPAP-finetune Midas

Following command will finetune Midas as the guidance model with PPAP framework.

export PYTHONPATH=$PYTHONPATH:$(pwd)
mpiexec -n [number_of_gpus] python python_scripts/deepfloyd_guidance_ppap.py --iterations 300000 --batch_size 64 --gpus ["gpu_ids"] --log_path ["path for logging directory"]

D. Trained checkpoints

experts [0,200] [200,400] [400,600] [600,800] [800,1000]

E. Generating samples

Please refer deepfloyd_guidance_ppap.ipynb, which contains examples for depth guidance with PPAP.

F. Used dataset in DeepFloyd-IF PPAP

The generated dataset produced in B. Generate unconditional image dataset for PPAP. can be download in link.

BibTex

@inproceedings{go2023towards,
  title={Towards Practical Plug-and-Play Diffusion Models},
  author={Go, Hyojun and Lee, Yunsung and Kim, Jin-Young and Lee, Seunghyun and Jeong, Myeongho and Lee, Hyun Seung and Choi, Seungtaek},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={1962--1971},
  year={2023}
}

ppap's People

Contributors

gohyojun15 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ppap's Issues

Correct checkpoints for muti-expert weights

Thanks so much for your work and codes!

However I have found out that the checkpoints of resnet50 for Muti-Expert methods may be incorrect, and it is the LoRA fine tuned PPAP method checkpoints. Could you kindly share the correct resnet50 checkpoints trained supervised by Muti-Expert methods?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.