Giter VIP home page Giter VIP logo

datainf's Introduction

DataInf: Efficiently Estimating Data Influence in LoRA-tuned LLMs and Diffusion Models

We provide a codebase for "DataInf: Efficiently Estimating Data Influence in LoRA-tuned LLMs and Diffusion Models" accepted at ICLR 2024. DataInf is an efficient influence approximation method that is practical for large-scale generative AI models such as LLMs or stable diffusion models. DataInf leverages an easy-to-compute closed-form expression, outperforming existing influence computation algorithms in terms of computational and memory efficiency.

Quick start

(Task 1) Mislabeled data detection

An easy-to-start Jupyter notebook notebokes/Mislabeled_Data_Detection-RoBERTa-MRPC.ipynb demonstrates how to compute the influence function values and how to detect mislabeled data points using the computed influence function values.

  • We use the RoBERTa-large model and LoRA, a parameter-efficient fine-tuning technique, to significantly reduce the total number of parameters.
  • We consider a noisy version of the GLUE-MRPC dataset; We synthetically generate mislabeled data points by flipping the label of data points. We randomly selected 20% of data points.

(Task 2) Influential data identification

A Jupyter notebook notebokes/Influential_Data_Identification-Llama2-Math-Reason.ipynb demonstrates how to efficiently compute the influence function values, showing its applications to identify most influential data points. We use the llama2-13b-chat. It has thw following steps.

  • Step 1 Dataset generation: generate the math_problem (with reasoning) dataset with the following bash command. It will be stored at the datasets folder.
python3 src/generate_sentence-math_datasets.py

It will generate the sentence_transformation and math_problem (withour reasoning) datasets as well.

  • Step 2 Fine-tune a model: fine-tune a llama-2-13b-chat model on the math problem (with reasoning) dataset. We use src/sft_trainer.py, which is built on HuggingFace's SFTTrainer. A sample CLI is given as follows.
python /YOUR-DATAINF-PATH/DataInf/src/sft_trainer.py \
    --model_name /YOUR-LLAMA-PATH/llama/models_hf/llama-2-13b-chat \
    --dataset_name /YOUR-DATAINF-PATH/DataInf/datasets/math_with_reason_train.hf \
    --output_dir /YOUR-DATAINF-PATH/DataInf/models/math_with_reason_13bf \
    --dataset_text_field text \
    --load_in_8bit \
    --use_peft
  • Step 3 Compute the gradients and influence function values.

The core python file

  • dataloader.py includes the construction of tokenizers and generates noisy datasets.

  • lora_model.py includes LoRA modules.

  • influence.py includes influence computation algorithms.

  • generate_sentence-math_datasets.py generates the sentence_transformation and the math problem datasets.

CLI tool for mislabeled data detection tasks

We also provide a CLI tool. The following command will compute the influence function values of the GLUE-QNLI dataset. It uses the RoBERTa-large model and the LoRA rank is set to 8.

python3 launcher.py run --exp_id='config_qnli4' --run-id=0 --runpath='./'

Cite Us

If you found the library or the paper useful, please cite us!

@article{kwon2023datainf,
  title={Datainf: Efficiently estimating data influence in lora-tuned llms and diffusion models},
  author={Kwon, Yongchan and Wu, Eric and Wu, Kevin and Zou, James},
  journal={arXiv preprint arXiv:2310.00902},
  year={2023}
}

datainf's People

Contributors

ykwon0407 avatar ajsanjoaquin avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.