Giter VIP home page Giter VIP logo

Comments (10)

lhao499 avatar lhao499 commented on May 25, 2024

Hi Enrico, it seems checkpointing / rematerialization is not included. Checkpointing is required for memory saving because otherwise total memory cost of all blocks accumulates to the same memory cost as vanilla transformer in forward and backward passes. The corresponding code is https://github.com/lhao499/blockwise-parallel-transformer/blob/4a668d5436adc5263df6786255c2bd684749aae2/bpt/blocks/blockwise_parallel.py#L446
You can check out PyTorch checkpoint for this purpose: https://pytorch.org/docs/stable/checkpoint.html.

It's also worth optimizing for loop with torch.compile / torch.jit.trace, the corresponding Jax code is https://github.com/lhao499/blockwise-parallel-transformer/blob/4a668d5436adc5263df6786255c2bd684749aae2/bpt/blocks/blockwise_parallel.py#L453

from ringattention.

conceptofmind avatar conceptofmind commented on May 25, 2024

Hi Enrico, it seems checkpointing / rematerialization is not included. Checkpointing is required for memory saving because otherwise total memory cost of all blocks accumulates to the same memory cost as vanilla transformer in forward and backward passes. The corresponding code is

https://github.com/lhao499/blockwise-parallel-transformer/blob/4a668d5436adc5263df6786255c2bd684749aae2/bpt/blocks/blockwise_parallel.py#L446

You can check out PyTorch checkpoint for this purpose: https://pytorch.org/docs/stable/checkpoint.html.
It's also worth optimizing for loop with torch.compile / torch.jit.trace, the corresponding Jax code is

https://github.com/lhao499/blockwise-parallel-transformer/blob/4a668d5436adc5263df6786255c2bd684749aae2/bpt/blocks/blockwise_parallel.py#L453

Hi Hao,

Thank you for the insight. I was not aware of the torch checkpointing utility and was looking into an alternative to remat.

I will try using the torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=True, **kwargs) function as follows:

        num_q, _, _, _ = inputs.shape
        res = []
        for i in range(num_q):
            res.append(checkpoint(self.lm_head, inputs[i]))
        res = torch.stack(res, dim=0)

Or

def lm_head(cell, hidden_states):
    outputs = cell(hidden_states)
    return outputs
outputs = checkpoint(lm_head, cell, hidden_states)

And then looping:

def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, torch.stack(ys)

I will have to further evaluate design and update the above functions to include the use of checkpointing as well as use torch.jit/compile on the final model.

I will add the rest of the code to this issue as I continue to work through it but wanted to get these core blockwise functions rewritten first.

Thank you,

Enrico

from ringattention.

conceptofmind avatar conceptofmind commented on May 25, 2024

An updated FeedForward example could look like:

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
from einops import rearrange

def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, torch.stack(ys)

class FFN(nn.Module):
    def __init__(self, dim, hidden_dim, dropout):
        super().__init__()
        self.fc_in = nn.Linear(dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, dim)
        self.act = nn.GELU()
        self.resid_dropout = nn.Dropout(dropout)
        self.ln_2 = nn.LayerNorm(dim)

    def forward(self, hidden_states):
        hidden_states = self.ln_2(hidden_states)
        hidden_states = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.fc_out(hidden_states)
        hidden_states = self.resid_dropout(hidden_states)
        return hidden_states

    def checkpoint_fn(self, carry, x):
        return None, checkpoint(self.forward, x)

def blockwise_compute_ffn(cell, inputs, chunk_size):
    inputs = rearrange(inputs, 'b (n c) d -> b n c d', c=chunk_size)
    inputs = rearrange(inputs, 'b n c d -> n b c d')
    _, res = scan(cell.checkpoint_fn, None, inputs, inputs.shape[0])
    res = rearrange(res, 'n b c d -> b (n c) d')
    return res

from ringattention.

lhao499 avatar lhao499 commented on May 25, 2024

Cool. Just some minor suggestions: Firstly verify dropout utilizes the same random seed for all blocks within a sequence. Additionally, it would be beneficial to explore ways to optimize the loop. You can find a relevant discussion on this topic in the PyTorch repo pytorch/pytorch#50688.

from ringattention.

conceptofmind avatar conceptofmind commented on May 25, 2024

I will add a manual seed to the FFN to ensure dropout is the same. I read through the issue and am going to look into optimizing the scan loop. I will likely see if I can get further input from the Triton team or even message Tri Dao.

Here is the updated Blockwise LM Head:

class Blockwise_LM_Head(nn.Module):
    def __init__(self, vocab_size, chunk_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.chunk_size = chunk_size
        self.lm_head = nn.Linear(
            chunk_size, 
            vocab_size, 
            bias=True
        )

    def checkpoint_fn(self, carry, x):
        return None, checkpoint(self.lm_head, x)

    def forward(self, inputs):
        inputs = rearrange(inputs, 'b (n c) d -> b n c d', c=self.chunk_size)
        inputs = rearrange(inputs, 'b n c d -> n b c d')
        _, res = scan(self.checkpoint_fn, None, inputs, inputs.shape[0])
        res = rearrange(res, 'n b c d -> b (n c) d')
        return res

and Blockwise Cross Entropy Loss:

def cross_entropy_loss_and_accuracy(logits, tokens, valid):
    valid_text_length = torch.clamp(torch.sum(valid, axis=-1), min=1e-10)

    token_log_prob = torch.log_softmax(logits, axis=-1)
    token_log_prob = torch.gather(token_log_prob, -1, tokens.unsqueeze(-1)).squeeze(-1)

    token_log_prob = torch.where(valid > 0.0, token_log_prob, torch.tensor(0.0, device=logits.device))
    correct = torch.where(
        valid > 0.0,
        logits.argmax(axis=-1) == tokens,
        torch.tensor(False, device=logits.device)
    )
    return token_log_prob, correct, valid_text_length

def _loss_and_accuracy(carry, args):
    loss, accuracy, num = carry
    logits, tokens, valid = args
    token_log_prob, correct, valid_text_length = checkpoint(cross_entropy_loss_and_accuracy, logits, tokens, valid)
    loss = loss + torch.sum(token_log_prob, axis=-1) / valid_text_length
    accuracy = accuracy + torch.sum(correct.float(), axis=-1) / valid_text_length
    num = num + 1
    return (loss, accuracy, num), None

def blockwise_cross_entropy(logits, tokens, valid=None, chunk_size=None):
    if valid is None:
        valid = torch.ones(tokens.shape[:2])
    valid = valid.float()
    logits = logits.view(-1, logits.shape[-1])
    tokens = tokens.view(-1,)
    valid = valid.view(-1,)

    num_chunk = logits.shape[0] // chunk_size
    logits = rearrange(logits, '(n c) d -> n c d', c=chunk_size)
    tokens = rearrange(tokens, '(n c) -> n c', c=chunk_size)
    valid = rearrange(valid, '(n c) -> n c', c=chunk_size)
    (loss, accuracy, num), _ = scan(
        _loss_and_accuracy, (torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device), 0), 
        (logits, tokens, valid), length=num_chunk
    )
    loss = - loss / num
    return loss, accuracy

Should the checkpointing be applied to the entire _loss_and_accuracy function or just cross_entropy_loss_and_accuracy?

Thank you,

Enrico

from ringattention.

lhao499 avatar lhao499 commented on May 25, 2024

Sounds good. Triton team and Tri should know more about PyTorch and CUDA, I am mostly using Jax on TPU.

Should the checkpointing be applied to the entire _loss_and_accuracy function or just cross_entropy_loss_and_accuracy?

Both options should be fine, but the current one might be slighter better because fewer ops needed for recomputing.

from ringattention.

lhao499 avatar lhao499 commented on May 25, 2024

Please feel free to reopen when needed.

from ringattention.

vadimkantorov avatar vadimkantorov commented on May 25, 2024

Btw pytorch core recently merged in native rearrange: pytorch/pytorch#92675

from ringattention.

lhao499 avatar lhao499 commented on May 25, 2024

Just FYI, we have added llamabpt which is BPT applied to LLaMA.

from ringattention.

zpx01 avatar zpx01 commented on May 25, 2024

@lhao499 if I want to use llamabpt for fine-tuning, can i simply use the llamabpt model definition and load weights from huggingface directly and fine-tune my model using something like pytorch lightning? or will this not work because the model uses Jax?

from ringattention.

Related Issues (13)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.