Giter VIP home page Giter VIP logo

dexml's Introduction

DEXML

Codebase for learning dual-encoder models for (extreme) multi-label classification tasks.

Dual-encoders for Extreme Multi-label Classification
Nilesh Gupta, Devvrit Khatri, Ankit S. Rawat, Srinadh Bhojanapalli, Prateek Jain, Inderjit S. Dhillon
ICLR 2024

Highlights

  • Multi-label retrieval losses DecoupledSoftmax and SoftTopk (replacement for InfoNCE (Softmax) loss in multi-label and top-k retrieval settings)
  • Distributed dual-encoder training using gradient caching (allows for a large pool of labels in loss computation without getting OOM)
  • State-of-the-art dual-encoder models for extreme multi-label classification benchmarks

Notebook Demo

See dexml.ipynb notebook or try it in this colab

Download pretrained models

Dataset P@1 P@5 HF Model Page
LF-AmazonTitles-1.3M 58.40 45.46 https://huggingface.co/quicktensor/dexml_lf-amazontitles-1.3m
LF-Wikipedia-500K 85.78 50.53 https://huggingface.co/quicktensor/dexml_lf-amazontitles-131k
LF-AmazonTitles-131K 42.52 20.64 https://huggingface.co/quicktensor/dexml_lf-amazontitles-131k
EURLex-4K 86.78 60.19 https://huggingface.co/quicktensor/dexml_eurlex-4k

Training DEXML

Preparing Data

The codebase assumes following data structure:

Datasets/
└── EURLex-4K # Dataset name
    ├── raw
    │   ├── trn_X.txt # train input file, ith line is the text input for ith train data point
    │   ├── tst_X.txt # test input file, ith line is the text input for ith test data point
    │   └── Y.txt # label input file, ith line is the text input for ith label in the dataset
    ├── Y.trn.npz # train relevance matrix (stored in scipy sparse npz format), num_train x num_labels
    └── Y.tst.npz # test relevance matrix (stored in scipy sparse npz format), num_test x num_labels

Before running the training/testing the default code expects you to convert the input features to BERT's (or any text transformer) tokenized input indices. You can achieve that by running:

dataset="EURLex-4K"
python utils/tokenization_utils.py --data-path Datasets/${dataset}/raw/Y.txt --tf-max-len 128 --tf-token-type bert-base-uncased
python utils/tokenization_utils.py --data-path Datasets/${dataset}/raw/trn_X.txt --tf-max-len 128 --tf-token-type bert-base-uncased
python utils/tokenization_utils.py --data-path Datasets/${dataset}/raw/tst_X.txt --tf-max-len 128 --tf-token-type bert-base-uncased

For some extreme classification benchmark datasets such as LF-AmazonTitles-131K and LF-AmazonTitles-1.3M, you additionally need test time label filter files (Datasets/${dataset}/filter_labels_test.txt)) to get the right results. Please see note on these filter files here to know more.

Training commands

Training code assumes all hyperparameter and runtime arguments are specified in a config yaml file. Please see configs/dual_encoder.yaml for a brief description of all parameters (you can keep most of the parameters same across experiments). See configs/EURLex-4K/dist-de-all_decoupled-softmax.yaml to see some of the important hyperparameters that you may want to change for different experiments.

# Single GPU
dataset="EURLex-4K"
python train.py configs/${dataset}/dist-de-all_decoupled-softmax.yaml

# Multi GPU
num_gpus=4
accelerate launch --config_file configs/accelerate.yaml --num_processes ${num_gpus} train.py configs/${dataset}/dist-de-all_decoupled-softmax.yaml

Cite

@InProceedings{DEXML,
  author    = "Gupta, N. and Khatri, D. and Rawat, A-S. and Bhojanapalli, S. and Jain, P. and Dhillon, I.",
  title     = "Dual-encoders for Extreme Multi-label Classification",
  booktitle = "International Conference on Learning Representations",
  month     = "May",
  year      = "2024"
}

dexml's People

Contributors

nilesh2797 avatar thekop69 avatar anirudhb11 avatar

Stargazers

Chong Zhang avatar  avatar Dmitriy Sharipov avatar  avatar

Watchers

 avatar Kostas Georgiou avatar

dexml's Issues

Reproduce Table 1 of ICLR24 Paper

Hi @nilesh2797 ,

Thanks for the great work of DEXML, the results are very promising.

Can you please provide the more details about how to reproduce DEXML for datasets shown in Table 1 of your ICLR paper?
For example, what's the input data format, the exact hyper-parameters, and command line usage, etc.

Very much appreciate!

Support for MPS

Hi @nilesh2797 - As you might have noticed from my recent PR, I've been playing with the notebook in this repo to better understand DE in XMC. I've been running this on my Macbook M1 Pro and have had to make some changes to support MPS. If you are interested, I'd be happy to contribute these changes back into your repo via another PR. The changes itself are tiny but a key change is that MPS does not support float64, of which the label embeddings are. This doesn't affect CUDA however, though I guess, some floating point precision difference will be noticed when training on Metal and compared against an NVidia run.

Clarifications on the use of the trained model

Hi @nilesh2797 ,

Thanks for the great work on the DEXML model!

I have a few questions about using the code. I have trained the model on my dataset. From the logs during training, the model looks promising.

Q1: A folder with my model name appears in Results. There is a file val_metrics.tsv. Is that where the metrics from training the model on the train dataset are stored? Or was the test dataset used for validation?

Q2: The second question comes from the first one - how can I correctly use the model.pt file from Results? I tried to get results using the "Run DEXML" block example from dexml.ipynb, but nothing worked.

Thanks in advance!

Questions about training hyper-parameters of `DEXML`

@nilesh2797 ,

Thanks for your prompted update on the README.md and provided more detailed instructions about how to run the DEXML codebase. Very much appreciated! I am able to start the training job on LF-AmazonTitles-131K now.

Some follow-up questions regarding the hyper-parameters of DEXML.

Q1: The definition of batch_size in Table 7: does batch_size=1024 refers to batch_size_per_gpu, or the global_batch_size (i.e., batch_size_per_gpu times the number of GPUs)?

  • Take LF-AmazonTitles-131K as an example. From Table 9, you use 4 A100 GPUs, and from Table 7, the batch_size=1024. Does that mean the global_batch_size=1024x4=4096 ?

Q2: When the training DEXML on LF-AmazonTitles-131K, I saw there are some evaluation metrics logging info (see below). Are those metrics BEFORE or AFTER the "Reciprocal-pair Removal"?

INFO - root - 04-Apr-24 00:25:58 : Mean loss after epoch 0/100: 7.8386E+00
INFO - root - 04-Apr-24 00:26:19 :
P@1     P@3     P@5     nDCG@1  nDCG@3  nDCG@5  MRR@10  PSP@1   PSP@3   PSP@5   R@10    R@50    R@100   loss
15.93   14.35   11.04   15.93   20.73   22.83   25.69   15.15   23.04   27.33   33.77   44.56   48.68   7.8386E+00

INFO - root - 04-Apr-24 01:02:34 : Mean loss after epoch 1/100: 6.1878E+00
INFO - root - 04-Apr-24 01:38:50 : Mean loss after epoch 2/100: 5.5933E+00
INFO - root - 04-Apr-24 02:15:06 : Mean loss after epoch 3/100: 5.1147E+00
INFO - root - 04-Apr-24 02:51:22 : Mean loss after epoch 4/100: 4.7053E+00
INFO - root - 04-Apr-24 03:28:01 : Mean loss after epoch 5/100: 4.3784E+00
INFO - root - 04-Apr-24 03:28:24 :
P@1     P@3     P@5     nDCG@1  nDCG@3  nDCG@5  MRR@10  PSP@1   PSP@3   PSP@5   R@10    R@50    R@100   loss
21.73   19.36   14.92   21.73   28.18   31.06   33.44   18.57   28.83   34.56   45.27   57.12   61.36   4.3784E+00
...
INFO - root - 04-Apr-24 13:10:47 : Mean loss after epoch 21/100: 2.0527E+00
INFO - root - 04-Apr-24 13:47:04 : Mean loss after epoch 22/100: 1.9880E+00
INFO - root - 04-Apr-24 14:23:22 : Mean loss after epoch 23/100: 1.9112E+00
INFO - root - 04-Apr-24 14:59:41 : Mean loss after epoch 24/100: 1.8519E+00
INFO - root - 04-Apr-24 15:36:23 : Mean loss after epoch 25/100: 1.7886E+00
INFO - root - 04-Apr-24 15:36:46 :
P@1     P@3     P@5     nDCG@1  nDCG@3  nDCG@5  MRR@10  PSP@1   PSP@3   PSP@5   R@10    R@50    R@100   loss
24.84   22.81   17.57   24.84   32.75   36.03   37.73   20.22   33.16   39.92   51.48   62.55   66.34   1.7886E+00

Q3: For the training time reported in Table 9, are those measured for the full 100 epochs, or you choose an early iteration number with early stop (e.g., you pick the best model at epoch=50, then report running time for just 50 epochs, instead of the full 100 epochs)?

Q4: Does Datasets/<dataset>/filter_labels_train.txt or Datasets/<dataset>/filter_labels_test.txt affect the training dynamics? In other words, with and without those files, does DEXML training end up with the same model parameters? And it is just the evaluation metric becoming different?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.