lhao499 / ringattention Goto Github PK
View Code? Open in Web Editor NEWTransformers with Arbitrarily Large Context
License: Apache License 2.0
Transformers with Arbitrarily Large Context
License: Apache License 2.0
I couldn't find the dataset of './local/owt/openwebtext_train.jsonl' Can you provide relevant instructions or download address
Hi there, I am working on long context model. Is it possible to have the pretrained models?
In the original paper, it assumes that, "The computation for a query block is given by:"
If we want to add a normalization layer to the input of FFN, then it will break the end-to-end fusion, am I right?
Based on my understanding, normalizations like RMSNorm
require access to the whole tensor to calculate some statistics first.
Hi, I saw your research and found it amazing particularly for the memory efficiency.
Mosaic implemented Flash Attention with Triton. This works by fusing it into the forward and backward operations for memory and performance efficiency:
https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/models/layers/flash_attn_triton.py
Specifically, I am wondering if you have performed any tests against this type of flash attention since this is now considered SOTA in terms of speed.
I'm trying to run Ring Attention on a machine with 6 A100 GPUs, and I'm finding that when I try to set the sequence parallelism dimension to anything other than a power of 2, the process crashes with a JAX partitioning error.
I'd be grateful for any insight into whether or not I'm doing something wrong in the way I'm invoking the training script, and for any advice on how to work around this issue.
Consider the following script for invoking llamabpt.train
:
#########################################################
### Configuration 1 (runs successfully) ###
#########################################################
export CUDA_VISIBLE_DEVICES="0,1,2,3"
SEQ_PAR_DIM=4
MAX_SEQ_LEN=131072
#########################################################
#########################################################
### Configuration 2 (CRASHES with partitioning error) ###
#########################################################
# export CUDA_VISIBLE_DEVICES="0,1,2"
# SEQ_PAR_DIM=3
# MAX_SEQ_LEN=98304
#########################################################
python3 -m llamabpt.train \
--mesh_dim="1,1,1,${SEQ_PAR_DIM}" \
--dtype=bf16 \
--load_llama_config=1b \
--update_llama_config="{'max_sequence_length': ${MAX_SEQ_LEN}, 'scan_attention': True, 'scan_query_chunk_size': 2048, 'scan_key_chunk_size': 4096, 'remat_attention': 'nothing_saveable', 'scan_mlp': True, 'scan_mlp_chunk_size': 2048, 'remat_mlp': 'nothing_saveable', 'remat_block': 'nothing_saveable', 'scan_layers': True, 'attention_type': 'ring_blockwise', 'param_scan_axis': 0, 'mesh_dim': '1,1,1,${SEQ_PAR_DIM}'}" \
--total_steps=2 \
--log_freq=1 \
--save_model_freq=0 \
--save_milestone_freq=1000 \
--tokenizer.vocab_file="${TRAIN_DATA_PATH}" \
--optimizer.type=adamw \
--optimizer.adamw_optimizer.weight_decay=0.1 \
--optimizer.adamw_optimizer.lr=1.5e-4 \
--optimizer.adamw_optimizer.end_lr=1.5e-5 \
--optimizer.adamw_optimizer.lr_warmup_steps=1 \
--optimizer.adamw_optimizer.lr_decay_steps=10 \
--train_dataset.type=json \
--train_dataset.text_processor.fields=text \
--train_dataset.json_dataset.path="${TOKENIZER_PATH}" \
--train_dataset.json_dataset.seq_length=${MAX_SEQ_LEN} \
--train_dataset.json_dataset.batch_size=1 \
--train_dataset.json_dataset.tokenizer_processes=16
For Configuration 1, where the sequence parallelism dimension is 4
, the training script runs as expected without errors.
However, when I uncomment Configuration 2, where the sequence parallelism dimension is 3
, the training script crashes with the following error:
ValueError: One of pjit outputs with pytree key path .params['params']['lm_head']['kernel'] was given the sharding of NamedSharding(mesh={'dp': 1, 'fsdp': 1, 'tp': 1, 'sp': 3}, spec=PartitionSpec(('fsdp', 'sp'), 'tp')), which implies that the global size of its dimension 0 should be divisible by 3, but it is equal to 2048 (full shape: (2048, 32000))
The error occurs during the first call to sharded_init_fn
.
I would expect Configuration 2 to run successfully, because the total sequence length (98304
) is a multiple of the sequence parallelism dimension (3
).
Generalizing to more sequence parallelism dimensions, I find that:
SEQ_PAR_DIM
to either 2
or 4
runs successfully.SEQ_PAR_DIM
to either 3
or 6
crashes with a partitioning error.Hi,
I tried to run your script on Cloud TPU v4-64, but failed with following error:
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space vmem. Used 59.79M of 16.00M vmem. Exceeded vmem capacity by 43.79M.
I tried mesh dim of 1,1,1,32 and 1,1,4,8 all failed.
Any suggestion what caused the error? Thanks.
Thanks very much for sharing your work! I wonder if BPT is compatible with sequence parallelism. From the related work section, I find the statement "This creates an orthogonal relationship between our method and sequence parallelism, allowing for straightforward combination". However, this combination is not that straightforward to me. Please see my comments below.
To my understanding, BPT breaks long sequence into small sequences and feed these small sequences sequentially to attn+layernorm+MLP. Recent LLM training pipeline usually apply sequence parallelism(https://arxiv.org/pdf/2205.05198.pdf) to the layernorm part therefore do layernorm simultaneously amony GPUs. This seems to break the connection between attention and MLP when using BPT. I wonder if it is possible to seamlessly combine BPT with sequence parallelism when layernorm is considered.
Thanks for providing the repo. I have a question regarding fine-tuning as mentioned in the paper (Section 5,4)
As the README.md suggested, --load_checkpoint='params::/path/output'
is used for fine-tuning based on HF model converted from the hf2jax.py
script. However, when scan_layers=True
, it appears that the layer name (keys) from path/output
do not match those in shard_fns
during loading the HF weights. For example,
('transformer', 'h', 'scan_decoder', 'attention', 'wq', 'kernel')
from shard_fns
does not match the key'transformer', 'h', '0', 'attention', 'wq', 'kernel'
unpacked from /path/output
.
This eventually raises the KeyError: ('transformer', 'h', '0', 'attention', 'wq', 'kernel')
exception during load_checkpoint
.
have I missed anything for fine-tuning configuration or is there a workaround this?
Thank you!
In the project requirements, it is specified that the version of jax
is 0.4.13
. However, Pallas was added in the version 0.4.16
(google/jax@d872812).
Hi @lhao499 ,
I am working through a rewrite of your great work in PyTorch. Please excuse my ignorance as my understanding of JAX can be lacking in certain aspects. Below are the functions and classes for Blockwise Compute Feedforward Network, Blockwise Cross Entropy Loss, and Blockwise LM Head:
Blockwise Compute Feedforward Network
class FFN(nn.Module):
def __init__(
self,
dim,
hidden_dim,
dropout,
act=nn.GELU
):
super().__init__()
self.fc_in = nn.Linear(dim, hidden_dim, bias=True)
self.fc_out = nn.Linear(hidden_dim, dim, bias=True)
self.act = act()
self.resid_dropout = nn.Dropout(dropout)
self.ln_2 = nn.LayerNorm(dim)
def forward(self, hidden_states):
hidden_states = self.ln_2(hidden_states)
hidden_states = self.fc_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.fc_out(hidden_states)
hidden_states = self.resid_dropout(hidden_states)
return hidden_states
def blockwise_compute_ffn(cell, inputs, chunk_size):
inputs = rearrange(inputs, 'b (n c) d -> b n c d', c=chunk_size)
inputs = rearrange(inputs, 'b n c d -> n b c d')
num_q, _, _, _ = inputs.shape
# remat and scan
res = []
for i in range(num_q):
res.append(cell(inputs[i]))
res = torch.stack(res, dim=0)
res = rearrange(res, 'n b c d -> b (n c) d')
return res
Blockwise Cross Entropy Loss
def blockwise_cross_entropy(logits, tokens, valid=None, chunk_size=None):
if valid is None:
valid = torch.ones(tokens.shape[:2], device=logits.device)
valid = valid.float()
logits = logits.view(-1, logits.shape[-1])
tokens = tokens.view(-1)
valid = valid.view(-1)
def _cross_entropy_loss_and_accuracy(logits, tokens, valid):
valid_text_length = torch.max(valid.sum(dim=-1), torch.tensor(1e-10).to(logits.device))
token_log_prob = F.log_softmax(logits, dim=-1)
token_log_prob = token_log_prob[torch.arange(len(tokens)), tokens]
token_log_prob = torch.where(valid > 0.0, token_log_prob, torch.tensor(0.0).to(logits.device))
correct = torch.where(
valid > 0.0,
torch.argmax(logits, dim=-1) == tokens,
torch.tensor(False).to(logits.device)
)
return token_log_prob, correct.float(), valid_text_length
num_chunk = logits.shape[0] // chunk_size
logits = rearrange(logits, '(n c) d -> n c d', c=chunk_size)
tokens = rearrange(tokens, '(n c) -> n c', c=chunk_size)
valid = rearrange(valid, '(n c) -> n c', c=chunk_size)
loss, accuracy, num = 0.0, 0.0, 0
for i in range(num_chunk):
token_log_prob, correct, valid_text_length = _cross_entropy_loss_and_accuracy(logits[i], tokens[i], valid[i])
loss += token_log_prob.sum() / valid_text_length
accuracy += correct.sum() / valid_text_length
num = num + 1
loss = - loss / num
accuracy = accuracy / num
return loss, accuracy
Blockwise LM Head
class Blockwise_LM_Head(nn.Module):
def __init__(self, vocab_size, chunk_size):
super().__init__()
self.vocab_size = vocab_size
self.chunk_size = chunk_size
self.lm_head = nn.Linear(chunk_size, vocab_size, bias=True)
def forward(self, inputs):
inputs = rearrange(inputs, 'b (n c) d -> b n c d', c=self.chunk_size)
inputs = rearrange(inputs, 'b n c d -> n b c d')
num_q, _, _, _ = inputs.shape
# remat and scan
res = []
for i in range(num_q):
res.append(self.lm_head(inputs[i]))
res = torch.stack(res, dim=0)
res = rearrange(res, 'n b c d -> b (n c) d')
return res
When you have time, please let me know if anything stands out as incorrect and that needs to be resolved.
Thank you for previously clarifying that Flash Attention could be used as a direct replacement and for your help,
Enrico
Hi Hao,
First off, big thank you for the huge amount of work that has gone into open sourcing the implementation of your research, it is highly appreciated!
While going through the repo and trying to deeply understand the method I discovered that there are some issues with the test script.
for attention_type in attention_types:
llama_config_copy = copy.deepcopy(llama_config)
llama_config_copy.update(dict(attention_type=attention_type))
if attention_type == ['standard']:
llama_config_copy.update(dict(scan_attention=False, scan_mlp=False, scan_layers=False, remat_attention='', remat_mlp='', mesh_dim='1,-1,2,1'))
llama_config_copy.update(dict(attention_type=attention_type))
elif attention_type == 'ring_blockwise':
llama_config_copy.update(dict(scan_attention=True, scan_mlp=True, scan_layers=True, mesh_dim='1,1,2,-1'))
llama_config_copy.update(dict(attention_type=attention_type))
llama_config_copy.update(dict(scan_query_chunk_size=1024, scan_key_chunk_size=1024, scan_mlp_chunk_size=1024))
model = FlaxLLaMAForCausalLMModule(
llama_config_copy, dtype=get_float_dtype_by_name(FLAGS.dtype)
)
models.append(model)
model = models[0]
it appears that it isn't possible to change the mesh_dims as this is defined once at the start of the testing and is used as a context manager for the whole test. So I think we can't change between ring and blockwise during the test.
It doesn't look like the grads being returned are a 'FrozenDict' , so the unfreeze at line 163 is not needed (I think its fine that its not frozen in this case).
After applying my naive updates to compare Standard with Ring I am now seeing a larger diff in the logits and grads then expected.
standard
logits: 0.0 1.6717689 1.6717689
grads: 0.0 0.11031877 0.11031877
ring_blockwise
logits: 0.0044222176 1.6717689 1.6717689
grads: 6.278977e-05 0.11030923 0.11031877
Is this similar to your own results or should the results be more aligned to Standard Attention as my understanding is that the Blockwise Ring Attention is numerically equivalent. Please could you confirm if my configs are correct for comparing these methods, there is a good chance I have made a mistake somewhere. For reference, I am running on a TPU v4-8, so I only have 4 devices.
Would like to confirm if you agree with these observations, or have I just done something silly when applying my changes? If these are in-fact issues that have crept in I am happy to submit a fix ๐
Cheers,
Donal
First, great work! I read the paper and had a few questions.
s = 6c
, but where does this 6 come from? Is this related to 6bch
for the blocks memory?12bch
(instead of 6bch
) because each data is bfloat16?A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.