Comments (8)
Thanks @vwxyzjn for the clarification of the nomenclature! I think the hyperparameter your are citing are for the initialization of the policy before the PPO training.
For the PPO training they mention:
The batch size for each iteration is 512, with a minibatch size of 64. In other words, each batch is randomly split into 8 minibatches and is trained on for only a single inner epoch (Schulman et al., 2017).
So indeed a mini-bs>1 is used. I think we can address that quite easily with #100 if we use the attention mask to mask out the appropriate parts of the input. cc @younesbelkada
from trl.
Does this mean this is building multi-envs to collect rollouts?
I think multi-envs in this case is kind of like multiple instances of conversations :)
The batch size for each iteration is 512,
Ah, my mistake. Thanks for the info 🙏
So indeed a mini-bs>1 is used. I think we can address that quite easily with #100 if we use the attention mask to mask out the appropriate parts of the input. cc @younesbelkada
Sorry, I am probably missing something... What parts of the input should we mask out related to the minibatch size? It sounds like a minibatch of size 64 would mean 64 independent prompts as obs, 64 responses as actions, and 64 scalar rewards. We are trying to mask out the future tokens in each of these 64 prompts, right?
from trl.
@vwxyzjn mostly a practical thing: when we batch 64 sequences together which can have unequal length we need to pad the tensors. In transformers the tensors then usually come with an attention mask telling you where the paddings are: we can use this to know where each prompt/response starts and ends and where the paddings are we can ignore.
from trl.
I think it is quite common to optimize PPO with small batch sizes but maybe @natolambert or @edbeeching know more if we should change this?
from trl.
Ah, I need to dig through my John Schulman RLHF media tour notes. I vaguely remember the concept coming up. I'm really not sure.
from trl.
In principle, now that data parallelism via accelerate
is supported you effectively train with batches of the size of number of GPUs that are used.
from trl.
Hey @lvwerra cool library! PPO can deal with both large and small batch sizes depending on the tasks. The batch_size
equals to num_envs * num_steps
, where num_envs
is the number of envs in RL (maybe the number of conversations in RLHF), and num_steps
is the number of steps each env steps (maybe the number of responses the same model generates in RLHF in the same sequence of conversation).
In IsaacGym / Brax, it's common to use a large num_envs
and small num_steps
. E.g., num_envs=4096
and num_steps=5
, corresponding to batch_size=20480
. In Atari, it's common to use smaller num_envs
and larger num_steps
. E.g., num_envs=8
and num_envs=128
, corresponding to batch_size=1028
.
The instructGPT paper uses batch_size = 32
("We use a batch size of 32 for 1.3B and 6B models and 8 for the 175B model." Appendix C3), so I am imagining it's using num_steps=1
(which also correlates nicely with their bandit environment setting) and 32 prompts as obs, 32 responses as actions, and 32 scalar rewards.
from trl.
@vwxyzjn
batch_size = num_envs * num_steps
Does this mean this is building multi-envs to collect rollouts?
Assuming that the policy_model uses LLM, the step()
construction forward process is used to perform ppo-clip or ppo-ptx loss backward. In the current code implementation, the out of memory should appear. I have been thinking about this approach , do you consider using ZeRO-Offload to handle the template tensor generated by rollout?
from trl.
Related Issues (20)
- UserWarning: Could not find a config file
- concatenated_forward when self.ref_model is not provided HOT 1
- which model should i choose if i wanna try DPO algorithm?
- PPOTrainer behavior with `device_map = "auto"` HOT 1
- None ref_model in ppo train
- FSDP/ZeRO3 Support for QLoRA in DPO?
- Use `SFTTrainer` for completion-only model without `DataCollatorForCompletionOnlyLM`
- `disable_dropout` not used in KTOTrainer
- [question] how to apply model parallism to solve cuda memory error HOT 6
- misleading warning message
- How should I set the SFT label?
- dpo cli command error HOT 4
- [question] Differences in hh-rlhf datasets between versions 0.8.1 and 0.8.2.dev0
- DPOTrainer tokenization fails after 30 minutes HOT 1
- Feature Request: DNO
- a bug which leads to "Cuda: Out of memory" in CPOTrainer (cpo_trainer.py), trl 0.8.1 0.8.2, please fix this bug
- Why PPO/DDPO inherits the BaseTrainer class instead of the trainer class?
- setting compute metrics on SFTTrainer but "RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 102 but got size 212 for tensor number 1 in the list."
- [Question] ORPOTrainer and CPOTrainer looks very similar
- Bug in example DPO script in dataloading
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from trl.