lucidrains / muse-maskgit-pytorch Goto Github PK
View Code? Open in Web Editor NEWImplementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch
License: MIT License
Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch
License: MIT License
Hi.
For example, I know that in the first script I need to change folder = '/path/to/images',
by inserting the path to my own personal folder of images.
Beyond that, I do not know what might be necessary.
I also am not clear if my images need to be a specific shape, size, etc.
Lastly, once everything above is resolved, will I be running all 4 scripts found under Usage, consecutively?
(For what it's worth, when I run the first script with only folder
changed, I get the following output:)
680 training samples found at /content/gdrive/My Drive/Github/Generating-Synthetic-Handwritten-Historical-Documents/PyTorch-CycleGAN/datasets/text2illuminated/train/B
training with dataset of 646 samples and validating with randomly splitted 34 samples
/usr/local/lib/python3.8/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/usr/local/lib/python3.8/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%
528M/528M [00:03<00:00, 229MB/s]
0: vae loss: 0.9188588485121727 - discr loss: 10.198074519634247
0: saving to results
0: saving model to results
1: vae loss: -0.18518024682998657 - discr loss: 11.391329526901245
2: vae loss: -2.3277476727962494 - discr loss: 12.582469940185547
---------------------------------------------------------------------------
OSError Traceback (most recent call last)
<ipython-input-6-4b98abcfd671> in <module>
19 ).cuda()
20
---> 21 trainer.train()
16 frames
/content/gdrive/MyDrive/Github/muse-maskgit-pytorch/muse_maskgit_pytorch/trainers.py in train(self, log_fn)
364
365 while self.steps < self.num_train_steps:
--> 366 logs = self.train_step()
367 log_fn(logs)
368
/content/gdrive/MyDrive/Github/muse-maskgit-pytorch/muse_maskgit_pytorch/trainers.py in train_step(self)
265
266 for _ in range(self.grad_accum_every):
--> 267 img = next(self.dl_iter)
268 img = img.to(device)
269
/content/gdrive/MyDrive/Github/muse-maskgit-pytorch/muse_maskgit_pytorch/trainers.py in cycle(dl)
38 def cycle(dl):
39 while True:
---> 40 for data in dl:
41 yield data
42
/usr/local/lib/python3.8/dist-packages/accelerate/data_loader.py in __iter__(self)
381 if self.device is not None:
382 current_batch = send_to_device(current_batch, self.device)
--> 383 next_batch = next(dataloader_iter)
384 yield current_batch
385 current_batch = next_batch
/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py in __next__(self)
626 # TODO(https://github.com/pytorch/pytorch/issues/76750)
627 self._reset() # type: ignore[call-arg]
--> 628 data = self._next_data()
629 self._num_yielded += 1
630 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
669 def _next_data(self):
670 index = self._next_index() # may raise StopIteration
--> 671 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
672 if self._pin_memory:
673 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
56 data = self.dataset.__getitems__(possibly_batched_index)
57 else:
---> 58 data = [self.dataset[idx] for idx in possibly_batched_index]
59 else:
60 data = self.dataset[possibly_batched_index]
/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
56 data = self.dataset.__getitems__(possibly_batched_index)
57 else:
---> 58 data = [self.dataset[idx] for idx in possibly_batched_index]
59 else:
60 data = self.dataset[possibly_batched_index]
/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataset.py in __getitem__(self, idx)
293 if isinstance(idx, list):
294 return self.dataset[[self.indices[i] for i in idx]]
--> 295 return self.dataset[self.indices[idx]]
296
297 def __len__(self):
/content/gdrive/MyDrive/Github/muse-maskgit-pytorch/muse_maskgit_pytorch/trainers.py in __getitem__(self, index)
92 path = self.paths[index]
93 img = Image.open(path)
---> 94 return self.transform(img)
95
96 # main trainer class
/usr/local/lib/python3.8/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)
93 def __call__(self, img):
94 for t in self.transforms:
---> 95 img = t(img)
96 return img
97
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.8/dist-packages/torchvision/transforms/transforms.py in forward(self, img)
344 PIL Image or Tensor: Rescaled image.
345 """
--> 346 return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
347
348 def __repr__(self) -> str:
/usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional.py in resize(img, size, interpolation, max_size, antialias)
472 warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
473 pil_interpolation = pil_modes_mapping[interpolation]
--> 474 return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)
475
476 return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)
/usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional_pil.py in resize(img, size, interpolation)
250 raise TypeError(f"Got inappropriate size arg: {size}")
251
--> 252 return img.resize(tuple(size[::-1]), interpolation)
253
254
/usr/local/lib/python3.8/dist-packages/PIL/Image.py in resize(self, size, resample, box, reducing_gap)
1884 return im.convert(self.mode)
1885
-> 1886 self.load()
1887
1888 if reducing_gap is not None and resample != NEAREST:
/usr/local/lib/python3.8/dist-packages/PIL/ImageFile.py in load(self)
243 break
244 else:
--> 245 raise OSError(
246 "image file is truncated "
247 "(%d bytes not processed)" % len(b)
OSError: image file is truncated (15 bytes not processed)
thank you.
Hi,
Thanks a lot for the implementation.
While running the base MaskGIT training according to the usage example:
loss = base_maskgit(
images,
texts = texts
)
I noticed that the training cross entropy loss goes down to very small values pretty fast. I suspect this is related to the following line of code:
where ids is passed to the transformer forward function rather than the masked x.
Am I missing something? thanks :-)
Some problems listing below make it impossible for distributed training.
self.vae.discr
)ema_vae
is not wrapped by accelerate.prepare
(while self.valid_dl
is wrapped)I had a question on model configurations with paper.
In paper, f=8 thus the for super-res model's latent map has 6464 resolution map with image size 512512.
Thus, I was trying with the vae with number of layer is 3 thus to make sure they have 64*64 resolution. However I think it may different with paper's implementation detail that made vae consisit of 4 layers and 256 dim with finetuning or 2 layers and 128 dim without finetuning decoder.
How can I set configuration of Vae in code to match with paper's?
Since top_k threshold is 0.9 , gumbel sample will always take the indice with max probability, right? I guess temperature has no effect here?
So , the sampling here is argmax only?
Hi there, it would be nice to have support for tensorboard, have it so the loss and learning rate values are logged so we can use tensorboard to keep track of them and get a better idea of how the model changes overtime. It would also be nice for having an idea of how many steps we have trained for as it's really easy to lose track of the global number of steps you have trained for unless you do everything on a single session. I think it would be nice if we could have a global step counter that would persist through sessions, this could also be done with tensorboard or even easier by just having a json file with the info of the training session so we can reuse it for the next session, this would also make it easier to resume a session, we can just load the vae and continue from it but for the trainer it is as if we would be training from scratch every time.
In swapping autoencoder [Taesung Park et al., 2020], they experiment the effect of patch size w.r.t. FID and LPIPS. The larger patch size uses, the more geometry changes. Seems like your model may has more surprising results as you did text-image task.
Hey @lucidrains,
Impressive that you already started a MUSE PyTorch implementation here! At Hugging Face, we also started thinking/designing a MUSE reproduction.
We are quite excited about on open-reproduction effort where the outcome is a checkpoint that is competitive to MUSE available for everybody. The advantage of MUSE in our opinion to existing models is:
Can we help you in any way? Starting to curate a dataset etc...?
cc @patil-suraj
Hello!
I've read the paper of MUSE, and I think the result is so amazing!
So I want to try this amazing model.
Can you release the official pretrained weight of vae and maskgit?
Hello, I am interested in training a MaskGIT model, without any text conditioning and all the parts relevant to Muse. Is that possible?
pip install muse-maskgit-pytorch
Any suggestions on how to install this ?
I was trying to install this but it shows command not found.
Hi, @lucidrains , I would like to inquire whether the VQGAN that you implement in this repo can compress a 256 * 256 image to 16 * 16 while maintaining a good reconstruction quality.
I haven't included GAN in my VQVAE code, and compressing 256 * 256 images to 16 * 16 with subsequent reconstruction doesn't yield particularly good results. I'm wondering if the code in this library can achieve better performance. (I've attempted to run the official code for the first phase of VQGAN, but it seems the results are not as impressive as presented in their paper).
Additionally, have you tried replacing PatchGAN with StyleGAN in your VQGAN? As Maskgit/MagVIT mentioned, the discriminator of StyleGAN might be more stable.
Thank you.
Hi,
Thanks for implementing and open-sourcing the code for this T2I model.
I ran the first snippet of the code where the objective was to train a VQGanVAE model.
After training the VQGanVAE model for 50K iterations, I trained the MaskGIT module, although the set of images and texts passed into the training of MaskGIT were less compared to the first module training since I was getting memory issue.
Nevertheless, I passed 10 images and the corresponding texts to train the super resolution GIT and saved the images. The following are few of the images that I am getting.
My query is that whether this is the correct process that I am following? Do I need to train on more images to get the image corresponding to the text?
Thanks!
AssertionError while trying to train with base_maskgit_trainer() after finished training with vae_trainer(). I finished training the vae trainer with about 123k iterations, then while I moved on to the base maskgit trainer, the following error occured:
123000: vae loss: 1.0881760120391846 - discr loss: 1.993680715560913
training complete
Resuming VAE from: other/vae.23000.base.pt
Traceback (most recent call last):
File "muse_train.py", line 352, in <module>
base_maskgit_trainer()
File "muse_train.py", line 156, in base_maskgit_trainer
vae.load(args.resume_from.replace('.pt' , '.base.pt')) # you will want to load the exponentially moving averaged VAE
File "/home/eason/PyTorch/muse-pytorch/muse_maskgit_pytorch/vqgan_vae.py", line 408, in load
assert path.exists()
AssertionError
Hi any chance you'll upload any trained weights?
Really want to try this out. It looks amazing!
Would it be possible to implement negative prompting as described in section 2.7 in the paper? In the paper, the final tokens are calculated as
The only difference is that they use Masked tokens while they use noised tokens
I installed the code and its dependencies. When I do from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer
, it fails in muse_maskgit_pytorch when importing beartype:
/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/torch/onnx/_internal/_beartype.py:30: UserWarning: unhashable type: 'list'
warnings.warn(f"{e}")
Traceback (most recent call last):
File "/scratch/code/muse-maskgit-pytorch/./fni_main.py", line 3, in <module>
from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer
File "/scratch/code/muse-maskgit-pytorch/muse_maskgit_pytorch/__init__.py", line 2, in <module>
from muse_maskgit_pytorch.muse_maskgit_pytorch import Transformer, MaskGit, Muse, MaskGitTransformer, TokenCritic
File "/scratch/code/muse-maskgit-pytorch/muse_maskgit_pytorch/muse_maskgit_pytorch.py", line 16, in <module>
from beartype import beartype
File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/__init__.py", line 57, in <module>
from beartype._decor.decormain import (
File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_decor/decormain.py", line 23, in <module>
from beartype._conf.confcls import (
File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_conf/confcls.py", line 23, in <module>
from beartype._cave._cavemap import NoneTypeOr
File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_cave/_cavemap.py", line 33, in <module>
from beartype._util.hint.nonpep.utilnonpeptest import (
File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_util/hint/nonpep/utilnonpeptest.py", line 21, in <module>
from beartype._util.cache.utilcachecall import callable_cached
File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_util/cache/utilcachecall.py", line 32, in <module>
from beartype._util.func.arg.utilfuncargtest import (
File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_util/func/arg/utilfuncargtest.py", line 17, in <module>
from beartype._util.func.utilfunccodeobj import get_func_codeobj
File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_util/func/utilfunccodeobj.py", line 21, in <module>
from beartype._data.datatyping import (
File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_data/datatyping.py", line 129, in <module>
BeartypeReturn = Union[BeartypeableT, BeartypeConfedDecorator]
File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/typing.py", line 243, in inner
return func(*args, **kwds)
File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/typing.py", line 316, in __getitem__
return self._getitem(self, parameters)
File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/typing.py", line 421, in Union
parameters = _remove_dups_flatten(parameters)
File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/typing.py", line 215, in _remove_dups_flatten
all_params = set(params)
TypeError: unhashable type: 'list'
I googled "beartype unhashable type import error," but I haven't yet been able to find anything about this error. Does anyone else hit this problem?
Thank you for the quick implementation of the Muse code. A bit of documentation / a quick howto would be highly appreciated.
Particularly a howto on this:
How to pass text embedding instead of text ?
Just like what can be done in https://github.com/lucidrains/imagen-pytorch
Thank You!
Hi,
Im trying to test the example in the readme but im getting a TypeError: 'NoneType' object is not callable when calling loss = base_maskgit(
images,
texts = texts
) ( I didnt change anything ). Im probably missing something obvious here, could you tell me what im doing wrong ?
Thanks
i'm sort of confused what the mask schedule is
in the original maskgit paper, they only said the cosine schedule was best without showing any pseudocode. i always assumed it was simply https://github.com/lucidrains/muse-pytorch/blob/main/muse_pytorch/muse_pytorch.py#L277
however, in this paper, they had a whole section where they talked about arccos and its density function. for those mathematically inclined, is it simply https://github.com/lucidrains/muse-pytorch/blob/main/muse_pytorch/muse_pytorch.py#L280 ?
Hello, after running the below code, it says "TypeError: isinstance() arg 2 must be a type or tuple of types", what did I do wrong ?
`
for i in range(len(train)):
im = np.array(train[i])
im = cv2.resize(im, dsize=(32, 32), interpolation=cv2.INTER_CUBIC)[:, :, None]
im = np.concatenate((im,im,im), axis=2)
name = './mnist/im_'+str(i)+'.png'
cv2.imwrite(name, im)
import torch
from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer
vae = VQGanVAE(
dim = 16,
vq_codebook_size = 32
)
trainer = VQGanVAETrainer(
vae = vae,
image_size = 32, # you may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it
folder = './mnist',
batch_size = 2,
grad_accum_every = 8,
num_train_steps = 100000
)#.cuda()
Traceback (most recent call last):
File "C:\Users\Ext.Edmond_Jacoupeau\AppData\Local\Temp\ipykernel_3136\1437749784.py", line 11, in
trainer = VQGanVAETrainer(
File "<@beartype(muse_maskgit_pytorch.trainers.VQGanVAETrainer.init) at 0x278d4d5e040>", line 48, in init
File "C:\Users\Ext.Edmond_Jacoupeau\Anaconda3\envs\py39\lib\site-packages\muse_maskgit_pytorch\trainers.py", line 149, in init
ddp_kwargs = find_and_pop(
File "C:\Users\Ext.Edmond_Jacoupeau\Anaconda3\envs\py39\lib\site-packages\muse_maskgit_pytorch\trainers.py", line 47, in find_and_pop
ind = find_index(arr, cond)
File "C:\Users\Ext.Edmond_Jacoupeau\Anaconda3\envs\py39\lib\site-packages\muse_maskgit_pytorch\trainers.py", line 42, in find_index
if cond(el):
TypeError: isinstance() arg 2 must be a type or tuple of types
`
Why use this at line#93 in t5.py?
encoded_text = encoded_text.masked_fill(attn_mask[..., None], 0.)
Hi, in this repo, the classifier-free guidance is handled by setting a context_mask
based on the dropout probability.
However, some rows of the context_mask
could become all False
and I found that if we use flash_attn
, the F.scaled_dot_product_attention
will generate NaN
values.
I found that you have handled something similar in the x-transformer repo regarding the issue
But it seems like that can only handle the case for causal attention and the particular row of the mask. If we are working on maskgit where the attention is non-causal and sometimes context_mask
could be False anywhere, then the current version of F.scaled_dot_product_attention
will still generate NaN
.
Therefore I am wondering whether there is a better way to support the flash_attn
with classifier-free guidance.
Thank you.
I implemented support for flash attention via xformers:
Sygil-Dev/muse-maskgit-pytorch@main...Birch-san:muse-maskgit-pytorch:sdp-attn
It gets the same result, confirmed via allclose()
(I had to make the absolute tolerance a bit more forgiving, to pass that check, but it's still a very small tolerance).
I also implemented support for torch.nn.functional.scaled_dot_product_attention
. it'll use flash attention when mask is None
, but I guess your mask is usually defined.
even without flash attention: scaled_dot_product_attention
should still be faster than the einsum() * scale
approach, because (IIRC) its math fallback is based on baddbmm
, which fuses the scale factor into the matmul.
in stable-diffusion inference for example: we measured end-to-end image generation via baddbmm
to be ~11% faster than einsum() * scale
on CUDA (and 18% faster on Mac).
@lucidrains is this a contribution you would be interested in receiving as a PR?
Can you provide a script that uses distributed computing or an accelerate config ?
Hi - thank you for getting this out so quick. I came from Stable Diffusion world, Muse looks really interesting but I don't know how to run this. Can you write some easy-to-understand tutorial in README?
How to train the maskgit transformers (base model)?
The trainers.py
is only for training the VQGAN-VAE model, and in readme, the code is just said getting the loss. There is no trainer for maskgit? How can I train it?
I'm trying to run demo on my RTX-3080 but it shows lack of VRAM..
Hi
I am trying to train the transformer from scratch. I notice the loss quickly get stuck at a same value (my case is 8.x). I have tried lower/higher the learning rate but no luck. After ~15000k iters of training, the result is not satisfying as well.
I have tried different masking ratio, such as ~34% instead of ~64%, the loss value does go down to ~1.3, but the outpointing result is still far from satisfying.
So, has anyone trained a working transformer?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.