Giter VIP home page Giter VIP logo

promptem's Introduction

PromptEM: Prompt-tuning for Low-resource Generalized Entity Matching

PromptEM is a novel low-resource GEM (Generalized Entity Matching) solution powered by prompt-tuning and self-training. To address the gap between pre-training and fine-tuning, we cast GEM as a cloze-style task and design the GEM-specific prompt-tuning, which can stimulate the rich knowledge distributed in LMs. To select high-quality pseudo-labels, we develop a lightweight uncertainty-aware self-training method to boost performance. To further avoid expensive self-training, we prune useless training data dynamically using the proposed MC-EL2N, making the self-training process more lightweight and efficient.

For more technical details, see PromptEM: Prompt-tuning for Low-resource Generalized Entity Matching.

The illustration of fine-tuning and prompt-tuning.

Datasets

We use eight real-world benchmark datasets with different structures from Machamp and Geo-ER.

Quick Start

To train and evaluate with PromptEM.

python main.py [<args>] [-h | --help]

e.g.

python main.py -d=rel-heter -k=0.1 -st -dd=8 -ur=0.05 -er=0.05

The meaning of the flags:

  • --model_name_or_path: the name or local path of the pre-trained language model. e.g. roberta-base
  • --data_name: the name of the dataset. options: [rel-heter, rel-text, semi-heter, semi-homo, semi-rel, semi-text-c,semi-text-w, geo-heter, all]
  • --k: the proportion of training data used. e.g. 0.1
  • --num_iter: the number of iterations. e.g. 1
  • --template_no: the number of templates used in PromptEM. options: [0,1,2,3]
  • --self_training: the flag to enable self-training of PromptEM.
  • --dynamic_dataset: the frequency of dynamic dataset pruning. e.g. 8 (pruning for every 8 epochs)
  • --pseudo_label_method: the method of generating pseudo labels. e.g. uncertainty
  • --mc_dropout_pass: the number of MC-Dropout passes. e.g. 10
  • --uncertainty_ratio: the proportion of the unlabeled samples for generating pseudo labels. e.g. 0.05
  • --el2n_ratio: the proportion of the labeled samples for dynamic dataset pruning. e.g. 0.1
  • --text_summarize: the flag to enable text summarization in entity serialization.
  • --add_token: the flag to add special token in entity serialization.
  • --max_length: the maximum length (in number of tokens) for the inputs to the transformer model. e.g. 512
  • --teacher_epochs: the number of epochs for training teacher model. e.g. 20
  • --student_epochs : the number of epochs for training student model. e.g. 30
  • --batch_size: batch size. e.g. 32
  • --lr: learning rate. e.g.1e-5

Reproduction

All the experiments are conducted on an Ubuntu Server with an Intel Xeon Silver 4216 CPU and an NVIDIA A100 GPU.

Initial the environment

conda create -n promptem python=3.7
const activate promptem
pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install transformers==4.16.2
pip install scikit-learn==1.0.2

Note that you do not need to install OpenPrompt by pip manually.

We notice that the best hyper-parameters can be sensitive to your server environment and package version. If you do not have the same environment, we highly recommend you run the search for hyper-parameters in your environment.

We provide an example search script in search.sh

Download the PLM [Optional]

We use RoBERTa-base as the backbone structure of our model in all the experiments.

You can download the pre-trained checkpoint from huggingface manually.

Reproduce the result

You can train the model using the best hyper-parameters we provided in low_configs.json.

We also provide the corresponding logs in logs.

See Quick Start for more details of training parameters.

promptem's People

Contributors

peng-fei-wang avatar zeng1998 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

promptem's Issues

classes in dataset and verbalizer are different

I notice that in the dataset, 0 means unmatched and 1 means matched.
But in verbalizer, 0 means matched and 1 means unmatched. Do you notice this difference?

In prompt.py
classes = [
"yes",
"no",
]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.