Giter VIP home page Giter VIP logo

phased-consistency-model's Introduction

⚡️Phased Consistency Model⚡️

[Paper] [Project Page ✨] [Pre-trained Models in 🤗Hugging Face] [Demo] [Civitai] visitors

by Fu-Yun Wang1, Zhaoyang Huang2, Alexander William Bergman3,6, Dazhong Shen4,  Peng Gao4, Michael Lingelbach3,6, Keqiang Sun1, Weikang Bian1 Guanglu Song5, Yu Liu4, Hongsheng Li1, Xiaogang Wang1

1CUHK-MMLab 2Avolution AI 3Hedra 4Shanghai AI Lab 5SenseTime 6Stanford University

teaser
@article{wang2024phased,
  title={Phased Consistency Model},
  author={Wang, Fu-Yun and Huang, Zhaoyang and Bergman, Alexander William and Shen, Dazhong and Gao, Peng and Lingelbach, Michael and Sun, Keqiang and Bian, Weikang and Song, Guanglu and Liu, Yu and others},
  journal={arXiv preprint arXiv:2405.18407},
  year={2024}
}

News

  • [2024.07.27]: Release Training Scripts of PCM-LoRA with Stable Diffusion XL.
  • [2024.07.14]: FIX inference bug caused by the default parameters of DDIM.
  • [2024.06.19]: Release the training script of PCM-LoRA with Stable Diffusion 3. See text_to_image_sd3. Release the weights of PCM-LORA with Stable Diffusion 3. See PCM_Weights.
PCM-SD3-2step-Deterministic PCM-SD3-4step-Deterministic PCM-SD3-Stochastic (treat it as a clearer LCM)
Image 1 Image 2 Image 3
  • [2024.06.04]: Hugging Face Demo is available. Thanks @radames for the commit!
  • [2024.06.01]: Release PCM-LoRA weights of Stable Diffusion v1.5 and Stable Diffusion XL on huggingface.
  • [2024.06.01]: Release Training Script of PCM-LoRA with Stable Diffusion v1.5. See tran_pcm_lora_sd15.sh.

    We train the weights with 8 A 800. But my tentative experimental results suggest that using just one GPU can still achieve good results.

    Happy Children's Day! Never too old to celebrate the joys of childhood!

  • [2024.05.30]: Technical report is available on arXiv.
One-Step Generation Comparison by HyperSD One-Step Generation Comparison by PCM
hypersd ours
hypersd ours

Our model has clearly better generation diversity than the cocurrent work HyperSD.

Introduction

Phased Consistency Model (PCM) is (probably) current one of the most powerful sampling acceleration strategy for fast text-conditioned image generation in large diffusion models.

Consistency Model (CM), proposed by Yang Song et al, is a promising new famility of generative models that can generate high-fidelity images with very few steps (generally 2 steps) under the unconditional and class-conditional settings. Previous work, latent-consistency model (LCM), tried to replicate the power of consistency models for text-conditioned generation, but generally failed to achieve pleasant results, especially in low-step regime (1~4 steps). Instead, we believe PCM is a much more successful extension to the original consistency models for high-resolution, text-conditioned image generation, better replicating the power of original consistency models for more advanced generation settings.

Generally, we show there are mainly three limitations of (L)CMs:

  • LCM lacks flexibility for CFG choosing and is insensitive to negative prompts.
  • LCM fails to produce consistent results under different inference steps. Its results are blurry when step is too large (Stochastic sampling error) or small (inability).
  • LCM produces bad and blurry results at low-step regime.

These limitaions can be explicitly viewed from the following figure.

We generalize the design space of consistency models for high-resolution text-conditioned image generation, analyzing and tackling the limitations in the previous work LCM.

teaser

PF-ODE

Diffusion model, from a continuous time perspective, actually defines a forward conditional probability path, with a general representation of $\alpha_t \mathbf x_0 + \sigma_t \boldsymbol \epsilon \sim \mathcal N(\alpha_t\mathbf x_0, \sigma_{t}^2\mathbf I)$ for intermediate distribution $\mathbb P_{t}(\mathbf x | \mathbf x_0)$ conditioned on $\mathbf x_0$, which is equivalent to the stochastic differential equation $\mathrm d\mathbf x_{t} = f_{t} \mathbf x_{t} \mathrm d t + g_{t} \mathrm d \boldsymbol w_{t}$ with $w_{t}$ denoting the standard Winer process.

For the forward SDE, a remarkable property is that there exists a reverse time ODE trajectory, which is termed as PF ODE by song et al, which does not introduce additional stochasticity and still satisfy the pre-defined marginal distribution, that is

$\mathrm d \mathbf x = (f_t - \frac{1}{2} g_{t}^2 \nabla_{\mathbf x} \log \mathbb P_{t}(\mathbf x)) \mathrm d t$,

where $\mathbb P_{t}(\mathbf x)= \mathbb E\left[\mathbb P_{t}(\mathbf x|\mathbf x_{0})| \mathbf x_{0}\right]$. The diffusion training process inherently trains a score estimator with deep neural network ($\boldsymbol s_{\theta}$).

Generally say, there are just infinite probable paths for reversing the SDE. However, the ODE trajectory, without any stochasticity, is basically more stable for sampling. Most schedulers, including DDIM, DPM-solver, Euler, and Heun, etc., applied in the stable diffusion community are generally based on the principle of better approximating the ODE trajectory. Most distillation-based methods including rectified-flow, guided distillation, can also generally be seen as better approximating the ODE trajectory with larger steps (though most of them did not discuss the relevant parts).

Consistency models aims directly learn the solution point of the ODE trajectory either through distillation or training.

In PCMs, we focus our work on the distillation, which is generally easier for learning. For training, we leave it for futural research.

Learning Paradigm Comparison

Consistency Trajectory Model (CTM) points out that CM suffer from the stochasticity error accumulation when applied for multistep sampling for better sample quality and propose a more general framework, allowing for arbitrary pair moving along the ODE trajectory. Yet, it requires an additional target timesteps embedding, which is not aligned with design space of traditional diffusion models. Additionally, CTM is basically harder to train. Say we discretizing the ODE trajectory into $N$ points, the learning objective of diffusion models and consistency models are all $\mathcal O( N)$. Yet the number of learning objectives of CTM is $\mathcal O(N^2)$. Our proposed PCM, also solving the stochasticity error accumulation, but is much easier for training.

The core idea of our method is phasing the whole ODE trajectory into multiple sub-trajectories. The following figure illustrates the learning paradigm difference among diffusion models (DMs), consistency models (CMs), consistency trajectory models (CTMs), and our proposed phased consistency models (PCMs).

teaser

For a better comparison, we also implement a baseline, which we termed as simpleCTM. We adapt the high-level idea of CTM from the k-diffusion framework into the DDPM framework with stable diffusion, and compare its performance. When trained with the same resource, our method achieves significant superior performance.

Samples of PCM

PCM can achieve text-conditioned image synthesis with good quality in 1, 2, 4, 8, 16 steps.

teaser

Comparison

PCM achieves advanced generation results compared with current open-available powerful fast generation models, including the GAN-based methods: SDXL-Turbo, SD-Turbo, SDXL-Lightning; rectified-flow-based method: InstaFlow; CM-based methods: LCM, SimpleCTM.

comparison

Concact & Collaboration

If you have any questions about the code, please do not hesitate to contact me!

Email: [email protected]

phased-consistency-model's People

Contributors

g-u-n avatar radames avatar smilesdzgk avatar tripathiarpan20 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

phased-consistency-model's Issues

[Inference Issue] ValueError when trying to load LoRA weights with diffusers

Hey!

Congrats on you work, and thanks a lot of sharing it 🤗
When trying to use the sd1.5 and sdxl checkpoints on the hub for inference with diffusers, I got this following error when calling load_lora_weights:

from diffusers import AutoPipelineForText2Image

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
adapter_id = "wangfuyun/PCM_SDXL_LoRAs"

pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
pipe.load_lora_weights(adapter_id, weight_name="pcm_sdxl_normalcfg_16step.safetensors")

ValueError: Target modules {'base_model.model.up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0', 'base_model.model.up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0', 'base_model.model.down_blocks.0.attentions.1.proj_in', 'base_model.model.up_blocks.1.attentions.1.proj_in', 'base_model.model.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0', 'base_model.model.up_blocks.3.resnets.0.conv_shortcut', 'base_model.model.down_blocks.3.resnets.0.conv1', 'base_model.model.down_blocks.3.resnets.0.time_emb_proj', 'base_model.model.up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0', 'base_model.model.up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj', 'base_model.model.down_blocks.3.resnets.0.conv2', '
....
, 'base_model.model.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v', 'base_model.model.up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q', 'base_model.model.up_blocks.2.attentions.1.proj_out', 'base_model.model.up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v', 'base_model.model.up_blocks.3.attentions.0.proj_out', 'base_model.model.up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_v', , 'base_model.model.down_blocks.0.resnets.1.time_emb_proj', 'base_model.model.down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v'} not found in the base model. Please check the target modules and try again.

PCM-LORA

Incredible results over LCM guys, well done! I'm curious to know if you plan to also add a PCM-LORA, so we can easily integrate your approach with existing models, instead of having to perform training.

Also, is there any ETA?

About num_h_per_head=4

Hi,

Your work is truly impressive and inspiring!

Could you please explain why num_h_per_head=4 is used, allocating four DiscriminatorHeads per feature? Understanding the reasoning behind this would be very helpful.

Thanks!

PCM accelerates LORA's impact on the diversity of SDXL generative art?

I noticed that there are two kinds of accelerated LORAs, normal CFG and Small CFG, and setting CFG to 1-2 when using Small CFG Accelerated LORA has almost no effect on the artistic diversity of SDXL, which is great, how is this done? In this case, does the acceleration LORA of ordinary CFG still need to exist?

sd3 pcm problem

I tried to use pcm for sd3, but found that the value of d_loss was basically always 2, and the inference errors occurred after the saved lora was loaded. There was no problem when using the model verification during training.

I tried not to use lora, that is, to train the entire transformer part, and found that the loss was Nan. Can you give me some suggestions?

Training code for SDXL

Hi,

I'm impressed by your amazing work !!

Do you plan to open the training code for SDXL?

It would be helpful for the open-source community.

PCM One Step Inference Question

Thanks for sharing your work! I have a question regarding PCM inference: does PCM one-step inference require evaluating all $M$ consistency models that the PCM model was trained with? That is, after sampling initial noise $\hat{\boldsymbol{x}}_T$, do we run

$$\boldsymbol{x} \gets f_\theta^{M - 1, 0}(\hat{\boldsymbol{x}}_T, T) = f_\theta^0(\cdots f_\theta^{M - 2}(f_\theta^{M - 1}(\hat{\boldsymbol{x}}_T, T), s_{M - 1}) \cdots, s_1)$$

or can we go from $T$ to $0$ in one application of $F_\theta(\boldsymbol{x}, t, s)$ like a normal consistency model? (For example, it's not obvious to me that something like $\boldsymbol{x} \gets F_\theta(\hat{\boldsymbol{x}}_T, T, 0)$ should work.) I read through the paper and could not figure it out (apologies if I missed the explanation).

Put another way, in e.g. code/text_to_image_sd15/train_pcm_lora_sd15_adv.py's log_validation function with args.multiphase == num_inference_step == 8, when we do

images = pipeline(
prompt=prompt,
num_inference_steps=num_inference_step,
num_images_per_prompt=4,
generator=generator,
guidance_scale=cfg,
).images

is this one-step inference or 8-step inference?

Format of the dataset?

Hi, what an awesome work. Can you please share the dataset you are using to distillate the SD1.5? Or provide some sample data that we can try? Thank you so much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.