Giter VIP home page Giter VIP logo

dart's Introduction

DART

Implementation for ICLR2022 paper Differentiable Prompt Makes Pre-trained Language Models Better Few-shot Learners.

  • ❗NOTE: The code has been reorganized and we also provide a paper-list at PromptKG.

Environment

  • [email protected]
  • Use pip install -r requirements.txt to install dependencies.
  • wandb account is required if the user wants to search for best hyper-parameter combinations.

Data source

  • 16-shot GLUE dataset from LM-BFF.
  • Generated data consists of 5 random splits (13/21/42/87/100) for a task, each has 16 samples.
    • The generation process follows LM-BFF here.

How to run

  • To train / test on a data split from a single task with specific parameters, use run.py.
    • For customized training & evaluation, you can modify based on the sample configuration file config/sample.yml.
$ python run.py -h  
usage: run.py [-h] [--config CONFIG] [--do_train] [--do_test]

optional arguments:
  -h, --help            show this help message and exit
  --config CONFIG, -c CONFIG
                        Configuration file storing all parameters
  --do_train
  --do_test
  • To search optimal hyper-parameters for each task and reproduce our result, please use sweep.py:
    • Please refer to documentation for WandB for more details.
    • ❗NOTE: we follow LM-BFF in that we search optimal sets of hyper-parameters on different data splits respectively.
$ python sweep.py -h
usage: sweep.py [-h] [--project_name PROJECT_NAME] --task_name TASK_NAME
                [--data_split {13,21,42,87,100}]
                [--pretrain_model PRETRAIN_MODEL] [--pet_method {pet,diffpet}]
                [--random_seed RANDOM_SEED] [--max_run MAX_RUN]

optional arguments:
  -h, --help            show this help message and exit
  --project_name PROJECT_NAME
                        project name for sweep
  --task_name TASK_NAME
  --data_split {13,21,42,87,100}
                        few-shot split-id for GLUE dataset
  --pretrain_model PRETRAIN_MODEL
                        name or path for pretrained model
  --pet_method {pet,diffpet}
                        prompt encoding method
  --random_seed RANDOM_SEED
                        random seed for training
  --max_run MAX_RUN     maximum tries for sweep

How to Cite

@inproceedings{
zhang2022differentiable,
title={Differentiable Prompt Makes Pre-trained Language Models Better Few-shot Learners},
author={Ningyu Zhang and Luoqiu Li and Xiang Chen and Shumin Deng and Zhen Bi and Chuanqi Tan and Fei Huang and Huajun Chen},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=ek9a0qIafW}
}

dart's People

Contributors

dependabot[bot] avatar riroaki avatar zxlzr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

dart's Issues

Event extraction.

Hi, I notice that you report the few-shot results on ACE-2005 and I'm very interested in your implementation. Do you have any plan to share the code for event extraction? It would be very helpful.

confused about the usage of BLOCK_FLAG

I'm trying to read the code. The 'BLOCK_FLAG' really confused me, I think it is used for calculating the length of prompt, but I'm not sure if values in 'BLOCK_FLAG' are right. Take BoolQPVP as example, I think other than 'passage', 'question', 'self.mask', which are not part of template, the length of rest words in PATTERN should be count. So I think BLOCK_FLAG should be

BLOCK_FLAG = [0, 1, 1, 1, 0, 1, 0 ,1] 

insdead of

BLOCK_FLAG = [0, 0, 1, 0, 0, 0, 0, 0]

Is there something I misunderstood?

class BoolQPVP(PVP):

    VERBALIZER = {
        "False": ["No"],
        "True": ["Yes"]
    }
    """
    VERBALIZER_B = {
        "False": ["false"],
        "True": ["true"]
    }
    """

    PATTERN = ['passage', '.', 'the', ' Question: ',
               'question', '? Answer: ', 'self.mask', '.']

    BLOCK_FLAG = [0, 0, 1, 0, 0, 0, 0, 0]

    def get_parts(self, example: InputExample) -> FilledPattern:
        passage = self.shortenable(example.text_a)
        question = self.shortenable(example.text_b)

        # searched patterns in fully-supervised learning
        # string_list_a = [passage, '.', 'the', 'Question:', question, '?', 'the', 'Answer:', self.mask]
        # string_list_a = [passage, '.', 'the', question, '?', 'the', self.mask]
        # string_list_a = [passage, 'the', question, '?', 'the', self.mask]

        # few-shot
        if self.pattern_id == 1:

            string_list_a = [passage, '.', 'the', ' Question: ',
                             question, '? Answer: ', self.mask, '.']
            string_list_b = []
            block_flag_a = self.BLOCK_FLAG
            block_flag_b = []
            assert len(string_list_a) == len(block_flag_a)
            assert len(string_list_b) == len(block_flag_b)
            return string_list_a, string_list_b, block_flag_a, block_flag_b

        else:
            raise ValueError("unknown pattern_id.")

where to find the data

Hi Dear Author,

When trying to run the inference.py by directly running python inference.py, it gives me the following error:

Traceback (most recent call last): File "/data/co_project/DART/inference.py", line 53, in <module> train_data = load_examples(task_name, data_dir, TRAIN_SET, num_examples=-1) File "/data/co_project/DART/data_utils/processors.py", line 882, in load_examples examples = processor.get_train_examples(data_dir) File "/data/co_project/DART/data_utils/processors.py", line 700, in get_train_examples return self._create_examples(os.path.join(data_dir, "train.csv"), "train") File "/data/co_project/DART/data_utils/processors.py", line 722, in _create_examples with open(path, encoding='utf8') as f: FileNotFoundError: [Errno 2] No such file or directory: 'data/k-shot/mr/16-13/train.csv'

It looks like the data is not pre-downloaded. Can I ask where to download those data and how can I put them into the correct place? Thanks!

Some questions in model.py

Hello, I am very fond of your paper and work,but I have some problems when I try to understand your code.
In model.py:
def get_loss(self, batch, full_vocab=True, logits_key='pet_logits'):
# Compute Cross-Entropy loss for prompt verbalizers
assert logits_key in batch, 'logits should be pre-computed and stored in batch dict'
masked_logits = batch[logits_key][batch['pet_flags'] == -1]
labels = batch['pet_labels']
if not full_vocab:
masked_logits = masked_logits[:, self.label_ids]
labels = batch['label_ids']
return self.loss_fn(masked_logits, labels)

The size of masked_logits is [batch_size,hidden_states] (I used batch_size=4 and bertmodel so the size is [4,768]), and the size of labels is [batch_size,] . However,when computing crossentropy loss,it always raises error that Index out of range. Is there lacking some layers that can transform the masked_logits's size ([4,768]) into [4,num_classes]?
Meanwhile, I also can't understand the meaning of "full_vocab":when to set it false and when to set it true?
Looking forward to your reply!

Multi-token verbalizers

The current implementation requires all verbalizers to be single tokens, i.e. the word must be part of the vocabulary of the model being used. This is a significant obstacle for practical applications. I noticed there is a flag called force_single_token in get_verbalization_ids, somehow suggesting that verbalizers with more than one token should be an option. I have tried to modify this, but then I get some other errors further downstream, and I must my Pytorch skills are not quite up to the task of making the necessary modifications to the code. Any hints about how to go around this would be much appreciated.

Question about unused tokens

Hi,

It's so kind that you release your codes. After reading the codes, I have a question about unused token.

According to your paper, DART maps template and label as {h1, ..., hm, ..., hm+n} , where hi are the trainable parameters, and they are replaced with unused tokens(e.g., [unused1] or special tokens in vocabulary). In my opinion, they will be replaced with special tokens in high probability, because it's difficult to distinguish which token were not used during training.

However, in your codes, they are replaced with last few tokens in vocabulary. So I'd like to know how you ensure that the last few tokens in vocabulary were not used? Does it mean that I can pick tokens randomly from vocabulary to replace hi ?

Is there a way to store/view predictions | In order to analyze example results

Hi,

Great work!
I'm currently looking at the paper and running experiments on it.
In the codes: Is there a way to store/view predictions, In order to analyze example results?

I wanna see what are the shortcomings of regular techniques and how DART has overcome them,
and lastly, are there still things left for DART to solve or opportunities to improve DART?

Thank you

Questions about symbols

Hi, i have two questions about symbols in this paper.

  1. Eq. (4) shows that h_i (0<=i <=j) are trainable parameters. Thus, j is the number of trainable embedding in template.
    However, Eq. (6) shows that m is the number of trainable embedding in template.
    From Eq. (4) and template, we can see that j is the length of template and m is the length of sentence after being filled.
    Thus, what is the number of trainable embedding? I think it's j.

  2. To avoid optimizing any external parameters, {h_1,...,h_m,..,h_{m+n}} is replaced with unused tokens .... What does n mean? Does it denote the number of class?

Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.