Giter VIP home page Giter VIP logo

ringattention's People

Contributors

lhao499 avatar selimonder avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

ringattention's Issues

train_dataset. download

I couldn't find the dataset of './local/owt/openwebtext_train.jsonl' Can you provide relevant instructions or download address

Pretrained models?

Hi there, I am working on long context model. Is it possible to have the pretrained models?

[Question] Add a normalization layer between Attention and FFN?

In the original paper, it assumes that, "The computation for a query block is given by:"

image

If we want to add a normalization layer to the input of FFN, then it will break the end-to-end fusion, am I right?

Based on my understanding, normalizations like RMSNorm require access to the whole tensor to calculate some statistics first.

Question: Has this been tested against the Trition Flash Attention version?

Hi, I saw your research and found it amazing particularly for the memory efficiency.

image

Mosaic implemented Flash Attention with Triton. This works by fusing it into the forward and backward operations for memory and performance efficiency:

https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/models/layers/flash_attn_triton.py

Specifically, I am wondering if you have performed any tests against this type of flash attention since this is now considered SOTA in terms of speed.

JAX partitioning error when attempting to run with sequence parallelism factor not a power of 2

I'm trying to run Ring Attention on a machine with 6 A100 GPUs, and I'm finding that when I try to set the sequence parallelism dimension to anything other than a power of 2, the process crashes with a JAX partitioning error.

I'd be grateful for any insight into whether or not I'm doing something wrong in the way I'm invoking the training script, and for any advice on how to work around this issue.

Steps to Reproduce

Consider the following script for invoking llamabpt.train:

#########################################################
### Configuration 1 (runs successfully)               ###
#########################################################
export CUDA_VISIBLE_DEVICES="0,1,2,3"
SEQ_PAR_DIM=4
MAX_SEQ_LEN=131072
#########################################################

#########################################################
### Configuration 2 (CRASHES with partitioning error) ###
#########################################################
# export CUDA_VISIBLE_DEVICES="0,1,2"
# SEQ_PAR_DIM=3
# MAX_SEQ_LEN=98304
#########################################################

python3 -m llamabpt.train \
  --mesh_dim="1,1,1,${SEQ_PAR_DIM}" \
  --dtype=bf16 \
  --load_llama_config=1b \
  --update_llama_config="{'max_sequence_length': ${MAX_SEQ_LEN}, 'scan_attention': True, 'scan_query_chunk_size': 2048, 'scan_key_chunk_size': 4096, 'remat_attention': 'nothing_saveable', 'scan_mlp': True, 'scan_mlp_chunk_size': 2048, 'remat_mlp': 'nothing_saveable', 'remat_block': 'nothing_saveable', 'scan_layers': True, 'attention_type': 'ring_blockwise', 'param_scan_axis': 0, 'mesh_dim': '1,1,1,${SEQ_PAR_DIM}'}" \
  --total_steps=2 \
  --log_freq=1 \
  --save_model_freq=0 \
  --save_milestone_freq=1000 \
  --tokenizer.vocab_file="${TRAIN_DATA_PATH}" \
  --optimizer.type=adamw \
  --optimizer.adamw_optimizer.weight_decay=0.1 \
  --optimizer.adamw_optimizer.lr=1.5e-4 \
  --optimizer.adamw_optimizer.end_lr=1.5e-5 \
  --optimizer.adamw_optimizer.lr_warmup_steps=1 \
  --optimizer.adamw_optimizer.lr_decay_steps=10 \
  --train_dataset.type=json \
  --train_dataset.text_processor.fields=text \
  --train_dataset.json_dataset.path="${TOKENIZER_PATH}" \
  --train_dataset.json_dataset.seq_length=${MAX_SEQ_LEN} \
  --train_dataset.json_dataset.batch_size=1 \
  --train_dataset.json_dataset.tokenizer_processes=16

For Configuration 1, where the sequence parallelism dimension is 4, the training script runs as expected without errors.

However, when I uncomment Configuration 2, where the sequence parallelism dimension is 3, the training script crashes with the following error:

ValueError: One of pjit outputs with pytree key path .params['params']['lm_head']['kernel'] was given the sharding of NamedSharding(mesh={'dp': 1, 'fsdp': 1, 'tp': 1, 'sp': 3}, spec=PartitionSpec(('fsdp', 'sp'), 'tp')), which implies that the global size of its dimension 0 should be divisible by 3, but it is equal to 2048 (full shape: (2048, 32000))

The error occurs during the first call to sharded_init_fn.

I would expect Configuration 2 to run successfully, because the total sequence length (98304) is a multiple of the sequence parallelism dimension (3).

Generalizing to more sequence parallelism dimensions, I find that:

  • Setting SEQ_PAR_DIM to either 2 or 4 runs successfully.
  • Setting SEQ_PAR_DIM to either 3 or 6 crashes with a partitioning error.

vmem OOM on TPU

Hi,

I tried to run your script on Cloud TPU v4-64, but failed with following error:

jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space vmem. Used 59.79M of 16.00M vmem. Exceeded vmem capacity by 43.79M.

I tried mesh dim of 1,1,1,32 and 1,1,4,8 all failed.

Any suggestion what caused the error? Thanks.

How to combine BPT with sequence parallel?

Thanks very much for sharing your work! I wonder if BPT is compatible with sequence parallelism. From the related work section, I find the statement "This creates an orthogonal relationship between our method and sequence parallelism, allowing for straightforward combination". However, this combination is not that straightforward to me. Please see my comments below.

To my understanding, BPT breaks long sequence into small sequences and feed these small sequences sequentially to attn+layernorm+MLP. Recent LLM training pipeline usually apply sequence parallelism(https://arxiv.org/pdf/2205.05198.pdf) to the layernorm part therefore do layernorm simultaneously amony GPUs. This seems to break the connection between attention and MLP when using BPT. I wonder if it is possible to seamlessly combine BPT with sequence parallelism when layernorm is considered.

fine-tuning model mismatch - KeyError

Thanks for providing the repo. I have a question regarding fine-tuning as mentioned in the paper (Section 5,4)

As the README.md suggested, --load_checkpoint='params::/path/output' is used for fine-tuning based on HF model converted from the hf2jax.py script. However, when scan_layers=True, it appears that the layer name (keys) from path/output do not match those in shard_fns during loading the HF weights. For example,
('transformer', 'h', 'scan_decoder', 'attention', 'wq', 'kernel') from shard_fns does not match the key'transformer', 'h', '0', 'attention', 'wq', 'kernel' unpacked from /path/output.

This eventually raises the KeyError: ('transformer', 'h', '0', 'attention', 'wq', 'kernel') exception during load_checkpoint.

have I missed anything for fine-tuning configuration or is there a workaround this?

Thank you!

PyTorch Implementation

Hi @lhao499 ,

I am working through a rewrite of your great work in PyTorch. Please excuse my ignorance as my understanding of JAX can be lacking in certain aspects. Below are the functions and classes for Blockwise Compute Feedforward Network, Blockwise Cross Entropy Loss, and Blockwise LM Head:

Blockwise Compute Feedforward Network

class FFN(nn.Module):
    def __init__(
        self, 
        dim, 
        hidden_dim, 
        dropout, 
        act=nn.GELU
    ):
        super().__init__()
        self.fc_in = nn.Linear(dim, hidden_dim, bias=True)
        self.fc_out = nn.Linear(hidden_dim, dim, bias=True)
        self.act = act()
        self.resid_dropout = nn.Dropout(dropout)
        self.ln_2 = nn.LayerNorm(dim)

    def forward(self, hidden_states):
        hidden_states = self.ln_2(hidden_states)
        hidden_states = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.fc_out(hidden_states)
        hidden_states = self.resid_dropout(hidden_states)
        return hidden_states

def blockwise_compute_ffn(cell, inputs, chunk_size):
    inputs = rearrange(inputs, 'b (n c) d -> b n c d', c=chunk_size)
    inputs = rearrange(inputs, 'b n c d -> n b c d')
    num_q, _, _, _ = inputs.shape
    # remat and scan
    res = []
    for i in range(num_q):
        res.append(cell(inputs[i]))
    res = torch.stack(res, dim=0)
    res = rearrange(res, 'n b c d -> b (n c) d')
    return res

Blockwise Cross Entropy Loss

def blockwise_cross_entropy(logits, tokens, valid=None, chunk_size=None):
    if valid is None:
        valid = torch.ones(tokens.shape[:2], device=logits.device)
    valid = valid.float()
    logits = logits.view(-1, logits.shape[-1])
    tokens = tokens.view(-1)
    valid = valid.view(-1)

    def _cross_entropy_loss_and_accuracy(logits, tokens, valid):
        valid_text_length = torch.max(valid.sum(dim=-1), torch.tensor(1e-10).to(logits.device))

        token_log_prob = F.log_softmax(logits, dim=-1)
        token_log_prob = token_log_prob[torch.arange(len(tokens)), tokens]
        token_log_prob = torch.where(valid > 0.0, token_log_prob, torch.tensor(0.0).to(logits.device))
        correct = torch.where(
            valid > 0.0,
            torch.argmax(logits, dim=-1) == tokens,
            torch.tensor(False).to(logits.device)
        )
        return token_log_prob, correct.float(), valid_text_length

    num_chunk = logits.shape[0] // chunk_size
    logits = rearrange(logits, '(n c) d -> n c d', c=chunk_size)
    tokens = rearrange(tokens, '(n c) -> n c', c=chunk_size)
    valid = rearrange(valid, '(n c) -> n c', c=chunk_size)

    loss, accuracy, num = 0.0, 0.0, 0
    for i in range(num_chunk):
        token_log_prob, correct, valid_text_length = _cross_entropy_loss_and_accuracy(logits[i], tokens[i], valid[i])
        loss += token_log_prob.sum() / valid_text_length
        accuracy += correct.sum() / valid_text_length
        num = num + 1

    loss = - loss / num
    accuracy = accuracy / num
    return loss, accuracy

Blockwise LM Head

class Blockwise_LM_Head(nn.Module):
    def __init__(self, vocab_size, chunk_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.chunk_size = chunk_size
        self.lm_head = nn.Linear(chunk_size, vocab_size, bias=True)

    def forward(self, inputs):
        inputs = rearrange(inputs, 'b (n c) d -> b n c d', c=self.chunk_size)
        inputs = rearrange(inputs, 'b n c d -> n b c d')
        num_q, _, _, _ = inputs.shape
        # remat and scan
        res = []
        for i in range(num_q):
            res.append(self.lm_head(inputs[i]))
        res = torch.stack(res, dim=0)
        res = rearrange(res, 'n b c d -> b (n c) d')
        return res

When you have time, please let me know if anything stands out as incorrect and that needs to be resolved.

Thank you for previously clarifying that Flash Attention could be used as a direct replacement and for your help,

Enrico

Test Script Issues

Hi Hao,

First off, big thank you for the huge amount of work that has gone into open sourcing the implementation of your research, it is highly appreciated!

While going through the repo and trying to deeply understand the method I discovered that there are some issues with the test script.

  1. the test script does not appear to be running different attention methods and is only ever comparing against the default setting. My initial impression from the code was that by setting the 'attention_label' it would update the config and run the attention mechanism associate with that label (i.e standard, ring blockwise etc.) however after further inspection it seems like this no longer does anything and the method will always run based on what has been defined in the base config using the scan_attention, scan_mlp, scan_layers and mesh_dim arguments. In order to actually compare methods you have to update the config at each iteration.
for attention_type in attention_types:
        llama_config_copy = copy.deepcopy(llama_config)
        llama_config_copy.update(dict(attention_type=attention_type))
        if attention_type == ['standard']:
            llama_config_copy.update(dict(scan_attention=False, scan_mlp=False, scan_layers=False, remat_attention='', remat_mlp='',  mesh_dim='1,-1,2,1'))
            llama_config_copy.update(dict(attention_type=attention_type))
        elif attention_type == 'ring_blockwise':
            llama_config_copy.update(dict(scan_attention=True, scan_mlp=True, scan_layers=True, mesh_dim='1,1,2,-1'))
            llama_config_copy.update(dict(attention_type=attention_type))
            llama_config_copy.update(dict(scan_query_chunk_size=1024, scan_key_chunk_size=1024, scan_mlp_chunk_size=1024))
        model = FlaxLLaMAForCausalLMModule(
            llama_config_copy, dtype=get_float_dtype_by_name(FLAGS.dtype)
        )
        models.append(model)
    model = models[0]
  1. it appears that it isn't possible to change the mesh_dims as this is defined once at the start of the testing and is used as a context manager for the whole test. So I think we can't change between ring and blockwise during the test.

  2. It doesn't look like the grads being returned are a 'FrozenDict' , so the unfreeze at line 163 is not needed (I think its fine that its not frozen in this case).

  3. After applying my naive updates to compare Standard with Ring I am now seeing a larger diff in the logits and grads then expected.

standard
logits: 0.0 1.6717689 1.6717689
grads: 0.0 0.11031877 0.11031877

ring_blockwise
logits: 0.0044222176 1.6717689 1.6717689
grads: 6.278977e-05 0.11030923 0.11031877

Is this similar to your own results or should the results be more aligned to Standard Attention as my understanding is that the Blockwise Ring Attention is numerically equivalent. Please could you confirm if my configs are correct for comparing these methods, there is a good chance I have made a mistake somewhere. For reference, I am running on a TPU v4-8, so I only have 4 devices.

Would like to confirm if you agree with these observations, or have I just done something silly when applying my changes? If these are in-fact issues that have crept in I am happy to submit a fix ๐Ÿ˜ƒ

Cheers,

Donal

Questions about the paper

First, great work! I read the paper and had a few questions.

  • On p. 5, the paper says that minimal sequence length s = 6c, but where does this 6 come from? Is this related to 6bch for the blocks memory?
  • About the memory requirement, if I understand correctly, the total memory for 6 blocks might be 12bch (instead of 6bch) because each data is bfloat16?
  • Possibly, the interconnect bandwidth for TPUs might be wrong? According to https://cloud.google.com/blog/products/ai-machine-learning/introducing-cloud-tpu-v5p-and-ai-hypercomputer?hl=en (the table), ICI BW per chip is 2,400Gbps. My understanding is that this is the total of 6 links (to form 3D torus), so each link is 400Gbps or 50GB/s. Let me know if this interpretation is wrong.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.