Giter VIP home page Giter VIP logo

kgi-slot-filling's Introduction

KGI (Knowledge Graph Induction) for slot filling

This is the code for our KILT leaderboard submission to the T-REx and zsRE tasks. It includes code for training a DPR model then continuing training with RAG.

KGI model is described in: Robust Retrieval Augmented Generation for Zero-shot Slot Filling (EMNLP 2021).

Available from Hugging Face as:

Dataset Type Model Name Tokenizer Name
T-REx DPR (ctx) michaelrglass/dpr-ctx_encoder-multiset-base-kgi0-trex facebook/dpr-ctx_encoder-multiset-base
T-REx RAG michaelrglass/rag-token-nq-kgi0-trex rag-token-nq
zsRE DPR (ctx) michaelrglass/dpr-ctx_encoder-multiset-base-kgi0-zsre facebook/dpr-ctx_encoder-multiset-base
zsRE RAG michaelrglass/rag-token-nq-kgi0-zsre rag-token-nq

Process to reproduce

Download the KILT data and knowledge source

Segment the KILT Knowledge Source into passages:

python slot_filling/kilt_passage_corpus.py \
--kilt_corpus kilt_knowledgesource.json --output_dir kilt_passages --passage_ids passage_ids.txt

Generate the first phase of the DPR training data

python dpr/dpr_kilt_slot_filling_dataset.py \
--kilt_data structured_zeroshot-train-kilt.jsonl \
--passage_ids passage_ids.txt \
--output_file zsRE_train_positive_pids.jsonl

python dpr/dpr_kilt_slot_filling_dataset.py \
--kilt_data trex-train-kilt.jsonl \
--passage_ids passage_ids.txt \
--output_file trex_train_positive_pids.jsonl

Download and build Anserini. You will need to have Maven and a Java JDK.

git clone https://github.com/castorini/anserini.git
cd anserini
# to use the 0.4.1 version dprBM25.jar is built for
git checkout 3a60106fdc83473d147218d78ae7dca7c3b6d47c
export JAVA_HOME=your JDK directory
mvn clean package appassembler:assemble

put the title/text into the training instance with hard negatives from BM25

python dpr/anserini_prep.py \
--input kilt_passages \
--output anserini_passages

sh Anserini/target/appassembler/bin/IndexCollection -collection JsonCollection \
-generator LuceneDocumentGenerator -threads 40 -input anserini_passages \
-index anserini_passage_index -storePositions -storeDocvectors -storeRawDocs

export CLASSPATH=jar/dprBM25.jar:Anserini/target/anserini-0.4.1-SNAPSHOT-fatjar.jar
java com.ibm.research.ai.pretraining.retrieval.DPRTrainingData \
-passageIndex anserini_passage_index \
-positivePidData ${dataset}_train_positive_pids.jsonl \
-trainingData ${dataset}_dpr_training_data.jsonl

Train DPR

# multi-gpu is not well supported
export CUDA_VISIBLE_DEVICES=0

python dpr/biencoder_trainer.py \
--train_dir zsRE_dpr_training_data.jsonl \
--output_dir models/DPR/zsRE \
--num_train_epochs 2 \
--num_instances 131610 \
--encoder_gpu_train_limit 32 \
--full_train_batch_size 128 \
--max_grad_norm 1.0 --learning_rate 5e-5

python dpr/biencoder_trainer.py \
--train_dir trex_dpr_training_data.jsonl \
--output_dir models/DPR/trex \
--num_train_epochs 2 \
--num_instances 2207953 \
--encoder_gpu_train_limit 32 \
--full_train_batch_size 128 \
--max_grad_norm 1.0 --learning_rate 5e-5

Put the trained DPR query encoder into the NQ RAG model (dataset = trex, zsRE)

python dpr/prepare_rag_model.py \
--save_dir models/RAG/${dataset}_dpr_rag_init  \
--qry_encoder_path models/DPR/${dataset}/qry_encoder

Encode the passages (dataset = trex, zsRE)

python dpr/index_simple_corpus.py \
--embed 1of2 \
--dpr_ctx_encoder_path models/DPR/${dataset}/ctx_encoder \
--corpus kilt_passages  \
--output_dir kilt_passages_${dataset}

python dpr/index_simple_corpus.py \
--embed 2of2 \
--dpr_ctx_encoder_path models/DPR/${dataset}/ctx_encoder \
--corpus kilt_passages \
--output_dir kilt_passages_${dataset}

Index the passage vectors (dataset = trex, zsRE)

python dpr/faiss_index.py \
--corpus_dir kilt_passages_${dataset} \
--scalar_quantizer 8 \
--output_file kilt_passages_${dataset}/index.faiss

Train RAG

python dataloader/file_splitter.py \
--input trex-train-kilt.jsonl \
--outdirs trex_training \
--file_counts 64

python slot_filling/rag_client_server_train.py \
  --kilt_data trex_training \
  --output models/RAG/trex_dpr_rag \
  --corpus_endpoint kilt_passages_trex \
  --model_name facebook/rag-token-nq \
  --model_path models/RAG/trex_dpr_rag_init \
  --num_instances 500000 --warmup_instances 10000  --num_train_epochs 1 \
  --learning_rate 3e-5 --full_train_batch_size 128 --gradient_accumulation_steps 64


python slot_filling/rag_client_server_train.py \
  --kilt_data structured_zeroshot-train-kilt.jsonl \
  --output models/RAG/zsRE_dpr_rag \
  --corpus_endpoint kilt_passages_zsRE \
  --model_name facebook/rag-token-nq \
  --model_path models/RAG/zsRE_dpr_rag_init \
  --num_instances 147909  --warmup_instances 10000 --num_train_epochs 1 \
  --learning_rate 3e-5 --full_train_batch_size 128 --gradient_accumulation_steps 64

Apply RAG (dev_file = trex-dev-kilt.jsonl, structured_zeroshot-dev-kilt.jsonl)

python slot_filling/rag_client_server_apply.py \
  --kilt_data ${dev_file} \
  --corpus_endpoint kilt_passages_${dataset} \
  --output predictions/${dataset}_dev.jsonl \
  --model_name facebook/rag-token-nq \
  --model_path models/RAG/${dataset}_dpr_rag

python eval/convert_for_kilt_eval.py \
--apply_file predictions/${dataset}_dev.jsonl \
--eval_file predictions/${dataset}_dev_kilt_format.jsonl

Run official evaluation script

# install KILT evaluation scripts
git clone https://github.com/facebookresearch/KILT.git
cd KILT
conda create -n kilt37 -y python=3.7 && conda activate kilt37
pip install -r requirements.txt
export PYTHONPATH=`pwd`

# run evaluation
python kilt/eval_downstream.py predictions/${dataset}_dev_kilt_format.jsonl ${dev_file}

Publications

Re2G (NAACL 2022)

@inproceedings{glass-etal-2022-re2g,
   title = "{R}e2{G}: Retrieve, Rerank, Generate",
   author = "Glass, Michael  and
     Rossiello, Gaetano  and
     Chowdhury, Md Faisal Mahbub  and
     Naik, Ankita  and
     Cai, Pengshan  and
     Gliozzo, Alfio",
   booktitle = "Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies",
   month = jul,
   year = "2022",
   address = "Seattle, United States",
   publisher = "Association for Computational Linguistics",
   url = "https://aclanthology.org/2022.naacl-main.194",
   pages = "2701--2715",
}

KGI (EMNLP 2021)

@inproceedings{glass-etal-2021-robust,
   title = "Robust Retrieval Augmented Generation for Zero-shot Slot Filling",
   author = "Glass, Michael  and
     Rossiello, Gaetano  and
     Chowdhury, Md Faisal Mahbub  and
     Gliozzo, Alfio",
   booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
   month = nov,
   year = "2021",
   address = "Online and Punta Cana, Dominican Republic",
   publisher = "Association for Computational Linguistics",
   url = "https://aclanthology.org/2021.emnlp-main.148",
   doi = "10.18653/v1/2021.emnlp-main.148",
   pages = "1939--1949",
}

kgi-slot-filling's People

Contributors

gaetangate avatar michaelrglass avatar stevemar avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

kgi-slot-filling's Issues

Reproduce the results on KILT wizard of wikipedia

Hello, I got similar results in development set after run kgi_train.py .

Rprec recall@5 rougel f1 KILT-rougel KILT-f1
0.502947 0.690242 0.160686 0.182609 0.095659 0.108410

But after I run reranker_train.py and rerank_apply.py, I get a results which seems worse than the results in the table 2 of Re2G paper.

Rprec recall@5
my experiments 47.38 72.04
re2g paper 55.50 74.98

I think re2g is a solid work. Could you please give me some advices to reproduce the result?

I run the following command to train the reranker

python reranker/reranker_train.py \
  --model_type bert --model_name_or_path nboost/pt-bert-base-uncased-msmarco --do_lower_case \
  --positive_pids ${dataset}/train_positive_pids.jsonl \
  --initial_retrieval  predictions/dpr_bm25/wow_train.jsonl  \
  --num_train_epochs 2 \
  --output_dir models/reranker_stage1

run jar wrong

java com.ibm.research.ai.pretraining.retrieval.DPRTrainingData \
> -passageIndex anserini_passage_index \
> -positivePidData trex_train_positive_pids.jsonl \
> -trainingData trex_dpr_training_data.jsonl

Exception in thread "main" java.lang.NoClassDefFoundError: org/apache/lucene/search/Query
at com.ibm.research.ai.pretraining.retrieval.DPRTrainingData.main(DPRTrainingData.java:99)
Caused by: java.lang.ClassNotFoundException: org.apache.lucene.search.Query
at java.base/jdk.internal.loader.BuiltinClassLoader.loadClass(BuiltinClassLoader.java:582)
at java.base/jdk.internal.loader.ClassLoaders$AppClassLoader.loadClass(ClassLoaders.java:178)
at java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:521)
... 1 more

hi, I have prepared the data index and trex data, when I run the jar to search data, I got wrong.
Are there any problems?

But I run search in pyserini(python version of anserini), I can get results.
Can you share what the jar does so I can reproduce in python?
Thanks a lot!

from pyserini.search import SimpleSearcher

searcher = SimpleSearcher('anserini_passages_index')
hits = searcher.search('document')

for i in range(len(hits)):
    print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f}')

sub-batch in re2g model.

Hello, I really impressed by the Re2G: Retrive, Rerank, Generate paper. I'm grateful for these good paper.

I have some question about this paper and code.

In this Re2G paper appendix, table4(Re2G hyperparameters) , batch size for DPR and Generation is fixed at 128. However, batch size for reranker is 32.

Why is the batch size different for the reranker?

Also, in re2g code, there is sub-batch forwarding for generation model. According to hypers.retrieve_batch_factor==8 in re2g_hypers.py, there is assertion code like below

assert ret_batch_size // get_batch_size == hypers.retrieve_batch_factor

which means it is not the same with each batch size for DPR and generation.

Therefore, is there a specific reason for setting sub-batch for the generation ?

Thank for reading my question.

Could not generate training data

Hi @gaetangate
I'm at the final step to generate training files .jsonl
When I run this command:
export CLASSPATH=jar/dprBM25.jar:./anserini/target/anserini-0.4.1-SNAPSHOT-fatjar.jar
java com.ibm.research.ai.pretraining.retrieval.DPRTrainingData -passageIndex anserini_passage_index -positivePidData zsRE_train_positive_pids.jsonl -trainingData zsRE_dpr_training_data.jsonl

I face this error:
Error: Could not find or load main class com.ibm.research.ai.pretraining.retrieval.DPRTrainingData Caused by: java.lang.ClassNotFoundException: com.ibm.research.ai.pretraining.retrieval.DPRTrainingData
Do you have any idea about `com.ibm.research.ai.pretraining.retrieval.DPRTrainingData?

Reproduce the results on Trivia QA dataset

When I reproduce the Re2G in the Trivia QA dataset. I couldn't reproduce the results of the generation model in the second stage. In the second stage, the generation model only uses the retrieved passages from the trained DPR, very similar to the KGI paper.
I use the provided command for training:

python generation/kgi_train.py \
  --kilt_data ${dataset}_training \
  --output models/RAG/${dataset}_dpr_rag \
  --corpus_endpoint kilt_passages_${dataset} \
  --model_name facebook/rag-token-nq \
  --model_path models/RAG/${dataset}_dpr_rag_init \
  --warmup_fraction 0.05  --num_train_epochs 2 \
  --learning_rate 3e-5 --full_train_batch_size 128 --gradient_accumulation_steps 64

I got the following performance

R-Prec Recall@5 Accuracy F1 KILT-AC KILT-F1
57.82 62.13 37.51 57.35 26.14 38.52

Copy the KGI_0 results from the paper

R-Prec Recall@5 Accuracy F1 KILT-AC KILT-F1
60.49 63.54 60.99 66.55 42.85 46.08

The retrieval metrics (R-Prec, Recall@5) seem close to the KGI model, but the generation metrics (Accuracy, F1, Kilt-AC, Kilt-F1) are far worse than the KGI model.

包报错

nohup: ignoring input
Traceback (most recent call last):
File "dpr/biencoder_trainer.py", line 1, in
from torch_util.transformer_optimize import TransformerOptimize
ModuleNotFoundError: No module named 'torch_util'

when

ownloaded from central: https://repo.maven.apache.org/maven2/org/codehaus/plexus/plexus-utils/3.0.24/plexus-utils-3.0.24.jar (247 kB at 146 kB/s)
Downloaded from central: https://repo.maven.apache.org/maven2/org/codehaus/plexus/plexus-interactivity-api/1.0-alpha-6/plexus-interactivity-api-1.0-alpha-6.jar (12 kB at 6.9 kB/s)
[INFO] ------------------------------------------------------------------------
[INFO] BUILD FAILURE
[INFO] ------------------------------------------------------------------------
[INFO] Total time: 08:41 min
[INFO] Finished at: 2023-12-21T19:21:11+08:00
[INFO] Final Memory: 60M/388M
[INFO] ------------------------------------------------------------------------
[ERROR] Failed to execute goal org.apache.maven.plugins:maven-javadoc-plugin:3.0.1:jar (attach-javadocs) on project anserini: MavenReportException: Error while generating Javadoc: Unable to find javadoc command: The environment variable JAVA_HOME is not correctly set. -> [Help 1]
[ERROR]
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
[ERROR] Re-run Maven using the -X switch to enable full debug logging.
[ERROR]
[ERROR] For more information about the errors and possible solutions, please read the following articles:
[ERROR] [Help 1] http://cwiki.apache.org/confluence/display/MAVEN/MojoExecutionException

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.