Comments (5)
Hi @deepanwayx, thanks for your interest in the library. Lets see if I can answer your question:
1. Calculation of KL-divergence
I think both of your questions here can be answered by looking at the equation for the KL divergence:
which can be approximated for discrete values by the following formula:
This is a weighted sum of the term in the first equation. Each term is weighted by the probability p(x). Since we sample the tokens from p(x) we already took that into account implicitly. Tokens that are unlikely are rarely selected while tokens with high probability are selected more often. If we average over all elements in the sequence we achieve the same weighting as by weighting each possible tokens with its probability. In that case the step you propose would be redundant.
2. About the ratio
One important detail to mention here is that the PPO optimisation runs for several steps for every batch of data. For this reason the model changes after each optimisation step. Therefore, old_logprobs
stays the same while logprobs
change after each optimisation.
Now, the ratio
is an essential part of the PPO algorithm. The idea is that after the optimisation step you calculate the ratio
to see if the chosen action gets a higher or lower probability than during rollout. That value multiplied with the advantage yields the unclipped objective function (that is used in TRPO). The idea is that you want to increase the policy's probability of the actions with a high advantage and vice versa. PPO uses a clipped version of this objective for better stability. For more detail I highly recommend the excellent original paper!
I haven't thought about the effects of Dropout. I suspect the effect of the optimised model are larger than the fluctuations from dropout. But feel free to experiment with it and create a PR if it yields training improvement.
Remarks
Finally, I want to mention that my main contribution in this project was translating OpenAI's TensorFlow code to PyTorch and making it compatible with the Hugging Face library. The details above were already implemented in the original code and these are merely my explanations. See the original code and paper for more details. For the PPO implementation check out the train_policy.py script.
from trl.
Thanks for your detailed explanations. I think it makes a lot more sense now. I will check out the original PPO paper for more details.
from trl.
Hi @lvwerra
About the difference between logprobs
and old_logprobs
: You mentioned in the #10 that
So the reason we run PPO for 4 epochs is that we want to make most of the data we gathered. If generating samples was cheap we could only train for one epoch.
and in this issue, you said that logprobs
and old_logprobs
will different after one epoch, which means that i can't set the ppo_epoch
to 1
Quite confused about that.
from trl.
You can set ppo_epoch
to 1 and only the logprobs
will change, which makes sense since you the model changes after each ppo_epoch
, this the predictions are not the same. Why would that be a problem?
from trl.
You can set
ppo_epoch
to 1 and only thelogprobs
will change, which makes sense since you the model changes after eachppo_epoch
, this the predictions are not the same. Why would that be a problem?
In the first epoch log_probs
is the same as old_logprobs
(if we disregard the dropout effect) so I think that @yanghoonkim's comment makes sense, right? I.e., if ratio
is essential as you pointed, ppo_epoch
must be bigger than 1 for ratio
to ever be different than 1.
from trl.
Related Issues (20)
- PPOTrainer Appears to incorrectly handle `pad_token_id` HOT 1
- Prompt format clarification for ORPO
- TrlParser not working
- excessive RAM usage with quantized base model and LORA with SFTTrainer HOT 3
- ValueError: The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models. HOT 1
- Warning message in `DataCollatorForCompletionOnlyLM` is misleading when only `response_template` is missing in the batch
- `RichProgressCallback` would break model evaluation and prediction HOT 1
- Multi-GPU Training with DPO Full Parameter Stucks
- UserWarning: Could not find a config file
- concatenated_forward when self.ref_model is not provided HOT 1
- which model should i choose if i wanna try DPO algorithm?
- PPOTrainer behavior with `device_map = "auto"` HOT 1
- None ref_model in ppo train
- FSDP/ZeRO3 Support for QLoRA in DPO?
- Use `SFTTrainer` for completion-only model without `DataCollatorForCompletionOnlyLM`
- `disable_dropout` not used in KTOTrainer
- [question] how to apply model parallism to solve cuda memory error HOT 6
- misleading warning message
- How should I set the SFT label?
- dpo cli command error HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from trl.