Giter VIP home page Giter VIP logo

bayesdiff's Introduction

BayesDiff: Estimating Pixel-wise Uncertainty in Diffusion via Bayesian Inference

This repository is our codebase for BayesDiff.

Installation

conda create --name BayesDiff python==3.8
conda activate BayesDiff
conda install pip
git clone https://github.com/karrykkk/BayesDiff.git
cd BayesDiff
pip install -r requirements.txt

Framework

This repository integrates uncertainty quantification into three models, each in its own folder:

  1. ddpm_and_guided - Guided Diffusion Repository Link
  2. sd - Stable Diffusion Repository Link
  3. uvit - U-ViT Repository Link

Each folder contains a custom_model.py that emerged with uncertainty quantification techniques.

Usage

1. Guided Diffusion

cd ddpm_and_guided

Download pre-trained model checkpoint

cd configs
vim imagenet128_guided.yml

Download data to fit last-layer Laplace (LLLA)

  • Please download Imagenet to your_local_image_path.
  • Change the self.image_path attribute of class imagenet_dataset in la_train_datasets.py to your_local_image_path.
vim la_train_datasets.py

Sample and estimate corresponding pixel-wise uncertainty

In the file dpm.sh, you will find a template for usage for UQ-itegrated dpm-solver-2 sampler. By running this bash script, you can get the sorted_sample.png based on the image-wise uncertainty metric.

bash dpm.sh

For other samplers, just change dpm.sh to ddpm.sh or ddim.sh.

2. Stable Diffusion

cd sd

Download pre-trained model checkpoint

Download Stable Diffusion v1.5 to your_local_model_path

Download data to fit last-layer Laplace (LLLA)

Please download subset of laion-art to your_local_image_path. These images is a subset from the LAION-Art dataset, store it in your_laion_art_path. This will allow you to retrieve the corresponding prompts for the downloaded images. Note that a subset of approximately 1000 images is sufficient for effectively fitting the LLLA.

Sample and estimate corresponding pixel-wise uncertainty

In the file sd.sh, you will find a template for usage. Please adjust this template to match your local file path and the specific prompt you intend to use.

bash sd.sh

3. U-ViT

cd uvit

Download pre-trained model checkpoint

  • Download Autoencoder's ckpt from this link which contains image autoencoders converted from Stable Diffusion to your_local_encoder_path. Download ImageNet 256x256 (U-ViT-H/2) to your_local_uvit_path.

Download data to fit last-layer Laplace (LLLA)

  • Please download Imagenet to your_local_image_path.
  • Change the self.image_path attribute of class imagenet_feature_dataset in la_train_datasets.py to your_local_image_path.
vim la_train_datasets.py

Sample and estimate corresponding pixel-wise uncertainty

In the file dpm.sh, you will find a template for usage. Please adjust this template to match your local file path.

bash dpm.sh

Acknowledgements

This codebase is based on remarkable projects from the community, including DPM-Solver, U-ViT, Stable Diffusion.

Citation

If you find out work useful, please cite our paper at:

@inproceedings{kou2023bayesdiff,
  title={BayesDiff: Estimating Pixel-wise Uncertainty in Diffusion via Bayesian Inference},
  author={Kou, Siqi and Gan, Lei and Wang, Dequan and Li, Chongxuan and Deng, Zhijie},
  booktitle={The Twelfth International Conference on Learning Representations},
  year={2023}
}

bayesdiff's People

Contributors

karrykkk avatar callione avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.