Giter VIP home page Giter VIP logo

noise-contrastive-alignment's Introduction

Noise Contrastive Alignment of Language Models with Explicit Rewards

arXiv  demo 

This repo contains training scripts used in

Noise Contrastive Alignment of Language Models with Explicit Rewards
Huayu Chen, Guande He, Lifan, Yuan, Ganqu Cui, Hang Su, and Jun Zhu
Tsinghua

We enable aligning a pretrained language model with datasets annotated by explicit rewards instead of just binary preference by introducing Noise Contrastive Alignment (Figure 1). This framework includes two general algorithms (NCA and InfoNCA) that can deal with both preference data and reward data. Notably, we find that InfoNCA incorporates DPO loss as a special case in binary preference settings. Compared with DPO/InfoNCA, the main advantage of NCA is that it effectively prevents the chosen likelihood from decreasing, a phenomenon commonly observed when applying DPO/InfoNCA loss (Figure 2).

In this repo, we release:

  • The training scripts of NCA/InfoNCA for aligning Mistral-7B model using UltraFeedback Dataset.
  • Pretrained model weights.

Update

  • [2024.06] Dataset and training code are released.
  • [2024.05] The pairwise preference version of NCA has now been supported by trl library.
  • [2024.04] NCA algorithm helps empower Eurus-70B and Eurus-8*7B model, demonstrating significant advantages in complex reasoning tasks compared to the DPO algorithm. Eurus-70B outperformed GPT-3.5-Turbo in a comprehensive benchmark across 12 tests covering five different tasks.
  • [2024.03] Pretrained model weights are released.

Getting Started

Set up environments

cd alignment-handbook; pip install -e .

and

cd trl; pip install -e .

Train

Before running, please determine your available training device numbers and change gradient_accumulation_steps for an appropriate global batch size. We use 8*A40 GPUs and a global batch size of 32 by default.

For aligning with reward datasets, run

NCCL_P2P_DISABLE=1 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file alignment-handbook/recipes/accelerate_configs/multi_gpu.yaml --num_processes=8 --main_process_port=7000 run_reward.py yamls/reward_qlora.yaml --gradient_accumulation_steps=4 --beta=0.01 --loss_type=[NCA/InfoNCA] --output_dir=data/test_run

For aligning with preference datasets (e.g., Binarized UltraFeedback), run

NCCL_P2P_DISABLE=1 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file alignment-handbook/recipes/accelerate_configs/multi_gpu.yaml --num_processes=8 --main_process_port=7000 run_preference.py yamls/preference_qlora.yaml --gradient_accumulation_steps=4 --beta=0.01 --loss_type=[NCA/DPO] --output_dir=data/test_run

Evaluation

Check out alignment-handbook instructions for evaluating models on MT-bench and AlpacaEval.

License

MIT

BibTeX

@article{chen2024noise,
  title={Noise contrastive alignment of language models with explicit rewards},
  author={Chen, Huayu and He, Guande and Yuan, Lifan and Cui, Ganqu and Su, Hang and Zhu, Jun},
  journal={arXiv preprint arXiv:2402.05369},
  year={2024}
}

noise-contrastive-alignment's People

Contributors

chendrag avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Forkers

yyht sonwy2

noise-contrastive-alignment's Issues

关于DPO是InfoNCA的推导

image
这中间的 $\alpha$ 不是简单的reward temperature,他的来源是KL惩罚的系数,推导中使这个值变为0,岂不是就是完全去掉了KL的惩罚,这个和DPO的含义完全就变了吧。
这是我对论文中该部分的一点疑问,希望能得到作者的解答,谢谢!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.