Giter VIP home page Giter VIP logo

Comments (14)

lucidrains avatar lucidrains commented on May 29, 2024

can you show the full condensed script for which it errors?

from meshgpt-pytorch.

fighting-Zhang avatar fighting-Zhang commented on May 29, 2024

generate_samples_v1.py:

from meshgpt_pytorch import (
    MeshAutoencoder,
    MeshTransformer
)

autoencoder = MeshAutoencoder.init_and_load('./exps/mesh-autoencoder.ckpt.90.pt')

transformer = MeshTransformer(
    autoencoder,
    dim = 512,
    max_seq_len = 12000, #8192,
    flash_attn = True,
    gateloop_use_heinsen = False, 
    condition_on_text = False
).cuda()
transformer.load('./checkpoints/mesh-transformer.ckpt.5.pt')

face_coords, face_mask = transformer.generate(temperature=0.5)

meshgpt_pytorch.py is version 0.6.7

MeshTransformer.generate() :

@eval_decorator
 @torch.no_grad()
 @beartype
 def generate(
     self,
     prompt: Optional[Tensor] = None,
     batch_size: Optional[int] = None,
     filter_logits_fn: Callable = top_k,
     filter_kwargs: dict = dict(),
     temperature = 1.,
     return_codes = False,
     texts: Optional[List[str]] = None,
     text_embeds: Optional[Tensor] = None,
     cond_scale = 1.,
     cache_kv = True,
     max_seq_len = None,
     face_coords_to_file: Optional[Callable[[Tensor], Any]] = None
 ):
     max_seq_len = default(max_seq_len, self.max_seq_len)

     if exists(prompt):
         assert not exists(batch_size)

         prompt = rearrange(prompt, 'b ... -> b (...)')
         assert prompt.shape[-1] <= self.max_seq_len

         batch_size = prompt.shape[0]

     if self.condition_on_text:
         assert exists(texts) ^ exists(text_embeds), '`text` or `text_embeds` must be passed in if `condition_on_text` is set to True'
         if exists(texts):
             text_embeds = self.embed_texts(texts)

         batch_size = default(batch_size, text_embeds.shape[0])

     batch_size = default(batch_size, 1)

     codes = default(prompt, torch.empty((batch_size, 0), dtype = torch.long, device = self.device))
     
     curr_length = codes.shape[-1]

     cache = (None, None)

     for i in tqdm(range(curr_length, max_seq_len)):
         # v1([q1] [q2] [q1] [q2] [q1] [q2]) v2([eos| q1] [q2] [q1] [q2] [q1] [q2]) -> 0 1 2 3 4 5 6 7 8 9 10 11 12 -> v1(F F F F F F) v2(T F F F F F) v3(T F F F F F)

         can_eos = i != 0 and divisible_by(i, self.num_quantizers * 3)  # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residual VQ codes

         output = self.forward_on_codes(
             codes,
             text_embeds = text_embeds,
             return_loss = False,
             return_cache = cache_kv,
             append_eos = False,
             cond_scale = cond_scale,
             cfg_routed_kwargs = dict(
                 cache = cache
             )
         )

         if cache_kv:
             logits, cache = output

             if cond_scale == 1.:
                 cache = (cache, None)
         else:
             logits = output

         logits = logits[:, -1]

         if not can_eos:
             logits[:, -1] = -torch.finfo(logits.dtype).max

         filtered_logits = filter_logits_fn(logits, **filter_kwargs)

         if temperature == 0.:
             sample = filtered_logits.argmax(dim = -1)
         else:
             probs = F.softmax(filtered_logits / temperature, dim = -1)
             sample = torch.multinomial(probs, 1)
         
         codes, _ = pack([codes, sample], 'b *')

         # check for all rows to have [eos] to terminate

         is_eos_codes = (codes == self.eos_token_id)

         if is_eos_codes.any(dim = -1).all():
             break

     # mask out to padding anything after the first eos

     mask = is_eos_codes.float().cumsum(dim = -1) >= 1
     codes = codes.masked_fill(mask, self.pad_id)

     # remove a potential extra token from eos, if breaked early

     code_len = codes.shape[-1]
     round_down_code_len = code_len // self.num_quantizers * self.num_quantizers
     codes = codes[:, :round_down_code_len]

     # early return of raw residual quantizer codes

     if return_codes:
         codes = rearrange(codes, 'b (n q) -> b n q', q = self.num_quantizers)
         return codes

     self.autoencoder.eval()
     face_coords, face_mask = self.autoencoder.decode_from_codes_to_faces(codes)

     if not exists(face_coords_to_file):
         return face_coords, face_mask

     files = [face_coords_to_file(coords[mask]) for coords, mask in zip(face_coords, face_mask)]
     return files

MeshTransformer.forward_on_codes() :

@classifier_free_guidance
    def forward_on_codes(
        self,
        codes = None,
        return_loss = True,
        return_cache = False,
        append_eos = True,
        cache = None,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[Tensor] = None,
        cond_drop_prob = 0.
    ):
        # handle text conditions

        attn_context_kwargs = dict()

        if self.condition_on_text:
            assert exists(texts) ^ exists(text_embeds), '`text` or `text_embeds` must be passed in if `condition_on_text` is set to True'

            if exists(texts):
                text_embeds = self.conditioner.embed_texts(texts)

            if exists(codes):
                assert text_embeds.shape[0] == codes.shape[0], 'batch size of texts or text embeddings is not equal to the batch size of the mesh codes'

            _, maybe_dropped_text_embeds = self.conditioner(
                text_embeds = text_embeds,
                cond_drop_prob = cond_drop_prob
            )

            attn_context_kwargs = dict(
                context = maybe_dropped_text_embeds.embed,
                context_mask = maybe_dropped_text_embeds.mask
            )

        # take care of codes that may be flattened

        if codes.ndim > 2:
            codes = rearrange(codes, 'b ... -> b (...)')

        # get some variable

        batch, seq_len, device = *codes.shape, codes.device

        assert seq_len <= self.max_seq_len, f'received codes of length {seq_len} but needs to be less than or equal to set max_seq_len {self.max_seq_len}'

        # auto append eos token

        if append_eos:
            assert exists(codes)

            code_lens = ((codes == self.pad_id).cumsum(dim = -1) == 0).sum(dim = -1)

            codes = F.pad(codes, (0, 1), value = 0)

            batch_arange = torch.arange(batch, device = device)

            batch_arange = rearrange(batch_arange, '... -> ... 1')
            code_lens = rearrange(code_lens, '... -> ... 1')

            codes[batch_arange, code_lens] = self.eos_token_id

        # if returning loss, save the labels for cross entropy

        if return_loss:
            assert seq_len > 0
            codes, labels = codes[:, :-1], codes

        # token embed (each residual VQ id)

        codes = codes.masked_fill(codes == self.pad_id, 0)
        codes = self.token_embed(codes)

        # codebook embed + absolute positions

        seq_arange = torch.arange(codes.shape[-2], device = device)

        codes = codes + self.abs_pos_emb(seq_arange)

        # embedding for quantizer level

        code_len = codes.shape[1]

        level_embed = repeat(self.quantize_level_embed, 'q d -> (r q) d', r = ceil(code_len / self.num_quantizers))
        codes = codes + level_embed[:code_len]

        # embedding for each vertex

        vertex_embed = repeat(self.vertex_embed, 'nv d -> (r nv q) d', r = ceil(code_len / (3 * self.num_quantizers)), q = self.num_quantizers)
        codes = codes + vertex_embed[:code_len]

        # create a token per face, by summarizing the 3 vertices
        # this is similar in design to the RQ transformer from Lee et al. https://arxiv.org/abs/2203.01941

        num_tokens_per_face = self.num_quantizers * 3

        curr_vertex_pos = code_len % num_tokens_per_face # the current intra-face vertex-code position id, needed for caching at the fine decoder stage

        code_len_is_multiple_of_face = divisible_by(code_len, num_tokens_per_face)

        next_multiple_code_len = ceil(code_len / num_tokens_per_face) * num_tokens_per_face

        codes = pad_to_length(codes, next_multiple_code_len, dim = -2)

        # grouped codes will be used for the second stage

        grouped_codes = rearrange(codes, 'b (nf n) d -> b nf n d', n = num_tokens_per_face)

        # create the coarse tokens for the first attention network

 
        face_codes = grouped_codes if code_len_is_multiple_of_face else grouped_codes[:, :-1]


        face_codes = rearrange(face_codes, 'b nf n d -> b nf (n d)')
        face_codes = self.to_face_tokens(face_codes)

        face_codes_len = face_codes.shape[-2]

        # cache logic

        (
            cached_attended_face_codes,
            coarse_cache,
            fine_cache,
            coarse_gateloop_cache,
            fine_gateloop_cache
        ) = cache if exists(cache) else ((None,) * 5)

        if exists(cache):
            cached_face_codes_len = cached_attended_face_codes.shape[-2]
            need_call_first_transformer = face_codes_len > cached_face_codes_len
        else:
            need_call_first_transformer = True

        should_cache_fine = not divisible_by(curr_vertex_pos + 1, num_tokens_per_face)

        # attention on face codes (coarse)

        if need_call_first_transformer:
            if exists(self.coarse_gateloop_block):
                face_codes, coarse_gateloop_cache = self.coarse_gateloop_block(face_codes, cache = coarse_gateloop_cache)

            attended_face_codes, coarse_cache = self.decoder(
                face_codes,
                cache = coarse_cache,
                return_hiddens = True,
                **attn_context_kwargs
            )

            attended_face_codes = safe_cat((cached_attended_face_codes, attended_face_codes), dim = -2)
        else:
            attended_face_codes = cached_attended_face_codes

        # maybe project from coarse to fine dimension for hierarchical transformers

        attended_face_codes = self.maybe_project_coarse_to_fine(attended_face_codes)

        # auto prepend sos token

        sos = repeat(self.sos_token, 'd -> b d', b = batch)

       
        attended_face_codes_with_sos, _ = pack([sos, attended_face_codes], 'b * d')


        grouped_codes = pad_to_length(grouped_codes, attended_face_codes_with_sos.shape[-2], dim = 1)
        fine_vertex_codes, _ = pack([attended_face_codes_with_sos, grouped_codes], 'b n * d')

        fine_vertex_codes = fine_vertex_codes[..., :-1, :]

        # gateloop layers

        if exists(self.fine_gateloop_block):
            fine_vertex_codes = rearrange(fine_vertex_codes, 'b nf n d -> b (nf n) d')
            orig_length = fine_vertex_codes.shape[-2]
            fine_vertex_codes = fine_vertex_codes[:, :(code_len + 1)]

            fine_vertex_codes, fine_gateloop_cache = self.fine_gateloop_block(fine_vertex_codes, cache = fine_gateloop_cache)

            fine_vertex_codes = pad_to_length(fine_vertex_codes, orig_length, dim = -2)
            fine_vertex_codes = rearrange(fine_vertex_codes, 'b (nf n) d -> b nf n d', n = num_tokens_per_face)

        # fine attention - 2nd stage

        if exists(cache):
            fine_vertex_codes = fine_vertex_codes[:, -1:]

            if exists(fine_cache):
                for attn_intermediate in fine_cache.attn_intermediates:
                    ck, cv = attn_intermediate.cached_kv
                    ck, cv = map(lambda t: rearrange(t, '(b nf) ... -> b nf ...', b = batch), (ck, cv))
                    ck, cv = map(lambda t: t[:, -1, :, :curr_vertex_pos], (ck, cv))
                    attn_intermediate.cached_kv = (ck, cv)

        one_face = fine_vertex_codes.shape[1] == 1

        fine_vertex_codes = rearrange(fine_vertex_codes, 'b nf n d -> (b nf) n d')

        if one_face:
            fine_vertex_codes = fine_vertex_codes[:, :(curr_vertex_pos + 1)]

        attended_vertex_codes, fine_cache = self.fine_decoder(
            fine_vertex_codes,
            cache = fine_cache,
            return_hiddens = True
        )

        if not should_cache_fine:
            fine_cache = None

        if not one_face:
            # reconstitute original sequence

            embed = rearrange(attended_vertex_codes, '(b nf) n d -> b (nf n) d', b = batch)
            embed = embed[:, :(code_len + 1)]
        else:
            embed = attended_vertex_codes

        # logits

        logits = self.to_logits(embed)

        if not return_loss:
            if not return_cache:
                return logits

            next_cache = (
                attended_face_codes,
                coarse_cache,
                fine_cache,
                coarse_gateloop_cache,
                fine_gateloop_cache
            )

            return logits, next_cache

        # loss

        ce_loss = F.cross_entropy(
            rearrange(logits, 'b n c -> b c n'),
            labels,
            ignore_index = self.pad_id
        )

        return ce_loss

from meshgpt-pytorch.

lucidrains avatar lucidrains commented on May 29, 2024

@fighting-Zhang it works fine for me

can you update to 1.0 and retry?

from meshgpt-pytorch.

fighting-Zhang avatar fighting-Zhang commented on May 29, 2024

My problem mainly occurs when entering empty into self.decoder.
In the code below, face_codes.size()= [1,0,512]. What are your dimensions?

if need_call_first_transformer:
            if exists(self.coarse_gateloop_block):
                face_codes, coarse_gateloop_cache = self.coarse_gateloop_block(face_codes, cache = coarse_gateloop_cache)

            attended_face_codes, coarse_cache = self.decoder(
                face_codes,
                cache = coarse_cache,
                return_hiddens = True,
                **attn_context_kwargs
            )

            attended_face_codes = safe_cat((cached_attended_face_codes, attended_face_codes), dim = -2)
        else:
            attended_face_codes = cached_attended_face_codes

from meshgpt-pytorch.

lucidrains avatar lucidrains commented on May 29, 2024

@fighting-Zhang what version of x-transformers are you using?

from meshgpt-pytorch.

fighting-Zhang avatar fighting-Zhang commented on May 29, 2024

1.27.3

from meshgpt-pytorch.

lucidrains avatar lucidrains commented on May 29, 2024

@fighting-Zhang does the very first example in the readme run for you?

from meshgpt-pytorch.

fighting-Zhang avatar fighting-Zhang commented on May 29, 2024

wow, amazing!
The very first example works fine.

from meshgpt-pytorch.

fighting-Zhang avatar fighting-Zhang commented on May 29, 2024

But when I put the data and model on CUDA, I got the above error.

from meshgpt-pytorch.

lucidrains avatar lucidrains commented on May 29, 2024

wow, amazing! The very first example works fine.

well, the very first example is also promptless. so i don't think that's the issue

from meshgpt-pytorch.

fighting-Zhang avatar fighting-Zhang commented on May 29, 2024

Thank you for your patient answer.
I will continue to look for ways to solve the cuda error.

from meshgpt-pytorch.

MarcusLoppe avatar MarcusLoppe commented on May 29, 2024

@fighting-Zhang
Is the Autoencoder also on the GPU?

from meshgpt-pytorch.

fighting-Zhang avatar fighting-Zhang commented on May 29, 2024

@MarcusLoppe yes
generate_sample_v0.py :

import torch

from meshgpt_pytorch import (
    MeshAutoencoder,
    MeshTransformer
)

# autoencoder

autoencoder = MeshAutoencoder(
    num_discrete_coors = 128
).cuda()

# mock inputs

vertices = torch.randn((2, 121, 3)).cuda()            # (batch, num vertices, coor (3))
faces = torch.randint(0, 121, (2, 64, 3)).cuda()      # (batch, num faces, vertices (3))

# make sure faces are padded with `-1` for variable lengthed meshes

# forward in the faces

loss = autoencoder(
    vertices = vertices,
    faces = faces
)

loss.backward()

# after much training...
# you can pass in the raw face data above to train a transformer to model this sequence of face vertices

transformer = MeshTransformer(
    autoencoder,
    dim = 512,
    max_seq_len = 768
).cuda()

loss = transformer(
    vertices = vertices,
    faces = faces
)

loss.backward()

# after much training of transformer, you can now sample novel 3d assets

faces_coordinates, face_mask = transformer.generate()

Error message :

Traceback (most recent call last):
File "/code/mesh-auto/generate_sample_v0.py", line 48, in
faces_coordinates, face_mask = transformer.generate()
File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/x_transformers/autoregressive_wrapper.py", line 27, in inner
out = fn(self, *args, **kwargs)
File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "<@beartype(meshgpt_pytorch.meshgpt_pytorch.MeshTransformer.generate) at 0x7f05a2f83760>", line 170, in generate
File "/code/mesh-auto/meshgpt_pytorch/meshgpt_pytorch.py", line 1186, in generate
output = self.forward_on_codes(
File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py", line 153, in inner
outputs = fn_maybe_with_text(self, *args, **fn_kwargs, **kwargs_without_cond_dropout)
File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py", line 131, in fn_maybe_with_text
return fn(self, *args, **kwargs)
File "/code/mesh-auto/meshgpt_pytorch/meshgpt_pytorch.py", line 1413, in forward_on_codes
attended_face_codes, coarse_cache = self.decoder(
File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/x_transformers/x_transformers.py", line 1336, in forward
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/x_transformers/x_transformers.py", line 944, in forward
out, intermediates = self.attend(
File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/x_transformers/attend.py", line 274, in forward
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/x_transformers/attend.py", line 214, in flash_attn
out = F.scaled_dot_product_attention(
RuntimeError: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

from meshgpt-pytorch.

MarcusLoppe avatar MarcusLoppe commented on May 29, 2024

My problem mainly occurs when entering empty into self.decoder. In the code below, face_codes.size()= [1,0,512]. What are your dimensions?

@fighting-Zhang
I've checked and I also get the same shape but I've run the example you provided and it works for me.

I'm guessing you'll need to reinstall meshgpt with all the dependencies or the GPU your using isn't compatible.
So give it a go with the reinstall otherwise can you say what GPU you are using along with the pytorch & CUDA version?

from meshgpt-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.