Giter VIP home page Giter VIP logo

tauldr's Introduction

A Continuous Time Framework for Discrete Denoising Models

Paper Link

Install

conda env create --file tldr_env.yml

Sampling

To reproduce the FID score reported in the paper, first download the pre-trained model from https://www.dropbox.com/scl/fo/zmwsav82kgqtc0tzgpj3l/h?dl=0&rlkey=k6d2bp73k4ifavcg9ldjhgu0s

Then go to config/eval/cifar10.py and change the model_location field to point to the cifar10 pytorch checkpoint. Also change model_config_location to the config included with the pytorch checkpoint.

We provide a sampling script to sample 50000 images from the model. Go to scripts/sample.py, and change save_samples_path to a location where samples can be saved. Then run

python scripts/sample.py

Then to compute the FID score, you will also need to download the CIFAR10 dataset. Once you have the CIFAR10 pngs in another folder you can compute the FID score using

python -m pytorch_fid --device cuda:0 path/to/tauLDR_samples path/to/cifar10pngs

ELBO evaluation

To reproduce the ELBO value reported in the paper for the CIFAR10 model, obtain the checkpoint as above.

Then go to config/eval/cifar10_elbo.py and change config.experiment_dir to point to a directory that has the following structure

experiment_dir
    - checkpoints
        - ckpt_0001999999.pt
    - config
        - config_001.yaml

using the files downloaded from the dropbox. Also change config.checkpoint_path to point to /path/to/experiment_dir/checkpoints/ckpt_0001999999.pt. Change config.cifar10_path to point to somewhere where the CIFAR10 dataset can be downloaded (it will be automatically downloaded to that location).

Finally, run the following command

python elbo_evaluation.py

This will save the ELBO value in a eval folder within the experiment_dir. The ELBO value is written in the file neg_elbo, the first 0 can be ignored and the second number is the negative ELBO averaged over the pixels. It should be around 3.59.

Notebooks

To generate CIFAR10 samples, open the notebooks/image.ipynb notebook. Change the paths at the top of the config/eval/cifar10.py config file to point to a folder where CIFAR10 can be downloaded and the paths to the model and config downloaded from the dropbox link.

To generate piano samples, open the notebooks/piano.ipynb notebook. Change the paths at the top of the config/eval/piano.py config file to point to the dataset downloaded from the dropbox link as well as the model weights and config file.

The sampling settings can be set in the config files, switching between standard tau-leaping and with predictor-corrector steps.

Training

CIFAR10

The CIFAR10 model can be trained using

python train.py cifar10

Paths to store the output and to download the CIFAR10 dataset should be set in the training config, config/train/cifar10.py. To train the model over multiple GPUs, use

python dist_train.py cifar10

with settings found in the config/train/cifar10_distributed.py config file.

Piano

The piano model can be trained using

python train.py piano

Paths to store the output and to the dataset downloaded from the dropbox link should be set in config/train/piano.py.

Audio Samples

These are 4 pairs of audio samples. The first is a music sequence generated by the model conditioned on the first 2 bars (~2.5 secs) of the piece. The second is the ground truth song from the test dataset.

Pair a

a.mp4
true_a.mp4

Pair b

b.mp4
true_b.mp4

Pair c

c.mp4
true_c.mp4

Pair d

d.mp4
true_d.mp4

tauldr's People

Contributors

andrew-cr avatar

Stargazers

ZHU-Zhiyu avatar Chuwd avatar Kai Yi avatar Alex avatar Henry avatar Jeff Carpenter avatar Chemgyu avatar  avatar HeyangXue1997 avatar Ruidong Wu avatar  avatar  avatar Zhe Xu avatar 0417itsuki avatar MuhammadAnwar avatar Carlos Couto avatar  avatar zhengzx-nlp avatar tqchen avatar NGUYEN Van-Khoa avatar Simon Dirmeier avatar Siddharth Shrivastava avatar Shitty Girl avatar Daehoon Gwak avatar Ricky Chen avatar Fay avatar Sereiwathna avatar Kunat Pipatanakul avatar Shenghang Tsai avatar James Hensman avatar Hwidong Na avatar Junyi Zhang avatar Lin Zheng avatar Yiwei Guo avatar Chin-Yun Yu avatar Ari Pakman avatar  avatar Orizuru avatar

Watchers

 avatar

tauldr's Issues

reproduce the training

Hi @andrew-cr

I tried to reproduce the training on cifar-1o and obtained the following losses.

It seems to me a little weird that the neg_elbo term is sooo large compared to the NLL terms. Am I wrong? FYI, I changed the batch size to 40 to save memory consumption. Do you guys also obtain similar values in the initial training stage?

NLL: 8.108, neg_elbo: 12571671.000

Best,
Zhangzhi

Clarification Needed on +1 Addition in Detailed Balance Condition

Hello,

I am currently exploring the implementation details of the tauLDR model and have a question regarding the detailed balance condition as implemented in the code. Specifically, I am referring to the line in models.py:

rate[i, j] = rate[j, i] * np.exp(- ( (j+1)**2 - (i+1)**2 + S*(i+1) - S*(j+1) ) / (2 * self.Q_sigma**2) )

rate[i, j] = rate[j, i] * np.exp(- ( (j+1)**2 - (i+1)**2 + S*(i+1) - S*(j+1) ) / (2 * self.Q_sigma**2) )

My question pertains to the addition of +1 in each term of the numerator within the exponential function. Upon reviewing Appendix E of the associated paper, I did not find a mention or requirement for adding 1 to the indices i and j in the detailed balance condition.

Could you please provide some insight into the rationale behind this addition? Is it related to a zero-indexing adjustment for Python, or is there another theoretical justification or correction applied here that might not be immediately apparent from the paper's description?

Understanding the reasoning behind this discrepancy will greatly aid in my comprehension of the model's implementation and ensure alignment with the theoretical foundations laid out in the paper.

Thank you for your time and for your contributions to this fascinating project.

How do you choose corrector_step_size_multiplier ?

Hi,

Experimenting myself with the PC scheme, I was wondering if there is any rational explanation behind the choice of the corrector_step_size_multiplier parameter ?

Do have any tips to tune it or have you simply swept over different values for each experiment ?

Best regards,

Antoine

Question about reported ELBO value

Hi Andrew!

Thank you so much for this great work! We are trying to better understand the continuous time discrete diffusion framework and the evaluations reported in the paper. In Table 1, an ELBO of -3.59 is reported on the test set in bits per dimension for CIFAR10. However, as mentioned in #4, the neg_elbo on CIFAR is on the scale of 1e7 due to many dropped constants, and even after scaling by the image dimensions 32 x 32 x 3 is not at the right scale. The nll term does look to be on the right scale. We are wondering whether the reported ELBO in the paper is actually just the nll term averaged over the test set (hence over all time steps), or there is an implementation of the actual ELBO somewhere?

Thank you so much!

Best,
Bear

LICENSE

Thank you very much for the great work and being advocate of open source! Would you please provide the license?

training on a different high dimensional discrete dataset

Hi,
Thank you for the great work!

Any guidelines for training/testing on a different dataset -aside CIFAR and piano dataset?

Training on a different dataset would most likely involve chaning the data path in config file,
any other adjustments to also keep in mind while also making these changes?

Thank you!

Dimension of transition choice

Hi,

Great work and nice implementation !

I have a question on the way you choose the dimension of transition during training (code below) :

    rate_vals_square = rate[
        torch.arange(B, device=device).repeat_interleave(D),
        x_t.long().flatten(),
        :
    ] # (B*D, S)

    rate_vals_square[
        torch.arange(B*D, device=device),
        x_t.long().flatten()
    ] = 0.0 # 0 the diagonals

    rate_vals_square = rate_vals_square.view(B, D, S)

    rate_vals_square_dimsum = torch.sum(rate_vals_square, dim=2).view(B, D)

    square_dimcat = torch.distributions.categorical.Categorical(
        rate_vals_square_dimsum
    )

    square_dims = square_dimcat.sample() # (B,) taking values in [0, D)

What I understand is that you sample the dimension of transition from the distribution of the total outgoing rates of x components. Even though it seems pretty intuitive to me, I don't see how this is justified from a theoretical perspective, and can't find any clear explanation in your paper.

With that in mind, could you please elaborate on how this implementation can be related to your model ?

Thanks,

Antoine

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.