Giter VIP home page Giter VIP logo

unlikelihood_training's Introduction

Neural Text deGeneration with Unlikelihood Training

PyTorch implementation of the paper:

Neural Text Generation with Unlikelihood Training
Sean Welleck*, Ilia Kulikov*, Stephen Roller, Emily Dinan, Kyunghyun Cho, Jason Weston
*Equal contribution. The order was decided by a coin flip.

We present code for training models described in the paper, as well as pre-trained models. The code includes:

  • An implementation of unlikelihood training, fine-tuning, and evaluation for fairseq.
  • A script for fine-tuning a GPT-2 model from pytorch-transformers with the unlikelihood sequence loss.
Table of Contents
Setup
Training
Evaluation
Finetuning GPT-2

Please cite our work if you found the resources in this repository useful:

@misc{welleck2019neural,
    title={Neural Text Generation with Unlikelihood Training},
    author={Sean Welleck and Ilia Kulikov and Stephen Roller and Emily Dinan and Kyunghyun Cho and Jason Weston},
    year={2019},
    eprint={1908.04319},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Setup

Dependencies

The implementation is a custom fairseq module. Download and install fairseq:

git clone https://github.com/pytorch/fairseq.git
cd fairseq
git checkout 2b68e91f231a2b7997664e1418f30b808d889963
pip install --editable .

Install other dependencies:

pip install nltk
pip install pandas
pip install pytorch-transformers   # (optional); for GPT-2 fine-tuning
pip install tensorflow=1.14
pip install tensorboardX           # (optional); for tensorboard logs
pip install torch==1.4.0           # overwriting the latest version of pytorch, as installed by fairseq

'Installing' the unlikelihood module

Copy the custom directory in this repo into the fairseq repo that you downloaded above:

export FAIRSEQ_DIR=/path/to/fairseq
export UNLIKELIHOOD_DIR=/path/to/unlikelihood_training

cp -r $UNLIKELIHOOD_DIR/custom $FAIRSEQ_DIR/fairseq

Now ls $FAIRSEQ_DIR/fairseq should resemble:

binarizer.py
...
criterions
custom
data
...

Next Steps

We recommend performing the following steps from the fairseq repo's base directory:

cd $FAIRSEQ_DIR

Dataset

Download the binarized wikitext-103 dataset (160MB, install wget if needed):

wget https://dl.fbaipublicfiles.com/unlikelihood/wikitext-103_v0.tar.gz

Unpack the dataset (440MB):

tar xzvf wikitext-103_v0.tar.gz

This command unpacks the dataset into a data-bin folder in the current directory.

Create a checkpoint folder

mkdir checkpoint

Download pre-trained models

*This step is not necessary for training a model from scratch.

We provide all fairseq models used in the paper. Download the model archive (warning: large (16gb) file):

wget https://dl.fbaipublicfiles.com/unlikelihood/checkpoints_v0.tar.gz

Unpack the model checkpoints from the archive:

tar xzvf checkpoints_v0.tar.gz

Training

*We tested these scripts using Tesla V100 32GB gpu(s) in both single and multi-gpu (8) settings. If you get OOM errors, try decreasing the batch size (--max-tokens,--tokens-per-sample). Otherwise, the hyper-parameters used here are similar to the example LM training code in fairseq.

The commands below assume you are in the $FAIRSEQ_DIR directory.

Baseline (MLE) model

python -u ./train.py --task language_modeling_with_generation ./data-bin/wikitext-103 \
    --user-dir ./fairseq/custom --arch transformer_lm_ul --max-tokens 1536 --tokens-per-sample 1536 \
    --fp16 --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 \
    --lr-scheduler cosine --lr-shrink 0.75 --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 \
    --optimizer nag --lr 0.0001 --clip-norm 0.1 --update-freq 3 --seed 1 --sample-break-mode none \
    --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d --save-interval-updates 10000 \
    --keep-interval-updates 2 --no-progress-bar --log-interval 100 \
    --criterion cross_entropy_wcustom_metrics \
    --save-dir ./checkpoint/baseline_model \
    --tensorboard-logdir ./checkpoint/baseline_model

Train a token-level unlikelihood model

python -u ./train.py --task language_modeling_with_generation ./data-bin/wikitext-103 \
    --user-dir ./fairseq/custom --arch transformer_lm_ul --max-tokens 1536 --tokens-per-sample 1536 \
    --fp16 --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 \
    --lr-scheduler cosine --lr-shrink 0.75 --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 \
    --optimizer nag --lr 0.0001 --clip-norm 0.1 --update-freq 3 --seed 1 --sample-break-mode none \
    --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d --save-interval-updates 10000 \
    --keep-interval-updates 2 --no-progress-bar --log-interval 100 \
    --criterion candidate_penalty_cross_entropy --rank-alpha 1.0 \
    --save-dir ./checkpoint/token_level_model \
    --tensorboard-logdir ./checkpoint/token_level_model

Sequence-level fine tuning

For sequence-level fine tuning you need an initial checkpoint (via --restore-file). You can use your own checkpoints, or a provided checkpoint as shown below.

Fine-tuning the baseline model

python -u ./train.py --task language_modeling_with_generation ./data-bin/wikitext-103 \
    --user-dir ./fairseq/custom --arch transformer_lm_ul --max-tokens 1536 --tokens-per-sample 1536 \
    --fp16 --max-update 1500 --max-lr 1.0e-2 --t-mult 2 --lr-period-updates 270000 \
    --lr-scheduler cosine --lr-shrink 0.75 --warmup-updates 0 --warmup-init-lr 1e-07 --min-lr 1e-09 \
    --optimizer nag --lr 0.0001 --clip-norm 0.1 --update-freq 3 --seed 1 --sample-break-mode none \
    --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d --save-interval-updates 100 \
    --keep-interval-updates 2 --no-progress-bar --log-interval 10 \
    --rank-alpha 1.0 --sequence-level-train-rate 0.5 \
    --reset-lr-scheduler --reset-optimizer --reset-meters \
    --compute-metrics-interval 1 --restore-file ./public_checkpoints/mle_baseline/checkpoint_best.pt \
    --criterion cross_entropy_wcustom_metrics \
    --sequence-prefix-length 50 --sequence-completion-length 100 \
    --sequence-ngram-n 4 \
    --save-dir ./checkpoint/seq_level_on_baseline \
    --tensorboard-logdir ./checkpoint/seq_level_on_baseline

Fine-tuning the token-level unlikelihood model

python -u ./train.py --task language_modeling_with_generation ./data-bin/wikitext-103 \
    --user-dir ./fairseq/custom --arch transformer_lm_ul --max-tokens 1536 --tokens-per-sample 1536 \
    --fp16 --max-update 1500 --max-lr 1.0e-2 --t-mult 2 --lr-period-updates 270000 \
    --lr-scheduler cosine --lr-shrink 0.75 --warmup-updates 0 --warmup-init-lr 1e-07 --min-lr 1e-09 \
    --optimizer nag --lr 0.0001 --clip-norm 0.1 --update-freq 3 --seed 1 --sample-break-mode none \
    --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d --save-interval-updates 100 \
    --keep-interval-updates 2 --no-progress-bar --log-interval 10 \
    --rank-alpha 1.0 --sequence-level-train-rate 0.5 \
    --reset-lr-scheduler --reset-optimizer --reset-meters \
    --compute-metrics-interval 1 --restore-file ./public_checkpoints/token_level_ul/checkpoint_best.pt \
    --criterion candidate_penalty_cross_entropy \
    --sequence-prefix-length 50 --sequence-completion-length 100 \
    --sequence-ngram-n 4 \
    --save-dir ./checkpoint/seq_level_on_token_level \
    --tensorboard-logdir ./checkpoint/seq_level_on_token_level

Evaluation

A single script (custom/evaluation.py) performs sequence-level and token level evaluation. For the sequence-level evaluation one can choose greedy search, beam search, top-k, or top-p (nucleus) sampling.

Each evaluation run produces the following files (in the --save-path directory):

  • completions__{params}.txt: prefixes with corresponding completions
  • single_token_predictions__{params}.txt: next-token greedy predictions (i.e. given human context)
  • metrics__{params}.pkl: metrics extracted on the token-level (e.g. PPL, loss, acc, rep, etc.)
  • targets__{params}.txt: reference sequences

Example command to run evaluation using the pretrained baseline model:

python -u ./fairseq/custom/evaluation.py \
    --batch-size-single-prediction 1536 --batch-size-completion 48 \
    --data-prefix-length 50 --completion-length 100 \
    --save-path ./public_checkpoints/ --ckpt all \
    --model-path ./public_checkpoints/mle_baseline \
    --data-dir ./data-bin/wikitext-103 \
    --base-dir ./

Evaluation from the paper

We share evaluation outputs for models used in our paper. To download and unpack the outputs:

wget https://dl.fbaipublicfiles.com/unlikelihood/eval_public_v0.tar.gz
tar xzvf eval_public_v0.tar.gz

To post-process evaluation output (requires pandas (pip install pandas)):

python fairseq/custom/report_metrics.py \
    --eval-dir ./eval_public \
    --model-names mle_baseline token_level_ul seq_level_ul_mle seq_level_ul_token_level_ul

This yields the following output:

     model_name beam size beam block topk topp  split  seq-rep-1  seq-rep-4  uniq-seq     ppl    acc    rep   wrep   uniq
0  mle_baseline         1          0   50  0.0  valid      0.381      0.016     21396  24.592  0.401  0.619  0.346  11654
1  mle_baseline         1          0    1  0.0  valid      0.690      0.429     10629  24.592  0.401  0.619  0.346  11654
2  mle_baseline         1          0   50  0.0   test      0.382      0.016     22670  25.639  0.395  0.627  0.352  11849
3  mle_baseline         1          0    1  0.0   test      0.697      0.442     10845  25.639  0.395  0.627  0.352  11849
4  mle_baseline         1          0    1  0.9  valid      0.368      0.014     25574  24.592  0.401  0.619  0.346  11654
5  mle_baseline         1          0    1  0.9   test      0.370      0.016     27275  25.639  0.395  0.627  0.352  11849
6  mle_baseline        10          0    1  0.0  valid      0.726      0.495      9470  24.592  0.401  0.619  0.346  11654
7  mle_baseline        10          0    1  0.0   test      0.740      0.523      9530  25.639  0.395  0.627  0.352  11849
8  mle_baseline        10          4    1  0.0  valid      0.505      0.000     13350  24.592  0.401  0.619  0.346  11654
9  mle_baseline        10          4    1  0.0   test      0.511      0.000     14158  25.639  0.395  0.627  0.352  11849



MODEL: token_level_ul

       model_name beam size beam block topk topp  split  seq-rep-1  seq-rep-4  uniq-seq     ppl    acc    rep   wrep   uniq
0  token_level_ul         1          0   50  0.0  valid      0.303      0.007     22861  25.624  0.396  0.569  0.305  12462
1  token_level_ul         1          0    1  0.0  valid      0.584      0.274     12630  25.624  0.396  0.569  0.305  12462
2  token_level_ul         1          0   50  0.0   test      0.304      0.007     24476  26.910  0.390  0.577  0.311  12728
3  token_level_ul         1          0    1  0.0   test      0.586      0.283     13195  26.910  0.390  0.577  0.311  12728
4  token_level_ul         1          0    1  0.9  valid      0.279      0.005     28859  25.624  0.396  0.569  0.305  12462
5  token_level_ul         1          0    1  0.9   test      0.280      0.005     31325  26.910  0.390  0.577  0.311  12728
6  token_level_ul        10          0    1  0.0  valid      0.615      0.327     11225  25.624  0.396  0.569  0.305  12462
7  token_level_ul        10          0    1  0.0   test      0.619      0.336     11753  26.910  0.390  0.577  0.311  12728
8  token_level_ul        10          4    1  0.0  valid      0.433      0.000     14622  25.624  0.396  0.569  0.305  12462
9  token_level_ul        10          4    1  0.0   test      0.437      0.000     15386  26.910  0.390  0.577  0.311  12728



MODEL: seq_level_ul_mle

         model_name beam size beam block topk topp  split  seq-rep-1  seq-rep-4  uniq-seq     ppl    acc    rep   wrep   uniq
0  seq_level_ul_mle         1          0   50  0.0  valid      0.305  1.000e-03     23169  24.284  0.406  0.603  0.329  12355
1  seq_level_ul_mle         1          0   50  0.0   test      0.307  1.000e-03     24946  25.416  0.399  0.609  0.335  12779
2  seq_level_ul_mle         1          0    1  0.0  valid      0.507  1.306e-01     12663  24.284  0.406  0.603  0.329  12355
3  seq_level_ul_mle         1          0    1  0.0   test      0.514  1.369e-01     13144  25.416  0.399  0.609  0.335  12779
4  seq_level_ul_mle         1          0    1  0.9  valid      0.290  6.000e-04     31012  24.284  0.406  0.603  0.329  12355
5  seq_level_ul_mle         1          0    1  0.9   test      0.294  9.000e-04     33926  25.416  0.399  0.609  0.335  12779
6  seq_level_ul_mle        10          0    1  0.0  valid      0.374  1.830e-02     16817  24.284  0.406  0.603  0.329  12355
7  seq_level_ul_mle        10          0    1  0.0   test      0.376  1.910e-02     18352  25.416  0.399  0.609  0.335  12779
8  seq_level_ul_mle        10          4    1  0.0  valid      0.356  0.000e+00     16898  24.284  0.406  0.603  0.329  12355
9  seq_level_ul_mle        10          4    1  0.0   test      0.358  0.000e+00     18432  25.416  0.399  0.609  0.335  12779



MODEL: seq_level_ul_token_level_ul

                    model_name beam size beam block topk topp  split  seq-rep-1  seq-rep-4  uniq-seq     ppl    acc    rep   wrep   uniq
0  seq_level_ul_token_level_ul         1          0   50  0.0  valid      0.254  5.000e-04     24253  25.375  0.401  0.551  0.287  13375
1  seq_level_ul_token_level_ul         1          0   50  0.0   test      0.257  6.000e-04     25997  26.718  0.395  0.559  0.293  13759
2  seq_level_ul_token_level_ul         1          0    1  0.0  valid      0.428  5.190e-02     14845  25.375  0.401  0.551  0.287  13375
3  seq_level_ul_token_level_ul         1          0    1  0.0   test      0.438  5.850e-02     15428  26.718  0.395  0.559  0.293  13759
4  seq_level_ul_token_level_ul         1          0    1  0.9  valid      0.233  3.000e-04     32011  25.375  0.401  0.551  0.287  13375
5  seq_level_ul_token_level_ul         1          0    1  0.9   test      0.234  3.000e-04     34824  26.718  0.395  0.559  0.293  13759
6  seq_level_ul_token_level_ul        10          0    1  0.0  valid      0.335  1.310e-02     17562  25.375  0.401  0.551  0.287  13375
7  seq_level_ul_token_level_ul        10          0    1  0.0   test      0.338  1.350e-02     19151  26.718  0.395  0.559  0.293  13759
8  seq_level_ul_token_level_ul        10          4    1  0.0  valid      0.322  0.000e+00     17792  25.375  0.401  0.551  0.287  13375
9  seq_level_ul_token_level_ul        10          4    1  0.0   test      0.326  0.000e+00     19439  26.718  0.395  0.559  0.293  13759

Finetuning GPT-2

We also provide a script for sequence-level and maximum-likelihood fine-tuning a GPT-2 model from the pytorch transformers library.

Install (we used version 1.1.0):

pip install pytorch-transformers

We will again assume that you are in the fairseq base directory:

cd $FAIRSEQ_DIR

Download and unpack the BPE-tokenized WikiText:

wget https://dl.fbaipublicfiles.com/unlikelihood/wikitext-103-bpe_v0.tar.gz
tar -xzvf wikitext-103-bpe_v0.tar.gz
mv wikitext-103-bpe_v0 data-bin/

Sequence-level finetuning

python fairseq/custom/gpt2/run_gpt2.py  \
    --data-base ./data-bin/wikitext-103-bpe_v0 \
    --output-dir ./checkpoint/gpt2/seq_tune \
    --eval-split valid \
    --mode train

MLE-tuning

python fairseq/custom/gpt2/run_gpt2.py  \
    --data-base ./data-bin/wikitext-103-bpe_v0 \
    --output-dir ./checkpoint/gpt2/mle_tune \
    --eval-split valid \
    --train-n-steps 20000 \
    --validate-every 1000 \
    --sequence-tune-rate 0.0 \
    --mode train

Sequence-level finetuning after MLE-tuning

python fairseq/custom/gpt2/run_gpt2.py  \
    --data-base ./data-bin/wikitext-103-bpe_v0 \
    --output-dir ./checkpoint/gpt2/seq_mle_tune \
    --eval-split valid \
    --model-load-dir ./checkpoint/gpt2/mle_tune/best \
    --mode train

Evaluation

python fairseq/custom/gpt2/run_gpt2.py  \
    --data-base ./data-bin/wikitext-103-bpe_v0 \
    --output-dir ./checkpoint/gpt2/seq_mle_tune \
    --eval-split valid \
    --model-load-dir ./checkpoint/gpt2/seq_mle_tune \
    --mode eval-both

We used a single Tesla V100 32GB gpu.

License

unlikelihood_training is CC-BY-NC 4.0 licensed, as found in the LICENSE file.

unlikelihood_training's People

Contributors

stephenroller avatar swabhs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

unlikelihood_training's Issues

[Request] Output data set from model trained on unlikelihood objective

Our research group is studying security implications of large scale generative models and creating defenses to detect their outputs. We came across your paper and realize that such structural changes in neural architectures could make defenses difficult. Therefore, we would like to study how to create more robust defenses so that we can prevent bad actors from using your methodology to spread misinformation online.

If you could please provide an output data set from your generative model trained/fine-tuned on the unlikelihood objective, that would really help us out. We know you have released the script, but given the time constraint and resource limitation on our end, we are unable to fine-tune GPT-2.

Thank you!

When "padding_idx" != 1, "negative_targets" incorrect?

Thanks for sharing the code!

When running the script on my own data, I encountered the error "RuntimeError: Invalid index in scatter". I figured out the reason to be the following,

ctx_cands_ = ctx_cands_ * ctx_cands_.triu()

If padding_idx != 1 or -1 (say 10000), line 54 above will make the upper triangle of ctx_cands_ all 10000*100000, which resulted in the invalid index problem. And also, even if padding_idx*padding_idx is within valid index (say 10*10=100), it would still cast the wrong index into negative_targets.

A possible solution could be, adding

ctx_cands = ctx_cands.masked_fill(ctx_cands == (self.padding_idx**2), self.padding_idx)

after line 58 below

ctx_cands = ctx_cands.masked_fill(ctx_cands == target.unsqueeze(1), self.padding_idx)

Negative candidates in the implementation are different than the ones proposed in the paper

According to the paper, the negative candidates for a timestep are previous context tokens excluding the target token for that timestep.

image

However, according to the implementation, the negative candidates for a timestep are all tokens in the vocabulary excluding the target token for that timestep.

  1. Do both approaches achieve the same result?
  2. The implementation seems to be less efficient since it has to take into account all tokens in the vocabulary instead of just the previous context tokens, right?

Thanks.

gzip: stdin: unexpected end of file

When unpacking the file checkpoints_v0.tar.gz:
$ wget https://dl.fbaipublicfiles.com/unlikelihood/checkpoints_v0.tar.gz

I got the following error:
public_checkpoints/mle_baseline/
public_checkpoints/mle_baseline/checkpoint_best.pt

gzip: stdin: unexpected end of file
tar: Unexpected EOF in archive
tar: Unexpected EOF in archive
tar: Error is not recoverable: exiting now

Need help in understanding how the negative candidates are chosen

Hi, I am trying to understand the following code snippet.

if self.candidate_type == 'prev_context':
# Make 'the triangle'.
ctx_cands = target.unsqueeze(0).expand(target.size(0), target.size(0))
ctx_cands_ = (ctx_cands.tril(-1) + self.padding_idx)
ctx_cands_ = ctx_cands_ * ctx_cands_.triu()
ctx_cands = ctx_cands.tril(-1) + ctx_cands_
# Don't include the target for that timestep as a negative target.
ctx_cands = ctx_cands.masked_fill(ctx_cands == target.unsqueeze(1), self.padding_idx)
negative_targets = torch.zeros_like(lprobs).scatter_(1, ctx_cands, 1)

If my understanding is correct, ctx_cands is a square matrix where each dimension is of size batch_size x sequence_len after the following statement.

ctx_cands = target.unsqueeze(0).expand(target.size(0), target.size(0))

If I assume, self.padding_idx=0, what is the point of the following two statements.

ctx_cands_ = (ctx_cands.tril(-1) + self.padding_idx) 
ctx_cands_ = ctx_cands_ * ctx_cands_.triu() 

Because after the above two statements, ctx_cands_ will be a zero tensor. Isn't it?

Can you please explain how the lines of code pick the previous context tokens as negative candidates?

Maybe the BUG of the token-level unlikelyhood training loss

Hello, thank you for your wonderful work!

After carefully analyzing the token-level unlikelihood training loss, I think the batch-version unlikelihood training loss is different from the one defined in the paper.

In your paper, the negative candidates should be the context of the current token:
unlikelyhood

But in your code, I notice that you simply flat all the tokens in a batch (may consist of N samples):
https://github.com/facebookresearch/unlikelihood_training/blob/main/custom/candidate_penalty_ce_loss.py#L55

If the batch size if 1, the code is consistent with the definition in the paper. But if the batch size is larger than 1, the negative candidate of the sample, for example, the sample i>0, its negative candidates not only contains the previous tokens in sample i but also contains all the tokens in previous samples j<=i. Thus, in this case, the negative candidates are much larger.

Am I right? Looking forward to your response.

Sincerely.

Tian Lan

Dimension check error

Hi guys,

I try to track the data transformation in your methods.

During training, the original input has a shape => [new_batch_size, prefix_length] after this line:

batch = batch_input_sequence_by_prefix_length(input_sequence, args.prefix_length)

which is no long [1, original_sequence_length]

However, in the function top_k_top_p_filtering, there is an assertion:

assert logits.size(0) == 1 # batch size 1 for now - could be updated for more but the code would be less clear

and the code could only be excuted with this requirement.

Now I am confused with this situation, why the assertion of batch_size == 1 is required? Is this a flaw in the code?

Need help to understand sequence level unlikelihood loss function.

def ul_seq(model, batch, args):

    input_sequence = batch[0].cuda()
    batch = batch_input_sequence_by_prefix_length(input_sequence, args.prefix_length)
    completions, continuation_logits = sample_sequence(model, batch, args.prefix_length, args.continuation_length, args.top_k, args.top_p)
    pred_toks = completions[:, args.prefix_length:].contiguous()

    mask = ngram_repeat_mask(pred_toks, args.sequence_ngram_n).type_as(continuation_logits)

    lprobs = F.log_softmax(continuation_logits, dim=-1)
    pred_lprobs = lprobs.view(-1, lprobs.size(2)).gather(1, pred_toks.view(-1, 1))
    one_minus_probs = torch.clamp((1.0 - pred_lprobs.exp()), min=1e-20).view(pred_toks.size(0), pred_toks.size(1))
    loss = -torch.log(one_minus_probs) * mask
    loss = loss.sum()
    ntokens = pred_toks.numel()  # number of output tokens (tokens in completions)

What exactly do the variables completions, continuation_logits, and mask represent according to the paper? Would appreciate it a lot if someone can explain how the custom loss is calculated here.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.