Giter VIP home page Giter VIP logo

ties-merging's Introduction

Resolving Interference When Merging Models (NeurIPS 2023)

teaser image

Setup

  1. Create a virtual environment and activate it.
python3 -m venv env
source env/bin/activate
  1. Install dependencies
python -m pip install -r requirements.txt -f https://download.pytorch.org/whl/cu113/torch_stable.html
  1. Download Story Cloze Dataset and update its path in data/dataset_readers.py StoryClozeReader class.

  2. Set the path to where finetuned models are stored in utils/merge_utils.py

We have released the IA3 checkpoints here!

Train

Train T5 Models

python src/training.py -c configs/t5_base.json -k train_batch_size=8 gradient_accumulation_factor=1 project_name=training experiment_name=test train_dataset=rte train_dataset_mixture=None num_batches=2

Evaluation

Evaluate IA3 across multiple prompts and report median.

$path_to_checkpoint = # path to your checkpoint
$eval_split = validation
$dataset = rte

python ./src/inference.py -c configs/ia3_base.json --multiple_prompts -i ${dataset} --kwargs checkpoint_to_directly_load_model=${path_to_checkpoint} split=${eval_split} project_name=ia3 experiment_name=${dataset}

Evaluate T5-Large.

$path_to_checkpoint = # path to your checkpoint
$eval_split = validation
$dataset = rte

python ./src/inference.py -c configs/t5_large.json -i ${dataset} --kwargs checkpoint_to_directly_load_model=${path_to_checkpoint} split=${eval_split} project_name=t5-large experiment_name=${dataset}

Merging Models

T5-Large

Basic Averaging

$eval_split = validation

python ./src/ties_merging.py -c configs/t5_large.json -i t5_mixture -m t5_mixture -f basic_mean --kwargs split=${eval_split} project_name=t5-large experiment_name=mean

Task Vectors

$eval_split = validation
$eval_function = task-vector_linear+0.1+1.01+0.1

python ./src/ties_merging.py -c configs/t5_large.json -i t5_mixture -m t5_mixture -f ${eval_function} --kwargs split=${eval_split} project_name=t5-large experiment_name=task_vectors

Performs merging for different values of lambda. will try out all lambda values between 0 and 1 in incrementso of 0.1.

TIES MERGING

$eval_split = validation
$redundant = topk20
$elect = mass
$agg = dis-mean
$scale = linear+0.8+2.51+0.1

python ./src/ties_merging.py -c configs/t5_large.json -i t5_mixture -m t5_mixture -f ${redundant}_${elect}_${agg}_${scale} --kwargs split=${eval_split} project_name=t5-large experiment_name=ties

IA3

Basic Averaging

$eval_split = validation

python ./src/ties_merging.py -c configs/ia3_base.json -i T0_held_out -m T0_held_out -f basic_mean --multiple_prompts --kwargs pretrained_model=bigscience/T0_3B split=${eval_split} project_name=ia3 experiment_name=mean

Task Vectors

$eval_split = validation
$eval_function = task-vector_linear+0.1+1.01+0.1

python ./src/ties_merging.py -c configs/ia3_base.json -i T0_held_out -m T0_held_out -f ${eval_function} --multiple_prompts --kwargs pretrained_model=bigscience/T0_3B split=${eval_split} project_name=ia3 experiment_name=task_vectors

TIES MERGING

$eval_split = validation
$redundant = topk20
$elect = mass
$agg = dis-mean
$scale = linear+0.8+2.51+0.1

python ./src/ties_merging.py -c configs/ia3_base.json -i T0_held_out -m T0_held_out -f ${redundant}_${elect}_${agg}_${scale} --multiple_prompts --kwargs pretrained_model=bigscience/T0_3B split=${eval_split} project_name=ia3 experiment_name=ties

Reference

Please cite our paper if you use our models in your works:

@inproceedings{
      yadav2023tiesmerging,
      title={{TIES}-Merging: Resolving Interference When Merging Models},
      author={Prateek Yadav and Derek Tam and Leshem Choshen and Colin Raffel and Mohit Bansal},
      booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
      year={2023},
      url={https://openreview.net/forum?id=xtaX3WyCj1}
}

ties-merging's People

Contributors

prateeky2806 avatar kim-jake avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.