Giter VIP home page Giter VIP logo

rq-vae-transformer's Introduction

Autoregressive Image Generation using Residual Quantization (CVPR 2022)

The official implementation of "Autoregressive Image Generation using Residual Quantization"
Doyup Lee*, Chiheon Kim*, Saehoon Kim, Minsu Cho, Wook-Shin Han (* Equal contribution)
CVPR 2022

The examples of generated images by RQ-Transformer using class conditions and text conditions.
Note that the text conditions of the examples are not used in training time.

TL;DR For autoregressive (AR) modeling of high-resolution images, we propose the two-stage framework, which consists of RQ-VAE and RQ-Transformer. Our framework can precisely approximate a feature map of an image and represent an image as a stack of discrete codes to effectively generate high-quality images.

Requirements

We have tested our codes on the environment below

  • Python 3.7.10 / Pytorch 1.9.0 / torchvision 0.10.0 / CUDA 11.1 / Ubuntu 18.04 .

Please run the following command to install the necessary dependencies

pip install -r requirements.txt

Coverage of Released Codes

  • Implementation of RQ-VAE and RQ-Transformer
  • Pretrained checkpoints of RQ-VAEs and RQ-Transformers
  • Training and evaluation pipelines of RQ-VAE
  • Image generation and its evaluation pipeline of RQ-Transformer
  • Jupyter notebook for text-to-image generation of RQ-Transformer

Pretrained Checkpoints

Checkpoints Used in the Original Paper

We provide pretrained checkpoints of RQ-VAEs and RQ-Transformers to reproduce the results in the paper. Please use the links below to download tar.gz files and unzip the pretrained checkpoints. Each link contains pretrained checkpoints of RQ-VAE and RQ-Transformer and their model configurations.

Dataset RQ-VAE & RQ-Transformer # params of RQ-Transformer FID
FFHQ link 355M 10.38
LSUN-Church link 370M 7.45
LSUN-Cat link 612M 8.64
LSUN-Bedroom link 612M 3.04
ImageNet (cIN) link 480M 15.72
ImageNet (cIN) link 821M 13.11
ImageNet (cIN) link 1.4B 11.56 (4.45)
ImageNet (cIN) link 1.4B 8.71 (3.89)
ImageNet (cIN) link 3.8B 7.55 (3.80)
CC-3M link 654M 12.33

FID scores above are measured based on original samples and generated images, and the scores in brackets are measured using 5% rejection sampling via pretrained ResNet-101. We do not provide the pipeline of rejection sampling in this repository.

(NOTE) Large-Scale RQ-Transformer for Text-to-Image Generation

We also provide the pretrained checkpoint of large-scale RQ-Transformer for text-to-image (T2I) generation. Our paper does not include the results of this large-scale RQ-Transformer for T2I generation, since we trained RQ-Transformer with 3.9B parameters on about 30 millions of text-to-image pairs from CC-3M, CC-12M, and YFCC-subset after the paper submission. Please use the link below to download the checkpoints of large-scale T2I model. We emphasize that any commercial use of our checkpoints is strictly prohibited.

Download of Pretrained RQ-Transformer on 30M text-image pairs

Dataset. RQ-VAE & RQ-Transformer # params
CC-3M + CC-12M + YFCC-subset link 3.9B

Evaluation of Large-Scale RQ-Transformer on MS-COCO

In this repository, we evaluate the pretrained RQ-Transformer with 3.9B parameters on MS-COCO. According to the evaluation protocol of DALL-Eval, we randomly select 30K text captions in val2014 split of MS-COCO and generate 256x256 images using the selected captions. We use (1024, 0.95) for top-(k, p) sampling, and FID scores of other models are from Table 2 in DALL-Eval paper.

Model # params # data Image / Grid Size FID on 2014val
X-LXMERT 228M 180K 256x256 / 8x8 37.4
DALL-E small 120M 15M 256x256 / 16x16 45.8
ruDALL-E-XL 1.3B 120M 256x256 / 32x32 18.6
minDALL-E 1.3B 15M 256x256 / 16x16 24.6
RQ-Transformer (ours) 3.9B 30M 256x256 / 8x8x4 16.9

Note that some text captions in MS-COCO are also included in the YFCC-subset, but the FIDs are not much different whether the duplicated captions are removed in the evaluation or not. See this paper for more details.

Examples of Text-to-Image (T2I) Generation using RQ-Transformer

We provide a jupyter notebook for you to easily enjoy text-to-image (T2I) generation of pretrained RQ-Transformers and the results ! After you download the pretrained checkpoints for T2I generation, open notebooks/T2I_sampling.ipynb and follows the instructions in the notebook file. We recommend to use a GPU such as NVIDIA V100 or A100, which has enough memory size over 32GB, considering the model size.

We attach some examples of T2I generation from the provided Jupyter notebook.

Examples of Generated Images from Text Conditions

a painting by Vincent Van Gogh
a painting by RENÉ MAGRITTE
Eiffel tower on a desert.
Eiffel tower on a mountain.
a painting of a cat with sunglasses in the frame.
a painting of a dog with sunglasses in the frame.

Training and Evaluation of RQ-VAE

Training of RQ-VAEs

Our implementation uses DistributedDataParallel in Pytorch for efficient training with multi-node and multi-GPU environments. Four NVIDIA A100 GPUs are used to train all RQ-VAEs in our paper. You can also adjust -nr, -np, and -nr according to your GPU setting.

  • Training 8x8x4 RQ-VAE on ImageNet 256x256 with a single node having four GPUs

    python -m torch.distributed.launch \
        --master_addr=$MASTER_ADDR \
        --master_port=$PORT \
        --nnodes=1 --nproc_per_node=4 --node_rank=0 \ 
        main_stage1.py \
        -m=configs/imagenet256/stage1/in256-rqvae-8x8x4.yaml -r=$SAVE_DIR
  • If you want to train 8x8x4 RQ-VAE on ImageNet using four nodes, where each node has one GPU, run the following scripts at each node with $RANK being the node rank (0, 1, 2, 3). Here, we assume that the master node corresponds to the node with rank 0.

    python -m torch.distributed.launch \
        --master_addr=$MASTER_ADDR \
        --master_port=$PORT \
        --nnodes=4 --nproc_per_node=1 --node_rank=$RANK \ 
        main_stage1.py \
        -m=configs/imagenet256/stage1/in256-rqvae-8x8x4.yaml -r=$SAVE_DIR

Finetuning of Pretrained RQ-VAE

  • To finetune a pretrained RQ-VAE on other datasets such as LSUNs, you have to load the pretrained checkpoints giving -l=$RQVAE_CKPT argument.
  • For example, when a pretrained RQ-VAE is finetuned on LSUN-Church, you can run the command below:
    python -m torch.distributed.launch \
        --master_addr=$MASTER_ADDR \
        --master_port=$PORT \
        --nnodes=1 --nproc_per_node=4 --node_rank=0 \ 
        main_stage1.py \
        -m=configs/lsun-church/stage1/church256-rqvae-8x8x4.yaml -r=$SAVE_DIR -l=$RQVAE_CKPT 

Evaluation of RQ-VAEs

Run compute_rfid.py to evaluate the reconstruction FID (rFID) of learned RQ-VAEs.

python compute_rfid.py --split=val --vqvae=$RQVAE_CKPT
  • The model checkpoint of RQ-VAE and its configuration yaml file have to be located in the same directory.
  • compute_rfid.py evaluates rFID of RQ-VAE on the dataset in the configuration file.
  • Adjust --batch-size as the memory size of your GPU environment.

Evaluation of RQ-Transformer

In this repository, the quantitative results in the paper can be reproduced by the codes for the evaluation of RQ-Transformer. Before the evaluation of RQ-Transformer on a dataset, the dataset has to be prepared for computing the feature vectors of its samples. To reproduce the results in the paper, we provide the statistics of feature vectors of each dataset, since extracting feature vectors accompanies computational costs and a long time. You can also prepare the datasets, which are used in our paper, as you follow the instructions of data/READMD.md.

  • Download the feature statistics of datasets as follows:
    cd assets
    wget https://twg.kakaocdn.net/brainrepo/etc/RQVAE/8b325b628f49bf60a3094fcf9419398c/fid_stats.tar.gz
    tar -zxvf fid_stats.tar.gz

FFHQ, LSUN-{Church, Bedroom, Cat}, (conditional) ImageNet

  • After the pretrained RQ-Transformer generates 50K images, FID (and IS) between the generated images and its training samples is computed.
  • You can input --save-dir to specify directory where the generated images are saved. If --save-dir is not given, the generated images are saved at the directory of the checkpoint.
  • When four GPUs in a single node are used, run the command below
    python -m torch.distributed.launch \
      --master_addr=$MASTER_ADDR \
      --master_port=$PORT \
      --nnodes=1 --nproc_per_node=4 --node_rank=0 \ 
      main_sampling_fid.py \
      -v=$RQVAE_CKPT -a=$RQTRANSFORMER_CKPT --save-dir=$SAVE_IMG_DIR

CC-3M

  • After the pretrained RQ-Transformer generates images using text captions of CC-3M validation set, FID between the validation images and generated images is computed together with CLIP score of generated images and their text conditions.
  • Evaluation of RQ-Transformer requires text prompts of cc-3m. Thus, please refer to data/READMD.md and prepare the dataset first.
  • When four GPUs in a single node are used, run the command below
    python -m torch.distributed.launch \
      --master_addr=$MASTER_ADDR \
      --master_port=$PORT \
      --nnodes=1 --nproc_per_node=4 --node_rank=0 \ 
      main_sampling_txt2img.py \
      -v=$RQVAE_CKPT -a=$RQTRANSFORMER_CKPT --dataset="cc3m" --save-dir=$SAVE_IMG_DIR

MS-COCO

  • We follow the protopocal of DALL-Eval to evaluate RQ-Transformer on MS-COCO, we use 30K samples, which are randomly selected in MS-COCO 2014val split, and provide the sampled samples as json file.
  • Evaluation of RQ-Transformer requires text prompts of MS_COCO. Thus, please refer to data/READMD.md and prepare the dataset first.
  • When four GPUs in a single node are used, run the command below
    python -m torch.distributed.launch \
      --master_addr=$MASTER_ADDR \
      --master_port=$PORT \
      --nnodes=1 --nproc_per_node=4 --node_rank=0 \ 
      main_sampling_txt2img.py \
      -v=$RQVAE_CKPT -a=$RQTRANSFORMER_CKPT --dataset="coco_2014val" --save-dir=$SAVE_IMG_DIR

NOTE

  • Unfortunately, we do not provide the training code of RQ-Transformer to avoid unexpected misuses by finetuning our checkpoints. We note that any commercial use of our checkpoints is strictly prohibited.
  • To accurately reproduce the reported results, the checkpoints of RQ-VAE and RQ-Transformer are correctly matched as described above.
  • The generated images are saved as .pkl files in the directory $DIR_SAVED_IMG.
  • For top-k and top-p sampling, the saved setting in the configuration file of pretrained checkpoints is used. If you want to use different top-(k,p) settings, use --top-k and --top-p in running the sampling scripts.
  • Once generated images are saved, compute_metrics.py can be used to evaluate the images again as follows:
python compute_metrics.py fake_path=$DIR_SAVED_IMG ref_dataset=$DATASET_NAME

Sampling speed benchmark

We provide the codes to measure the sampling speed of RQ-Transformer according to the code shape of RQ-VAEs, such as 8x8x4 or 16x16x1, as shown in Figure 4 in the paper. To reproduce the figure, run the following commands on NVIDIA A100 GPU:

# RQ-Transformer (1.4B) on 16x16x1 RQ-VAE (corresponds to VQ-GAN 1.4B model)
python -m measure_throughput f=16 d=1 c=16384 model=huge batch_size=100
python -m measure_throughput f=16 d=1 c=16384 model=huge batch_size=200
python -m measure_throughput f=16 d=1 c=16384 model=huge batch_size=500  # this will result in OOM.

# RQ-Transformer (1.4B) on 8x8x4 RQ-VAE
python -m measure_throughput f=32 d=4 c=16384 model=huge batch_size=100
python -m measure_throughput f=32 d=4 c=16384 model=huge batch_size=200
python -m measure_throughput f=32 d=4 c=16384 model=huge batch_size=500

BibTex

@inproceedings{lee2022autoregressive,
  title={Autoregressive Image Generation using Residual Quantization},
  author={Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={11523--11532},
  year={2022}
}

Licenses

Contact

If you would like to collaborate with us or provide us a feedback, please contaus us,[email protected]

Acknowledgement

Our transformer-related implementation is inspired by minGPT and minDALL-E. We appreciate the authors of VQGAN for making their codes available to public.

Limitations

Since RQ-Transformer is trained on publicly available datasets, some generated images can include socially unacceptable contents according to the text conditions. When the problem occurs, please let us know the pair of "text condition" and "generated images".

rq-vae-transformer's People

Contributors

chiheonk avatar leedoyup avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

rq-vae-transformer's Issues

question about target of compute_loss function

Thanks for your work!
I want to know the target of following function when use_soft_target==True
Due to your implemention "targets = targets.reshape(-1, targets.shape[-1])", i guess the target maybe onehot version of the code or the soft_code in your rqvae/quantizations.get_soft_codes function.
Hope you can answer me!

def compute_loss(self, logits, targets, use_soft_target=False):

Checkpointing is not compatible with .grad() or when an `inputs` parameter is passed to .backward(). Please use .backward() and do not pass its `inputs` argument.

hope this letter finds you well. I am one of the users of your [project/research], and I wanted to bring to your attention an issue I encountered while using your code.

In my application, I attempted to utilize PyTorch's checkpointing feature to reduce memory usage and optimize the training process of my model. However, when I tried to pass the inputs parameter to the .backward() method while performing backpropagation, I encountered a RuntimeError:

vbnet
Copy code
RuntimeError: Checkpointing is not compatible with .grad() or when an inputs parameter is passed to .backward(). Please use .backward() and do not pass its inputs argument.
I believe this issue may be related to the incompatibility of the checkpointing feature with the .grad() method or when passing the inputs parameter to the .backward() method simultaneously. While I did not encounter any problems when calling the .grad() method, passing the inputs parameter to the .backward() method resulted in this error.

I was wondering if you could provide some guidance or suggestions on how to address this issue. I am highly interested in your work, and I hope to fully leverage your code and apply it to my project.

Thank you very much for your time and assistance. I look forward to hearing from you.

add web demo/model to Huggingface

Hi, would you be interested in adding rq-vae-transformer to Hugging Face? The Hub offers free hosting, and it would make your work more accessible and visible to the rest of the ML community. There is already a kakaobrain organization on Hugging Face (https://huggingface.co/kakaobrain) to add models/datasets/spaces(web demos) to.

Example from other organizations:
Keras: https://huggingface.co/keras-io
Microsoft: https://huggingface.co/microsoft
Facebook: https://huggingface.co/facebook

Example spaces with repos:
github: https://github.com/salesforce/BLIP
Spaces: https://huggingface.co/spaces/salesforce/BLIP

github: https://github.com/facebookresearch/omnivore
Spaces: https://huggingface.co/spaces/akhaliq/omnivore

and here are guides for adding spaces/models/datasets to your org

How to add a Space: https://huggingface.co/blog/gradio-spaces
how to add models: https://huggingface.co/docs/hub/adding-a-model
uploading a dataset: https://huggingface.co/docs/datasets/upload_dataset.html

Please let us know if you would be interested and if you have any questions, we can also help with the technical implementation.

ImageNet model has missing keys

Hello, I am using the imagenet_1.4B_rqvae_50e model and using the T2I_sampling.ipynb notebook. I have encountered this error:


---------------------------------------------------------------------------
ConfigAttributeError                      Traceback (most recent call last)
<ipython-input-6-5dd6698162d2> in <module>
      1 # prepare text encoder to tokenize natual languages
----> 2 text_encoder = TextEncoder(tokenizer_name=config.dataset.txt_tok_name, 
      3                            context_length=config.dataset.context_length)

c:\Users\super\Desktop\rq-vae-transformer-main\python_env\lib\site-packages\omegaconf\dictconfig.py in __getattr__(self, key)
    356         except ConfigKeyError as e:
    357             self._format_and_raise(
--> 358                 key=key, value=None, cause=e, type_override=ConfigAttributeError
    359             )
    360         except Exception as e:

c:\Users\super\Desktop\rq-vae-transformer-main\python_env\lib\site-packages\omegaconf\base.py in _format_and_raise(self, key, value, cause, msg, type_override)
    215             msg=str(cause) if msg is None else msg,
    216             cause=cause,
--> 217             type_override=type_override,
    218         )
    219         assert False

c:\Users\super\Desktop\rq-vae-transformer-main\python_env\lib\site-packages\omegaconf\_utils.py in format_and_raise(node, key, value, msg, cause, type_override)
    842         ex.ref_type_str = ref_type_str
    843 
--> 844     _raise(ex, cause)
    845 
    846 

c:\Users\super\Desktop\rq-vae-transformer-main\python_env\lib\site-packages\omegaconf\_utils.py in _raise(ex, cause)
    740     else:
    741         ex.__cause__ = None
--> 742     raise ex.with_traceback(sys.exc_info()[2])  # set env var OC_CAUSE=1 for full trace
    743 
    744 

c:\Users\super\Desktop\rq-vae-transformer-main\python_env\lib\site-packages\omegaconf\dictconfig.py in __getattr__(self, key)
    352         try:
    353             return self._get_impl(
--> 354                 key=key, default_value=_DEFAULT_MARKER_, validate_key=False
    355             )
    356         except ConfigKeyError as e:

c:\Users\super\Desktop\rq-vae-transformer-main\python_env\lib\site-packages\omegaconf\dictconfig.py in _get_impl(self, key, default_value, validate_key)
    443         try:
    444             node = self._get_node(
--> 445                 key=key, throw_on_missing_key=True, validate_key=validate_key
    446             )
    447         except (ConfigAttributeError, ConfigKeyError):

c:\Users\super\Desktop\rq-vae-transformer-main\python_env\lib\site-packages\omegaconf\dictconfig.py in _get_node(self, key, validate_access, validate_key, throw_on_missing_value, throw_on_missing_key)
    480         if value is None:
    481             if throw_on_missing_key:
--> 482                 raise ConfigKeyError(f"Missing key {key!s}")
    483         elif throw_on_missing_value and value._is_missing():
    484             raise MissingMandatoryValue("Missing mandatory value: $KEY")

ConfigAttributeError: Missing key txt_tok_name
    full_key: dataset.txt_tok_name
    object_type=dict

Minimum GPU memory size for training RQ-Transformer

First of all, thank you all the authors for releasing this remarkable researches and models!

I tried to finetune this RQ-Transformer model(3.9B) at certain domain. (I'm already aware that it is impossible to release official training code.) In my training code, 'CUDA out of memory' error occurred with 8 NVIDIA RTX A6000(48GB) in training phase(optimizer step). (Batch size 1 per each device) I'm trying to find out reason of errors and alternative solutions.

So I have a question about minimum GPU memory size for this training task. I saw that NVIDIA A100 was used in your research paper. Was that 80GB memory? (I ask this because there are 2 versions in A100 GPU, 40GB/80GB.)

And should I implement 'model parallelism' code for this task with this resource? If your opinion is that learning process is possible with 48gb, I will look for the wrong part in my code.

Confused by this LogitMask

rqvae/models/rqtransformer/primitives.py
class LogitMask(nn.Module):
def init(self, vocab_size: Iterable[int], value=-1e6):
super().init()

    self.vocab_size = vocab_size
    self.mask_cond = [vocab_size[0]]*len(vocab_size) != vocab_size
    self.value = value

def forward(self, logits: Tensor) -> Tensor:
    if not self.mask_cond:
        return logits
    else:
        for idx, vocab_size in enumerate(self.vocab_size):
            logits[:, idx, vocab_size:].fill_(-float('Inf'))
        return logits

The logits mentioned in LogitMask should probably be expressed as logits[:, idx, vocab_size:] = logits[:, idx, vocab_size:].fill_(-float('Inf'))

using MSE loss to update cookbook?

In your paper, you said that use EMA to update cookbook, but in your code, you use MES to update latent code.

loss_total = loss_recon + self.latent_loss_weight * loss_latent

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.