Comments (10)
Hi @jpanaro, so first of all, the reward is only given once the sequence generation is complete which is why the score/reward is only added to the last token. You are right that then it is discounted and added to the previous tokens as well. You can find the equation in the original PPO paper in equations (11)+(12). To simplify calculations it is done from back to front starting with the last token.
It only makes sense to add the rewards to the last token since for example the BLEU score is only valid for the complete sequence. Similar to Atari games where you only get a reward after you complete a level and then the advantage equation is used to discount the rewards to previous actions. How strong the discount is depends on the value of lambda
.
Does that answer your question?
from trl.
That makes a lot more sense, thanks! I am just trying to resolve my negative kl-divergence problems which are causing my model to slowly diverge and produce garbage.
Currently:
- I have tried performing top_p_top_k filtering on the logits which has somewhat helped but I am limited by the fact that my models decoder produces all outputs and hidden states at once so I cannot perform the filtering as it unrolls, only after which I feel limits its effectiveness.
- I have also tried to zero out all of the logits following the first EOS token generated but this just led to identical performance to the top_k_top_p filtering. I have also recently discovered that this might be pointless as when the logprobs are calculated the indices that used to be all zeros now are filled with nonzero values.
- Lastly regardless of actual gt length my model will also produce logprobs to the max length (29) so my final consideration is to try and find the first EOS, add the reward to that indice, and then cut all following indices or set them to zero.
Any thoughts on these methods or any solutions I may have missed I would greatly appreciate your input!
from trl.
I had issues with negative KL-divergence twice! Both times it was related to the generation function and the model found ways to exploit some functionality, such as the padding tokens or the fact that if the min_length
is not yet reached the logprob of the EOS
token is set to zero. The model can achieve negative KL-divergence by assigning astronomically small logprobs for the tokens that the generation function sets "manually". I tried to summarize this at the end of this notebook.
I hope this helps. My suggestion would be to use greedy decoding and then modify the reward in your code (e.g. adding an extra term to BLEU) if the EOS token appears too early.
I hope this helps!
from trl.
As a remark: if you sample properly from your tuned model you should never achieve negative KL-divergence. This indicates that something is wrong in the way you generate the sequences.
from trl.
Yeah, I saw when I was reading through the notebooks how those issues cropped up. Since my model uses the JoeyNMT library and not the HuggingFace library I am sure there are some differences in generation so I guess I will have to find those differences myself.
I will give that a go! I think if I can potentially penalize the early generation of the EOS token as well as the secondary problem of the model making too many periods prior to the EOS token.
Regarding the last remark. When you say generate sequences do you mean how the model actually produce the decoder hidden states that compose the logits or do you mean things like how you produce the logprobs or how you apply a probability distribution to those logprobs (i.e. greedy vs categorical vs multinomial)?
from trl.
So the output of the model are logits and with a softmax
function you can transform these into probabilities. If you just sample from these probabilities you should get positive KL-divergence. What the generation function (at least in the transformers library) do is applying extra tricks like overwriting some probabilities with 0. This leaves the model a backdoor to exploit by setting the probabilities of these overwritten outputs very low thus leading to negative KL-divergence which means positive rewards for the model. To avoid this I wrote a custom generation function respond_to_batch
to have full control.
from trl.
I managed to fix the negative kl divergence problem. It turns out it was just a sampling alignment issue stemming from the greedy decoding. I
Unfortunately my new problem is that my reward seems to want to decrease as the quality of the sentences are also decreasing. This leads to a direct hit to my BLEU score which means the reward can only reduce as the finetuning goes on.
I think one of the issues is that the model I am starting with already has a majority of the training BLEU 4 scores in the 90->100 range so improvement on the training set is very difficult despite the BLEU 4 scores for the validation set being ~19 at the most. I tried casting all my BLEU scores in the 80->100 range to the range of -4.5->4.5 with scores less than 80 getting set to -4.0 to mimic your positive sentiment scores range as I'm fairly certain just using the raw scores in the loss calculation would blow the rewards out of proportion and also it naturally has no negative rewards so I didn;t think it would penalize the model for lower BLEU scores enough. That ended up looking like this: .
I thought this would rectify the score decreasing issue but unfortunately the reward_mean still immediately sinks below 0 if it did not start there after the first epoch and then bounces between -0.05 and -0.15. I think part of the problem is my spiky KL value which now ranges from 1.5-16 and my lowish average score value which stabilizes at a little under 0.5 after about 10 epochs. Have you dealt with a positive KL but negative rewards similar to this matter?
from trl.
Unfortunately I have not encountered that specific problem. Within the PPO trainer the advantages are whitened, meaning the mean is set to zero and the standard deviation to one. The difference between training and validation score seem to indicate that you might be overfitting the training set. Have you tried reducing that? Also, you could try to decrease the KL factor and see if it improves the problem? Since you are directly measuring the quality of text with BLEU it might not be so important to constrain the language model.
from trl.
Ah, ok that makes sense, so when they are whitened, it does not matter how large the reward is, it will be distributed across the entire advantage equally? I'm just worried the massive numbers (92, 88, 100) will wash everything else out if I don't "dampen" them.
Overfitting seems to be a major problem, possibly the main contributor to the lack of performance gain. I think I will cut the models learning process to fewer epochs and give PPO a chance to explore more solutions.
This might be worth experimenting with seeing as I wan't the model to explore a little more anyways. Thanks for the tips!
from trl.
It will not be distributed but scaled down, such that the distribution in a batch is normalised. I think this should get rid off the characteristic scale of your scores in the PPO training. But you might want to check whether this is really true. If you also use weights&biases you can monitor all scores from the dashboard. You might also be able to see it in the loss scale and distribution.
In any case good luck!
from trl.
Related Issues (20)
- issue with precompute_ref_log_probs and dpo training loss stucks at 0.6931 HOT 1
- Setting the `dataset_num_proc>1` process on `DPOTrainer` seems to block. HOT 12
- Can I use sfttrainer with pytorch dataset HOT 1
- Unnecessary memory cached during DPO reference logp calculation HOT 1
- TR-DPO bug HOT 4
- VLM dpo bug HOT 2
- Training on Teacher model logits HOT 3
- PPOv2Trainer throws `AttributeError: 'NoneType' object has no attribute 'modules'` because `value_model`'s default is `None` HOT 2
- `PPOv2Trainer` `reward_model` throws `AttributeError: '<My Custom Class>' object has no attribute 'base_model_prefix'`
- `PPOv2Trainer` and `RLOOTrainer`: Remove the implicit assumption that the reward model & policy model share the same tokenizer HOT 1
- `PPOTrainer` OOM Error Because of Forced Upcast to `torch.float32` HOT 1
- how to convert dpodata to ktodata HOT 1
- [Question] Why TR-DPO default alpha and tau don't match the values suggested in the paper?
- Why is num_labels=1 in the reward_madeling.py script? HOT 1
- Correct masking when the same roles are present in adjacent messages in DataCollatorForCompletionOnlyLM
- CUDA error: device-side assert triggered HOT 3
- Does PPOV2 not support PEFT or Lora?
- Can DPO be used to shorten the model response length preference? HOT 2
- Fine-tune large vision language model for chat completion only HOT 2
- Question: Does `trl` support training on AMD GPUs? HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from trl.