Comments (10)
Hi Enrico, it seems checkpointing / rematerialization is not included. Checkpointing is required for memory saving because otherwise total memory cost of all blocks accumulates to the same memory cost as vanilla transformer in forward and backward passes. The corresponding code is https://github.com/lhao499/blockwise-parallel-transformer/blob/4a668d5436adc5263df6786255c2bd684749aae2/bpt/blocks/blockwise_parallel.py#L446
You can check out PyTorch checkpoint for this purpose: https://pytorch.org/docs/stable/checkpoint.html.
It's also worth optimizing for loop with torch.compile / torch.jit.trace, the corresponding Jax code is https://github.com/lhao499/blockwise-parallel-transformer/blob/4a668d5436adc5263df6786255c2bd684749aae2/bpt/blocks/blockwise_parallel.py#L453
from ringattention.
Hi Enrico, it seems checkpointing / rematerialization is not included. Checkpointing is required for memory saving because otherwise total memory cost of all blocks accumulates to the same memory cost as vanilla transformer in forward and backward passes. The corresponding code is
You can check out PyTorch checkpoint for this purpose: https://pytorch.org/docs/stable/checkpoint.html.
It's also worth optimizing for loop with torch.compile / torch.jit.trace, the corresponding Jax code is
Hi Hao,
Thank you for the insight. I was not aware of the torch checkpointing utility and was looking into an alternative to remat.
I will try using the torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=True, **kwargs)
function as follows:
num_q, _, _, _ = inputs.shape
res = []
for i in range(num_q):
res.append(checkpoint(self.lm_head, inputs[i]))
res = torch.stack(res, dim=0)
Or
def lm_head(cell, hidden_states):
outputs = cell(hidden_states)
return outputs
outputs = checkpoint(lm_head, cell, hidden_states)
And then looping:
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, torch.stack(ys)
I will have to further evaluate design and update the above functions to include the use of checkpointing as well as use torch.jit/compile on the final model.
I will add the rest of the code to this issue as I continue to work through it but wanted to get these core blockwise functions rewritten first.
Thank you,
Enrico
from ringattention.
An updated FeedForward example could look like:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
from einops import rearrange
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, torch.stack(ys)
class FFN(nn.Module):
def __init__(self, dim, hidden_dim, dropout):
super().__init__()
self.fc_in = nn.Linear(dim, hidden_dim)
self.fc_out = nn.Linear(hidden_dim, dim)
self.act = nn.GELU()
self.resid_dropout = nn.Dropout(dropout)
self.ln_2 = nn.LayerNorm(dim)
def forward(self, hidden_states):
hidden_states = self.ln_2(hidden_states)
hidden_states = self.fc_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.fc_out(hidden_states)
hidden_states = self.resid_dropout(hidden_states)
return hidden_states
def checkpoint_fn(self, carry, x):
return None, checkpoint(self.forward, x)
def blockwise_compute_ffn(cell, inputs, chunk_size):
inputs = rearrange(inputs, 'b (n c) d -> b n c d', c=chunk_size)
inputs = rearrange(inputs, 'b n c d -> n b c d')
_, res = scan(cell.checkpoint_fn, None, inputs, inputs.shape[0])
res = rearrange(res, 'n b c d -> b (n c) d')
return res
from ringattention.
Cool. Just some minor suggestions: Firstly verify dropout utilizes the same random seed for all blocks within a sequence. Additionally, it would be beneficial to explore ways to optimize the loop. You can find a relevant discussion on this topic in the PyTorch repo pytorch/pytorch#50688.
from ringattention.
I will add a manual seed to the FFN to ensure dropout is the same. I read through the issue and am going to look into optimizing the scan loop. I will likely see if I can get further input from the Triton team or even message Tri Dao.
Here is the updated Blockwise LM Head:
class Blockwise_LM_Head(nn.Module):
def __init__(self, vocab_size, chunk_size):
super().__init__()
self.vocab_size = vocab_size
self.chunk_size = chunk_size
self.lm_head = nn.Linear(
chunk_size,
vocab_size,
bias=True
)
def checkpoint_fn(self, carry, x):
return None, checkpoint(self.lm_head, x)
def forward(self, inputs):
inputs = rearrange(inputs, 'b (n c) d -> b n c d', c=self.chunk_size)
inputs = rearrange(inputs, 'b n c d -> n b c d')
_, res = scan(self.checkpoint_fn, None, inputs, inputs.shape[0])
res = rearrange(res, 'n b c d -> b (n c) d')
return res
and Blockwise Cross Entropy Loss:
def cross_entropy_loss_and_accuracy(logits, tokens, valid):
valid_text_length = torch.clamp(torch.sum(valid, axis=-1), min=1e-10)
token_log_prob = torch.log_softmax(logits, axis=-1)
token_log_prob = torch.gather(token_log_prob, -1, tokens.unsqueeze(-1)).squeeze(-1)
token_log_prob = torch.where(valid > 0.0, token_log_prob, torch.tensor(0.0, device=logits.device))
correct = torch.where(
valid > 0.0,
logits.argmax(axis=-1) == tokens,
torch.tensor(False, device=logits.device)
)
return token_log_prob, correct, valid_text_length
def _loss_and_accuracy(carry, args):
loss, accuracy, num = carry
logits, tokens, valid = args
token_log_prob, correct, valid_text_length = checkpoint(cross_entropy_loss_and_accuracy, logits, tokens, valid)
loss = loss + torch.sum(token_log_prob, axis=-1) / valid_text_length
accuracy = accuracy + torch.sum(correct.float(), axis=-1) / valid_text_length
num = num + 1
return (loss, accuracy, num), None
def blockwise_cross_entropy(logits, tokens, valid=None, chunk_size=None):
if valid is None:
valid = torch.ones(tokens.shape[:2])
valid = valid.float()
logits = logits.view(-1, logits.shape[-1])
tokens = tokens.view(-1,)
valid = valid.view(-1,)
num_chunk = logits.shape[0] // chunk_size
logits = rearrange(logits, '(n c) d -> n c d', c=chunk_size)
tokens = rearrange(tokens, '(n c) -> n c', c=chunk_size)
valid = rearrange(valid, '(n c) -> n c', c=chunk_size)
(loss, accuracy, num), _ = scan(
_loss_and_accuracy, (torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device), 0),
(logits, tokens, valid), length=num_chunk
)
loss = - loss / num
return loss, accuracy
Should the checkpointing be applied to the entire _loss_and_accuracy
function or just cross_entropy_loss_and_accuracy
?
Thank you,
Enrico
from ringattention.
Sounds good. Triton team and Tri should know more about PyTorch and CUDA, I am mostly using Jax on TPU.
Should the checkpointing be applied to the entire _loss_and_accuracy function or just cross_entropy_loss_and_accuracy?
Both options should be fine, but the current one might be slighter better because fewer ops needed for recomputing.
from ringattention.
Please feel free to reopen when needed.
from ringattention.
Btw pytorch core recently merged in native rearrange
: pytorch/pytorch#92675
from ringattention.
Just FYI, we have added llamabpt which is BPT applied to LLaMA.
from ringattention.
@lhao499 if I want to use llamabpt for fine-tuning, can i simply use the llamabpt model definition and load weights from huggingface directly and fine-tune my model using something like pytorch lightning? or will this not work because the model uses Jax?
from ringattention.
Related Issues (13)
- How to combine BPT with sequence parallel? HOT 2
- Question: Has this been tested against the Trition Flash Attention version? HOT 9
- train_dataset. download HOT 1
- [Question] Add a normalization layer between Attention and FFN? HOT 4
- JAX partitioning error when attempting to run with sequence parallelism factor not a power of 2
- Pretrained models?
- vmem OOM on TPU HOT 1
- fine-tuning model mismatch - KeyError
- Questions about the paper HOT 2
- Test Script Issues
- Incorrect project requirements
- scripts/jax2hf. py error HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ringattention.