Giter VIP home page Giter VIP logo

maskgit-pytorch's Introduction

MaskGIT PyTorch

GitHub stars Open In Colab License drawing

Welcome to the unofficial MaskGIT PyTorch repository. This project aims to provide an external reproduction of the results from MaskGIT: Masked Generative Image Transformer, a PyTorch reimplementation of the models, and pretrained weights. Official JAX implementation of MaskGIT can be found here.

Repository Structure

Here's an overview of the repository structure:

  ├ MaskGIT-pytorch/
  |    ├── Metrics/                               <- evaluation tool
  |    |      ├── inception_metrics.py                  
  |    |      └── sample_and_eval.py
  |    |    
  |    ├── Network/                             
  |    |      ├── Taming/                         <- VQGAN architecture   
  |    |      └── transformer.py                  <- Transformer architecture  
  |    |
  |    ├── Trainer/                               <- Main class for training
  |    |      ├── trainer.py                      <- Abstract trainer     
  |    |      └── vit.py                          <- Trainer of maskgit
  |    ├── save_img/                              <- Image samples         
  |    |
  |    ├── colab_demo.ipynb                       <- Inference demo 
  |    ├── download_models.py                     <- download the pretrained models
  |    ├── LICENSE.txt                            <- MIT license
  |    ├── requirements.yaml                      <- help to install env 
  |    ├── README.md                              <- Me :) 
  |    └── main.py                                <- Main

Usage

To get started with this project, follow these steps:

  1. Clone the repository:

    git clone https://github.com/valeoai/MaskGIT-pytorch.git
    cd MaskGIT-pytorch
    
  2. Install requirement

    conda env create -f environment.yaml
    conda activate maskgit
    
  3. (Opt.) Download Pretrained models

    python download_models.py
    
  4. Resume training for 1 additional epoch

    data_folder="/datasets_local/ImageNet/"
    vit_folder="./pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth"
    vqgan_folder="./pretrained_maskgit/VQGAN/"
    writer_log="./logs/"
    num_worker=16
    bsize=64
    # Single GPU
    python main.py  --bsize ${bsize} --data-folder "${data_folder}" --vit-folder "${vit_folder}" --vqgan-folder "${vqgan_folder}" --writer-log "${writer_log}" --num_workers ${num_worker} --img-size 256 --epoch 301 --resume
    # Multiple GPUs single node
    torchrun --standalone --nnodes=1 --nproc_per_node=gpu main.py  --bsize ${bsize} --data-folder "${data_folder}" --vit-folder "${vit_folder}" --vqgan-folder "${vqgan_folder}" --writer-log "${writer_log}" --num_workers ${num_worker} --img-size 256 --epoch 301 --resume

Demo

You are interested only in the inference of the model? You can run the demo_colab.ipynb in google collab! Open In Colab

Training Details

The model consists of a total of 246.303M parameters, with 174.161M for the transformer and 72.142M for VQGAN. The VQGAN reduces a 256x256 (resp. 512x512) image to a 16x16 (resp. 32x32) token representation, over a bank of 1024 possible codes. During the masked transformer training, I used a batch size of 512 over 300 epochs, leveraging 8 GPUs (~768 GPUs/hour on Nvidia A100) for 755 200 iterations on ImageNet 256x256. Then, I finetune the same model on ~750 000 iterations on ImageNet 512x512 with a batch size of 128 and ~384 GPUs/hour on Nvidia A100.

The transformer architecture hyperparameters:

Hidden Dimension Codebook Size Depth Attention Heads MLP Dimension Dropout Rate
768 1024 24 16 3072 0.1

The optimizer employed is Adam with a learning rate of 1e-4, utilizing an 'arccos' scheduler for masking. Additionally, during training, I applied a 10% dropout for the CFG.

For all the details, please refer to our technical report

Performance on ImageNet

Using the following hyperparameters for sampling:

Image Size Softmax Temp Gumbel Temp CFG (w) Randomization Schedule Schedule Step
256*256 1 4.5 3 "linear" "arccos" 8
512*512 1 7 2.8 "linear" "arccos" 15

We reach this performance on ImageNet:

Metric Ours 256*256 Paper 256*256 Ours 512*512 Paper 512*512
FID (Fréchet Inception Distance) 6.80 6.18 7.26 7.32
IS (Inception Score) 214.0 182.1 223.1 156.0
Precision 0.82 0.80 0.85 0.78
Recall 0.51 0.51 0.49 0.50
Density 1.25 - 1.33 -
Coverage 0.84 - 0.86 -

The IS rises monotonically along the training while the FID decrease:

drawing

For visualization, to boost the image quality, we increase the amount of steps (32) the softmax temperature (1.3) and the cfg weight (9) to trade diversity for fidelity

Performance on ImageNet 256

sample

Performance on ImageNet 512

sample

And generation process: sample sample

Inpainting

The model demonstrates good capabilities in inpainting ImageNet-generated images into scenes: drawing

Pretrained Model

You can download the pretrained MaskGIT models in hugging face.

Contribute

The reproduction process might encounter bugs or issues, or there could be mistakes on my part. If you're interested in collaborating or have suggestions, please feel free to reach out (by creating an issue). Your input and collaboration are highly valued!

License

This project is licensed under the MIT License. See the LICENSE file for details.

Acknowledgement

This project is powered by IT4I Karolina Cluster located in the Czech Republic.

The pretrained VQGAN ImageNet (f=16), 1024 codebook. The implementation and the pre-trained model is coming from the VQGAN official repository

BibTeX

If you find our work beneficial for your research, please consider citing both our work and the original source.

@misc{besnier2023MaskGit_pytorch,
      title={A Pytorch Reproduction of Masked Generative Image Transformer}, 
      author={Victor Besnier and Mickael Chen},
      year={2023},
      eprint={2310.14400},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

@InProceedings{chang2022maskgit,
  title = {MaskGIT: Masked Generative Image Transformer},
  author={Huiwen Chang and Han Zhang and Lu Jiang and Ce Liu and William T. Freeman},
  booktitle = {CVPR},
  month = {June},
  year = {2022}
}

maskgit-pytorch's People

Contributors

kifarid avatar llvictorll avatar mickaelchen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

maskgit-pytorch's Issues

Clarification on Additional Token Usage and Embedding in Maskgit-pytorch Transformer

Hi. Thanks for the great work. I have two questions.

  1. Can you please clarify what the second 1 is used for?
    codebook_size is 1024, so its indices are between [0, 1023]. The first 1 in the code is for the mask token, which is 1024. nclass is 1000 for ImageNet. I do not understand the purpose of increasing the nn.Embedding with another 1.
    Link to code

  2. In the following code, why is self.codebook_size+1 used instead of self.codebook_size? What is the purpose of the additional token when after that we compute the cross-entropy loss?
    Link to code

About the training intermediate result.

Hello, really thanks for your great work,
I have a question about the intermediate result. I am trying to reproduce the result in the video domain, but it is really hard to train and the loss does not drop significantly and it keeps producing the pure image where this whole image contains only one color. I just want to ask is it also true for the intermedia result in the Maskgit ?
Really looking forward and thanks for your reply

Warm-up of CFG weight

First of all, thank you for providing such great codes and materials. I was also struggling to reproduce MaskGIT, so it has been a tremendous help.

I noticed an implementation that was not mentioned in the report, which is the warm-up of CFG weight during sampling.

_w = w * (indice / len(scheduler))

If you don't mind, could you please provide insights into the differences in results when this warm-up is applied versus not applied?

Here's another minor point, but would it be more in line with the intended processing if the weight calculation is modified as follows?
_w = w * (indice / (len(scheduler)-1))

Sampling with CFG = 0

Hello,

in the vit.py, on line 381, there is

logit = self.vit(code.clone(), labels, drop_label=~drop)

When debugging, I found that drop is Tensor([True, True, ...]), so it is turned to Tensor([False, False, ...]), meaning the labels are not dropped.
I'm wondering whether this is working as expected, since a CFG of 0 usually means that the label is ignored, right?

Question about the masking scheduler during inference

Hello, I noticed that in your adap_sche function, you normalized the obtained mask ratio function so that the sum of the mask ratios of all steps equals one. I can roughly understand your intention. This means that the total number of tokens retained from all your steps is the final number of tokens (for example, 16x16=256).
However, this seems to be different fromOfficial Jax Implementation of MaskGIT (https://github.com/google-research/maskgit). The maximum value of its mask ratio is from 1 to 0. This means that it predicts all tokens at once in the last decoding step and retains all tokens obtained in the last step. I’m not sure if I misunderstood it. Could you please clarify? Thanks a lot!

Questions about the mask scheduling function

Hi, thanks for your great open-source work! I have two questions about the mask scheduler initialization when running the inference.

In your code, the mask scheduling function is set to arccos: val_to_mask = torch.arccos(r) / (math.pi * 0.5). However, in the original MaskGIT paper, it is set to the cos function. I’m not sure about the purpose of this difference in the code.
When you initialize the input of the mask scheduling function, you choose r = torch.linspace(1, 0, step). But the paper claims that the input of the function should be 0/T, 1/T…(T-1)/T, which is different from torch.linspace(1, 0, step)=0/(step-1), 1/(step-1)...1. I’m not sure if I have misunderstood the paper.

train vqgan

I want to use my own Dataset to train. So do I need to retrain vqgan? if so I see that vqgan training seems to be missing the discriminator. how do I train vqgan?

questions about two stage training

Hey @llvictorll and team,

Really appreciate your reproducing and open source it! It's really helpful for the community. I want to further understand the training and fine-tuning strategy mentioned in the tech report Sec.2. Is that meaning the first stage training is for 256256 and the second fine-tuning is for 512512?

It would be very helpful if you can kindly explain it more.

Training loss jumping up when resuming training.

Hey there,

when I load the model and optimizer state dict from a checkpoint and try to resume training, the training loss suddenly spikes up, removing a lot of the progress the previous training run brought. After a while it goes down again, but the training process is set back by a large margin.

Would you, by chance, know what causes this behavior?

Thanks a lot in advance!

Regarding training a mask model with my own data, could you please provide guidance on the steps involved

Thank you for your great work,I have some questions I would like to ask you, if you don't mind.
data_folder="/datasets_local/ImageNet/"
vit_folder="./pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth"
vqgan_folder="./pretrained_maskgit/VQGAN/"
writer_log="./logs/"
num_worker=16
bsize=64

Single GPU

python main.py --bsize ${bsize} --data-folder "${data_folder}" --vit-folder "${vit_folder}" --vqgan-folder "${vqgan_folder}" --writer-log "${writer_log}" --num_workers ${num_worker} --img-size 256 --epoch 301 --resume

Multiple GPUs single node

torchrun --standalone --nnodes=1 --nproc_per_node=gpu main.py --bsize ${bsize} --data-folder "${data_folder}" --vit-folder "${vit_folder}" --vqgan-folder "${vqgan_folder}" --writer-log "${writer_log}" --num_workers ${num_worker} --img-size 256 --epoch 301 --resume
If I want to train the mask with custom data, what changes do I need to make to this code? I've already trained my own VQGAN

reproducibility

Hi @llvictorll, thanks for your nice reproduction. When I evaluated the checkpoints provided, with the following command

torchrun --standalone --nnodes=1 --nproc_per_node=1 main.py --bsize 128 --data-folder imagenet --vit-folder pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth --vqgan-folder pretrained_maskgit/VQGAN/ --writer-log logs --num_workers 16 --img-size 256 --epoch 301 --resume --test-only

Size of model autoencoder: 72.142M                                                                                                                       
Acquired codebook size: 1024                                                                                                                             
load ckpt from: pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth                                                                                      
Size of model vit: 174.161M                                                                                                                              
Evaluation with hyper-parameter ->                                                                                                                       
scheduler: arccos, number of step: 8, softmax temperature: 1.0, cfg weight: 3, gumbel temperature: 4.5                                                   
{'Eval/fid_conditional': 7.655000113633889, 'Eval/inception_score_conditional': 228.72691345214844, 'Eval/precision_conditional': 0.8194600000000002, 'Eval/recall_conditional': 0.5016600000000001, 'Eval/density_conditional': 1.2358733333333334, 'Eval/coverage_conditional': 0.8560800000000001}    

The FID result is lower than you have reported (6.80), as shown above. Could you please help figure out where this gap come from? Thanks.

Question about training loss

Hi, I’m struggling to reproduce the work. However, when I start training following the process in this repository, the loss decreases rapidly and it seems to be approaching convergence. Despite this, the model fails to reconstruct images. Does this make sense?
image
image

Target tokens for loss computation

Hi, I'd like to qustion about a loss computation part.
This repository (and the original repository?) compute cross-emtropy loss with entire groud-truth tokens.
This implies that the model learns to predict 'known (unmasked)' tokens as well, which is relatively easy to estimate.
As a result, the training may exhibit a strong bias towards the known tokens.

loss = self.criterion(pred.reshape(-1, 1024 + 1), code.view(-1)) / self.args.grad_cum

Intuitively thinking, in this case, a model firstly ignore the loss of 'masked' tokens and the loss of known tokens would be drastically descreased at the beginning of training.

I think there is another option of masking the known position in the target tokens, which results in forcing a model to predict only the unknown (masked) tokens (as same as the approach taken in the following repository).
https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/muse_maskgit_pytorch.py#L680

I'd like to know if you have any insights on the following two points regarding this.

  1. Which approach tends to yield better learning outcomes: masking the target or not masking the target?
  2. Could you share the loss curve through epochs during training so that we can confirm if our training is going well?

Thank you for considering my request!
Best regards,
Yukara

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.