Giter VIP home page Giter VIP logo

poisson's Introduction

Poisson Flow Generative Models

Pytorch implementation of the NeurIPS 2022 paper Poisson Flow Generative Models,

by Yilun Xu*, Ziming Liu*, Max Tegmark, Tommi S. Jaakkola

We propose a new Poisson flow generative model (PFGM) that maps a uniform distribution on a high-dimensional hemisphere into any data distribution. We interpret the data points as electrical charges on the $z=0$ hyperplane in a space augmented with an additional dimension $z$, generating a high-dimensional electric field (the gradient of the solution to Poisson equation). We prove that if these charges flow upward along electric field lines, their initial distribution in the $z=0$ plane transforms into a distribution on the hemisphere of radius $r$ that becomes uniform in the $r \to\infty$ limit. To learn the bijective transformation, we estimate the normalized field {in the augmented space}. For sampling, we devise a backward ODE that is anchored by the physically meaningful additional dimension: the samples hit the (unaugmented) data manifold when the $z$ reaches zero.

schematic

Experimentally, PFGM achieves current state-of-the-art performance among the normalizing flow models on CIFAR-10, with an Inception score of 9.68 and a FID score of 2.35. It also performs on par with the state-of-the-art SDE approaches (e.g., score-based SDEs or Diffusion models) while offering 10x to 20x acceleration on image generation tasks. Additionally, PFGM appears more tolerant of estimation errors on a weaker network architecture and robust to the step size in the Euler method, and capable of scale-up to higher resolution datasets.


Acknowledgement: Our implementation relies on the repo https://github.com/yang-song/score_sde_pytorch.

Dependencies

We provide two solutions to install a subset of necessary python packages for our code. Please find the best fit for you.

  1. The old dependency in repo https://github.com/yang-song/score_sde_pytorch
pip install -r requirements_old.txt
  1. Our dependency (Python 3.9.12, CUDA Version 11.6)
pip install -r requirements.txt

Usage

Train and evaluate our models through main.py.

python3 main.py:
  --config: Training configuration.
  --eval_folder: The folder name for storing evaluation results
    (default: 'eval')
  --mode: <train|eval>: Running mode: train or eval
  --workdir: Working directory

For example, to train a new PFGM w/ DDPM++ model on CIFAR-10 dataset, one could execute

python3 main.py --config ./configs/poisson/cifar10_ddpmpp.py --mode train \
--workdir poisson_ddpmpp
  • config is the path to the config file. The prescribed config files are provided in configs/. They are formatted according to ml_collections and should be quite self-explanatory.

    Naming conventions of config files: the path of a config file is a combination of the following dimensions:

    • Method: ๐ŸŒŸPFGM: poisson; Score-based models : ve, vp, sub_vp
    • dataset: One of cifar10, celeba64, celebahq_256, ffhq_256, celebahq, ffhq.
    • model: One of ncsnpp, ddpmpp.
    • continuous: train the model with continuously sampled time steps (only for score-based models).

    ๐ŸŒŸImportant Note 1 : We use a large batch (e.g. current training.batch_size=4096 for CIFAR-10, ~25G GPU memory usage) to calculate the Poisson field for each mini-batch samples (e.g. training.small_batch_size=128 for CIFAR-10). To adjust GPU memory cost, please modify the training.batch_size parameter in the config files.

    ๐ŸŒŸImportant Note 2 : If rk45 solver exibits unstability for your dataset/neural network, please try to use the forward Euler method or Improved Euler method by modifying the config.sampling.ode_solver parameter to forward_euler or improved_euler.

    Please set some key hyper-parameters for specific dataset by running

    python3 hyper-parameters.py 
    	--data_norm: Average data norm of the dataset 
    	--data_dim: Data dimension

    We also list a few other useful tips in Tips section.

  • workdir is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results.

  • eval_folder is the name of a subfolder in workdir that stores all artifacts of the evaluation process, like meta checkpoints for pre-emption prevention, image samples, and numpy dumps of quantitative results.

  • mode is either "train" or "eval". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in workdir/checkpoints-meta .

  • Below are the list of evalutation command-line flags:

    --config.eval.enable_sampling: Generate samples and evaluate sample quality, measured by FID and Inception score.

    --config.eval.enable_bpd : Compute log-likelihoods

    --config.eval.dataset=train/test : Indicate whether to compute the likelihoods on the training or test dataset.

    --config.eval.enable_interpolate : Image Interpolation

    --config.eval.enable_rescale : Temperature scaling

Tips

  • ๐ŸŒŸImportant : We use a large batch (e.g. current training.batch_size=4096 for CIFAR-10, ~25G GPU memory usage) to calculate the Poisson field for each mini-batch samples (e.g. training.small_batch_size=128 for CIFAR-10). To adjust GPU memory cost, please modify the training.batch_size parameter in the config files.

  • ๐ŸŒŸ How to set the hyper-parameters : The prior distribution on the $z=z_{max}$ hyperplane is a long-tail distribution. We recommend clipping the sample norm by the hyper-parameters sampling.upper_norm. Please refer to Appendix B.1.1 and Appendix B.2.1 in the paper (https://arxiv.org/abs/2209.11178) for our recommended setups for hyper-parameters training.M, sampling.z_max and sampling.upper_norm for general datasets.

    We provide a script for easily calculating those hyper-parameters:

    python3 hyper-parameters.py 
    	--data_norm: Average data norm of the dataset 
    	--data_dim: Data dimension
    
  • If rk45 solver exibits unstability for your dataset/neural network, please try to use the forward Euler method or Improved Euler method by modifying the config.sampling.ode_solver parameter to forward_euler or improved_euler.

  • TODO

Checkpoints

Please place the pretrained checkpoints under the directory workdir/checkpoints, e.g., cifar10_ddpmpp/checkpoints.

To generate and evaluate the FID/IS of (10k) samples of the PFGM w/ DDPM++ model, you could execute:

python3 main.py --config ./configs/poisson/cifar10_ddpmpp.py --mode eval \ 
--workdir cifar10_ddpmpp --config.eval.enable_sampling --config.eval.num_samples 10000

To only generate and visualize 100 samples of the PFGM w/ DDPM++ model, you could execute:

python3 main.py --config ./configs/poisson/cifar10_ddpmpp.py --mode eval \ 
--workdir cifar10_ddpmpp --config.eval.enable_sampling --config.eval.save_images --config.eval.batch_size 100

The samples will be saved to cifar10_ddpmpp/eval/ode_images_{ckpt}.png.

All checkpoints are provided in this Google drive folder.

Dataset Checkpoint path Invertible? IS FID NFE (RK45)
CIFAR-10 poisson/cifar10_ddpmpp/ โœ”๏ธ 9.65 2.48 ~104
CIFAR-10 poisson/cifar10_ddpmpp_deep/ โœ”๏ธ 9.68 2.35 ~110
LSUN bedroom $256^2$ poisson/bedroom_ddpmpp/ โœ”๏ธ - 13.66 ~122
CelebA $64^2$ poisson/celeba_ddpmpp/ โœ”๏ธ - 3.68 ~110

FID statistics

Please find the statistics for FID scores in the following links:

CIFAR-10, CelebA 64, LSUN bedroom 256

poisson's People

Contributors

newbeeer avatar scarypg avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.