Giter VIP home page Giter VIP logo

maskdit's Introduction

Fast Training of Diffusion Models with Masked Transformers

Official implementation of the paper Fast Training of Diffusion Models with Masked Transformers

Abstract: We propose an efficient approach to train large diffusion models with masked transformers. While masked transformers have been extensively explored for representation learning, their application to generative learning is less explored in the vision domain. Our work is the first to exploit masked training to reduce the training cost of diffusion models significantly. Specifically, we randomly mask out a high proportion (e.g., 50%) of patches in diffused input images during training. For masked training, we introduce an asymmetric encoder-decoder architecture consisting of a transformer encoder that operates only on unmasked patches and a lightweight transformer decoder on full patches. To promote a long-range understanding of full patches, we add an auxiliary task of reconstructing masked patches to the denoising score matching objective that learns the score of unmasked patches. Experiments on ImageNet-256x256 show that our approach achieves the same performance as the state-of-the-art Diffusion Transformer (DiT) model, using only 31% of its original training time. Thus, our method allows for efficient training of diffusion models without sacrificing the generative performance.

Architecture

Training efficiency

Our MaskDiT applies Automatic Mixed Precision (AMP) by default. We also add the MaskDiT without AMP (Ours_ft32) for reference.

Requirements

  • We recommend training maskDiT on 8 A100 GPUs, which takes around 260 hours to perform 2M updates with a batch size of 1024.
  • At least one high-end GPU for sampling.
  • Dockerfile is provided for exact software environment.

Prepare dataset

We use the pre-trained VAE to first encode the ImageNet dataset into latent space. You can download the pre-trained VAE by using download_assets.py.

python3 download_assets.py --name vae --dest assets

You can also directly download the dataset we have prepared by running

python3 download_assets.py --name imagenet-latent-data --dest [destination directory]

Train

We first train MaskDiT with 50% mask ratio with AMP enabled.

python3 train_latent.py --config configs/train/maskdit-latent-imagenet.yaml --num_process_per_node 8

We then finetune with unmasking. For example,

python3 train_latent.py --config configs/finetune/maskdit-latent-imagenet-const.yaml --ckpt_path [path to checkpoint] --use_ckpt_path False --use_strict_load False --no_amp
Train on the original ImageNet. Click to expand.

We also provide code for training MaskDiT without pre-encoded dataset in train.py. This is only for reference. We did not fully test it. After preparing the original ImageNet dataset, run

python3 train.py --config configs/train/maskdit-imagenet.yaml --num_process_per_node 8

Generate samples

To generate samples from provided checkpoints, for example, run

python3 generate.py --config configs/train/maskdit-latent-imagenet.yaml --ckpt_path results/2075000.pt --class_idx 388 --cfg_scale 2.5

Checkpoints of MaskDiT can be downloaded by running download_assets.py. For example,

python3 download_assets.py --name maskdit-finetune0 --dest results

We provide the following checkpoints.

Generated samples from MaskDiT. Upper panel: without CFG. Lower panel: with CFG (scale=1.5).

Evaluation

First, download the reference from ADM repo directly. You can also use download_assets.py by running

python3 download_assets.py --name imagenet256 --dest [destination directory]

Then we use the evaluator evaluator.py from ADM repo, or fid.py from EDM repo, to evaluate the generated samples.

Generative performance on ImageNet-256x256. The area of each bubble indicates the FLOPs for a single forward pass during training.

Acknowledgements

Thanks to the open source codebases such as DiT, MAE, U-ViT, ADM, and EDM. Our codebase is built on them.

maskdit's People

Contributors

devzhk avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.