Giter VIP home page Giter VIP logo

muse-maskgit-pytorch's People

Contributors

gothos avatar lucidrains avatar pranoyr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

muse-maskgit-pytorch's Issues

Usage questions

Hi.

  1. Would you mind pointing out any changes that need to be made to the scripts given under "Usage" in the "README.md"?

For example, I know that in the first script I need to change folder = '/path/to/images', by inserting the path to my own personal folder of images.

Beyond that, I do not know what might be necessary.

  1. I also am not clear if my images need to be a specific shape, size, etc.

  2. Lastly, once everything above is resolved, will I be running all 4 scripts found under Usage, consecutively?

(For what it's worth, when I run the first script with only folder changed, I get the following output:)

680 training samples found at /content/gdrive/My Drive/Github/Generating-Synthetic-Handwritten-Historical-Documents/PyTorch-CycleGAN/datasets/text2illuminated/train/B
training with dataset of 646 samples and validating with randomly splitted 34 samples
/usr/local/lib/python3.8/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%
528M/528M [00:03<00:00, 229MB/s]
0: vae loss: 0.9188588485121727 - discr loss: 10.198074519634247
0: saving to results
0: saving model to results
1: vae loss: -0.18518024682998657 - discr loss: 11.391329526901245
2: vae loss: -2.3277476727962494 - discr loss: 12.582469940185547
---------------------------------------------------------------------------
OSError                                   Traceback (most recent call last)
<ipython-input-6-4b98abcfd671> in <module>
     19 ).cuda()
     20 
---> 21 trainer.train()

16 frames
/content/gdrive/MyDrive/Github/muse-maskgit-pytorch/muse_maskgit_pytorch/trainers.py in train(self, log_fn)
    364 
    365         while self.steps < self.num_train_steps:
--> 366             logs = self.train_step()
    367             log_fn(logs)
    368 

/content/gdrive/MyDrive/Github/muse-maskgit-pytorch/muse_maskgit_pytorch/trainers.py in train_step(self)
    265 
    266         for _ in range(self.grad_accum_every):
--> 267             img = next(self.dl_iter)
    268             img = img.to(device)
    269 

/content/gdrive/MyDrive/Github/muse-maskgit-pytorch/muse_maskgit_pytorch/trainers.py in cycle(dl)
     38 def cycle(dl):
     39     while True:
---> 40         for data in dl:
     41             yield data
     42 

/usr/local/lib/python3.8/dist-packages/accelerate/data_loader.py in __iter__(self)
    381                 if self.device is not None:
    382                     current_batch = send_to_device(current_batch, self.device)
--> 383                 next_batch = next(dataloader_iter)
    384                 yield current_batch
    385                 current_batch = next_batch

/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    626                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    627                 self._reset()  # type: ignore[call-arg]
--> 628             data = self._next_data()
    629             self._num_yielded += 1
    630             if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    669     def _next_data(self):
    670         index = self._next_index()  # may raise StopIteration
--> 671         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    672         if self._pin_memory:
    673             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     56                 data = self.dataset.__getitems__(possibly_batched_index)
     57             else:
---> 58                 data = [self.dataset[idx] for idx in possibly_batched_index]
     59         else:
     60             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     56                 data = self.dataset.__getitems__(possibly_batched_index)
     57             else:
---> 58                 data = [self.dataset[idx] for idx in possibly_batched_index]
     59         else:
     60             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataset.py in __getitem__(self, idx)
    293         if isinstance(idx, list):
    294             return self.dataset[[self.indices[i] for i in idx]]
--> 295         return self.dataset[self.indices[idx]]
    296 
    297     def __len__(self):

/content/gdrive/MyDrive/Github/muse-maskgit-pytorch/muse_maskgit_pytorch/trainers.py in __getitem__(self, index)
     92         path = self.paths[index]
     93         img = Image.open(path)
---> 94         return self.transform(img)
     95 
     96 # main trainer class

/usr/local/lib/python3.8/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)
     93     def __call__(self, img):
     94         for t in self.transforms:
---> 95             img = t(img)
     96         return img
     97 

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.8/dist-packages/torchvision/transforms/transforms.py in forward(self, img)
    344             PIL Image or Tensor: Rescaled image.
    345         """
--> 346         return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
    347 
    348     def __repr__(self) -> str:

/usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional.py in resize(img, size, interpolation, max_size, antialias)
    472             warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
    473         pil_interpolation = pil_modes_mapping[interpolation]
--> 474         return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)
    475 
    476     return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)

/usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional_pil.py in resize(img, size, interpolation)
    250         raise TypeError(f"Got inappropriate size arg: {size}")
    251 
--> 252     return img.resize(tuple(size[::-1]), interpolation)
    253 
    254 

/usr/local/lib/python3.8/dist-packages/PIL/Image.py in resize(self, size, resample, box, reducing_gap)
   1884             return im.convert(self.mode)
   1885 
-> 1886         self.load()
   1887 
   1888         if reducing_gap is not None and resample != NEAREST:

/usr/local/lib/python3.8/dist-packages/PIL/ImageFile.py in load(self)
    243                                     break
    244                                 else:
--> 245                                     raise OSError(
    246                                         "image file is truncated "
    247                                         "(%d bytes not processed)" % len(b)

OSError: image file is truncated (15 bytes not processed)

MaskGit forward() - "ids" is fed to transformer rather than "x"

Hi,
Thanks a lot for the implementation.

While running the base MaskGIT training according to the usage example:

loss = base_maskgit(
    images,
    texts = texts
)

I noticed that the training cross entropy loss goes down to very small values pretty fast. I suspect this is related to the following line of code:
image

https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/muse_maskgit_pytorch.py#L513

where ids is passed to the transformer forward function rather than the masked x.

Am I missing something? thanks :-)

VQGanVAETrainer does not support distributed training.

Some problems listing below make it impossible for distributed training.

  • Some modules are called directly. (such as self.vae.discr)
  • unused parameters exist (for discriminator training)
  • ema_vae is not wrapped by accelerate.prepare (while self.valid_dl is wrapped)

vae enc_dec configuration questions

I had a question on model configurations with paper.

In paper, f=8 thus the for super-res model's latent map has 6464 resolution map with image size 512512.

Thus, I was trying with the vae with number of layer is 3 thus to make sure they have 64*64 resolution. However I think it may different with paper's implementation detail that made vae consisit of 4 layers and 256 dim with finetuning or 2 layers and 128 dim without finetuning decoder.

How can I set configuration of Vae in code to match with paper's?

Gumbel Sample

Since top_k threshold is 0.9 , gumbel sample will always take the indice with max probability, right? I guess temperature has no effect here?

So , the sampling here is argmax only?

T5-small model is too large

I have a GPU with 40GB memory but even this space is not enough to load t5-small model, which apparently requires more than 100GB based on the error message below.

image

Is there any other way to avoid this issue? Perhaps smaller model or any other embedding model that can fit my GPU?

Add support for logging and visualizing the loss and learning rate with tensorboard.

Hi there, it would be nice to have support for tensorboard, have it so the loss and learning rate values are logged so we can use tensorboard to keep track of them and get a better idea of how the model changes overtime. It would also be nice for having an idea of how many steps we have trained for as it's really easy to lose track of the global number of steps you have trained for unless you do everything on a single session. I think it would be nice if we could have a global step counter that would persist through sessions, this could also be done with tensorboard or even easier by just having a json file with the info of the training session so we can reuse it for the next session, this would also make it easier to resume a session, we can just load the vae and continue from it but for the trainer it is as if we would be training from scratch every time.

About Patch Size

In swapping autoencoder [Taesung Park et al., 2020], they experiment the effect of patch size w.r.t. FID and LPIPS. The larger patch size uses, the more geometry changes. Seems like your model may has more surprising results as you did text-image task.

[From Hugging Face] Can we help in any way?

Hey @lucidrains,

Impressive that you already started a MUSE PyTorch implementation here! At Hugging Face, we also started thinking/designing a MUSE reproduction.

We are quite excited about on open-reproduction effort where the outcome is a checkpoint that is competitive to MUSE available for everybody. The advantage of MUSE in our opinion to existing models is:

  • Should be faster than Stable Diffusion, Imagen & co.
  • Inpainting works out of the box
  • No need for complex schedulers

Can we help you in any way? Starting to curate a dataset etc...?

cc @patil-suraj

Can you release the official trained model?

Hello!
I've read the paper of MUSE, and I think the result is so amazing!
So I want to try this amazing model.
Can you release the official pretrained weight of vae and maskgit?

How to resolve this error ?

pip install muse-maskgit-pytorch

Any suggestions on how to install this ?
I was trying to install this but it shows command not found.

About the VQGAN

Hi, @lucidrains , I would like to inquire whether the VQGAN that you implement in this repo can compress a 256 * 256 image to 16 * 16 while maintaining a good reconstruction quality.

I haven't included GAN in my VQVAE code, and compressing 256 * 256 images to 16 * 16 with subsequent reconstruction doesn't yield particularly good results. I'm wondering if the code in this library can achieve better performance. (I've attempted to run the official code for the first phase of VQGAN, but it seems the results are not as impressive as presented in their paper).

Additionally, have you tried replacing PatchGAN with StyleGAN in your VQGAN? As Maskgit/MagVIT mentioned, the discriminator of StyleGAN might be more stable.

Thank you.

Confusion regarding outputs from first and second module

Hi,

Thanks for implementing and open-sourcing the code for this T2I model.

I ran the first snippet of the code where the objective was to train a VQGanVAE model.

After training the VQGanVAE model for 50K iterations, I trained the MaskGIT module, although the set of images and texts passed into the training of MaskGIT were less compared to the first module training since I was getting memory issue.

Nevertheless, I passed 10 images and the corresponding texts to train the super resolution GIT and saved the images. The following are few of the images that I am getting.

maskgit_2_
maskgit_0_
maskgit_1_

My query is that whether this is the correct process that I am following? Do I need to train on more images to get the image corresponding to the text?

Thanks!

AssertionError while trying to train with base_maskgit_trainer() after finished training with vae_trainer()

AssertionError while trying to train with base_maskgit_trainer() after finished training with vae_trainer(). I finished training the vae trainer with about 123k iterations, then while I moved on to the base maskgit trainer, the following error occured:

123000: vae loss: 1.0881760120391846 - discr loss: 1.993680715560913
training complete
Resuming VAE from:  other/vae.23000.base.pt
Traceback (most recent call last):
  File "muse_train.py", line 352, in <module>
    base_maskgit_trainer()
  File "muse_train.py", line 156, in base_maskgit_trainer
    vae.load(args.resume_from.replace('.pt' , '.base.pt'))    # you will want to load the exponentially moving averaged VAE
  File "/home/eason/PyTorch/muse-pytorch/muse_maskgit_pytorch/vqgan_vae.py", line 408, in load
    assert path.exists()
AssertionError

Trained weights

Hi any chance you'll upload any trained weights?
Really want to try this out. It looks amazing!

Negative prompting from the Muse paper?

Would it be possible to implement negative prompting as described in section 2.7 in the paper? In the paper, the final tokens are calculated as $\ell_g = (1+t)\ell_c-t\ell_u$, and negative prompting would is done "by replacing the uncondtional logit $\ell_u$ with a logit conditioned on a negative prompt."

beartype error: unhashable type: 'list'

I installed the code and its dependencies. When I do from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer, it fails in muse_maskgit_pytorch when importing beartype:

/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/torch/onnx/_internal/_beartype.py:30: UserWarning: unhashable type: 'list'
  warnings.warn(f"{e}")
Traceback (most recent call last):
  File "/scratch/code/muse-maskgit-pytorch/./fni_main.py", line 3, in <module>
    from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer
  File "/scratch/code/muse-maskgit-pytorch/muse_maskgit_pytorch/__init__.py", line 2, in <module>
    from muse_maskgit_pytorch.muse_maskgit_pytorch import Transformer, MaskGit, Muse, MaskGitTransformer, TokenCritic
  File "/scratch/code/muse-maskgit-pytorch/muse_maskgit_pytorch/muse_maskgit_pytorch.py", line 16, in <module>
    from beartype import beartype
  File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/__init__.py", line 57, in <module>
    from beartype._decor.decormain import (
  File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_decor/decormain.py", line 23, in <module>
    from beartype._conf.confcls import (
  File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_conf/confcls.py", line 23, in <module>
    from beartype._cave._cavemap import NoneTypeOr
  File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_cave/_cavemap.py", line 33, in <module>
    from beartype._util.hint.nonpep.utilnonpeptest import (
  File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_util/hint/nonpep/utilnonpeptest.py", line 21, in <module>
    from beartype._util.cache.utilcachecall import callable_cached
  File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_util/cache/utilcachecall.py", line 32, in <module>
    from beartype._util.func.arg.utilfuncargtest import (
  File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_util/func/arg/utilfuncargtest.py", line 17, in <module>
    from beartype._util.func.utilfunccodeobj import get_func_codeobj
  File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_util/func/utilfunccodeobj.py", line 21, in <module>
    from beartype._data.datatyping import (
  File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/site-packages/beartype/_data/datatyping.py", line 129, in <module>
    BeartypeReturn = Union[BeartypeableT, BeartypeConfedDecorator]
  File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/typing.py", line 243, in inner
    return func(*args, **kwds)
  File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/typing.py", line 316, in __getitem__
    return self._getitem(self, parameters)
  File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/typing.py", line 421, in Union
    parameters = _remove_dups_flatten(parameters)
  File "/home/forrest/anaconda3/envs/mar2023_maskgit/lib/python3.9/typing.py", line 215, in _remove_dups_flatten
    all_params = set(params)
TypeError: unhashable type: 'list'

I googled "beartype unhashable type import error," but I haven't yet been able to find anything about this error. Does anyone else hit this problem?

  • I'm using Python 3.9. Is that an ok version to use?
  • I am using beartype 0.12.0, which is what was automatically installed when I installed the muse-maskgit-pytorch code.

Request for basic documentation

Thank you for the quick implementation of the Muse code. A bit of documentation / a quick howto would be highly appreciated.

Particularly a howto on this:

  • How to train the model(s)
  • How to do inference / generate images

How to make your example work

Hi,
Im trying to test the example in the readme but im getting a TypeError: 'NoneType' object is not callable when calling loss = base_maskgit(
images,
texts = texts
) ( I didnt change anything ). Im probably missing something obvious here, could you tell me what im doing wrong ?
Thanks

mask schedule

i'm sort of confused what the mask schedule is

in the original maskgit paper, they only said the cosine schedule was best without showing any pseudocode. i always assumed it was simply https://github.com/lucidrains/muse-pytorch/blob/main/muse_pytorch/muse_pytorch.py#L277

however, in this paper, they had a whole section where they talked about arccos and its density function. for those mathematically inclined, is it simply https://github.com/lucidrains/muse-pytorch/blob/main/muse_pytorch/muse_pytorch.py#L280 ?

TypeError: isinstance() arg 2 must be a type or tuple of types

Hello, after running the below code, it says "TypeError: isinstance() arg 2 must be a type or tuple of types", what did I do wrong ?
`
for i in range(len(train)):
im = np.array(train[i])
im = cv2.resize(im, dsize=(32, 32), interpolation=cv2.INTER_CUBIC)[:, :, None]
im = np.concatenate((im,im,im), axis=2)
name = './mnist/im_'+str(i)+'.png'
cv2.imwrite(name, im)

import torch
from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer

vae = VQGanVAE(
dim = 16,
vq_codebook_size = 32
)

train on folder of images, as many images as possible

trainer = VQGanVAETrainer(
vae = vae,
image_size = 32, # you may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it
folder = './mnist',
batch_size = 2,
grad_accum_every = 8,
num_train_steps = 100000
)#.cuda()
Traceback (most recent call last):

File "C:\Users\Ext.Edmond_Jacoupeau\AppData\Local\Temp\ipykernel_3136\1437749784.py", line 11, in
trainer = VQGanVAETrainer(

File "<@beartype(muse_maskgit_pytorch.trainers.VQGanVAETrainer.init) at 0x278d4d5e040>", line 48, in init

File "C:\Users\Ext.Edmond_Jacoupeau\Anaconda3\envs\py39\lib\site-packages\muse_maskgit_pytorch\trainers.py", line 149, in init
ddp_kwargs = find_and_pop(

File "C:\Users\Ext.Edmond_Jacoupeau\Anaconda3\envs\py39\lib\site-packages\muse_maskgit_pytorch\trainers.py", line 47, in find_and_pop
ind = find_index(arr, cond)

File "C:\Users\Ext.Edmond_Jacoupeau\Anaconda3\envs\py39\lib\site-packages\muse_maskgit_pytorch\trainers.py", line 42, in find_index
if cond(el):

TypeError: isinstance() arg 2 must be a type or tuple of types
`

T5 bug?

Why use this at line#93 in t5.py?

encoded_text = encoded_text.masked_fill(attn_mask[..., None], 0.)

Better way to handle Classifier-free guidance with flash_attn

Hi, in this repo, the classifier-free guidance is handled by setting a context_mask based on the dropout probability.

https://github.com/lucidrains/muse-maskgit-pytorch/blob/849bbd87f975255943349690a17a40f0103521d9/muse_maskgit_pytorch/muse_maskgit_pytorch.py#L306C1-L310C47

However, some rows of the context_mask could become all False and I found that if we use flash_attn, the F.scaled_dot_product_attention will generate NaN values.

I found that you have handled something similar in the x-transformer repo regarding the issue

https://github.com/lucidrains/x-transformers/blob/4f9775ba62ea65ba46cc496f13adad827605537c/x_transformers/attend.py#L201

But it seems like that can only handle the case for causal attention and the particular row of the mask. If we are working on maskgit where the attention is non-causal and sometimes context_mask could be False anywhere, then the current version of F.scaled_dot_product_attention will still generate NaN.

Therefore I am wondering whether there is a better way to support the flash_attn with classifier-free guidance.

Thank you.

Would you like flash attention?

I implemented support for flash attention via xformers:
Sygil-Dev/muse-maskgit-pytorch@main...Birch-san:muse-maskgit-pytorch:sdp-attn

It gets the same result, confirmed via allclose() (I had to make the absolute tolerance a bit more forgiving, to pass that check, but it's still a very small tolerance).

I also implemented support for torch.nn.functional.scaled_dot_product_attention. it'll use flash attention when mask is None, but I guess your mask is usually defined.
even without flash attention: scaled_dot_product_attention should still be faster than the einsum() * scale approach, because (IIRC) its math fallback is based on baddbmm, which fuses the scale factor into the matmul.
in stable-diffusion inference for example: we measured end-to-end image generation via baddbmm to be ~11% faster than einsum() * scale on CUDA (and 18% faster on Mac).

@lucidrains is this a contribution you would be interested in receiving as a PR?

How to run this?

Hi - thank you for getting this out so quick. I came from Stable Diffusion world, Muse looks really interesting but I don't know how to run this. Can you write some easy-to-understand tutorial in README?

How to train the maskgit transformers (base model)?

How to train the maskgit transformers (base model)?

The trainers.py is only for training the VQGAN-VAE model, and in readme, the code is just said getting the loss. There is no trainer for maskgit? How can I train it?

Maskgit transformer is hard to train

Hi
I am trying to train the transformer from scratch. I notice the loss quickly get stuck at a same value (my case is 8.x). I have tried lower/higher the learning rate but no luck. After ~15000k iters of training, the result is not satisfying as well.
I have tried different masking ratio, such as ~34% instead of ~64%, the loss value does go down to ~1.3, but the outpointing result is still far from satisfying.

So, has anyone trained a working transformer?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.