Giter VIP home page Giter VIP logo

disentangled-retriever's Introduction

Disentangled Neural Ranking

Disentangle License made-with-pytorch code-size

This is the official repo for our paper Disentangled Modeling of Domain and Relevance for Adaptable Dense Retrieval. Disentangled Neural Ranking is a novel paradigm that supports effective and flexible domain adaptation for neural ranking models including Dense Retrieval, uniCOIL, SPLADE, ColBERT, and BERT re-ranker.

Features

  • One command for unsupervised and effective domain adaption.
  • One command for effective few-shot domain adaption.
  • Various ranking architectures, including Dense Retrieval, uniCOIL, SPLADE, ColBERT, and BERT re-ranker.
  • Two source-domain finetuning methods, contrastive finetuning and distillation.
  • Huggingface-style training and inference, supporting multi-gpus, mixed precision, etc.

Quick Links

Quick Tour

Neural Ranking models are vulnerable to domain shift: the trained models may even perform worse than traditional retrieval methods like BM25 in out-of-domain scenarios.

In this work, we propose Disentangled Neural Ranking (DNR) to support effective and flexible domain adaptation. DNR consists of a Relevance Estimation Module (REM) for modeling domain-invariant matching patterns and several Domain Adaption Modules (DAMs) for modeling domain-specific features of multiple target corpora. DNR enables a flexible training paradigm in which REM is trained with supervision once and DAMs are trained with unsupervised data.

Neural Ranking Disentangled Neural Ranking

The idea of DNR can date back to classic retrieval models in the pre-neural-ranking era. BM25 utilizes the same formula for estimating relevance scores across domains but measures word importance with corpus-specific IDF values. However, it does not exist in vanilla neural ranking models where the abilities of relevance estimation and domain modeling are jointly learned during training and entangled within the model parameters.

Here are two examples when we apply disentangled modeling for domain adaption. We plot the figure where y-axis shows the relative improvement over BM25 and x-axis shows different out-of-domain test sets. The ranking performance of Dense Retrieval (DR) and Disentangled Dense Retrieval (DDR) is shown below.

NDCG@10 Recall@1000

The ranking performance of ColBERT and Disentangled ColBERT (D-ColBERT) is shown below.

NDCG@10 Recall@1000

Disentangled modeling brings amazing out-of-domain performance gains! More details are available in our paper.

Installation

This repo is developed with PyTorch and Faiss. They should be installed manually due to the requirement of platform-specific custom configuration. In our development, we run the following commands for installation.

# XX.X is a placeholder for cudatoolkit version. It should be specified according to your environment
conda install pytorch torchvision torchaudio cudatoolkit=XX.X -c pytorch 
conda install -c conda-forge faiss-gpu

After these, now you can install from our code:

git clone https://github.com/jingtaozhan/disentangled-retriever
cd disentangled-retriever
pip install .

For development, use

pip install --editable .

Released Models

We release about 50 models to facilitate reproducibility and reusage. You do not have to manually download these. They will be automatically downloaded at runtime.

Relevance Estimation Modules for Dense Retrieval (click to expand)
Relevance Estimation Modules for UniCOIL (click to expand)
Relevance Estimation Modules for SPLADE (click to expand)
Relevance Estimation Modules for ColBERT (click to expand)
Relevance Estimation Modules for BERT re-ranker (click to expand)
Domain Adaption Modules for various datasets (click to expand)

Besides Disentangled Neural Ranking models, we also release the vanilla/traditional neural ranking models, which are baselines in our paper.

Vanilla Neural Ranking Checkpoints (click to expand)
Vanilla Dense Retrieval (click to expand)
Vanilla uniCOIL (click to expand)
Vanilla SPLADE (click to expand)
Vanilla ColBERT (click to expand)
Vanilla BERT re-ranker (click to expand)
*Note: Our code also supports training and evaluating vanilla neural ranking models!*

Example usage:

Here is an example about using disentangled dense retrieval for ranking. The REM is generic, while the DAM is domain-specifically trained to mitigate the domain shift. The two modules are assembled during inference.

from transformers import AutoConfig, AutoTokenizer
from disentangled_retriever.dense.modeling import AutoDenseModel

# This is the Relevance Estimation Module (REM) contrastively trained on MS MARCO
# It can be used in various English domains.
REM_URL = "https://huggingface.co/jingtao/REM-bert_base-dense-contrast-msmarco/resolve/main/lora192-pa4.zip"
## For example, we will apply the model to TREC-Covid dataset. 
# Here is the Domain Adaption Module for this dataset.
DAM_NAME = "jingtao/DAM-bert_base-mlm-msmarco-trec_covid"

## Load the modules
config = AutoConfig.from_pretrained(DAM_NAME)
config.similarity_metric, config.pooling = "ip", "average"
tokenizer = AutoTokenizer.from_pretrained(DAM_NAME, config=config)
model = AutoDenseModel.from_pretrained(DAM_NAME, config=config)
adapter_name = model.load_adapter(REM_URL)
model.set_active_adapters(adapter_name)
model.merge_lora(adapter_name)

## Let's try to compute the similarities
queries  = ["When will the COVID-19 pandemic end?", "What are the impacts of COVID-19 pandemic to society?"]
passages = ["It will end soon.", "It makes us care for each other."]
query_embeds = model(**tokenizer(queries, return_tensors="pt", padding=True, truncation=True, max_length=512))
passage_embeds = model(**tokenizer(passages, return_tensors="pt", padding=True, truncation=True, max_length=512))

print(query_embeds @ passage_embeds.T)

Results are:

tensor([[107.6821, 101.4270],
        [103.7373, 105.0448]], grad_fn=<MmBackward0>)

Preparing datasets

We will use various datasets to show how disentangled modeling facilitates flexible domain adaption. Before the demonstration, please download and preprocess the corresponding datasets. Here we provide detailed instructions:

Zero-shot Domain Adaption

Suppose you already have a REM module (trained by yourself or provided by us) and you need to adapt the model to an unseen domain. To do this, just train a Domain Adaption Module (DAM) to mitigate the domain shift.

The training process is completely unsupervised and only requires the target-domain corpus. Each line of the corpus file should be formatted as `id doc' separated by tab. Then you can train a DAM model with only one command.

python -m torch.distributed.launch --nproc_per_node 4 \
    -m disentangled_retriever.adapt.run_adapt_with_mlm \
    --corpus_path ... ... ...

The trained DAM can be combined with different REMs and formalize a well-performing neural ranking models. We provide many REMs (see this section) that correspond to different ranking methods or are trained with different losses. The trained DAM can be combined with any REM to become an effective ranking model, e.g., a Dense Retrievla model or a ColBERT model. For example, if you want to acquire a dense retrieval model, use the following command for inference:

python -m torch.distributed.launch --nproc_per_node 4 \
    -m disentangled_retriever.dense.evaluate.run_eval \
    --backbone_name_or_path [path-to-the-trained-DAM] \
    --adapter_name_or_path [path-to-the-dense-retrieval-rem] \
    --corpus_path ... --query_path ... ... ...

If you want to acquire a ColBERT model, use the following command for inference:

python -m torch.distributed.launch --nproc_per_node 4 \
    -m disentangled_retriever.colbert.evaluate.run_eval \
    --backbone_name_or_path [path-to-the-trained-DAM] \
    --adapter_name_or_path [path-to-the-dense-retrieval-rem] \
    --corpus_path ... --query_path ... ... ...

We give two adaption examples. They train a separate DAM in the target domain and re-use our released REMs.

Please try these examples before using our methods on your own datasets.

Few-shot Domain Adaption

Coming soon.

Learning Generic Relevance Estimation Ability

We already release a bunch of Relevance Estimation Modules (REMs) for various kinds of ranking methods. You can directly adopt these public checkpoints. But if you have some private labeled data and want to a Relevance Estimation Module (REM) on it, here we provide instructions on how to do this.

To directly use this codebase for training, you need to convert your dataformat as follows

  • corpus.tsv: corpus file. each line is `docid doc' separated by tab.
  • query.train: training queries. each line is `qid query' separated by tab.
  • qrels.train: annotations. each line is `qid 0 docid rel_level' separated by tab.
  • [Optional] hard negative file for contrastive training: each line is `qid neg_docid1 neg_docid2 ...'. qid and neg_docids are separated by tab. neg_docids are separated by space.
  • [Optional] soft labels for knowledge distillation: a pickle file containing a dict: {qid: {docid: score}}. It should contain the soft labels of positive pairs and of several negative pairs.

If you still have questions about the data formatting, you can check how we convert MS MARCO.

With formatted supervised data, now you can train a REM module. We use a disentangled finetuning trick: first training a DAM module to capture domain-specific features and then training the REM module to learn domain-invariant matching patterns.

Here we provide instructions about training REMs for different ranking methods.

  • Train REM for Dense Retrieval: on English MS MARCO | on Chinese Dureader
  • Train REM for uniCOIL: [on English] [on Chinese] (coming soon)
  • Train REM for SPLADE: [on English] [on Chinese] (coming soon)
  • Train REM for ColBERT: [on English] [on Chinese] (coming soon)
  • Train REM for BERT re-ranker: [on English] [on Chinese] (coming soon)

Reproducing Results with Released Checkpoints

We provide commands for reproducing the various results in our paper.

Training Vanilla Neural Ranking Models

This powerful codebase not only supports Disentangled Neural Ranking, but also vanilla Neural Ranking models. You can easily reproduce state-of-the-art Dense Retrieval, uniCOIL, SPLADE, ColBERT, and BERT rerankers using this codebase! The instructions are provided as below.

Citation

If you find our work useful, please consider citing us :)

@article{zhan2022disentangled,
  title={Disentangled Modeling of Domain and Relevance for Adaptable Dense Retrieval},
  author={Zhan, Jingtao and Ai, Qingyao and Liu, Yiqun and Mao, Jiaxin and Xie, Xiaohui and Zhang, Min and Ma, Shaoping},
  journal={arXiv preprint arXiv:2208.05753},
  year={2022}
}

disentangled-retriever's People

Contributors

jingtaozhan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

disentangled-retriever's Issues

run_contrast.py: AttributeError: 'BertModel' object has no attribute 'add_adapter'

Hi,

I was trying to train my own REM by following the instruction.

output_dir="./data/dense-mlm/english-marco/train_rem/rem-with-hf-dam/contrast"

python -m torch.distributed.launch --nproc_per_node 4 \ 
    -m disentangled_retriever.dense.finetune.run_contrast \
    --lora_rank 192 --parallel_reduction_factor 4 --new_adapter_name msmarco \
    --pooling average \
    --similarity_metric ip \
    --qrel_path ./data/datasets/msmarco-passage/qrels.train \
    --query_path ./data/datasets/msmarco-passage/query.train \
    --corpus_path ./data/datasets/msmarco-passage/corpus.tsv \
    --negative ./data/datasets/msmarco-passage/msmarco-hard-negatives.tsv \
    --output_dir $output_dir \
    --model_name_or_path jingtao/DAM-bert_base-mlm-msmarco \
    --logging_steps 100 \
    --max_query_len 24 \
    --max_doc_len 128 \
    --per_device_train_batch_size 32 \
    --inv_temperature 1 \
    --gradient_accumulation_steps 1 \
    --fp16 \
    --neg_per_query 3 \
    --learning_rate 2e-5 \
    --num_train_epochs 5 \
    --dataloader_drop_last \
    --overwrite_output_dir \
    --dataloader_num_workers 0 \
    --weight_decay 0 \
    --lr_scheduler_type "constant" \
    --save_strategy "epoch" \
    --optim adamw_torch

However, I then get AttributeError: 'BertModel' object has no attribute 'add_adapter'.

    def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False, set_active: bool = False):
        """
        Adds a new adapter module of the specified type to the model.

        Args:
            adapter_name (str): The name of the adapter module to be added.
            config (str or dict, optional): The adapter configuration, can be either:

                - the string identifier of a pre-defined configuration dictionary
                - a configuration dictionary specifying the full config
                - if not given, the default configuration for this adapter type will be used
            overwrite_ok (bool, optional):
                Overwrite an adapter with the same name if it exists. By default (False), an exception is thrown.
            set_active (bool, optional):
                Set the adapter to be the active one. By default (False), the adapter is added but not activated.

        If self.base_model is self, must inherit from a class that implements this method, to preclude infinite
        recursion
        """
        if self.base_model is self:
            super().add_adapter(adapter_name, config, overwrite_ok=overwrite_ok, set_active=set_active)
        else:
            # error thrown here on the following line
            self.base_model.add_adapter(adapter_name, config, overwrite_ok=overwrite_ok, set_active=set_active) 

Error Stack

[WARNING|modeling_utils.py:3180] 2023-11-04 16:40:20,978 >> Some weights of the model checkpoint at jingtao/DAM-bert_base-mlm-msmarco were not used when initializing BertDense: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertDense from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertDense from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[WARNING|modeling_utils.py:3192] 2023-11-04 16:40:20,978 >> Some weights of BertDense were not initialized from the model checkpoint at jingtao/DAM-bert_base-mlm-msmarco and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[INFO|modeling_utils.py:2839] 2023-11-04 16:40:21,154 >> Generation config file not found, using a generation config created from the model config.
11/04/2023 16:40:21-INFO-adapter_arg- Add a lora adapter and only train the adapter
11/04/2023 16:40:21-INFO-adapter_arg- Add a parallel adapter and only train the adapter
Traceback (most recent call last):
  File "C:\Users\ymurong\Documents\Github\Domain-Adapation-French-Legal-Retrieval\scripts\disentangled-retriever\run_contrast.py", line 203, in <module>
    main()
  File "C:\Users\ymurong\Documents\Github\Domain-Adapation-French-Legal-Retrieval\scripts\disentangled-retriever\run_contrast.py", line 145, in main
    model.add_adapter(model_args.new_adapter_name, config=adapter_config)
  File "C:\Users\ymurong\Documents\Github\Domain-Adapation-French-Legal-Retrieval\venv\lib\site-packages\transformers\adapters\model_mixin.py", line 1077, in add_adapter
    self.base_model.add_adapter(adapter_name, config, overwrite_ok=overwrite_ok, set_active=set_active)
  File "C:\Users\ymurong\Documents\Github\Domain-Adapation-French-Legal-Retrieval\venv\lib\site-packages\torch\nn\modules\module.py", line 1269, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'BertModel' object has no attribute 'add_adapter'

Is there anything that I could do wrong?

One small modification that I had done is to change the import augument as there is no BertAdapterModel in transformers in my case. Maybe it could be the reason? I am currenty using transformers-4.33.3 with adapter-transformers==3.2.1. I am running python3.10.

from transformers import BertAdapterModel

to

from transformers.adapters import BertAdapterModel

Thank you for your help!

官方示例运行报错

在运行官方示例时,model.merge_lora这一步报了没有这个属性错误。是更新了相关的版本吗?
AttributeError: 'BertDense' object has no attribute 'merge_lora'

Design choice w.r.t. the DAM MLM training

Hi Jingtao:

I wonder what is the max_seq_len hyparameter you use when doing MLM on target domain dataset?

In the huggingface example they use a block_size=128, just curious why you did not use their method directly.

Also I am interested in the effect of leaving out [CLS] and [SEP] in preparing the MLM dataloader, does it make a huge difference if you do not remove these two tokens?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.