Giter VIP home page Giter VIP logo

parallel-context-windows's Introduction

Parallel Context Windows (PCW)

Omri Notes

This repo was ill-forked from AI21 repo, and updated to work with llama2. The goal is to test different variations of PCW (with topk etc) on long-text problems.

Original readme from here

This repo contains the code for reproducing the classification experiments from AI21 Labs' paper Parallel Context Windows for Large Language Models .
The code was tested with python 3.10, for CPU, GPU and multiple GPU runs. Currently, the code supports using GPT2 and LLaMa model families.

Setup

To install the required libraries in our repo, run:

pip install -r requirements.txt

To have a Pytorch version specific to your CUDA, install your version before running the above command.

Evaluation

Due to the fact that the paper's results were based on an earlier implementation of PCW and not HuggingFace Transformers, the results produced using this code may differ slightly from those shown in the paper. To reproduce similar results shown in the appendix for GPT2-XL for a specific dataset (for example SST2), simply run:

python run_evaluation.py \
--dataset sst2 \
--model gpt2-xl \
--n-windows 1 \
--n-windows 3 \
--subsample-test-set 250 \
--n-runs 30 \
--output-dir $OUTPUT_DIR

In this run, PCW's performance is evaluated on a subsample (250 samples) of the full test set. The experiment is repeated 30 times (with different random samples of training examples) for each number of windows (in this case - one and three). As a default, the script uses as many examples per window as possible. Note that using a single window is equivalent to regular ICL settings. Thus, this run should give similar results to those shown in Table 5 for SST2 with GPT2-XL.

The evaluation output is a numpy file (shaped [2,30]) found in $OUTPUT_DIR with the mean accuracy for each repetition and number of windows. You could read the file directly with np.load, or use utils.py function to load and plot the results. See --help for further instructions.

PCW Usage examples

In the evaluation code, only classification tasks are performed. The code snippet below shows how PCW can be used both for classification and generation:

import numpy as np

from model_loaders import load_pcw_wrapper
from logits_processor import RestrictiveTokensLogitsProcessor

from utils import encode_labels

wrapper = load_pcw_wrapper('gpt2-large', n_windows=2)

# use PCW with few shot for classification example:
labels_input_ids = np.array(encode_labels(wrapper.tokenizer, ['positive', 'negative']))
# using RestrictiveTokensLogitsProcessor forces the output to be one of the labels:
logit_processor = RestrictiveTokensLogitsProcessor(labels_input_ids, eos_token_id=wrapper.tokenizer.eos_token_id)
output = wrapper.pcw_generate(contexts=["Review: Great movie! Sentiment: positive\n",
                                        "Review: Horrible film Sentiment: negative\n"],
                              task_text="Review: I liked it Sentiment:",
                              restrictive_logit_preprocessor=logit_processor,
                              temperature=0,
                              max_new_tokens=1)
print(output.strip())
# use PCW for generation:
output = wrapper.pcw_generate(contexts=["Review: Great movie!\n", "Review: Horrible film\n"],
                              task_text="Review:",
                              temperature=1,
                              do_sample=True,
                              max_new_tokens=16)
print(output)

Citation

If you find our paper or code helpful, please consider citing our paper:

@misc{ratner2023parallel,
      title={Parallel Context Windows for Large Language Models}, 
      author={Nir Ratner and Yoav Levine and Yonatan Belinkov and Ori Ram and Inbal Magar and Omri Abend and Ehud Karpas and Amnon Shashua and Kevin Leyton-Brown and Yoav Shoham},
      year={2023},
      eprint={2212.10947},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

parallel-context-windows's People

Contributors

omri123 avatar nirrai21 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.