Giter VIP home page Giter VIP logo

Comments (5)

lvwerra avatar lvwerra commented on June 24, 2024 4

Hi @deepanwayx, thanks for your interest in the library. Lets see if I can answer your question:

1. Calculation of KL-divergence

I think both of your questions here can be answered by looking at the equation for the KL divergence:

which can be approximated for discrete values by the following formula:

This is a weighted sum of the term in the first equation. Each term is weighted by the probability p(x). Since we sample the tokens from p(x) we already took that into account implicitly. Tokens that are unlikely are rarely selected while tokens with high probability are selected more often. If we average over all elements in the sequence we achieve the same weighting as by weighting each possible tokens with its probability. In that case the step you propose would be redundant.

2. About the ratio

One important detail to mention here is that the PPO optimisation runs for several steps for every batch of data. For this reason the model changes after each optimisation step. Therefore, old_logprobs stays the same while logprobs change after each optimisation.

Now, the ratio is an essential part of the PPO algorithm. The idea is that after the optimisation step you calculate the ratio to see if the chosen action gets a higher or lower probability than during rollout. That value multiplied with the advantage yields the unclipped objective function (that is used in TRPO). The idea is that you want to increase the policy's probability of the actions with a high advantage and vice versa. PPO uses a clipped version of this objective for better stability. For more detail I highly recommend the excellent original paper!

I haven't thought about the effects of Dropout. I suspect the effect of the optimised model are larger than the fluctuations from dropout. But feel free to experiment with it and create a PR if it yields training improvement.

Remarks

Finally, I want to mention that my main contribution in this project was translating OpenAI's TensorFlow code to PyTorch and making it compatible with the Hugging Face library. The details above were already implemented in the original code and these are merely my explanations. See the original code and paper for more details. For the PPO implementation check out the train_policy.py script.

from trl.

deepanwayx avatar deepanwayx commented on June 24, 2024

Thanks for your detailed explanations. I think it makes a lot more sense now. I will check out the original PPO paper for more details.

from trl.

yanghoonkim avatar yanghoonkim commented on June 24, 2024

Hi @lvwerra
About the difference between logprobs and old_logprobs: You mentioned in the #10 that

So the reason we run PPO for 4 epochs is that we want to make most of the data we gathered. If generating samples was cheap we could only train for one epoch.

and in this issue, you said that logprobs and old_logprobs will different after one epoch, which means that i can't set the ppo_epoch to 1

Quite confused about that.

from trl.

lvwerra avatar lvwerra commented on June 24, 2024

You can set ppo_epoch to 1 and only the logprobs will change, which makes sense since you the model changes after each ppo_epoch, this the predictions are not the same. Why would that be a problem?

from trl.

JoaoLages avatar JoaoLages commented on June 24, 2024

You can set ppo_epoch to 1 and only the logprobs will change, which makes sense since you the model changes after each ppo_epoch, this the predictions are not the same. Why would that be a problem?

In the first epoch log_probs is the same as old_logprobs (if we disregard the dropout effect) so I think that @yanghoonkim's comment makes sense, right? I.e., if ratio is essential as you pointed, ppo_epoch must be bigger than 1 for ratio to ever be different than 1.

from trl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.