Giter VIP home page Giter VIP logo

imagen-pytorch's Introduction

Imagen - Pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network that beats DALL-E2, in Pytorch. It is the new SOTA for text-to-image synthesis.

Architecturally, it is actually much simpler than DALL-E2. It consists of a cascading DDPM conditioned on text embeddings from a large pretrained T5 model (attention network). It also contains dynamic clipping for improved classifier free guidance, noise level conditioning, and a memory efficient unet design.

It appears neither CLIP nor prior network is needed after all. And so research continues.

AI Coffee Break with Letitia | Assembly AI | Yannic Kilcher

Please join Join us on Discord if you are interested in helping out with the replication with the LAION community

Shoutouts

  • StabilityAI for the generous sponsorship, as well as my other sponsors out there

  • 🤗 Huggingface for their amazing transformers library. The text encoder portion is pretty much taken care of because of them

  • Jonathan Ho for bringing about a revolution in generative artificial intelligence through his seminal paper

  • Sylvain and Zachary for the Accelerate library, which this repository uses for distributed training

  • Alex for einops, indispensable tool for tensor manipulation

  • Jorge Gomes for helping out with the T5 loading code and advice on the correct T5 version

  • Katherine Crowson, for her beautiful code, which helped me understand the continuous time version of gaussian diffusion

  • Marunine and Netruk44, for reviewing code, sharing experimental results, and help with debugging

  • Marunine for providing a potential solution for a color shifting issue in the memory efficient u-nets. Thanks to Jacob for sharing experimental comparisons between the base and memory-efficient unets

  • Marunine for finding numerous bugs, resolving an issue with resize right, and for sharing his experimental configurations and results

  • MalumaDev for proposing the use of pixel shuffle upsampler to fix checkboard artifacts

  • Valentin for pointing out insufficient skip connections in the unet, as well as the specific method of attention conditioning in the base-unet in the appendix

  • BIGJUN for catching a big bug with continuous time gaussian diffusion noise level conditioning at inference time

  • Bingbing for identifying a bug with sampling and order of normalizing and noising with low resolution conditioning image

  • Kay for contributing one line command training of Imagen!

  • Hadrien Reynaud for testing out text-to-video on a medical dataset, sharing his results, and identifying issues!

Install

$ pip install imagen-pytorch

Usage

import torch
from imagen_pytorch import Unet, Imagen

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

for i in (1, 2):
    loss = imagen(images, text_embeds = text_embeds, unet_number = i)
    loss.backward()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = imagen.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
], cond_scale = 3.)

images.shape # (3, 3, 256, 256)

For simpler training, you can directly supply text strings instead of precomputing text encodings. (Although for scaling purposes, you will definitely want to precompute the textual embeddings + mask)

The number of textual captions must match the batch size of the images if you go this route.

# mock images and text (get a lot of this)

texts = [
    'a child screaming at finding a worm within a half-eaten apple',
    'lizard running across the desert on two feet',
    'waking up to a psychedelic landscape',
    'seashells sparkling in the shallow waters'
]

images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

for i in (1, 2):
    loss = imagen(images, texts = texts, unet_number = i)
    loss.backward()

With the ImagenTrainer wrapper class, the exponential moving averages for all of the U-nets in the cascading DDPM will be automatically taken care of when calling update

import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    text_encoder_name = 't5-large',
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

# wrap imagen with the trainer class

trainer = ImagenTrainer(imagen)

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(64, 256, 1024).cuda()
images = torch.randn(64, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

loss = trainer(
    images,
    text_embeds = text_embeds,
    unet_number = 1,            # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2
    max_batch_size = 4          # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
)

trainer.update(unet_number = 1)

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = trainer.sample(texts = [
    'a puppy looking anxiously at a giant donut on the table',
    'the milky way galaxy in the style of monet'
], cond_scale = 3.)

images.shape # (2, 3, 256, 256)

You can also train Imagen without text (unconditional image generation) as follows

import torch
from imagen_pytorch import Unet, Imagen, SRUnet256, ImagenTrainer

# unets for unconditional imagen

unet1 = Unet(
    dim = 32,
    dim_mults = (1, 2, 4),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True),
    layer_cross_attns = False,
    use_linear_attn = True
)

unet2 = SRUnet256(
    dim = 32,
    dim_mults = (1, 2, 4),
    num_resnet_blocks = (2, 4, 8),
    layer_attns = (False, False, True),
    layer_cross_attns = False
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    condition_on_text = False,   # this must be set to False for unconditional Imagen
    unets = (unet1, unet2),
    image_sizes = (64, 128),
    timesteps = 1000
)

trainer = ImagenTrainer(imagen).cuda()

# now get a ton of images and feed it through the Imagen trainer

training_images = torch.randn(4, 3, 256, 256).cuda()

# train each unet separately
# in this example, only training on unet number 1

loss = trainer(training_images, unet_number = 1)
trainer.update(unet_number = 1)

# do the above for many many many many steps
# now you can sample images unconditionally from the cascading unet(s)

images = trainer.sample(batch_size = 16) # (16, 3, 128, 128)

Or train only super-resoluting unets

import torch
from imagen_pytorch import Unet, NullUnet, Imagen

# unet for imagen

unet1 = NullUnet()  # add a placeholder "null" unet for the base unet

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    image_sizes = (64, 256),
    timesteps = 250,
    cond_drop_prob = 0.1
).cuda()

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

loss = imagen(images, text_embeds = text_embeds, unet_number = 2)
loss.backward()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings as well as low resolution images

lowres_images = torch.randn(3, 3, 64, 64).cuda()  # starting un-resoluted images

images = imagen.sample(
    texts = [
        'a whale breaching from afar',
        'young girl blowing out candles on her birthday cake',
        'fireworks with blue and green sparkles'
    ],
    start_at_unet_number = 2,              # start at unet number 2
    start_image_or_video = lowres_images,  # pass in low resolution images to be resoluted
    cond_scale = 3.)

images.shape # (3, 3, 256, 256)

At any time you can save and load the trainer and all associated states with the save and load methods. It is recommended you use these methods instead of manually saving with a state_dict call, as there are some device memory management being done underneath the hood within the trainer.

ex.

trainer.save('./path/to/checkpoint.pt')

trainer.load('./path/to/checkpoint.pt')

trainer.steps # (2,) step number for each of the unets, in this case 2

Dataloader

You can also rely on the ImagenTrainer to automatically train off DataLoader instances. You simply have to craft your DataLoader to return either images (for unconditional case), or of ('images', 'text_embeds') for text-guided generation.

ex. unconditional training

from imagen_pytorch import Unet, Imagen, ImagenTrainer
from imagen_pytorch.data import Dataset

# unets for unconditional imagen

unet = Unet(
    dim = 32,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 1,
    layer_attns = (False, False, False, True),
    layer_cross_attns = False
)

# imagen, which contains the unet above

imagen = Imagen(
    condition_on_text = False,  # this must be set to False for unconditional Imagen
    unets = unet,
    image_sizes = 128,
    timesteps = 1000
)

trainer = ImagenTrainer(
    imagen = imagen,
    split_valid_from_train = True # whether to split the validation dataset from the training
).cuda()

# instantiate your dataloader, which returns the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks. in this case, only images is returned as it is unconditional training

dataset = Dataset('/path/to/training/images', image_size = 128)

trainer.add_train_dataset(dataset, batch_size = 16)

# working training loop

for i in range(200000):
    loss = trainer.train_step(unet_number = 1, max_batch_size = 4)
    print(f'loss: {loss}')

    if not (i % 50):
        valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 4)
        print(f'valid loss: {valid_loss}')

    if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
        images = trainer.sample(batch_size = 1, return_pil_images = True) # returns List[Image]
        images[0].save(f'./sample-{i // 100}.png')

Multi GPU

Thanks to 🤗 Accelerate, you can do multi GPU training easily with two steps.

First you need to invoke accelerate config in the same directory as your training script (say it is named train.py)

$ accelerate config

Next, instead of calling python train.py as you would for single GPU, you would use the accelerate CLI as so

$ accelerate launch train.py

That's it!

Command-line

Imagen can also be used via CLI directly.

Configuration

ex.

$ imagen config

or

$ imagen config --path ./configs/config.json

In the config you are able to change settings for the trainer, dataset and the imagen config.

The Imagen config parameters can be found here

The Elucidated Imagen config parameters can be found here

The Imagen Trainer config parameters can be found here

For the dataset parameters all dataloader parameters can be used.

Training

This command allows you to train or resume training your model

ex.

$ imagen train

or

$ imagen train --unet 2 --epoches 10

You can pass following arguments to the training command.

  • --config specify the config file to use for training [default: ./imagen_config.json]
  • --unet the index of the unet to train [default: 1]
  • --epoches how many epoches to train for [default: 50]

Sampling

Be aware when sampling your checkpoint should have trained all unets to get a usable result.

ex.

$ imagen sample --model ./path/to/model/checkpoint.pt "a squirrel raiding the birdfeeder"
# image is saved to ./a_squirrel_raiding_the_birdfeeder.png

You can pass following arguments to the sample command.

  • --model specify the model file to use for sampling
  • --cond_scale conditioning scale (classifier free guidance) in decoder
  • --load_ema load EMA version of unets if available

In order to use a saved checkpoint with this feature, you either must instantiate your Imagen instance using the config classes, ImagenConfig and ElucidatedImagenConfig or create a checkpoint via the CLI directly

For proper training, you'll likely want to setup config-driven training anyways.

ex.

import torch
from imagen_pytorch import ImagenConfig, ElucidatedImagenConfig, ImagenTrainer

# in this example, using elucidated imagen

imagen = ElucidatedImagenConfig(
    unets = [
        dict(dim = 32, dim_mults = (1, 2, 4, 8)),
        dict(dim = 32, dim_mults = (1, 2, 4, 8))
    ],
    image_sizes = (64, 128),
    cond_drop_prob = 0.5,
    num_sample_steps = 32
).create()

trainer = ImagenTrainer(imagen)

# do your training ...

# then save it

trainer.save('./checkpoint.pt')

# you should see a message informing you that ./checkpoint.pt is commandable from the terminal

It really should be as simple as that

You can also pass this checkpoint file around, and anyone can continue finetune on their own data

from imagen_pytorch import load_imagen_from_checkpoint, ImagenTrainer

imagen = load_imagen_from_checkpoint('./checkpoint.pt')

trainer = ImagenTrainer(imagen)

# continue training / fine-tuning

Inpainting

Inpainting follows the formulation laid out by the recent Repaint paper. Simply pass in inpaint_images and inpaint_masks to the sample function on either Imagen or ElucidatedImagen

inpaint_images = torch.randn(4, 3, 512, 512).cuda()      # (batch, channels, height, width)
inpaint_masks = torch.ones((4, 512, 512)).bool().cuda()  # (batch, height, width)

inpainted_images = trainer.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles',
    'dust motes swirling in the morning sunshine on the windowsill'
], inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, cond_scale = 5.)

inpainted_images # (4, 3, 512, 512)

For video, similarly pass in your videos to inpaint_videos keyword on .sample. Inpainting mask can either be the same across all frames (batch, height, width) or different (batch, frames, height, width)

inpaint_videos = torch.randn(4, 3, 8, 512, 512).cuda()   # (batch, channels, frames, height, width)
inpaint_masks = torch.ones((4, 8, 512, 512)).bool().cuda()  # (batch, frames, height, width)

inpainted_videos = trainer.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles',
    'dust motes swirling in the morning sunshine on the windowsill'
], inpaint_videos = inpaint_videos, inpaint_masks = inpaint_masks, cond_scale = 5.)

inpainted_videos # (4, 3, 8, 512, 512)

Experimental

Tero Karras of StyleGAN fame has written a new paper with results that have been corroborated by a number of independent researchers as well as on my own machine. I have decided to create a version of Imagen, the ElucidatedImagen, so that one can use the new elucidated DDPM for text-guided cascading generation.

Simply import ElucidatedImagen, and then instantiate the instance as you did before. The hyperparameters are different than the usual ones for discrete and continuous time gaussian diffusion, and can be individualized for each unet in the cascade.

Ex.

from imagen_pytorch import ElucidatedImagen

# instantiate your unets ...

imagen = ElucidatedImagen(
    unets = (unet1, unet2),
    image_sizes = (64, 128),
    cond_drop_prob = 0.1,
    num_sample_steps = (64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
    sigma_min = 0.002,           # min noise level
    sigma_max = (80, 160),       # max noise level, @crowsonkb recommends double the max noise level for upsampler
    sigma_data = 0.5,            # standard deviation of data distribution
    rho = 7,                     # controls the sampling schedule
    P_mean = -1.2,               # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                 # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

# rest is the same as above

Text to Video

This repository will also start accumulating new research around text guided video synthesis. For starters it will adopt the 3d unet architecture described by Jonathan Ho in Video Diffusion Models

Update: verified working by Hadrien Reynaud!

Ex.

import torch
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer

unet1 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()

unet2 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()

# elucidated imagen, which contains the unets above (base unet and super resoluting ones)

imagen = ElucidatedImagen(
    unets = (unet1, unet2),
    image_sizes = (16, 32),
    random_crop_sizes = (None, 16),
    temporal_downsample_factor = (2, 1),        # in this example, the first unet would receive the video temporally downsampled by 2x
    num_sample_steps = 10,
    cond_drop_prob = 0.1,
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

# mock videos (get a lot of this) and text encodings from large T5

texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles',
    'dust motes swirling in the morning sunshine on the windowsill'
]

videos = torch.randn(4, 3, 10, 32, 32).cuda() # (batch, channels, time / video frames, height, width)

# feed images into imagen, training each unet in the cascade
# for this example, only training unet 1

trainer = ImagenTrainer(imagen)

# you can also ignore time when training on video initially, shown to improve results in video-ddpm paper. eventually will make the 3d unet trainable with either images or video. research shows it is essential (with current data regimes) to train first on text-to-image. probably won't be true in another decade. all big data becomes small data

trainer(videos, texts = texts, unet_number = 1, ignore_time = False)
trainer.update(unet_number = 1)

videos = trainer.sample(texts = texts, video_frames = 20) # extrapolating to 20 frames from training on 10 frames

videos.shape # (4, 3, 20, 32, 32)

You can also train on text - image pairs first. The Unet3D will automatically convert it to single framed videos and learn without the temporal components (by automatically setting ignore_time = True), whether it be 1d convolutions or causal attention across time.

This is the current approach taken by all the big artificial intelligence labs (Brain, MetaAI, Bytedance)

FAQ

  • Why are my generated images not aligning well with the text?

Imagen uses an algorithm called Classifier Free Guidance. When sampling, you apply a scale to the conditioning (text in this case) of greater than 1.0.

Researcher Netruk44 have reported 5-10 to be optimal, but anything greater than 10 to break.

trainer.sample(texts = [
    'a cloud in the shape of a roman gladiator'
], cond_scale = 5.) # <-- cond_scale is the conditioning scale, needs to be greater than 1.0 to be better than average
  • Are there any pretrained models yet?

Not at the moment but one will likely be trained and open sourced within the year, if not sooner. If you would like to participate, you can join the community of artificial neural network trainers at Laion (discord link is in the Readme above) and start collaborating.

  • Will this technology take my job?

More the reason why you should start training your own model, starting today! The last thing we need is this technology being in the hands of an elite few. Hopefully this repository reduces the work to just finding the necessary compute, and augmenting with your own curated dataset.

  • What am I allowed to do with this repository?

Anything! It is MIT licensed. In other words, you can freely copy / paste for your own research, remixed for whatever modality you can think of. Go train amazing models for profit, for science, or simply to satiate your own personal pleasure at witnessing something divine unravel in front of you.

Cool Applications!

Related Works

Todo

  • use huggingface transformers for T5-small text embeddings

  • add dynamic thresholding

  • add dynamic thresholding DALLE2 and video-diffusion repository as well

  • allow for one to set T5-large (and perhaps small factory method to take in any huggingface transformer)

  • add the lowres noise level with the pseudocode in appendix, and figure out what is this sweep they do at inference time

  • port over some training code from DALLE2

  • need to be able to use a different noise schedule per unet (cosine was used for base, but linear for SR)

  • just make one master-configurable unet

  • complete resnet block (biggan inspired? but with groupnorm) - complete self attention

  • complete conditioning embedding block (and make it completely configurable, whether it be attention, film etc)

  • consider using perceiver-resampler from https://github.com/lucidrains/flamingo-pytorch in place of attention pooling

  • add attention pooling option, in addition to cross attention and film

  • add optional cosine decay schedule with warmup, for each unet, to trainer

  • switch to continuous timesteps instead of discretized, as it seems that is what they used for all stages - first figure out the linear noise schedule case from the variational ddpm paper https://openreview.net/forum?id=2LdBqxc1Yv

  • figure out log(snr) for alpha cosine noise schedule.

  • suppress the transformers warning because only T5encoder is used

  • allow setting for using linear attention on layers where full attention cannot be used

  • force unets in continuous time case to use non-fouriered conditions (just pass the log(snr) through an MLP with optional layernorms), as that is what i have working locally

  • removed learned variance

  • add p2 loss weighting for continuous time

  • make sure cascading ddpm can be trained without text condition, and make sure both continuous and discrete time gaussian diffusion works

  • use primer's depthwise convs on the qkv projections in linear attention (or use token shifting before projections) - also use new dropout proposed by bayesformer, as it seems to work well with linear attention

  • explore skip layer excitation in unet decoder

  • accelerate integration

  • build out CLI tool and one-line generation of image

  • knock out any issues that arised from accelerate

  • add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865

  • build a simple checkpointing system, backed by a folder

  • add skip connection from outputs of all upsample blocks, used in unet squared paper and some previous unet works

  • add fsspec, recommended by Romain @rom1504, for cloud / local file system agnostic persistence of checkpoints

  • test out persistence in gcs with https://github.com/fsspec/gcsfs

  • extend to video generation, using axial time attention as in Ho's video ddpm paper

  • allow elucidated imagen to generalize to any shape

  • allow for imagen to generalize to any shape

  • add dynamic positional bias for the best type of length extrapolation across video time

  • move video frames to sample function, as we will be attempting time extrapolation

  • attention bias to null key / values should be a learned scalar of head dimension

  • add self-conditioning from bit diffusion paper, already coded up at ddpm-pytorch

  • add v-parameterization (https://arxiv.org/abs/2202.00512) from imagen video paper, the only thing new

  • incorporate all learnings from make-a-video (https://makeavideo.studio/)

  • build out CLI tool for training, resuming training off config file

  • allow for temporal interpolation at specific stages

  • make sure temporal interpolation works with inpainting

  • make sure one can customize all interpolation modes (some researchers are finding better results with trilinear)

  • imagen-video : allow for conditioning on preceding (and possibly future) frames of videos. ignore time should not be allowed in that scenario

  • make sure to automatically take care of temporal down/upsampling for conditioning video frames, but allow for an option to turn it off

  • make sure inpainting works with video

  • make sure inpainting mask for video can accept be customized per frame

  • add flash attention

  • reread cogvideo and figure out how frame rate conditioning could be used

  • bring in attention expertise for self attention layers in unet3d

  • consider bringing in NUWA's 3d convolutional attention

  • consider transformer-xl memories in the temporal attention blocks

  • consider perceiver-ar approach to attending to past time

  • frame dropouts during attention for achieving both regularizing effect as well as shortened training time

  • investigate frank wood's claims https://github.com/lucidrains/flexible-diffusion-modeling-videos-pytorch and either add the hierarchical sampling technique, or let people know about its deficiencies

  • offer challenging moving mnist (with distractor objects) as a one-line trainable baseline for researchers to branch off of for text to video

  • preencoding of text to memmapped embeddings

  • be able to create dataloader iterators based on the old epoch style, also configure shuffling etc

  • be able to also pass in arguments (instead of requiring forward to be all keyword args on model)

  • bring in reversible blocks from revnets for 3d unet, to lessen memory burden

  • add ability to only train super-resolution network

  • read dpm-solver see if it is applicable to continuous time gaussian diffusion

  • allow for conditioning video frames with arbitrary absolute times (calculate RPE during temporal attention)

  • accommodate dream booth fine tuning

  • add textual inversion

  • cleanup self conditioning to be extracted at imagen instantiation

  • make sure eventual dreambooth works with imagen-video

  • add framerate conditioning for video diffusion

  • make sure one can simulataneously condition on video frames as a prompt, as well as some conditioning image across all frames

  • test and add distillation technique from consistency models

Citations

@inproceedings{Saharia2022PhotorealisticTD,
    title   = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding},
    author  = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily L. Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and Seyedeh Sara Mahdavi and Raphael Gontijo Lopes and Tim Salimans and Jonathan Ho and David Fleet and Mohammad Norouzi},
    year    = {2022}
}
@article{Alayrac2022Flamingo,
    title   = {Flamingo: a Visual Language Model for Few-Shot Learning},
    author  = {Jean-Baptiste Alayrac et al},
    year    = {2022}
}
@inproceedings{Sankararaman2022BayesFormerTW,
    title   = {BayesFormer: Transformer with Uncertainty Estimation},
    author  = {Karthik Abinav Sankararaman and Sinong Wang and Han Fang},
    year    = {2022}
}
@article{So2021PrimerSF,
    title   = {Primer: Searching for Efficient Transformers for Language Modeling},
    author  = {David R. So and Wojciech Ma'nke and Hanxiao Liu and Zihang Dai and Noam M. Shazeer and Quoc V. Le},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2109.08668}
}
@misc{cao2020global,
    title   = {Global Context Networks},
    author  = {Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
    year    = {2020},
    eprint  = {2012.13375},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Karras2022ElucidatingTD,
    title   = {Elucidating the Design Space of Diffusion-Based Generative Models},
    author  = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2206.00364}
}
@inproceedings{NEURIPS2020_4c5bcfec,
    author      = {Ho, Jonathan and Jain, Ajay and Abbeel, Pieter},
    booktitle   = {Advances in Neural Information Processing Systems},
    editor      = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
    pages       = {6840--6851},
    publisher   = {Curran Associates, Inc.},
    title       = {Denoising Diffusion Probabilistic Models},
    url         = {https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf},
    volume      = {33},
    year        = {2020}
}
@article{Lugmayr2022RePaintIU,
    title   = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
    author  = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2201.09865}
}
@misc{ho2022video,
    title   = {Video Diffusion Models},
    author  = {Jonathan Ho and Tim Salimans and Alexey Gritsenko and William Chan and Mohammad Norouzi and David J. Fleet},
    year    = {2022},
    eprint  = {2204.03458},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{rogozhnikov2022einops,
    title   = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
    author  = {Alex Rogozhnikov},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{chen2022analog,
    title   = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
    author  = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
    year    = {2022},
    eprint  = {2208.04202},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{Singer2022,
    author  = {Uriel Singer},
    url     = {https://makeavideo.studio/Make-A-Video.pdf}
}
@article{Sunkara2022NoMS,
    title   = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
    author  = {Raja Sunkara and Tie Luo},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.03641}
}
@article{Salimans2022ProgressiveDF,
    title   = {Progressive Distillation for Fast Sampling of Diffusion Models},
    author  = {Tim Salimans and Jonathan Ho},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.00512}
}
@article{Ho2022ImagenVH,
    title   = {Imagen Video: High Definition Video Generation with Diffusion Models},
    author  = {Jonathan Ho and William Chan and Chitwan Saharia and Jay Whang and Ruiqi Gao and Alexey A. Gritsenko and Diederik P. Kingma and Ben Poole and Mohammad Norouzi and David J. Fleet and Tim Salimans},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.02303}
}
@misc{gilmer2023intriguing
    title  = {Intriguing Properties of Transformer Training Instabilities},
    author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
    year   = {2023},
    status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
@inproceedings{Hang2023EfficientDT,
    title   = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
    author  = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
    year    = {2023}
}
@article{Zhang2021TokenST,
    title   = {Token Shift Transformer for Video Classification},
    author  = {Hao Zhang and Y. Hao and Chong-Wah Ngo},
    journal = {Proceedings of the 29th ACM International Conference on Multimedia},
    year    = {2021}
}
@inproceedings{anonymous2022normformer,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=GMYWzWztDx5},
    note    = {under review}
}

imagen-pytorch's People

Contributors

animebing avatar birch-san avatar deepglugs avatar ezhang7423 avatar gauenk avatar gowdygamble avatar haukened avatar jacobwjs avatar jorgemcgomes avatar lucidrains avatar netruk44 avatar nodja avatar pacocp avatar progamergov avatar ryanrussell avatar semitrivial avatar thefusion21 avatar vfragoso avatar wheest avatar windsorwho avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

imagen-pytorch's Issues

In sample scripts, add (commented out?) code to actually display the images

When I run the first sample script [with all .cuda()'s stripped out since I don't have a CUDA---that shouldn't stop the model from working though, right?], the following shows me a bunch of apparently random noise instead of "a whale breaching from afar":

from PIL import Image
Image.fromarray(images[0].numpy(), mode='RGB').show()

Thanks for your patience with us newbies :)

Training with coco dataset resulted in noise:

I followed colab notebook to train the model using MS-COCO 2014 dataset. It resulted noise. I first trained it for 2 epochs for each Unets as in the sample. It resulted noise. Then I trained it for 15 epochs for each Unets the result was worse noise output.

Result after 2 epochs when prompted for house:
house before

Result after 15 epochs when prompted for house:
noiseeee

Trained on windows PC, RTX 2080 Ti

distributed training

Hi,

Does this repo support any distributed training framework like Pytorch's DDP or Huggingface's Accelerator like DALLE-Pytorch?

Pretrained imagen-pytorch

Hi! Thanks for creating this project!

I was wondering if there's a pretrained version of this architecture so that I can readily use it for generation without having to worry about finding a dataset and the compute to train the whole thing.

Thank you!
David

Noise in output

Using the latest master, I'm noticing big improvement from 0.0.60. The output form the upscaling unet isn't nearly as "swirly", but I am noticing red or green bits of noise on the output images:

(Top row unet1, second unet2)
imagen_2_262

I added a torch.clamp(0, 1.0) after the image is created in Imagen.sample(), but that didn't seem to help. Any ideas where the noise is coming from?

CUDA out of memory

RuntimeError: CUDA out of memory. Tried to allocate 258.00 MiB (GPU 0; 2.00 GiB total capacity; 1.01 GiB already allocated; 55.91 MiB free; 1.03 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Anyone know how to fix?

Choice of T5-Large

I was reading the paper, and in Figure A.6 (p.23) they compare many different encoder models:
image
(reproduced from https://arxiv.org/pdf/2205.11487.pdf)

Although T5-Base isnt mentioned anywhere else in the paper, according to this data it appears that T5-Base offers a performance similar to T5-Large (at least with the 300M params diffusion model they use for this comparison).
Given the large performance difference between Small and Base, and the fact that T5-Base is ~3x smaller than Large but appears to perform similarly, wouldn't Base be the most sensible starting point?

I've noticed that this repo currently only offers Small and Large as options.

Can't see generated image in colab

Hi,

First, great work !
I've tried to create a colab,
i've copy paste your code.
all runing well, but it seem it don't display the generated image at the end.
I am missing something ?

thanks

Screenshot_2022-05-28-20-12-54-37_40deb401b9ffe8e1df2f1cc5ba480b12

Training on SVHN Dataset

Hello, Thanks for coming up with the code. I was trying to training SVHN dataset(http://ufldl.stanford.edu/housenumbers/) with images of number and txt file containing the number itself. Even after 50 epochs of training the result I get is just some random noise. So my questions are

  1. Is it possible to generate samples of SVHN dataset through this kind of training.
    2)Is yes, do you think I need to add some additional information to this model.

Thanks

About the implementation of Unet?

Hi there, thank you for releasing your great work.

As said in the paper, the parameters of unt64X64 is 2B, while your implementation of BaseUnet64x64 just has 1.3B parameters. I am not sure if this is due to the other hyperparameters or the network implementation.

Do you have any idea about this? Thanks.

CUDA out of memory with `max_batch_size=1` using unconditional image-to-image

Based on the README usage instructions, except with max_batch_size=1 running on Windows:

import torch
from imagen_pytorch import Imagen, ImagenTrainer, SRUnet256, Unet

# unets for unconditional imagen

unet1 = Unet(
    dim=32,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=3,
    layer_attns=(False, True, True),
    layer_cross_attns=(False, True, True),
    use_linear_attn=True,
)

unet2 = SRUnet256(
    dim=32,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=(2, 4, 8),
    layer_attns=(False, False, True),
    layer_cross_attns=(False, False, True),
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    condition_on_text=False,  # this must be set to False for unconditional Imagen
    unets=(unet1, unet2),
    image_sizes=(64, 128),
    timesteps=1000,
)

trainer = ImagenTrainer(imagen).cuda()

# now get a ton of images and feed it through the Imagen trainer

training_images = torch.randn(4, 3, 256, 256).cuda()

# train each unet in concert, or separately (recommended) to completion

for u in (1, 2):
    loss = trainer(training_images, unet_number=u, max_batch_size=1)
    trainer.update(unet_number=u)

# do the above for many many many many steps
# now you can sample images unconditionally from the cascading unet(s)

images = trainer.sample(batch_size=16)  # (16, 3, 128, 128)

The OOM error occurs during the SRUnet (set a breakpoint and checked)

Exception has occurred: RuntimeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 6.00 GiB total capacity; 4.26 GiB already allocated; 0 bytes free; 4.31 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\torch\autograd\__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\torch\_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\imagen_pytorch\trainer.py", line 508, in forward
    self.scale(loss, unet_number = unet_number).backward()
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\imagen_pytorch\trainer.py", line 98, in inner
    out = fn(model, *args, **kwargs)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\sterg\Documents\GitHub\sparks-baird\xtal2png\scripts\imagen_pytorch_example.py", line 41, in <module>
    loss = trainer(training_images, unet_number=u, max_batch_size=1)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 97, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 268, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 197, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,

I'm using an NVIDIA GeForce RTX 2060:

Type Value
GPU Architecture Turing
RTX-OPS 37T
Boost Clock 1680 MHz
Frame Buffer 6GB GDDR6
Memory Speed 14 Gbps

See also #12

Pretrained models

Will be avaialable some pretrained models/ open dataset to start the train?

Prerequisites section in README

I think it's better to have a prerequisites section for software, and minimum hardware requirements in README.

It'll take a lot of time for people - who are new to both Python & AI - to install the required modules (python2, python3, einops, CUDA, ...) to be able to run the example code.

A random output

When I follow the usage, I got an output image which looks like noise. Something wrong in my code?

import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

def transform_convert(img_tensor, transform):
    """
    param img_tensor: tensor
    param transforms: torchvision.transforms
    """
    if 'Normalize' in str(transform):
        normal_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform.transforms))
        mean = torch.tensor(normal_transform[0].mean, dtype=img_tensor.dtype, device=img_tensor.device)
        std = torch.tensor(normal_transform[0].std, dtype=img_tensor.dtype, device=img_tensor.device)
        img_tensor.mul_(std[:, None, None]).add_(mean[:, None, None])

    img_tensor = img_tensor.transpose(0, 2).transpose(0, 1)  # C x H x W  ---> H x W x C

    if 'ToTensor' in str(transform) or img_tensor.max() < 1:
        img_tensor = img_tensor.detach().numpy() * 255

    if isinstance(img_tensor, torch.Tensor):
        img_tensor = img_tensor.numpy()

    if img_tensor.shape[2] == 3:
        img = Image.fromarray(img_tensor.astype('uint8')).convert('RGB')
    elif img_tensor.shape[2] == 1:
        img = Image.fromarray(img_tensor.astype('uint8')).squeeze()
    else:
        raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_tensor.shape[2]))

    return img
 
ToTensor_transform = transforms.Compose([transforms.ToTensor()])
img = transform_convert(images[0].cpu(), ToTensor_transform)

plt.imshow(img)
plt.savefig('./test_out.png')

Clarity about losses

Hi Guys,

Thanks for the amazing work.
How many losses we have overall.

  1. MSE( BaseModel(x_noisy), x)
  2. MSE( SuperResModel1(x_noisy), x) # 64 ----> 256
  3. MSE( SuperResModel2(x_noisy), x) # 256 ----> 1024

So is it mean(L1 + L2 + L3) ?

Or update based on each loss separately. I am little confused.

Random noise outputted when sampling after training

Hi!

First of all, thank you very much for the cool implementation, there are not many diffusion model implementations out there.

I have been trying the model with a biomedical dataset without conditioning on anything. However, while the model is appropriately training (both the unet losses go down) and if I visualize the training predictions the model is able to predict x_start correctly, with the sampling method the model only outputs random noise.
I have been debugging the code and in the p_sample_loop function, the model is returning random noise images all the time, even though the p_sample function seems to work perfectly.

Has anyone encountered this issue?

Thanks!

output tensors contain only noise - usage documentation incomplete?

i ran the code as specified in the "usage" part of the documentation.

since the final result seems to be stored in the images-variable, i tried to plot it like this:

img = images[1].permute(1,2,0)
plt.imshow(img.cpu().numpy())

however, the result was only noise. also, it was resolution 256 squared instead of 1024 squared. is there something missing from the demo?

edit: oh, its not a pretrained model. sorry for the confusion. i closed the issue.

Source image + text prompt

Hi!

Thank you for the work on this!

As I am new and still learning, I was wondering if it’s possible to generate images given a source image and a text prompt? Ie. Upload an empty picture of my room and a prompt for “show me what alternative interior designs could look like”

Saving Images

I'm not certain I am saving the tensor to images correctly.

All I seem to get is noise - albeit it does have some structure to it. I tried 10,000 time steps and still just noise.

from torchvision.utils import save_image
img1 = images[0]  #torch.Size([3, 256, 256])
save_image(img1, 'img1.png’)

I’ve tried other ways such as:

from torchvision import transforms
im = transforms.ToPILImage()(images[0]).convert("RGB”)
im.save('img1.png’)

But I get results like this regardless:

img1 copy

Poor sampling quality in upscaler Unets

I've seen good quality results with the upscaling Unets in your DALLE2 repo but have been having trouble getting similar ones with the Imagen ones over the same training period.

After reviewing the code and the Imagen paper I wonder if this is the problem:

lowres_noise_times = noise_scheduler.get_times(batch_size, lowres_sample_noise_level, device = device)

return torch.full((batch_size,), int(self.num_timesteps * noise_level), device = device, dtype = torch.long)

I think that this should be (1.0 - lowres_sample_noise_level)? I see that you do your augmentation by sampling at a specific timestep based on the overall number of timesteps, but the default of 0.2 would sample at time 200 - which is actually closer to 0.8 augmentation if you were using a linear scale.

As a workaround I am trying to pass 0.8 into my sample function but I'm not sure this is enough to fully address the issue. Maybe it just takes longer to train since the training operates on a full aug level of 1.0 to 0.0 like the paper does, but I'll keep the unet training for now.

Usage guide

Dear people,

this implementation looks very interesting and I managed to run it without errors. But now I'm asking myself how to use it. Can anyone point me in the right direction?

Greetings,
Chrizzo

Installation Question

Hi there! Love the project just needing a little help with installation of a package.
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME ModuleNotFoundError: No module named 'imagen_pytorch.t5'; 'imagen_pytorch' is not a package

Sorry if I'm being stupid or if this is the wrong place to ask.

transformers T5Large version

I am seeing a warning about "Some weights of the model checkpoint at t5-large were not used when initializing T5EncoderModel". I am using the latest transfomers version in pypi (4.19.2)... can this be ignored?

Torch object has no attribute pi, simple find and replace fixes it

>>> import imagen_pytorch
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "~/repos/MeshGen/lucidrains_imagen/imagen-pytorch/imagen_pytorch/__init__.py", line 1, in <module>
    from imagen_pytorch.imagen_pytorch import Imagen, Unet
  File "~/repos/MeshGen/lucidrains_imagen/imagen-pytorch/imagen_pytorch/imagen_pytorch.py", line 293, in <module>
    def alpha_cosine_log_snr(t, s: float = 0.008):
  File "~/.local/lib/python3.8/site-packages/torch/jit/_script.py", line 1143, in script
    fn = torch._C._jit_script_compile(
RuntimeError: 
object has no attribute pi:
  File "~/repos/MeshGen/lucidrains_imagen/imagen-pytorch/imagen_pytorch/imagen_pytorch.py", line 294
@torch.jit.script
def alpha_cosine_log_snr(t, s: float = 0.008):
    return -log((torch.cos((t + s) / (1 + s) * torch.pi * 0.5) ** -2) - 1)
                                               ~~~~~~~~ <--- HERE

opening imagen_pytorch.py and doing a simple find and replace for torch.pi to math.pi fixed it on my end.

Hope it's not an internal issues on my end but figured I'd share

Question about progress

Is there a ballpark date to when we can start generating images using this?

I'd love to know! Sorry if there is something posted that I totally missed that explains this!

:)

text_embeds size?

In the readme line 59, text embeds appear to be shape (ntexts, 256, 768). But on line 146 they are shape (ntexts,256,1024). And the output of t5_encode_text appears to be (ntexts,ncharacters,768).

Which is correct? It seems like I can just tile along dimension 1 to pad out to 256, but what is the correct value for dimension 2?

Colab notebook, also nan issue

I just made a simple colab training notebook, it seems to work, but got issues. I think, at this moment, it's no reasons to do full training while the code is not finished yet.
https://colab.research.google.com/drive/1zVFcWU7REDmQXKs5gBNAu7kdPEAgG5wy?usp=sharing

Also, when i played with batch size, (if i implemented it right for training, of course), i noticed that the batch size more than 1 causes a 'nan' losses during training. It can be avoided by torch.nn.utils.clip_grad_norm_ i suppose, or changing optimizer to AdamW.
One more bug: trainer.sample can sample only once, after that there is a "RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same", can be fixed by reinitialization.

Dataset for only 60 images from LAION just to test
https://drive.google.com/file/d/1hiZvegBcunQAR5NrpXXAxnjVMSa7IN4z/view?usp=sharing

And last thing, with losses about 0.3 generated images still be noise.

Pre-trained model parameters?

I understand that currently the only way is to train my own model with appropriate datasets. Alternatively, will pre-trained model parameters matching your classes (e.g. on HuggingFace) be released soon?

And thanks for the great work!

How to use?

Could you please tell us how to use this project in detail in the readme.md? I could not run the code correctly.

why perceiver attention?

it seems like the imagen paper doesn't make any references to Perceiver IO - why do you think that's what their "attention pool" mechanism was referring to vs eg a 1d version of something like AttentionPool2D in CLIP?

Documentation

Is there some related documentation to this code ? Or guide to, what is what ? Maybe something more general you can point me to, lots of these models looks similar, but if you never worked with this you have no idea what's going on. :D

How many timesteps are needed at this moment to get the result?

I trained the model with a loss of about 0.06 on a train, and here are some samples with the prompt "Milky Way in the shape of a dog":
100 timesteps no ema
100 n
1000 timesteps no ema
1000n
10000 timesteps no ema
10000 n
100 timesteps with ema
100 ema
1000 timesteps with ema
1000 ema
1000 timesteps with ema
10000 ema

I know the losses are too high for good results, but i expected at least something, and 10k steps looks a bit more complete. And the time spent on this is too long, 5 minutes for 1k steps and 50 minutes for 10k steps on kaggle's T4.
Also i train this model without ema, disabling or enabling this when i initializing trainer to load checkpoint for inference like

trainer = ImagenTrainer(imagen, use_ema = False)
trainer.load('./checkpoint/checkpoint.pt')

My gpu quota is ended today and i can't try to continue training with ema to see how it's affect sampling with/without ema on it. Is it really will improve training speed (i need to wait about 15 epochs before losses decay by 0.001) and result not only noise samples with ema on?
Link to this checkpoint https://drive.google.com/file/d/1pq2OVRJuA2szQc9WMd8M7mkOtbtUCkv4/view?usp=sharing

learned variance

does anyone have an opinion on whether learned variance was used in Imagen? it would greatly simplify the project if i can just remove it

Unet not respecting text conditioning

I've done a few test runs with your Imagen and DALLE2 models on some smaller scale models (~192 dim) on a 10M dataset across 30 epochs.

I've found that the DALLE2 model guided by CLIP image and text embeds has better alignment than the Imagen model, and that the Imagen model never seems to really get any sort of alignment, almost like it's stuck doing unconditional generation.

Looking through the code I wonder if it's something to do with how text embeds are handled in general and that the DALLE2 model just gets away with using the Image embeddings.

I wonder if there might be a problem with this:

text_tokens = torch.where(

The null text embed is filled with random values per your initialization:

self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))

...and these values persist with the rest of the model after construction time.

The model seems to be allowed to look at the entire space of the embedding, however:

c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2)

There's no way for the model to understand that the random data is different from the rest of the embedding as far as I can tell since the masks are only used to decide whether the random data is used instead of the actual token embedding data.

I'm going to try doing a run without using the token_mask and just padding to max length to try and avoid picking up the null_text_embed, but for pad to longest I'm not sure this is working as intended.

how to save the image

Dear Author,
Thanks for your amazing work. I have a question about how to save as a image. I find the output is between 0 to 1, so I use the below code to save as a image. Looking forward your answer, thanks.

images = imagen.sample(texts = [
    'A woman cutting a large white sheet cake.',
    'A woman wearing a hair net cutting a large sheet cake.',
    'A woman wearing a net on her head cutting a cake.'
], cond_scale = 2.)

num_imgs = len(images)
for i in range(num_imgs):
    gen_img = images[i].permute(1,2,0)
    cv2.imwrite('%d.jpg'%i, np.uint8(gen_img.cpu().numpy() * 255))

Implement about `Conditional Augmentation`

Hi Phil,
I noticed that Imagen applies noise conditional augmentation for both two super-res models.

we corrupt the low-resolution image with the augmentation (corresponding to aug_level), and condition the diffusion model on aug_level. During training, aug_level is chosen randomly, while during inference, we sweep over its different values to find the best sample quality. In our case, we use Gaussian noise as a form of augmentation, and apply variance preserving Gaussian noise augmentation resembling the forward process used in diffusion models (Appendix A). The augmentation level is specified using aug_level ∈ [0, 1].

Based on my understanding, in train phrase, they first apply forward diffusion(with aug_level noise scale/level) to the low-res(x_lr ), and then they feed x_lr, z_t(hidden variable of high-res), t and aug_level to the Unet and optimize loss. In sample phrase they did the same thing. Like the following pseudo-code:
image

My doubt is: how to add aug_level into the Unet? The first idea proposed to use the noise-aug-condition is in Cascaded Diffusion Models for High Fidelity Image Generation (CDM)(https://arxiv.org/abs/2106.15282), and they suggest use another time-embedding for aug-level(refer to time s).
image

Recent papers very fewly mentioned this trick except Imagen, I just wonder how do you think about this trick, and is is necessary to combine this trick to current SuperResUnet model?

Best,

Manual device placement and distributed frameworks

Hi just digging through the code, nice work once again.

I notice lot's of manual device placement, which is fairly abstracted but I'm not sure how it plays with the likes of Pytorch Lightning or Hugging face's Accelerator.

Have there been (from anyone) any tests doing distributed training using one of the above frameworks (not DDP)?

Looking for feedback before getting started if anyone has already headed down that path :)

TypeError: __init__() got multiple values for argument 'self'

Traceback (most recent call last):
  File "/import/home/w/.pycharm_helpers/pydev/pydevd.py", line 1483, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/import/home/w/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/import/home/w/diffusion/imagen-pytorch/example_1.py", line 35, in <module>
    cond_drop_prob=0.5
  File "/import/home/w/diffusion/imagen-pytorch/imagen_pytorch/imagen_pytorch.py", line 1249, in __init__
    channels_out=unet_channels_out
  File "/import/home/w/diffusion/imagen-pytorch/imagen_pytorch/imagen_pytorch.py", line 998, in cast_model_parameters
    return self.__class__(**{**self._locals, **updated_kwargs})
TypeError: __init__() got multiple values for argument 'self'

In the debug mode, there is an error.

Possible Breaking Bug

assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'

Think this should be image instead of images since this is the variable name used elsewhere within this code block. The unnamed variable error went away when I changed the variable to image. Although I do think conceptually all instances of image (in this code block) should be replaced with images instead.

Can we finetune on a trained model?

I understand that there is no trained model out in public yet,
however when there is a trained model on LAION dataset;
will we be able to fine-tune on our own specific text-image pair dataset?

Training of the cascading DDPM

Hi,

Thank you for your excellent implementation.

Just a small but vital question for the training detail. To train the cascaded Imagen model, do I need to try the low resolution model until it converges, then use that model to train the high resolution model. Or should I train them simuteneoly for each mini-batch?

To be more clear for my quesition, for the first way of training, I mean first train the text to 64x64 low resolution model on the whole dataset for enought epoches until it converges. Then based on the well trained model to train the high resoltuoin model.

For the simutaneously trainng way, i mean for each minibatch, first backward the loss from the low resolution model and then backwards the loss for the high resoluton model and then repeated the steps for all the minibach of a datasets.

Error when setting `return_all_unet_outputs=True`

Hi,

Thanks your implementation! I tried to set return_all_unet_outputs=True but got the following error:

sample_images = trainer.sample(
                    batch_size = 4, max_batch_size = 4,
                    return_all_unet_outputs = True,
                )

Error message:

  File "/home/ubuntu/hychiang/imagen/uncond_image_gen.py", line 156, in <module>
    run()
  File "/home/ubuntu/hychiang/imagen/uncond_image_gen.py", line 142, in run
    sample_images = trainer.sample(
  File "/home/ubuntu/anaconda3/envs/imagen-pytorch-master/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/hychiang/imagen-pytorch-master/imagen_pytorch/trainer.py", line 98, in inner
    out = fn(model, *args, **kwargs)
  File "/home/ubuntu/hychiang/imagen-pytorch-master/imagen_pytorch/trainer.py", line 269, in inner
    return torch.cat(outputs, dim = 0)
TypeError: expected Tensor as element 0 in argument 0, but got list

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.