Giter VIP home page Giter VIP logo

meshgpt-pytorch's Introduction

MeshGPT - Pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch

Will also add text conditioning, for eventual text-to-3d asset

Please join Join us on Discord if you are interested in collaborating with others to replicate this work

Appreciation

  • StabilityAI, A16Z Open Source AI Grant Program, and ๐Ÿค— Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research

  • Einops for making my life easy

  • Marcus for the initial code review (pointing out some missing derived features) as well as running the first successful end-to-end experiments

  • Marcus for the first successful training of a collection of shapes conditioned on labels

  • Quexi Ma for finding numerous bugs with automatic eos handling

  • Yingtian for finding a bug with the gaussian blurring of the positions for spatial label smoothing

  • Marcus yet again for running the experiments to validate that it is possible to extend the system from triangles to quads

Install

$ pip install meshgpt-pytorch

Usage

import torch

from meshgpt_pytorch import (
    MeshAutoencoder,
    MeshTransformer
)

# autoencoder

autoencoder = MeshAutoencoder(
    num_discrete_coors = 128
)

# mock inputs

vertices = torch.randn((2, 121, 3))            # (batch, num vertices, coor (3))
faces = torch.randint(0, 121, (2, 64, 3))      # (batch, num faces, vertices (3))

# make sure faces are padded with `-1` for variable lengthed meshes

# forward in the faces

loss = autoencoder(
    vertices = vertices,
    faces = faces
)

loss.backward()

# after much training...
# you can pass in the raw face data above to train a transformer to model this sequence of face vertices

transformer = MeshTransformer(
    autoencoder,
    dim = 512,
    max_seq_len = 768
)

loss = transformer(
    vertices = vertices,
    faces = faces
)

loss.backward()

# after much training of transformer, you can now sample novel 3d assets

faces_coordinates, face_mask = transformer.generate()

# (batch, num faces, vertices (3), coordinates (3)), (batch, num faces)
# now post process for the generated 3d asset

For text-conditioned 3d shape synthesis, simply set condition_on_text = True on your MeshTransformer, and then pass in your list of descriptions as the texts keyword argument

ex.

transformer = MeshTransformer(
    autoencoder,
    dim = 512,
    max_seq_len = 768,
    condition_on_text = True
)


loss = transformer(
    vertices = vertices,
    faces = faces,
    texts = ['a high chair', 'a small teapot'],
)

loss.backward()

# after much training of transformer, you can now sample novel 3d assets conditioned on text

faces_coordinates, face_mask = transformer.generate(texts = ['a long table'])

If you want to tokenize meshes, for use in your multimodal transformer, simply invoke .tokenize on your autoencoder (or same method on autoencoder trainer instance for the exponentially smoothed model)

mesh_token_ids = autoencoder.tokenize(
    vertices = vertices,
    faces = faces
)

# (batch, num face vertices, residual quantized layer)

Todo

  • autoencoder

    • encoder sageconv with torch geometric
    • proper scatter mean accounting for padding for meaning the vertices and RVQ the vertices before gathering back for decoder
    • complete decoder and reconstruction loss + commitment loss
    • handle variable lengthed faces
    • add option to use residual LFQ, latest quantization development that scales code utilization
    • xcit linear attention in encoder and decoder
    • figure out how to auto-derive face_edges directly from faces and vertices
    • embed any derived values (area, angles, etc) from the vertices before sage convs
    • add an extra graph conv stage in the encoder, where vertices are enriched with their connected vertex neighbors, before aggregating into faces. make optional
    • allow for encoder to noise the vertices, so autoencoder is a bit denoising. consider conditioning decoder on noise level, if varying
  • transformer

    • properly mask out eos logit during generation
    • make sure it trains
      • take care of sos token automatically
      • take care of eos token automatically if sequence length or mask is passed in
    • handle variable lengthed faces
      • on forwards
      • on generation, do all eos logic + substitute everything after eos with pad id
    • generation + cache kv
  • trainer wrapper with hf accelerate

    • autoencoder - take care of ema
    • transformer
  • text conditioning using own CFG library

    • complete preliminary text conditioning
    • make sure CFG library can support passing in arguments to the two separate calls when cond scaling (as well as aggregating their outputs)
    • polish up the magic dataset decorator and see if it can be moved to CFG library
  • hierarchical transformers (using the RQ transformer)

  • fix caching in simple gateloop layer in other repo

  • local attention

  • fix kv caching for two-staged hierarchical transformer - 7x faster now, and faster than original non-hierarchical transformer

  • fix caching for gateloop layers

  • allow for customization of model dimensions of fine vs coarse attention network

  • figure out if autoencoder is really necessary - it is necessary, ablations are in the paper

    • when mesh discretizer is passed in, one can inject inter-face attention with the relative distance
    • additional embeddings (angles, area, normal), can also be appended before coarse transformer attention
  • make transformer efficient

    • reversible networks
  • speculative decoding option

  • spend a day on documentation

Citations

@inproceedings{Siddiqui2023MeshGPTGT,
    title   = {MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers},
    author  = {Yawar Siddiqui and Antonio Alliegro and Alexey Artemov and Tatiana Tommasi and Daniele Sirigatti and Vladislav Rosov and Angela Dai and Matthias Nie{\ss}ner},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:265457242}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@inproceedings{Leviathan2022FastIF,
    title   = {Fast Inference from Transformers via Speculative Decoding},
    author  = {Yaniv Leviathan and Matan Kalman and Y. Matias},
    booktitle = {International Conference on Machine Learning},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:254096365}
}
@misc{yu2023language,
    title   = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation}, 
    author  = {Lijun Yu and Josรฉ Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
    year    = {2023},
    eprint  = {2310.05737},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Lee2022AutoregressiveIG,
    title   = {Autoregressive Image Generation using Residual Quantization},
    author  = {Doyup Lee and Chiheon Kim and Saehoon Kim and Minsu Cho and Wook-Shin Han},
    journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    year    = {2022},
    pages   = {11513-11522},
    url     = {https://api.semanticscholar.org/CorpusID:247244535}
}
@inproceedings{Katsch2023GateLoopFD,
    title   = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
    author  = {Tobias Katsch},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:265018962}
}

meshgpt-pytorch's People

Contributors

kurokabe avatar lucidrains avatar marcusloppe avatar qixuema avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

meshgpt-pytorch's Issues

Missing features for graph embedding

Hi,

First time posting on a projects issue page so apologies if I make any mistakes.
I've read through the paper many times I think that you are not embedding all the features mentioned in the paper, I believe the features is (F):
9 (coordinates) + 1 (area) + 3 (angles) + 3 (normal) = 16

So since F = 16, the input for the graph encoder should be 16x196 and out 16x 576 per face.
I figure that I post here now since you have progressed the project quite a bit and probably testing it soon.

I'm not great at tensor programming so I just asked ChatGPT to modify the encoder using the details from the paper. The code is probably incorrect since I don't 100% understand the tensor operations you are doing but at least I can provide with some inspiration or boilerplate example.

@beartype
def encode(
    self,
    *,
    vertices:         TensorType['b', 'nv', 3, int],
    faces:            TensorType['b', 'nf', 3, int],
    face_edges:       TensorType['b', 'e', 2, int],
    face_mask:        TensorType['b', 'nf', bool],
    face_edges_mask:  TensorType['b', 'e', bool],
    return_face_coordinates = False
):
    # ... [existing code up to face_embed definition] ...

    # Calculate additional face attributes
    # Using vertices and faces to calculate the area, angles, and normal for each face
    face_vertices = vertices.gather(1, faces) # Gather vertices for each face
    sides = face_vertices[:, :, [1, 2, 0], :] - face_vertices[:, :, [0, 1, 2], :]
    side_lengths = sides.norm(dim=-1)

    # Calculate area (using Heron's formula for simplicity)
    s = side_lengths.sum(dim=-1) / 2
    area = torch.sqrt(s * (s - side_lengths[:, :, 0]) * (s - side_lengths[:, :, 1]) * (s - side_lengths[:, :, 2]))
    area = area.unsqueeze(-1) # Reshape for concatenation

    # Calculate angles (using cosine rule)
    angles = torch.acos((side_lengths[:, :, [1, 2, 0]] ** 2 + side_lengths[:, :, [2, 0, 1]] ** 2 - side_lengths[:, :, [0, 1, 2]] ** 2) / (2 * side_lengths[:, :, [1, 2, 0]] * side_lengths[:, :, [2, 0, 1]]))
    angles = angles.flatten(start_dim=-2) # Flatten angles for concatenation

    # Calculate normals (using cross product)
    normals = torch.cross(sides[:, :, 0, :], sides[:, :, 1, :], dim=-1)
    normals = normals / normals.norm(dim=-1, keepdim=True) # Normalize

    # Concatenate additional features
    face_additional_features = torch.cat([area, angles, normals], dim=-1)
    face_additional_features = rearrange(face_additional_features, 'b nf d -> b nf (d)')

    # Concatenate with existing face embeddings
    face_embed = torch.cat([face_embed, face_additional_features], dim=-1)

    # ... [rest of the existing code] ...

    return face_embed, face_coords

when `transformer.generate(prompt=None)`,empty code is passed to the decoder. Error!!

Thank you very much for your work!!!

When I try to generate with the trained model, if I don't add the prompt, the codes are generated one by one starting from empty.
codes = default(prompt, torch.empty((batch_size, 0), dtype = torch.long, device = self.device))
When empty is input into the model decoder, an error will be reported.
Only when face_codes.size(1) is not 0, no error will be reported. I would like to ask you how to solve it.

I tried entering prompt and it was generated successfully.

transformer.generate(prompt=None)
face_codes.size()= [1,0,512]

Error:
CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
File "/ssd1/meshgpt-pytorch/meshgpt_pytorch/meshgpt_pytorch.py", line 1436, in forward_on_codes
attended_face_codes, coarse_cache = self.decoder(
File "/ssd1/meshgpt-pytorch/meshgpt_pytorch/meshgpt_pytorch.py", line 1189, in generate
output = self.forward_on_codes(
File "/ssd1/meshgpt-pytorch/generate_samples_v1.py", line 109, in
face_coords, face_mask = transformer.generate(temperature=r, texts=texts_list)
RuntimeError: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Transformer - High VRAM, context length

Hello again, this issue is for next year ๐Ÿ˜ƒ

When training the transformer, I used the follow config:

transformer = MeshTransformer(
    autoencoder,
    dim = 128,
    attn_depth = 4,
    attn_dim_head = 8,
    attn_heads = 4,
    coarse_pre_gateloop_depth = 1,#6,
    fine_pre_gateloop_depth= 0,#4, 
    max_seq_len = max_seq,
    gateloop_use_heinsen = False,
    condition_on_text = True
)

This resulted in a transformer that was 22M parameters.
I then tried try to train it on a 6206 faces mesh which is 37236 tokens (6206 * 6).
When I feed it the faces codes (1,6206,128) it used about 11GB VRAM and at the end of the forward it used about 20 GB.
If I used a transformer that as 188M (256dim) it used 50GB of VRAM.

My suggestion to implement Sliding-Window Attention / Local attention since most long context LLM uses it and it seems to be working.

Or creating a embedding of the tokens and concating it together with the text conditioner embedding so the cross attention can beware of previous tokens as well.

Also take a look if Grouped-Query Attention is beneficial :)

attended_face_codes, coarse_cache = self.decoder(
                face_codes,
                cache = coarse_cache,
                return_hiddens = True,
                **attn_context_kwargs
            )

Commit Loss doesn't go down

Thanks for your work!

I tried to use the table dataset in ShapeNetV2 and test the code, the code works fine. And here is my issue:

During training phase, if I test the code using small dataset, namely 100 objs, the commit loss can converge quickly.
But when I use the full table dataset, the commit loss gets higher and higher each epoch.
When I apply several augmentation strategies, things get worse.

I tried both Residual VQ and Residual LFQ, none of them work.

I guess it might be the autoencoder meet difficulties to quantize so many different faces. But I am not familiar with autoencoders or quantizations.

Anyone also meet this issue or are familiar with Residual VQs? Thanks

Transformer - token_embed outputs nan values

This issue occurs if you have too high learning rate (1-e2) at a low loss (0.3), through this also occurred when I had 1-e3 as lr and at 0.01 loss.
edit: Using flash attention it goes from 5.0 loss to nan in the 5th epoch using 1e-4 lr.

After the codes are masked the and token_embed is called, it will output nan values.
Not sure if this issue is a pytorch, meshgpt-pytorch or user error :)

codes = codes.masked_fill(codes == self.pad_id, 0)
codes = self.token_embed(codes)
codes  after  masked_fill  torch.Size([2, 912]) tensor([[11965,   608, 11350,  ...,     0,     0,     0],
        [15507, 13398,  5400,  ...,  8247, 13231,  5280]], device='cuda:0') 

codes token_embed after  torch.Size([2, 912, 512]) tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0',
       grad_fn=<EmbeddingBackward0>)

Mesh intra face vertex id ordering convention

Hi,
Thank you for your great implementation of the meshGPT paper.
I have a question related to the section 3.1 : "For sequence ordering, Polygen [43] suggests a convention where faces are ordered based on their lowest vertex index, followed by the next lowest, and so forth. Vertices are sorted in z-y-x order (z representing the vertical axis), progressing from lowest to highest. Within each face, indices are cyclically permuted to place the lowest index first."

@MarcusLoppe , I used your vertices and faces ordering function (available in your repo notebook) for my own data preparation.
I have the feeling the vertices ordering is performed as in the paper. However, your faces ordering seems not.

Let's take the following toy example face with given vertices IDs: [3, 2, 8]
On the one hand, according to the paper, they cyclically permute the IDs to place the lowest index first, it gives [3, 2, 8] -> [8, 3, 2] -> [2, 8, 3] after two cyclic permutations or even better [3, 2, 8] -> [2, 8, 3] with one cyclic permutation in the other direction.
On the other hand, you permute the IDs like this: [3, 2, 8] -> [2, 3, 8] which can not be given by a cyclic permutation.

I think it can cause further issues as you can see in the attached image of my own data. Since your permutation may not be cyclic, it can induce the triangles normals to be inverted. The triangles angle values may also be different if they are oriented but I did not check. And since these features are computed inside the model and fed to the GCNN encoder, I would say we are not forwarding the right mesh to the model.

To conclude, this is not an issue related to the model itself but rather on the data preparation for training and how to better stick to the existing paper.

What do you think?
Thanks!
face_ordering_meshgpt

Trainer, fails to load.

When loading a model using the current version, the optimizer fails to load since it excepts a dict with the optimizer key.

So rename:
self.optimizer.load_state_dict(pkg['optimizer'])
To:
self.optimizer.load_state_dict(pkg)

    def load(self, path):
        path = Path(path)
        assert path.exists()

        pkg = torch.load(str(path))

        if version.parse(__version__) != version.parse(pkg['version']):
            self.print(f'loading saved mesh transformer at version {pkg["version"]}, but current package version is {__version__}')

        self.model.load_state_dict(pkg['model'])
        self.optimizer.load_state_dict(pkg['optimizer'])
        self.step.copy_(pkg['step'])

How to evaluate the codes returning from the autoencoder?

Hi, I'm trying to return a mesh I can see in Blender.

This is what I have, but it's incomplete. No idea what I'm doing.

Appreciate the work you're doing. Reply when you can.

from gltf_dataset.gltf_dataset import GLTFDataset
import wandb, torch
from meshgpt_pytorch import (
    MeshAutoencoder,
    MeshAutoencoderTrainer,
)
from meshgpt_pytorch.data import derive_face_edges_from_faces

from einops import rearrange, reduce

from meshgpt_pytorch.meshgpt_pytorch import undiscretize

wandb.init(
    project="meshgpt-pytorch"
)

dataset = GLTFDataset('gltf_dataset/blockmesh_test/blockmesh')

checkpoint_path = 'checkpoints/mesh-autoencoder.ckpt.1.pt'

autoencoder = MeshAutoencoder.init_and_load_from(checkpoint_path)

autoencoder.eval()

sample_data = dataset.__getitem__(0)
vertices = sample_data[0].unsqueeze(0) 
faces = sample_data[1].unsqueeze(0) 

with torch.no_grad():
    pad_value = -1
    face_edges = derive_face_edges_from_faces(faces, pad_value)

    num_faces, num_face_edges, device = faces.shape[1], face_edges.shape[1], faces.device
    face_mask = reduce(faces != pad_value, 'b nf c -> b nf', 'all')
    face_edges_mask = reduce(face_edges != pad_value, 'b e ij -> b e', 'all')

    encoded = autoencoder.encode(
        vertices = vertices,
        faces = faces,
        face_edges = face_edges,
        face_edges_mask = face_edges_mask,
        face_mask = face_mask,
    )
    rvq_sample_codebook_temp = 1
    quantized, codes, commit_loss = autoencoder.quantize(
        face_embed = encoded,
        faces = faces,
        face_mask = face_mask,
        rvq_sample_codebook_temp = rvq_sample_codebook_temp
    )

    decode = autoencoder.decode(
        quantized,
        face_mask = face_mask
    )

    pred_face_coords = autoencoder.to_coor_logits(decode)

    pred_face_coords = rearrange(pred_face_coords, 'b ... c -> b c (...)')

    continuous_coors = undiscretize(
                pred_face_coords,
                num_discrete = 128,
                continuous_range = (-1., 1.)
            )
    # Todo

Training on my own data

Hi @MarcusLoppe Thanks for your pulled repo, but I was facing a little trouble in the notebook, as I am using Colab instead of Kaggle, could you guide me a little here: tables = load_json("/kaggle/input/shapenet/data.json",2)
Can you also help me that if I have my own data, what should I do then? Much thanks.

How to embed additional information like uvs / bones

add an extra graph conv stage in the encoder, where vertices are enriched with their connected vertex neighbors, before aggregating into faces. make optional

Will this allow me to embed uvs or bone + bone weights? We have a huge problem with not enough training power, but I was curious.

  1. uvs are per vetex. Typically [0, 1] but can be a real number

  2. Bone weights and bone indices are per vertex too.

Weights is a value from 0 to 1.

Each vertex must sum to 1.0 for all 4, 8 .. 32 .. N influences. (Typically 4. or 8.)

Bone indices can be represented as an entire bone group like a separate mesh. Like a character is represented as hips bone of the hips sections with each part of that mesh having a secondary value of [0, 1]

  1. Same as the previous idea but a token representing category. Like this mesh triangle faces is 42 which means it's the "Head" vertex group.

Post Process 3d Assets

Hi, thank you for the wonderful code. I want to understand that, how can I post process the following as you mentioned in Usage. And, how do I run the code?

# after much training of transformer, you can now sample novel 3d assets

faces_coordinates = transformer.generate()

# (batch, num faces, vertices (3), coordinates (3))
# now post process for the generated 3d asset

Classifier-Free Guidance, cond_drop_prob=1.0, attn_mask=False: Error!!!

When using text conditions, even if the parameter text_condition_cond_drop_prob is set to 0.25 when initializing the MeshTransformer, it is easy to overlook the cond_drop_prob parameter in MeshTransformer.forward_on_codes(cond_drop_prob = 0.).

_, maybe_dropped_text_embeds = self.conditioner( text_embeds = text_embeds, cond_drop_prob = cond_drop_prob )

Accidentally retaining the default parameter cond_drop_prob = 0. means that text conditions are not properly dropped out during training, which is not conducive to the subsequent use of Classifier Free Guidance.
It is recommended that the author sets the default parameter of cond_drop_prob to None. @lucidrains

However, most importantly, when I set MeshTransformer.forward_on_codes(cond_drop_prob = 1.0), the mask for text_embedding is all false, meaning that during the calculation of cross-attention, attn_mask = False, which causes the output to be NaN.

How can the above issues be resolved ???

import torch.nn.functional as F out = F.scaled_dot_product_attention( q, k, v, attn_mask = mask, dropout_p = self.dropout if self.training else 0., is_causal = causal )
It seems that attn_mask cannot be False at all positions !!!

Isn't MeshGPT a VQ-VAE?

Hi guys, thanks for all the work!

I have a conceptual question. Isn't MeshGPT a VQ-VAE? If not, what are the differences?

Thanks!

Large memory requirement

Hi,

The memory requirement for the transformer is quite large, in the paper they trained with max 800 faces and for this you'll need 19 GB VRAM (napkin math).
This is quite impractical and if you want to generate a large 3D model with let's say 4000 faces you'll need 95 GB VRAM (napkin math).

Each triangles uses 6 tokens, for a 240 face model the transformer needs a sequence length of 1440 tokens.
When training the autoencoder (4 batch size) reaches 2.6 GB VRAM, then when training using the transformer (1 batch size) it reaches 8.3GB so training the transformer with 1440 tokens uses 5.7 GB.
1440 / 5.7 = 252, 1GB VRAM = 252 tokens = 42 triangles.

My idea is to chunk the model's faces into sequences of e.g 200 faces and have a continue token to indicate this shape isn't complete. Each token chunk should maybe contain 50% of the tokens of the previous chunk and the other 50% new tokens.
Using the text condition it should also help the model understand what kind of shape that it should generate.

This way the model can have low VRAM usage but longer training times but this resolves the issues of training/inference on large 3D models since you can chunk the sequences into manageable sizes, this way you can train and generate complex models with low memory usage since it will continue until it hits the EOS token.

Another idea is to also add a embedding (e.g create a token summarizer) for all the tokens before the previous chunk, so you have a chunk of tokens plus a embedding of the previous tokens which will give the model context/summarization of the overall shape.

Since you can embed a 8k text sequence as a 768 vector you probably can do the same thing and be able to train/generate large 3D models with low memory usage.

custom_collate - only adding tensor to output

The custom_collate function will not append datum if it's not a tensor.
This means that all datasets that contain texts will have their texts value removed.

From:

    for datum in zip(*data):
        if is_tensor(first(datum)):
            padded = pad_sequence(datum, batch_first = True, padding_value = pad_id)
            output.append(padded)
        else:
            datum = list(datum) 

To:

    for datum in zip(*data):
        if is_tensor(first(datum)):
            padded = pad_sequence(datum, batch_first = True, padding_value = pad_id)
            output.append(padded)
        else:
            datum = list(datum) 
            output.append(datum )

`gather` function fails after padding faces with `-1` for variable length meshes

Hello,

I've encountered an issue with the gather function after padding operation. As suggested in the usage notes:

"# make sure faces and face_edges are padded with -1 for variable lengthed meshes. otherwise, you will need to explicitly pass in face_len as well as face_edges_len."

Following this, I padded the faces tensor with -1 to account for variable length meshes. However, this seems to have caused the gather function to fail. The padding was done by setting the values of the last n faces in faces entirely to -1.

Is the padding process causing the gather operation to fail? Or is there a specific way the padding should be performed in this context? Below is the snippet of my code where the padding is applied:

import torch
import random

t = torch.randint(0, 121, (2, 64, 2))
for i in range(t.shape[0]):
    n = random.randint(1, t.shape[1])    
    t[i, -n:, :] = -1

Packed Sequences

I have an internal version of a similar architecture to meshgpt (with down sampling with a graph seanet/encodec style encoder/decoder
for shorter sequences and a few other changes, will open source when I have completed weights).

Was just wondering why you are not using packed sequences for the vertex/face inputs to your autoencoder. That way its a bit nicer to handle dynamic number of vertices and faces in a batch.

gateloop_use_heinsen=True on MeshTransformer results in NaN loss

The loss quickly became NaN when training on ShapeNet after filtering for mesh < 800 faces after decimation (resulting in ~15k different 3D models) with condition_on_text=True so I thought it was similar to #44 but even after training without text, adjusting the learning rate, adding warmup, larger batch size... I still had this problem.

I found with detect_anomaly that the NaN come from gateloop_transformer>heinsen_associative_scan on the log backward.

Previously I was able to train successfully on 5 ShapeNet categories, with 10 meshes x 256 transformations each = 12'800 3D models, but it was with meshgpt-pytorch version 0.4.2. Maybe updating to version 0.5.5 also updated gateloop-transformer which added this bug. Anyway, setting gateloop_use_heinsen=False seems to have solved the problem for me.

Just to get started

Hello there, I'm a designer with basic coding skill but totally unfamilliar with machine learning stuff. Really want to experiment with it. I have install the package and a little confuse about where to start like what data assets should I collect and how to train them. Much appreciate!

Bad performance using local attention

I noticed that the local attention layers isn't that great.
Below is a testing using 6 unique 3d mesh chairs that is duplicated 200 times = 1200 examples per epoch.
I've seen this pattern before but I wanted to test it on some more complex 3D mesh before posting a issue.

bild

Another topic since the discord link isn't working:
I did some testing with text using 10 mesh chairs, I apply augmentation so each chair got 3 variations.
Then i duplicated each variation 500 times so the total dataset size is 3000.
I encoded the meshes using text as well but just using the same word 'chair', but this proves that the text generation works.

After 22 minutes training the encoder (0.24 loss) and then 2.5 hrs training the transformer (0.0048) I got the result below.
I trained the transformer on different learning rates but in total there was 30 epochs e.g 30x 3000= 90 000 steps.

Training using text is about x2 times slower so might be good to see if any improvements on that front can be made.

I provided the text "chair" and looped the generation to use different temperature values from 0 to 1.0 with 0.1 as stepping value.

bild
bild

Data augmentation strategies

In #6

For each mesh I generate augments_per_item (like 200), then I use it to index into the dataset.

Using a seed I augment using this strategy.

What do you think?

scale = random.uniform(0.8, 1.2)  # Uniform scaling
rotation = R.from_euler('y', random.uniform(-180, 180), degrees=True)  # Random rotation around y-axis
translation = np.array([random.uniform(-0.5, 0.5) for _ in [0, 2]])  # Random translation in x and z directions

The goal is for a chair item to be rotated, moved or scaled, but upright.

Edited:

The idea is to have a chair be displaced but under gravity so it keeps its lowest vertex position.

Is there a pretrained model and if not, how to train the model

Hello @lucidrains, @MarcusLoppe,

We are trying to use it for a quick university project and we are not sure on how to train the model using shapenet or similar datasets. Could you help us with this? The goal we have is to be able to prove text to a flask server which utilizes your transformer to generate a model for us, which is then returned.

Is there any pretrained Version that can be used with this?
How exactly does one use MeshTransformerTrainer and MeshAutoencoderTrainer in combination with Shapenet?
etc...
As it seems you guys have trained the model already on a broad amount of categories from the ShapeNet dataset, could you provide us with the state file, so that we can load a trained model into our version?

Regards

is gaussian_blur_1d correct?

I think we should view the discrete coordinate dimension as the "spatial" dimension in conv1d.
Replacing

def gaussian_blur_1d(
    t: Tensor,
    *,
    sigma: float = 1.
) -> Tensor:

    _, channels, _, device, dtype = *t.shape, t.device, t.dtype

    width = int(ceil(sigma * 5))
    width += (width + 1) % 2
    half_width = width // 2

    distance = torch.arange(-half_width, half_width + 1, dtype = dtype, device = device)

    gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2))
    gaussian = l1norm(gaussian)

    kernel = repeat(gaussian, 'n -> c 1 n', c = channels)
    return F.conv1d(t, kernel, padding = half_width, groups = channels)

with

def gaussian_blur_1d(
    t: Tensor,
    *,
    sigma: float = 1.
) -> Tensor:

    _, _, channels, device, dtype = *t.shape, t.device, t.dtype # change made here

    width = int(ceil(sigma * 5))
    width += (width + 1) % 2
    half_width = width // 2

    distance = torch.arange(-half_width, half_width + 1, dtype = dtype, device = device)

    gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2))
    gaussian = l1norm(gaussian)

    kernel = repeat(gaussian, 'n -> c 1 n', c = channels)
    return F.conv1d(t.permute(0, 2, 1), kernel, padding = half_width, groups = channels).permute(0, 2, 1) # change made here

makes the blurring work.

I made a breakpoint to check the value of target_one_hot (ignore the peak location change).
Before:

(Pdb) target_one_hot[0, :, 1]
tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0404,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0404, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.9192, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000], device='cuda:0')

After:

(Pdb) target_one_hot[0, :, 1]
tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0404, 0.9192, 0.0404, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000], device='cuda:0')

Residual quantization - VRAM bottleneck

This might be a issue related to the vector-quantize-pytorch but I'll bring it up here.

During some testing I discovered that the vector quantizer took up most VRAM. I'm not sure if this is a bug or the requirements of the quantizer.

After calling Encode() it used up about 500MB and decode uses about 300MB. After calling the quantize() it used about 4ย 116MB, if this could be resolved, the VRAM issues might be resolved as well.
I saw no signifiant difference using ResidualLFQ or ResidualVQ.

I tested with a big 3D model with 5956 vertices and 3k faces with batch size 1, the issue remains if you use a smaller 3d model.
I setup a bunch of debug prints to monitor VRAM, here is the results after calling autoendcoder.forward()


Autoencoder - forward start: 3064.0 MB
After encode() call : 3564.0 MB

Calling quantize():
ResidualLFQ - Forward: : 3564.0 MB
ResidualLFQ - for quantizer_index, layer in enumerate(self.layers)

LFQ layer 1:
Start VRAM : 3564.0 MB
distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook)
4312.0 MB distance : torch.Size([1, 5956, 1, 16384]) original_input: torch.Size([1, 5956, 1, 14]) self.codebook: torch.Size([16384, 14])

prob = (-distance * inv_temperature).softmax(dim = -1)
4686.0 MB prob: torch.Size([1, 5956, 1, 16384])

per_sample_entropy = entropy(prob).mean()
5808.0 MB per_sample_entropy: 1.1416213512420654

entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
5856.0 MB entropy_aux_loss: -0.8769185543060303

LFQ layer 2:
Start VRAM : 5808.0 MB
distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook)
6184.0 MB distance : torch.Size([1, 5956, 1, 16384]) original_input: torch.Size([1, 5956, 1, 14]) self.codebook: torch.Size([16384, 14])

prob = (-distance * inv_temperature).softmax(dim = -1)
6558.0 MB prob: torch.Size([1, 5956, 1, 16384])

per_sample_entropy = entropy(prob).mean()
7680.0 MB per_sample_entropy: 0.0

entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
7728.0 MB entropy_aux_loss: -1.3798797130584717

After calling quantize(): 7680.0 MB
After calling decoded() : 7937.0 MB

Step #2:
Autoencoder - forward start: 9095.0 MB


bild
I don't think MeshDiscretizer can be used as a solution since I think they already tried it in paper and got quite bad results.

Support for quads

Would it be within the scope of the project to also support quads for meshes? That would add tremendous utility.

Typo Error: `dim` Parameter in `MeshAutoencoder`

Hi Phil,

I've noticed a potential typo error. A few days ago, you removed the dim parameter in the MeshAutoencoder, but I still see it being used here. Is this a typo error?

for _ in range(local_attn_encoder_depth):
    self.encoder_local_attn_blocks.append(nn.ModuleList([
        LocalMHA(dim = dim, **attn_kwargs, **local_attn_kwargs),
        nn.Sequential(RMSNorm(curr_dim), FeedForward(curr_dim, glu = True, dropout = ff_dropout))
    ]))

Best regards,
Xueqi

derive_face_edges_from_faces high ram usage

So I'm trying to see what prevents from training on high poly count meshes.
I tried with 5k & 16k face count meshes, below are the results.

Using a batch size 1 at 5k the memory usage went up 1.5 GB, when I switched to batch size, it went up to 4 GB and if loaded it using the GPU i increased by 6 GB.

The face_edges object that is return has a actual usage of 1386.40 MB so 3.4GB is junk (at 4 batch size of 5k), I tried calling gc.collect() but no change.

I've tried to optimize the derive_face_edges_from_faces function but haven't had much luck, current it convert a batch of 1 in 0.43sec so if there is head room of making it slower and more memory effective.

Making it slower might affect the transformer since it needs to call it each step.
Current this looks like a big memory issue and I hope someone better can resolve it.
I'll try to see if there is another bottlenecks.

Metric 5k Faces - 1 Batch Size 5k Faces - 4 Batch Size
Initial RAM Usage (MB) 653.93 1803.65
all_edges Usage (MB) 346.60 346.60
face_masks Usage (MB) 0.00 0.02
face_edges_masks Usage (MB) 21.66 86.65
shared_vertices Usage (MB) 194.96 779.85
Before loop (MB) 1221.18 2929.93
face_edges after loop (MB) 346.60 1386.40
face_edges Usage (MB) 346.60 1386.40
After loop (MB) 2158.66 5825.12
torch.Size [1, 22714756, 2] [4, 22714756, 2]

15k face - 1 batch size:

Metric Value (MB)
Initial RAM Usage 689.57
all_edges Usage 4234.15
face_masks Usage 0.02
face_edges_masks Usage 264.63
shared_vertices Usage 2381.71
Before loop 7546.62
face_edges after loop 0.94
face_edges Usage 0.94
After loop 10195.26
    all_edges = torch.stack(torch.meshgrid(
        torch.arange(max_num_faces, device = device),
        torch.arange(max_num_faces, device = device),
    indexing = 'ij'), dim = -1)

    face_masks = reduce(faces != pad_id, 'b nf c -> b nf', 'all')
    
    face_edges_masks = rearrange(face_masks, 'b i -> b i 1') & rearrange(face_masks, 'b j -> b 1 j')
    shared_vertices = rearrange(faces, 'b i c -> b i 1 c 1') == rearrange(faces, 'b j c -> b 1 j 1 c')  
     
    print(f"all_edges Usage: {all_edges.element_size() * all_edges.numel() / (1024 ** 2):.2f} MB")
    print(f"face_masks Usage: {face_masks.element_size() * face_masks.numel() / (1024 ** 2):.2f} MB")
    print(f"face_edges_masks Usage: {face_edges_masks.element_size() * face_edges_masks.numel() / (1024 ** 2):.2f} MB")
    print(f"shared_vertices Usage: {shared_vertices.element_size() * shared_vertices.numel() / (1024 ** 2):.2f} MB")
    
    print(f"Before loop: {get_ram_usage():.2f} MB") 
    
    for face, face_edge_mask in zip(faces, face_edges_masks):
                   ...............
  
    print(f"face_edges after loop: {sum(tensor.element_size() * tensor.numel() for tensor in face_edges) / (1024 ** 2):.2f} MB")
 
    face_edges = pad_sequence(face_edges, padding_value = pad_id, batch_first = True)
    print(f"face_edges Usage: {face_edges.element_size() * face_edges.numel() / (1024 ** 2):.2f} MB")
 
    print(f"After loop:: {get_ram_usage():.2f} MB")
    if is_one_face:
        face_edges = rearrange(face_edges, '1 e ij -> e ij')

    return face_edges

Simple training script for toy data?

Hi there,
I wonder if it's possible to have some script reproducing the same toy example from an older paper. I tried to run the training, but the best thing I came up with is this:
image
I also constantly run into NaN as reported here. Thanks for any help!

Recipe for overfitting a single obj

Hi Happy New Year :) and sorry to ask again.

I failed to overfit a single obj (the autoencoder recon loss is always around 0.7).

Currently I am using sampled 100 tables from ShapeNetV2 as the dataset without any augmentation, and I follow the ipynb written by @MarcusLoppe. So can I kindly ask for the "data recipe" if I want to overfit a table as the first step.

Thanks a lot.

use_heinsen

Thanks for your great work!But why i get this?
image

Use mesh autoencoder as is for chatml finetuning

Hi! I want to thank you for your project.

Since you were able to get autoencoding working, I have a question. I was trying to get this to be encoded using your code. This is a mock-up of the syntax, but it's supposed to be a 3d-simplex (tetrahedron) translated by 1 meters.

Meshgpt

If I am able to get the meshes in this form, we can take "off the shelf" training services like https://together.ai and start processing models from https://huggingface.co/datasets/allenai/objaverse-xl with translation, rotation (for like a basketball) and scale (affine transforms) to make a 3d mesh foundational model.

Can you provide some insight how to do this?

[
    {
        "role": "system",
        "content": "This assistant can understand 3D models using the meshgpt-pytorch Unicode plane 15 codebook for 16384 triangles and the .ply 3d format."
    },
    {
        "role": "user",
        "content": "Please convert this to the meshgpt-pytorch codebook. ```ply\nformat ascii 1.0\nelement vertex 4\nproperty float x\nproperty float y\nproperty float z\nelement face 4\nproperty list uchar int vertex_index\nend_header\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n3 0 1 2\n3 0 2 3\n3 0 1 3\n3 1 2 3```"
    },
    {
        "role": "assistant",
        "content": "๔€€€๔€€๔€€‚๔€€ƒ"
    }
    {
        "role": "user",
        "content": "Please translate this mesh by 1 meter"
    },
    {
        "role": "assistant",
        "content": "๔€€„๔€€…๔€€†๔€€‡"
    },
    {
        "role": "user",
        "content": "Can you convert it back to .ply format?"
    },
    {
        "role": "assistant",
        "content": "Sure, here is your translated mesh in .ply format: ```ply\nformat ascii 1.0\nelement vertex 4\nproperty float x\nproperty float y\nproperty float z\nelement face 4\nproperty list uchar int vertex_index\nend_header\n1 0 0\n2 0 0\n1 1 0\n1 0 1\n3 0 1 2\n3 0 2 3\n3 0 1 3\n3 1 2 3```"
    }
]

We can also map the codebook to .ply in chatml to be more obvious to the large language model since .ply support is very common.

ply
format ascii 1.0
element vertex 4
property float x
property float y
property float z
element face 4
property list uchar int vertex_index
end_header
0 0 0
1 0 0
0 1 0
0 0 1
3 0 1 2
3 0 2 3
3 0 1 3
3 1 2 3

See chatml description in https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B and

.ply sample

The MeshTransformer does not generate coherent results

I have trained the MeshTransformer on 200 different meshes from the chair category on ShapeNet after decimation and filtering meshes with less than 400 vertices and faces. The MeshTransformer reached a loss very close to 0
image
But when I call the generate method from the MeshTransformer, I get very bad results.
From left to right, ground truth, autoencoder output, MeshTransformer generated mesh with a temperature of 0, with a temperature of 0.1, a temperature of 0.7 and a temperature of 1. This is done with meshgpt-pytorch version 0.3.3
image
Note: the MeshTransformer was not conditioned on text or anything, so the output is not supposed to exactly look like the sofa, but it barely look like a chair. We can guess the backrest and the legs but that's it.

Initially I thought that there might have been an error with the KV cache so here are the results with cache_kv=False:
image

And this one with meshgpt-pytorch version 0.2.11
image

When I trained on a single chair with a version before 0.2.11, the generate method was able to create a coherent chair (from left to right, ground truth, autoencoder output, meshtranformer.generate())

comparisons

Why even though the transformer loss was very low the generated results are very bad?

I have uploaded the autoencoder and meshtransformer checkpoint (on version 0.3.3) as well as 10 data samples there: https://file.io/nNsfTyHX4aFB

Also quick question, why rewrite the transformer from scratch, and not use the HuggingFace GPT2 transformer?

Handling EOS Token in Max Sequence Length Scenarios

Hi Phil,

Thank you for your valuable contributions to this project!

I'm encountering a problem with handling long sequences here, where the maximum sequence length is set to 2048. The issue arises when processing a batch of code samples; if the iteration reaches the maximum sequence length without appending an eos_token_id at the end of all samples, the subsequent code block that depend on this are skipped.

mask = is_eos_codes.float().cumsum(dim = -1) >= 1
codes = codes.masked_fill(mask, self.pad_id)
break

This results in the codes, inputted into later stages of the process, still containing an eos_token_id. This could potentially lead to errors in operations such as gather that follow.

Best regards,
Xueqi

SageConv & ResNet sizes

In the paper they implement a SageConv and Resnet 34, at the end of the paper they show the model architecture with the different in/out sizes.
All the SageConv & ResNets are the same sizes here so is there any reason why this wasn't mirrored?

There is some reasons for why it's best to mirror the sizes:

  • In the paper they explain the reason for the 192 dim codebook size, the output of the SageConv is 576, and dividing 576/192 = 3.
    Currently the output of the SageConv is the same as the dim (e.g 512) which then goes into project_dim_codebook (Linear) which has a out size of 576. By using 512 as dim size (512 / 192 = 2,6666666) for the out for the SageConv, it might cause some ineffectiveness since the SageConv might be able to correlate better than the Linear layer.

bild

  • Optimization, since they people that wrote the paper probably have experimented with different sizes and found what worked best.

  • The input for the embedding seems to be at F x 196 , would this mean that each face gets a 196 tensor? I'm confused about this since currently it project_in has the feature values: Linear(in_features=832, out_features=512, bias=True)
    Since there are 16 features and 196 / 16 = 12,5 which seems very low. Maybe you can figure out what this means :)

bild

Another thing that the implementation might be missing. According to the paper they sort the vertices in z-y-x order.
Then sort the faces as per their lowest vertex index.
I think this functionality belongs to the dataset class but anyway, I just wanted to highlight it.

Add pr for ArgumentParser CLI

Since my other proposals didn't mention what it would take to be merged, I will ask before opening prs to avoid wasted effort.

I made a change to meshgpt-pytorch https://github.com/V-Sekai-fire/meshgpt-pytorch/blob/main/run.py to use ArgumentParser so it's easier to use.

python3 run.py --dataset_directory ../cats_dataset
python3 run.py 
usage: run.py [-h] --dataset_directory DATASET_DIRECTORY [--data_augment DATA_AUGMENT] [--autoencoder_learning_rate AUTOENCODER_LEARNING_RATE] [--transformer_learning_rate TRANSFORMER_LEARNING_RATE] [--autoencoder_train AUTOENCODER_TRAIN] [--transformer_train TRANSFORMER_TRAIN] [--batch_size BATCH_SIZE] [--grad_accum_every GRAD_ACCUM_EVERY] [--checkpoint_every CHECKPOINT_EVERY] [--dim DIM] [--encoder_depth ENCODER_DEPTH] [--decoder_depth DECODER_DEPTH]
              [--num_discrete_coors NUM_DISCRETE_COORS] [--inference_only] [--autoencoder_path AUTOENCODER_PATH] [--transformer_path TRANSFORMER_PATH] [--num_quantizers NUM_QUANTIZERS] [--test_mode]
run.py: error: the following arguments are required: --dataset_directory

Can also be made into a executable via https://pyinstaller.org/en/stable/

Mesh to Mesh training

Similar to image to image, it should be possible to use this to make a model decimation tool. More complicated to less triangles.

Question

Hi everyone,

I had some questions about this method:

  • The transformer seems to learn about 3D meshes as a sequence, but the order of that sequence must have a lot of importance then, from what I understood all meshes are ordered from bottom to top (along z axis), is that correct ? Then if we pass as a prompt the middle part of the mesh, this will not work for completion ? Also, does the encoding depend on the order of the sequence ?
  • I built a toy dataset of 1000 shapes with 100 faces each (no text labels), I expected things to train quite efficiently there but while autoencoder seemed to train Ok, even after >24 hours train the transformer still has not converged to a reasonable loss. I feel like I may be missing something.

I hope people can help :)
Have a great day,

How to preprocess ShapeNet models?

Hi, thanks for your great work. But it's still not clear that how to preprocess the 3D models. Based on my understanding of the paper, the number of triangle faces will be reduced to 800 with a tool called "Blender". However, "blender" is not in the depencency list of this repo. Could you kindly share an example demonstrating how your team processed the original mesh data?

By the way, the computational cost of Autoencoder is a lot larger than GPT in my experiments. My guess is that the Graph CNN doesn't scale well when the number of nodes grows large (i.e. too many faces in the 3D model). Bluring the model is one solution, but is it possible to break one giant 3D model to pieces and treat each piece as an independent model?

Sliding window for transformer

Hi,
I was wondering if it was possible to implement a sliding window decoder for the transformer?
When increasing the max sequences length, the training time goes up dramatic and and I think that using a sliding decoder would greatly help with the training and inference speed.

I've tried using LocalAttention but I'm not sure how to properly implement it since it inputs q, k and v.

I know @lucidrains have already spent all their allotted timed and more for this project so if I could be provided with some tips I could try to implement it.

How to view the generated mesh?

Hi, thank you for your code!!
I want to view the mesh and load the faces data into blender. However, the outcome is puzzling. Many faces are overlapping each other.
image
Is there anything I need to do after I got faces_coordinates?

Issue with Multiple `eos_token_id` in Code Sequences

Hi, Phil,

I've noticed a potential issue in the code where multiple eos_token_id are being added to the code sequences, instead of just one at the end of each row. It seems to insert four eos_token_id values. Could you please clarify if this is the expected behavior?

batch_arange = torch.arange(batch, device = device)
batch_arange = rearrange(batch_arange, '... -> ... 1')

codes[batch_arange, code_lens] = self.eos_token_id

Best regards,
Xueqi

Commit loss is negative

image

When I trained on several objects with several epochs, the commit loss starts to become negative, and it turns out that the overall loss keeps going down, but neither the recon loss nor the reconstruction result turns better.

I wonder if the commit loss being negative is normal or not, or what it implies

`face_mask` rearrange in Autoencoder Forward Pass

Hi Phil,

Happy New Year! ๐ŸŽ‰๐ŸŽ‰๐ŸŽ‰

I encountered a minor issue in the code when I set return_recon_faces to True. During the execution of the autoencoder's forward method, I noticed that face_mask is rearranged.

face_mask = rearrange(face_mask, 'b nf -> b nf 1 1')
recon_faces = recon_faces.masked_fill(~face_mask, float('nan'))

However, after applying masked_fill to recon_faces, it seems that the shape of face_mask isn't restored. This is causing an error when trying to perform a repeat operation on face_mask later in the process.

face_mask = repeat(face_mask, 'b nf -> b (nf r)', r = 9)

Perhaps the shape of face_mask could be restored right after the masked_fill operation.

Best regards,
Xueqi Ma

transformer.generate()

Thanks for your grear work!When I run transformer.generate(),it shows as below,how can I solve it?
image

MeshTransformer.generate does not work with a prompt if kv cache is enabled

I wanted to experiment how the MeshTransformer is able to complete a mesh by giving the initial codes, but there is a problem where I think the prompt codes are not correctly given down the line. Here is a small debug code:

vertices = torch.randn(2, 100, 3)
faces = torch.randint(0, 100, (2, 100, 3))
# gpt = # Load MeshTransformer from checkpoint
codes = gpt.autoencoder.tokenize(vertices=vertices, faces=faces)
generated = gpt.generate(prompt=codes)

It gives the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[421], [line 1](vscode-notebook-cell:?execution_count=421&line=1)
----> [1](vscode-notebook-cell:?execution_count=421&line=1) generated = gpt.generate(prompt=codes)

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\x_transformers\autoregressive_wrapper.py:27](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:27), in eval_decorator.<locals>.inner(self, *args, **kwargs)
     [25](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:25) was_training = self.training
     [26](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:26) self.eval()
---> [27](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:27) out = fn(self, *args, **kwargs)
     [28](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:28) self.train(was_training)
     [29](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:29) return out

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\utils\_contextlib.py:115](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    [112](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
    [113](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
    [114](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:114)     with ctx_factory():
--> [115](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:115)         return func(*args, **kwargs)

File <@beartype(meshgpt_pytorch.meshgpt_pytorch.MeshTransformer.generate) at 0x2239b61c9d0>:170, in generate(__beartype_func, __beartype_conf, __beartype_get_violation, __beartype_object_2352326455808, __beartype_object_2350005736832, __beartype_object_2349955153872, __beartype_object_140723080033008, __beartype_getrandbits, *args, **kwargs)

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\meshgpt_pytorch\meshgpt_pytorch.py:1238](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1238), in MeshTransformer.generate(self, prompt, batch_size, filter_logits_fn, filter_kwargs, temperature, return_codes, texts, text_embeds, cond_scale, cache_kv, face_coords_to_file)
   [1233](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1233) for i in tqdm(range(curr_length, self.max_seq_len)):
   [1234](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1234)     # v1([q1] [q2] [q1] [q2] [q1] [q2]) v2([eos| q1] [q2] [q1] [q2] [q1] [q2]) -> 0 1 2 3 4 5 6 7 8 9 10 11 12 -> v1(F F F F F F) v2(T F F F F F) v3(T F F F F F)
   [1236](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1236)     can_eos = i != 0 and divisible_by(i, self.num_quantizers * 3)  # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residual VQ codes
-> [1238](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1238)     output = self.forward_on_codes(
   [1239](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1239)         codes,
   [1240](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1240)         text_embeds = text_embeds,
   [1241](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1241)         return_loss = False,
   [1242](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1242)         return_cache = cache_kv,
   [1243](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1243)         append_eos = False,
   [1244](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1244)         cond_scale = cond_scale,
   [1245](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1245)         cfg_routed_kwargs = dict(
   [1246](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1246)             cache = cache
   [1247](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1247)         )
   [1248](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1248)     )
   [1250](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1250)     if cache_kv:
   [1251](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1251)         logits, cache = output

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\classifier_free_guidance_pytorch\classifier_free_guidance_pytorch.py:152](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:152), in classifier_free_guidance.<locals>.inner(self, cond_scale, rescale_phi, cfg_routed_kwargs, *args, **kwargs)
    [148](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:148) null_fn_kwargs = {k: v[1] for k, v in cfg_routed_kwargs.items()}
    [150](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:150) # non-null forward
--> [152](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:152) outputs = fn_maybe_with_text(self, *args, **fn_kwargs, **kwargs_without_cond_dropout)
    [154](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:154) if cond_scale == 1:
    [155](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:155)     return outputs

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\classifier_free_guidance_pytorch\classifier_free_guidance_pytorch.py:130](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:130), in classifier_free_guidance.<locals>.inner.<locals>.fn_maybe_with_text(self, *args, **kwargs)
    [127](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:127)     if 'raw_text_cond' in fn_params:
    [128](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:128)         kwargs.update(raw_text_cond = raw_text_cond)
--> [130](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:130) return fn(self, *args, **kwargs)

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\meshgpt_pytorch\meshgpt_pytorch.py:1514](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1514), in MeshTransformer.forward_on_codes(self, codes, return_loss, return_cache, append_eos, cache, texts, text_embeds, cond_drop_prob)
   [1511](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1511) if one_face:
   [1512](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1512)     fine_vertex_codes = fine_vertex_codes[:, :(curr_vertex_pos + 1)]
-> [1514](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1514) attended_vertex_codes, fine_cache = self.fine_decoder(
   [1515](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1515)     fine_vertex_codes,
   [1516](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1516)     cache = fine_cache,
   [1517](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1517)     return_hiddens = True
   [1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1518) )
   [1520](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1520) if not should_cache_fine:
   [1521](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1521)     fine_cache = None

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\nn\modules\module.py:1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1516)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1518)     return self._call_impl(*args, **kwargs)

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\nn\modules\module.py:1527](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
   [1522](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1522) # If we don't have any hooks, we want to skip the rest of the logic in
   [1523](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1523) # this function, and just call forward.
   [1524](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1524) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1525](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1525)         or _global_backward_pre_hooks or _global_backward_hooks
   [1526](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1526)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1527)     return forward_call(*args, **kwargs)
   [1529](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1529) try:
   [1530](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1530)     result = None

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\x_transformers\x_transformers.py:1299](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1299), in AttentionLayers.forward(self, x, context, mask, context_mask, attn_mask, self_attn_kv_mask, mems, seq_start_pos, cache, cache_age, return_hiddens, rotary_pos_emb)
   [1296](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1296)     x = pre_norm(x)
   [1298](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1298) if layer_type == 'a':
-> [1299](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1299)     out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, return_intermediates = True)
   [1300](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1300) elif layer_type == 'c':
   [1301](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1301)     out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\nn\modules\module.py:1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   [1516](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1516)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1517](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1518)     return self._call_impl(*args, **kwargs)

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\nn\modules\module.py:1527](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
   [1522](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1522) # If we don't have any hooks, we want to skip the rest of the logic in
   [1523](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1523) # this function, and just call forward.
   [1524](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1524) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1525](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1525)         or _global_backward_pre_hooks or _global_backward_hooks
   [1526](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1526)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1527)     return forward_call(*args, **kwargs)
   [1529](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1529) try:
   [1530](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1530)     result = None

File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\x_transformers\x_transformers.py:832](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:832), in Attention.forward(self, x, context, mask, context_mask, attn_mask, rel_pos, rotary_pos_emb, prev_attn, mem, return_intermediates, cache)
    [829](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:829)     mk, k = unpack(k, mem_packed_shape, 'b h * d')
    [830](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:830)     mv, v = unpack(v, mem_packed_shape, 'b h * d')
--> [832](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:832) k = torch.cat((ck, k), dim = -2)
    [833](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:833) v = torch.cat((cv, v), dim = -2)
    [835](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:835) if exists(mem):

RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 202 but got size 2 for tensor number 1 in the list.

If I disable however the kv cache with generated = gpt.generate(prompt=codes, cache_kv=False), it works (albeit being slow).

With the cache, in x_transformers > Attention > forward, ck.shape=[202,16,6,64] and k.shape=[2, 16, 1, 64] causing the shape mismatch error (same shapes for cv and v after)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.