Comments (11)
@Beomi I run some tests to check the validity of segmenting on the training loop, I tested the accuracy at every batch using the concat of logits and labels to check if the accuracy on the total sequence length is improving during training and once the learnable beta got some data we got the same accuracy rate with normal SDPA attention. Check the implementation on the my repo repo
from infinitransformer.
exactly! I'm currently implementing part-by-part, so yes, it's not implemented yet :)
from infinitransformer.
I've been working on an implementation since the paper release and the only part im having problems with is the segment-wise part,using your implementation i segment the input and feed each segment to the self.attn with a for loop but on the second segment i get a mismatch on memory_output = memory_output / norm_term_expanded, at dimension 3 the memory_output is = head_dim but norm_term_expanded=head_dim*num_heads, so i dont know if my logic of segmentation is wrong or your Infi-attention implementation doesnt account for the accumulation of segments. I would be grateful if you have any tips
from infinitransformer.
Although we attempted the segmentation approach, the inference of the model is very slow, any suggestions?
from infinitransformer.
@Beomi I think that the segmentation is not supposed to happen inside the attention but before passing the inputs to the whole transformer block. Passing the whole input on the attention still requires to load the whole sequence in the VRAM .If we assume a Huggingface like model I'd say either on the decoder layer class or the model class.
from infinitransformer.
@jlamprou You're right, maybe I have to edit decoder layer class, thanks for the note.
from infinitransformer.
@Beomi more likely the Model class, per my understanding we segment the input and we feed each segment to the decoder layers. What i'm not so sure about is how do we manage the compressed memory during the backward pass. I don't think we need gradients for the compressed memory so we should probably not directly assign self.norm_term and self.memory on the memory update but create new variable and then assign to self.memory, self.norm_term with either detach or torch.no_grad().
from infinitransformer.
@jlamprou I've been reconsidering your point that "Passing the entire input to the attention mechanism still requires loading the whole sequence into the VRAM." However, I believe that regardless of the method chosen, we end up loading all the input into VRAM eventually, and this could be O(N) (where N is the input length). The key issue, though, is that the paper aims to reduce memory usage in the quadratic component, which is the usual size due to the self attention. Thus, even when we segment within the attention loop, the overall memory size of the input sequence may be larger, but it has to reside somewhere—either on the CPU or GPU. Therefore, worrying about linear incremental memory usage isn't as crucial, since the vram usage part of the attention is fixed. How do you think?
from infinitransformer.
@zzr-idam In the published paper, they mentioned that they used "in this work we parameterize the memory with an associative matrix / cast the memory update and retrieval process as linear attention mechanism / we adopt the update rule
and retrieval mechanism by Katharopoulos et al. (2020) mainly due to its simplicity and competitive performance", so they might used Katharopoulos et al. (2020)(https://arxiv.org/pdf/2006.16236.pdf)
This repo is not implemented that paper's method yet, and I think that's the reason for slow inference.
from infinitransformer.
@Beomi I'm testing right now both ways of implementing the segmentation. You are right based on my tests, the VRAM usage difference is small, with the segmentation inside the Attention consuming just about 1GB extra but with better throughput. So probably its best to keep the current implementation. The actually weird thing is that classic SDPA attention(as is from the original huggingface implementation) consumes the same amount of VRAM too, no segmenting or anything... We should probably take a look at this Memformers - Pytorch which implements a recurrent trainer .Maybe the segmentation shouldn't happen in the model at all, but in the training loop? The paper states: "We set the Infini-attention segment length N to 2048 for all attention layers and the input sequence length to 32768 for training. This allows the Infini-attention to unroll over 16 steps w.r.t its compressive memory states." which could mean training steps.
from infinitransformer.
@jlamprou Hi, I think its time to consider open both options toward end-users let select which way would be beneficial.
As you said before, the attention itself does not have a small fraction of memory but other data input processing such as MLP layer or even embedding increases vram usage, which makes hard to get a bigger block size.
In my experience(training code), vram usage required almost same as original implementation, so maybe your implementation direction would be more helpful in terms of vram usage.
or, maybe there would be a room for make it like an adapter(PEFT style)? How do you think?
from infinitransformer.
Related Issues (20)
- Discord server for this?
- Code not running on GPU HOT 6
- config no attn_implementation = "eager" HOT 4
- question about norm_term_broadcastable HOT 5
- load model failed HOT 4
- Suggest to use the constant memory gradient computation in Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
- Model generating random sequence HOT 8
- Limitations of the method HOT 2
- Memory should be per layer
- Memory does not use PE
- Inference code (with Segments)
- Are there any trained InfinityTransformer weights available?
- Segment and block size error HOT 1
- mem and norm_term is nan? HOT 15
- What is the min GPU memory required to fine-tune the model?
- About memory missing location information HOT 5
- BitLinear
- Model loses information very quickly HOT 2
- Issue while runing test_train.small.gemma.infini.py HOT 2
- Support Zero-3? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from infinitransformer.