Giter VIP home page Giter VIP logo

mindrlhf's Introduction

Introduction

OPENAI's ChatGPT has demonstrated astonishing natural language processing capabilities, opening the door to universal artificial intelligence. Its exceptional performance is closely tied to the Reinforcement Learning from Human Feedback (RLHF) algorithm. In its predecessor, InstructGPT, RLHF was used to collect human feedback and generate content that better aligns with human cognition and values, thus compensating for potential cognitive biases in large models.

MindSpore RLHF (MindRLHF) is based on the MindSpore and utilizes the framework's capabilities for large model parallel training, inference, and deployment to help customers quickly train and deploy RLHF algorithm processes with models that have billions or trillions of parameters.

The MindRLHF learning process consists of three stages:

  • Stage 1: Supervised fine-tuning.
  • Stage 2: Reward model training.
  • Stage 3: Reinforcement learning training.

MindRLHF integrates the rich model library of the MindFormers, providing fine-tuning processes for basic models such as Pangu-Alpha (2.6B, 13B) and GPT-2.

Fully inheriting the parallel interface of MindSpore, MindRLHF can easily deploy models to the training cluster with just one click, enabling training and inference of large models.

To improve inference performance, MindRLHF integrates incremental inference, which is known as K-V cache or state reuse and can achieve more than a 30% improvement in inference performance compared to full inference.

MindRLHF architecture diagram is as follows:

framework

Installation

Current version 0.3.0 can be used directly.

There are some requirements for MindRLHF:

requirements version
MindSpore r2.2
Mindformers r1.0

Supported Models

Current version of MindRLHF: 0.3.0

The current version integrates Pangu-alpha(13B), GPT2, Baichuan2(7B/13B) models, and users can explore these two models. In the future, we will provide more models such as LLAMA, BLOOM, GLM, etc. To help users quickly implement their own applications. The specific supported list is shown below:

Table 1: The models and scales supported in MindRLHF

Models Pangu-alpha GPT2 Baichuan2
Scales 2.6B/13B 124M 7B/13B
Parallel Y Y Y
Device NPU NPU NPU

The support of models for different training stages is shown in the following table:

Table 2: The models and stages supported in MindRLHF

Stages Pangu-alpha GPT2 Baichuan2
SFT Y Y Y
RM Y Y Y
RLHF Y Y Y

In the future, we will integrate more models such as LLAMA, GLM, BLOOM, etc.

Get Started

  • Reward model training: a GPT2 based reward model training tutorial is listed in 'examples'.

  • RLHF fine-tuning: here is an example for RLHF fine-tuning in MindRLHF:

ppo_config, sft_model_config, ref_model_config, critic_model_config, rm_model_config = init_configs(
    args)
trainer = PPOTrainer(ppo_config=ppo_config, sft_model_config=sft_model_config, ref_model_config=ref_model_config,
                        critic_model_config=critic_model_config, rm_model_config=rm_model_config)
ppo_with_grad = init_network_and_optimizer(trainer)
rank_id = D.get_rank()
for epoch in range(ppo_config.epochs):
    # sampling
    trainer.make_experience(num_rollouts=ppo_config.num_rollouts)
    dataset = init_ppo_dataset(trainer)
    # use data sink to accelerate
    trainer.train(ppo_with_grad, dataset, epoch)
    trainer.save_checkpoint(rank_id, epoch)

Contribution

Welcome to the community. You can refer to the MindSpore contribution requirements on the Contributor Wiki.

License

Apache 2.0 License.

mindrlhf's People

Contributors

chessqian avatar mashirochen avatar okbaguo avatar kfertakis avatar kingcong avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.