Giter VIP home page Giter VIP logo

diffaug's Introduction

DiffAug: Differentiable Data Augmentation for Contrastive Sentence Representation Learning

  • This repo contains the code and pre-trained model checkpoints for our EMNLP 2022 paper.
  • Our code is based on the SimCSE.

Overview

We propose a method that makes high-quality positives for contrastive sentence representation learning. A pivotal ingredient of our approach is the use of prefix that attached to a pre-trained language model, which allows for differentiable data augmentation during contrastive learning. Our method can be summarized in two steps: supervised prefix-tuning followed by joint contrastive fine-tuning with unlabeled or labeled examples. The following figure is an overview of the proposed two-stage training strategy.

Install dependencies

First, install PyTorch on the official website. All our experiments are conducted with PyTorch v1.8.1 with CUDA v10.1. So you may use the following code to download the same PyTorch version:

pip install torch==1.8.1+cu101 torchvision==0.9.1+cu101 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html

Then run the following script to install the remaining dependencies:

pip install -r requirements.txt

Prepare training and evaluation datasets

We use the same training and evaluation datasets as SimCSE. Therefore, we adopt their scripts for downloading the datasets.

To download the unlabeled Wikipedia dataset, please run

cd data/
bash download_wiki.sh

To download the labeled NLI dataset, please run

cd data/
bash download_nli.sh

To download the evaluation datasets, please run

cd SentEval/data/downstream/
bash download_dataset.sh

Following previous works, we use SentEval to evaluate our model.

Training

We prepared two example scripts for reproducing our results under the semi-supervised and supervised settings respectively.

To train a semi-supervised model, please run

bash run_semi_sup_bert.sh

To train a supervised model, please run

bash run_sup_bert.sh

Evaluation

To evaluate the model, please run:

python evaluation.py \
    --model_name_or_path <model_checkpoint_path> \
    --mode <dev|test>

The results are expected to be shown in the following format:

*** E.g. Supervised model evaluatation results ***
+-------+-------+-------+-------+-------+--------+---------+-------+
| STS12 | STS13 | STS14 | STS15 | STS16 |  STSB  |  SICKR  |  Avg. |
+-------+-------+-------+-------+-------+--------+---------+-------+
| 77.40 | 85.24 | 80.50 | 86.85 | 82.59 | 84.12  |  80.29  | 82.43 |
+-------+-------+-------+-------+-------+--------+---------+-------+

Well-trained model checkpoints

We prepare two model checkpoints:

Here is an example about how to evaluate them on STS tasks:

python evaluation.py \
    --model_name_or_path Tianduo/diffaug-semisup-bert-base-uncased \
    --mode test

Citation

Please cite our paper if it is helpful to your work:

@inproceedings{wang2022diffaug,
   title={Differentiable Data Augmentation for Contrastive Sentence Representation Learning},
   author={Wang, Tianduo and Lu, Wei},
   booktitle={Empirical Methods in Natural Language Processing (EMNLP)},
   year={2022}
}

diffaug's People

Contributors

hsqzzpf avatar tianduowang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

dumpmemory

diffaug's Issues

When will this paper be accessible?

Hi, I'm Gordon Lee.
Thanks for your excellent work on sentence representation learning (SRL).
I would like to know when this paper will be accessible.
By the way, I am working on collecting the SRL paper list with the unsupervised STS leaderboard.
If you are interested, you can access the following repo:
https://github.com/Doragd/Awesome-Sentence-Embedding/

I will add your paper to our paper list soon :P

Question about No unsupervised representation learning experiment

Hi Tianduo,
I really appreciated your work in developing the learnable data augmentation for sentence representation learning. Your proposed method DiffAug has shown really good performance in semi-supervised and supervised settings.

However, I was wondering how is the performance of DiffAug on unsupervised settings.

  • If you have already tried, did DiffAug still show better performance than SimCSE?
  • If not, how do you think we first train the prefix on unsupervised contrastive learning (freeze the language model), and then jointly train the language model and prefix?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.