Giter VIP home page Giter VIP logo

Comments (11)

jlamprou avatar jlamprou commented on July 27, 2024 2

@Beomi I run some tests to check the validity of segmenting on the training loop, I tested the accuracy at every batch using the concat of logits and labels to check if the accuracy on the total sequence length is improving during training and once the learnable beta got some data we got the same accuracy rate with normal SDPA attention. Check the implementation on the my repo repo

from infinitransformer.

Beomi avatar Beomi commented on July 27, 2024

exactly! I'm currently implementing part-by-part, so yes, it's not implemented yet :)

from infinitransformer.

jlamprou avatar jlamprou commented on July 27, 2024

I've been working on an implementation since the paper release and the only part im having problems with is the segment-wise part,using your implementation i segment the input and feed each segment to the self.attn with a for loop but on the second segment i get a mismatch on memory_output = memory_output / norm_term_expanded, at dimension 3 the memory_output is = head_dim but norm_term_expanded=head_dim*num_heads, so i dont know if my logic of segmentation is wrong or your Infi-attention implementation doesnt account for the accumulation of segments. I would be grateful if you have any tips

from infinitransformer.

zzr-idam avatar zzr-idam commented on July 27, 2024

Although we attempted the segmentation approach, the inference of the model is very slow, any suggestions?

from infinitransformer.

jlamprou avatar jlamprou commented on July 27, 2024

@Beomi I think that the segmentation is not supposed to happen inside the attention but before passing the inputs to the whole transformer block. Passing the whole input on the attention still requires to load the whole sequence in the VRAM .If we assume a Huggingface like model I'd say either on the decoder layer class or the model class.
Screenshot_20240414-160553_Brave

from infinitransformer.

Beomi avatar Beomi commented on July 27, 2024

@jlamprou You're right, maybe I have to edit decoder layer class, thanks for the note.

from infinitransformer.

jlamprou avatar jlamprou commented on July 27, 2024

@Beomi more likely the Model class, per my understanding we segment the input and we feed each segment to the decoder layers. What i'm not so sure about is how do we manage the compressed memory during the backward pass. I don't think we need gradients for the compressed memory so we should probably not directly assign self.norm_term and self.memory on the memory update but create new variable and then assign to self.memory, self.norm_term with either detach or torch.no_grad().

from infinitransformer.

Beomi avatar Beomi commented on July 27, 2024

@jlamprou I've been reconsidering your point that "Passing the entire input to the attention mechanism still requires loading the whole sequence into the VRAM." However, I believe that regardless of the method chosen, we end up loading all the input into VRAM eventually, and this could be O(N) (where N is the input length). The key issue, though, is that the paper aims to reduce memory usage in the quadratic component, which is the usual size due to the self attention. Thus, even when we segment within the attention loop, the overall memory size of the input sequence may be larger, but it has to reside somewhere—either on the CPU or GPU. Therefore, worrying about linear incremental memory usage isn't as crucial, since the vram usage part of the attention is fixed. How do you think?

from infinitransformer.

Beomi avatar Beomi commented on July 27, 2024

@zzr-idam In the published paper, they mentioned that they used "in this work we parameterize the memory with an associative matrix / cast the memory update and retrieval process as linear attention mechanism / we adopt the update rule
and retrieval mechanism by Katharopoulos et al. (2020) mainly due to its simplicity and competitive performance", so they might used Katharopoulos et al. (2020)(https://arxiv.org/pdf/2006.16236.pdf)

This repo is not implemented that paper's method yet, and I think that's the reason for slow inference.

from infinitransformer.

jlamprou avatar jlamprou commented on July 27, 2024

@Beomi I'm testing right now both ways of implementing the segmentation. You are right based on my tests, the VRAM usage difference is small, with the segmentation inside the Attention consuming just about 1GB extra but with better throughput. So probably its best to keep the current implementation. The actually weird thing is that classic SDPA attention(as is from the original huggingface implementation) consumes the same amount of VRAM too, no segmenting or anything... We should probably take a look at this Memformers - Pytorch which implements a recurrent trainer .Maybe the segmentation shouldn't happen in the model at all, but in the training loop? The paper states: "We set the Infini-attention segment length N to 2048 for all attention layers and the input sequence length to 32768 for training. This allows the Infini-attention to unroll over 16 steps w.r.t its compressive memory states." which could mean training steps.

from infinitransformer.

Beomi avatar Beomi commented on July 27, 2024

@jlamprou Hi, I think its time to consider open both options toward end-users let select which way would be beneficial.

As you said before, the attention itself does not have a small fraction of memory but other data input processing such as MLP layer or even embedding increases vram usage, which makes hard to get a bigger block size.

In my experience(training code), vram usage required almost same as original implementation, so maybe your implementation direction would be more helpful in terms of vram usage.

or, maybe there would be a room for make it like an adapter(PEFT style)? How do you think?

from infinitransformer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.