Giter VIP home page Giter VIP logo

Comments (8)

lvwerra avatar lvwerra commented on June 24, 2024 3

Thanks @vwxyzjn for the clarification of the nomenclature! I think the hyperparameter your are citing are for the initialization of the policy before the PPO training.

For the PPO training they mention:

The batch size for each iteration is 512, with a minibatch size of 64. In other words, each batch is randomly split into 8 minibatches and is trained on for only a single inner epoch (Schulman et al., 2017).

So indeed a mini-bs>1 is used. I think we can address that quite easily with #100 if we use the attention mask to mask out the appropriate parts of the input. cc @younesbelkada

from trl.

vwxyzjn avatar vwxyzjn commented on June 24, 2024 1

Does this mean this is building multi-envs to collect rollouts?

I think multi-envs in this case is kind of like multiple instances of conversations :)

The batch size for each iteration is 512,

Ah, my mistake. Thanks for the info 🙏

So indeed a mini-bs>1 is used. I think we can address that quite easily with #100 if we use the attention mask to mask out the appropriate parts of the input. cc @younesbelkada

Sorry, I am probably missing something... What parts of the input should we mask out related to the minibatch size? It sounds like a minibatch of size 64 would mean 64 independent prompts as obs, 64 responses as actions, and 64 scalar rewards. We are trying to mask out the future tokens in each of these 64 prompts, right?

from trl.

lvwerra avatar lvwerra commented on June 24, 2024 1

@vwxyzjn mostly a practical thing: when we batch 64 sequences together which can have unequal length we need to pad the tensors. In transformers the tensors then usually come with an attention mask telling you where the paddings are: we can use this to know where each prompt/response starts and ends and where the paddings are we can ignore.

from trl.

lvwerra avatar lvwerra commented on June 24, 2024

I think it is quite common to optimize PPO with small batch sizes but maybe @natolambert or @edbeeching know more if we should change this?

from trl.

natolambert avatar natolambert commented on June 24, 2024

Ah, I need to dig through my John Schulman RLHF media tour notes. I vaguely remember the concept coming up. I'm really not sure.

from trl.

lvwerra avatar lvwerra commented on June 24, 2024

In principle, now that data parallelism via accelerate is supported you effectively train with batches of the size of number of GPUs that are used.

from trl.

vwxyzjn avatar vwxyzjn commented on June 24, 2024

Hey @lvwerra cool library! PPO can deal with both large and small batch sizes depending on the tasks. The batch_size equals to num_envs * num_steps, where num_envs is the number of envs in RL (maybe the number of conversations in RLHF), and num_steps is the number of steps each env steps (maybe the number of responses the same model generates in RLHF in the same sequence of conversation).

In IsaacGym / Brax, it's common to use a large num_envs and small num_steps. E.g., num_envs=4096 and num_steps=5, corresponding to batch_size=20480. In Atari, it's common to use smaller num_envs and larger num_steps. E.g., num_envs=8 and num_envs=128, corresponding to batch_size=1028.

The instructGPT paper uses batch_size = 32 ("We use a batch size of 32 for 1.3B and 6B models and 8 for the 175B model." Appendix C3), so I am imagining it's using num_steps=1 (which also correlates nicely with their bandit environment setting) and 32 prompts as obs, 32 responses as actions, and 32 scalar rewards.

from trl.

xesdiny avatar xesdiny commented on June 24, 2024

@vwxyzjn
batch_size = num_envs * num_steps
Does this mean this is building multi-envs to collect rollouts?
Assuming that the policy_model uses LLM, the step() construction forward process is used to perform ppo-clip or ppo-ptx loss backward. In the current code implementation, the out of memory should appear. I have been thinking about this approach , do you consider using ZeRO-Offload to handle the template tensor generated by rollout?

from trl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.