Giter VIP home page Giter VIP logo

Comments (10)

lvwerra avatar lvwerra commented on July 17, 2024 1

I am closing this issue for now. If you have further questions just contact me.

from trl.

lvwerra avatar lvwerra commented on July 17, 2024

Hi @trisongz

Glad you find the library useful. Let's see if I understand your objective correctly:

  • You have a dataset with protein sequences and you would like GPT-2 to generate realistic sequences.
  • You have trained BERT to classify whether two subsequences are compatible.

Now your question is how to setup the PPO training step? Before running PPO I would fine-tune (or train from scratch) GPT-2 on your dataset with the language modeling objective. Check out the training script from Hugging Face.

Then I would probably start by using the first subsequence (18 characters) as the query and then let GPT-2 respond for 18 characters. Although GPT-2 uses BPE encodings and not character level encodings, so the actual number might differ. Then I would pass query/response pairs to BERT for the prediction and use its output as a reward (I used the unnormalised logits, but you can also try to use the class predictions 0/1).

Regarding the other PPO parameters I didn't change them much from the original implementation except the batch size for memory reason. I would start there and adjust them later if it does not work. You also want to keep an eye on the KL-divergence (logged as objective/kl) to make sure the output distribution stays close to your initial data distribution.

from trl.

trisongz avatar trisongz commented on July 17, 2024

Hi @lvwerra thanks for the advice!

Yes, you're correct. I had actually started training GPT-2 from scratch with a custom tokenizer on the dataset prior to seeing this comment so I'm glad I am on the right track.

I also switched over to using RoBERTa as the classifier to test as well, which is currently at

‘mcc’: 0.9997714069569512,
‘tp’: 308736, ‘tn’: 164108,
‘fp’: 49,
‘fn’: 0
‘acc’: 0.9998963824797575
‘eval_loss’: 0.00044921892891853013

after 50k steps, although I am concerned that's a potential result of me not shuffling the csv data prior to training the model, as I wrote the csv file sequentially from the raw dataset. Is there a way you suggest to easily shuffle the csv file prior to the training step? I used your extremely helpful train_test_split function for eval and train data.

For this specific task, since it is sequence based, do you think a Masked LM would perform better at generation than GPT-2 since unlike human written text, there's likely sequence pairs that repeat?

So far what I currently have

BERT/RoBERTa Classifier:

Dataset structure

GTGG ACCA TATG GCCA, ACCA TATG GCCA TAAT, 1
ATCA GGAA GGCA AGAG, AAGT ACAC ATCA GGAA, 0

------------------------------
The Predictions below should result in 1
GTGG ACCA TATG GCCA -> ACCA TATG GCCA TAAT: [1]
GCCA TAAT CAAA AAGT -> TAAT CAAA AAGT ACAC: [1]
------------------------------
The Predictions below should result in 0
ATCA GGAA GGCA AGAG -> AAGT ACAC ATCA GGAA: [0]
CAAA AAGT ACAC ATCA -> GCCA TAAT CAAA AAGT: [0]

For GPT-2 LM:

Single line by line text file of only true (1) sequences

GTGG ACCA TATG GCCA ACCA TATG GCCA TAAT
GCCA TAAT CAAA AAGT TAAT CAAA AAGT ACAC

Does this look correct so far?

Thank you for the tips!

from trl.

lvwerra avatar lvwerra commented on July 17, 2024

‘acc’: 0.9998

That seems suspiciously high. either your task is trivial or there is some leakage in your dataset. Maybe entries exist more than once and are therefore in both train and test split. train_test_split should shuffle the dataset already. I would definitely investigate that further.

I have not much experience with such sequences so I don't know if MLM would work better. Also if training GPT-2 from scratch makes sense probably depends the size of the dataset and resources you have available. I guess you could try the simple, pretrained approach and if that does not work out consider moving to MLM or training GPT-2 from scratch.

For the GPT-2 LM that looks fine to me. You could also consider adding the EOS token at the end of each line (see here for a snippet how I processed the IMDB dataset).

Good luck.

from trl.

trisongz avatar trisongz commented on July 17, 2024

I'm at part 4 now where I'm running the RL environment, and looking through your comments. I also updated GPT-2 to train with a EOS token. I messed up a few things originally, but I think I'm on the right track. Since I created a custom tokenizer for GPT-2, each sequence of 4 letters is 1 token for I/o.

Currently I have my txt_in_len as well as txt_out_len set to 4, to match what BERT expects to see for sequence pair classification.

However, I realized after the scores returned that I hadn't updated the reward mechanism to 0/1 so the scores are a mess. (This is prior to updating the txt lengths properly to 4x4).

image

Could you point me to how I would be able to switch up the rewards based on the Classifier output? I was looking around here:

def compute_rewards(self, scores, logprobs, ref_logprobs):
        """Compute per token rewards from scores and KL-penalty."""
        kl = logprobs - ref_logprobs
        non_score_reward = -self.kl_ctl.value * kl
        rewards = non_score_reward.clone().detach()
        rewards[:, -1] += scores
        return rewards, non_score_reward, self.kl_ctl.value

But wasn't entirely sure

from trl.

lvwerra avatar lvwerra commented on July 17, 2024

You should normalise the scores before running the PPTrainer.step. The outputs you get from the BERT model are logits. So you would need to apply Softmax to the outputs and then find the max probability

probs = F.softmax(bert_outputs, dim=-1)
max_id = torch.argmax(probs, dim=-1)

max_id corresponds to the output index with the largest probability. If position 0 in your outputs corresponds to "not entailed" and position 1 to "entailed" that should be what you are looking for.

from trl.

trisongz avatar trisongz commented on July 17, 2024

I'm still relatively new to Torch, so I apologize for silly questions.

Would it be here that you add that step before appending it to rewards?

    #### tokenize text for sentiment analysis
    t = time.time()
    texts = [q + r for q,r in zip(game_data['query'], game_data['response'])]
    sentiment_inputs, attention_masks = build_bert_batch_from_txt(texts, sentiment_tokenizer, device)    
    timing['time/build_input_sentiment'] = time.time()-t

    #### get sentiment score
    t = time.time()
    rewards = []
    for i in range(int(config['batch_size']/fbs)):
        res = sentiment_model.forward(sentiment_inputs[i*fbs:(i+1)*fbs],
                                      attention_masks[i*fbs:(i+1)*fbs])[0][:, 1].detach()
        
        probs = F.softmax(res, dim=-1)
        max_id = torch.argmax(probs, dim=-1)
        rewards.append(max_id)
        #rewards.append(res)
    
    rewards = torch.cat(rewards)
    timing['time/get_sentiment_preds'] = time.time()-t

    #### Run PPO training 
    t = time.time()
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    timing['time/optimization'] = time.time()-t

from trl.

lvwerra avatar lvwerra commented on July 17, 2024

That looks about right. You need to remove the logit slicing in line:

res = sentiment_model.forward(sentiment_inputs[i*fbs:(i+1)*fbs],
                                      attention_masks[i*fbs:(i+1)*fbs])[0].detach()

Since [:, 1] slices out the logits for the positive sentiment in my example. Since you want to create discrete rewards you will need both positive and negative logits for the softmax layer.

from trl.

trisongz avatar trisongz commented on July 17, 2024

I think I'm getting closer. I had to do one additional step and transform max_id to max_id.float(). However, the outputs are all showing rewards as 0.0 so far - wanted to confirm.

Result of res step


        [ 5.5575, -6.0447],
        [ 5.5397, -6.0370],
        [ 5.5577, -6.0430],
        [ 5.5556, -6.0427],
        [ 5.5585, -6.0432],
        [ 5.5494, -6.0396],
        [ 5.5576, -6.0438],
        [ 5.5544, -6.0420],
        [ 5.5584, -6.0439],
        [ 5.5490, -6.0390],
        [ 5.5601, -6.0438],
        [ 5.5527, -6.0437],
        [ 5.5541, -6.0416],
        [ 5.5583, -6.0435],
        [ 5.5514, -6.0416],
        [ 5.5590, -6.0440],
        [ 5.5556, -6.0430],
        [ 5.5468, -6.0402],
        [ 5.5564, -6.0439],
        [ 5.5545, -6.0405],
        [ 5.5537, -6.0446],
        [ 5.5563, -6.0434],
        [ 5.5566, -6.0431],
        [ 5.5564, -6.0429],
        [ 5.5527, -6.0419],
        [ 5.5535, -6.0425],
        [ 5.5531, -6.0433],
        [ 5.5546, -6.0427],
        [ 5.5518, -6.0417],
        [ 5.5573, -6.0431],
        [ 5.5567, -6.0428]], device='cuda:0')

result of probs

tensor([[9.9999e-01, 9.2180e-06],
        [9.9999e-01, 9.1457e-06],
        [9.9999e-01, 9.3818e-06],
        [9.9999e-01, 9.1595e-06],
        [9.9999e-01, 9.1816e-06],
        [9.9999e-01, 9.1508e-06],
        [9.9999e-01, 9.2678e-06],
        [9.9999e-01, 9.1529e-06],
        [9.9999e-01, 9.1989e-06],
        [9.9999e-01, 9.1447e-06],
        [9.9999e-01, 9.2768e-06],
        [9.9999e-01, 9.1304e-06],
        [9.9999e-01, 9.1988e-06],
        [9.9999e-01, 9.2063e-06],
        [9.9999e-01, 9.1496e-06],
        [9.9999e-01, 9.2305e-06],
        [9.9999e-01, 9.1391e-06],
        [9.9999e-01, 9.1788e-06],
        [9.9999e-01, 9.2861e-06],
        [9.9999e-01, 9.1639e-06],
        [9.9999e-01, 9.2114e-06],
        [9.9999e-01, 9.1817e-06],
        [9.9999e-01, 9.1681e-06],
        [9.9999e-01, 9.1684e-06],
        [9.9999e-01, 9.1717e-06],
        [9.9999e-01, 9.2159e-06],
        [9.9999e-01, 9.2032e-06],
        [9.9999e-01, 9.1985e-06],
        [9.9999e-01, 9.1905e-06],
        [9.9999e-01, 9.2262e-06],
        [9.9999e-01, 9.1625e-06],
        [9.9999e-01, 9.1699e-06]], device='cuda:0')

result of max_id (non-float)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')

result of max_id.float()

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')

As a sanity check, I ran the post-training step to see the results, with modifying the rewards to match with the above.

#### sentiment analysis of query/response pairs before/after
texts = [q + r for q,r in zip(game_data['query'], game_data['response (before)'])]
sentiment_inputs, attention_masks = build_bert_batch_from_txt(texts, sentiment_tokenizer, device)    
#rewards = sentiment_model.forward(sentiment_inputs, attention_masks)[0][:, 1].detach()
res = sentiment_model.forward(sentiment_inputs, attention_masks)[0].detach()
probs = F.softmax(res, dim=-1)
max_id = torch.argmax(probs, dim=-1)
max_id = max_id.float()
rewards = max_id
game_data['rewards (before)'] = rewards.cpu().numpy()

texts = [q + r for q,r in zip(game_data['query'], game_data['response (after)'])]
sentiment_inputs, attention_masks = build_bert_batch_from_txt(texts, sentiment_tokenizer, device)    
#rewards = sentiment_model.forward(sentiment_inputs, attention_masks)[0][:, 1].detach()
res = sentiment_model.forward(sentiment_inputs, attention_masks)[0].detach()
probs = F.softmax(res, dim=-1)
max_id = torch.argmax(probs, dim=-1)
max_id = max_id.float()
rewards = max_id
game_data['rewards (after)'] = rewards.cpu().numpy()

image

Does this look right to you so far? I'm also not sure whether the classifier is issuing 0 as a result of not seeing all 8 tokens, as it's trained on 4/4 sequence pairs.

When I run

text_a = 'AGAC CACT GTGG ACCA'
text_b = 'CACT GTGG ACCA TATG'
output = sentiment_model.forward(sentiment_tokenizer.encode([text_a, text_b], return_tensors="pt"))
output
output[0][0, 1]

I get

tensor(0.2771, grad_fn=<SelectBackward>)

Whereas with

text = 'CACT GTGG ACCA TATG'
output = sentiment_model.forward(sentiment_tokenizer.encode(text, return_tensors="pt"))
output
output[0][0, 1]

It shows

tensor(-6.0448, grad_fn=<SelectBackward>)

from trl.

lvwerra avatar lvwerra commented on July 17, 2024

Indeed, it seems like the LM is not generating good sequences at the beginning. There are several things you could try:

  • Further fine-tune GPT2 on the language modeling task
  • Play with the language generation (e.g. try changing the sampling temperature)
  • Use the logits as reward function (like in my example), since they provide a continuous reward signal. In your case it only ever gets a reward when the probability for 1 is larger than that for 0. If you take the raw logits it gets a reward even if it's only getting closer.
  • Try to simplify the task by reducing the number of generated characters. Maybe try 12 query characters vs. 4 response characters.

These are just some ideas off the top of my head. I am sure there could be other problems and solutions.

from trl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.