Giter VIP home page Giter VIP logo

mdfe2's Introduction

MDFE2 - Multi-Domain Feature Extraction v2

This repository is based on the MDFE repository and additionally includes an unofficial PyTorch implementation of the UnED dataset and benchmark.

I. Setup

Here, we describe a step-by-step guide to setup and install dependencies on a UNIX-based system, such as Ubuntu, using conda as package manager. If conda is not available, alternative package managers such as venv can be used.

1. Create a virtual environment

conda create -n env_mdfe2 python=3.8
conda activate env_mdfe2

2. Clone the repository

git clone [email protected]:morrisfl/mdfe2.git

3. Install pytorch

Depending on your system and compute requirements, you may need to change the command below. See pytorch.org for more details.

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

4. Install the repository with all dependencies

cd mdfe2
python -m pip install .

If you want to make changes to the code, you can install the repository in editable mode:

python -m pip install -e .

II. Data Preparation

This repository includes the following datasets for training, evaluation, and testing:

Dataset Training Evaluation Testing
M4D-35k X - -
UnED X X X

1. M4D-35k

The M4D-35k dataset is a custom curated multi-domain training dataset. It was created for resource-efficient training of universal image embeddings. The curation process involved dataset selection and data sampling (optimize data size) by maximizing the performance on the Google Universal Image Embedding Challenge evaluation dataset. M4D-35k consists of 35k classes and 328k images sourced from four different datasets:

Domain Dataset # classes # images
Packaged goods Products-10k 9.5k 141.5k
Landmarks Google Landmarks v2 (subset) 10.0k 79.2k
Apparel & Accessories DeepFashion (Consumer to Shop) 14.3k 100.4k
Cars Stanford Cars (refined) 1.0k 7.3k
Multi-Domain M4D-35k 34.8k 328.4k

Notable, the Stanford Cars dataset was refined by enhancing the class granularity. Instead of classifying cars only by their model, the class labels were extended to the car color. More information about the refinement process can be found here.

The corresponding annotations of the M4D-35k dataset can be found in data/m4d-35k_train.csv. Make sure to download the corresponding datasets included in the M4D-35k dataset and place them in a <data_dir> of your choice. The directory structure should look like this:

<data_dir>/
├── m4d-35k_train.csv
├── products-10k/
│   └── train
├── google_landmark_recognition_2021/
│   └── train
├── deepfashion/
│   └── train
└──  stanford_cars/
    └── cars_train

2. UnED

Please refer to the official repository of the UnED dataset for information on how to download the dataset.

III. Image embedding model

The image embedding model consists of a visual-semantic foundation model as backbone with an attached projection layer, as shown in the figure below. The embeddings are learned using a margin-based softmax loss function.

Image embedding model

1. Foundation model

Different foundation models can be used, as shown in the table below.

Foundation Model Encoder architecture type model_name weights
OpenCLIP ViT clip see OpenCLIP see OpenCLIP
OpenCLIP ConvNeXt clip_convnext see OpenCLIP see OpenCLIP
CLIPA ViT clipav2 see OpenCLIP see OpenCLIP
EVA-CLIP ViT eva02 see timm -
MetaCLIP ViT meta-clip see OpenCLIP see OpenCLIP
SigLIP ViT siglip see timm -
DINOv2 ViT dinov2 see timm -
SAM ViT sam see timm -

In order to adjust the model architecture of the image embedding model, the following main parameters can be changed in the configuration file:

  • MODEL.embedding_dim: the dimension of the image embedding.
  • MODEL.BACKBONE.type: the type of the visual-semantic foundation model, supported types are those listed in the table above.
  • MODEL.BACKBONE.model_name: the name of the visual-semantic foundation model, specified by OpenCLIP or timm.
  • MODEL.BACKBONE.weights: the weights of the visual-semantic foundation model, only required for OpenCLIP models (corresponds to the pretrained parameter in OpenCLIP).
  • MODEL.NECK.type: the type to reduce the embedding dimension to the specified MODEL.embedding_dim, supported types are proj_layer and pooling.

2. Margin-based softmax loss

The margin-based softmax loss function is used to enhance the discriminative power of the learned embeddings. The supported loss functions are shown in the above image. The following main parameters can be changed in the configuration file:

  • MODEL.HEAD.name: the name of the margin-based softmax loss, supported names are ArcFace, DynM-ArcFace, AdaCos, LiArcFace, CurricularFace, and AdaFace.
  • MODEL.HEAD.k: the number of centers for the margin-based softmax loss.
  • MODEL.HEAD.s: the scaling factor for the margin-based softmax loss.
  • MODEL.HEAD.m: the margin for the margin-based softmax loss.

For the Sub-Center ArcFace loss, use the ArcFace and set the MODEL.HEAD.k parameter to the desired number of sub-centers.

The CLAM loss refers to class distribution aware additive angular margin loss and is a custom developed loss function. It incorporates a multiple class sub-centers and a dynamic margin value which adjusts based on the class distribution of the training data. CLAM can be used by using the DynM-ArcFace and set the MODEL.HEAD.k parameter to the desired number of sub-centers (k > 1). In addition, the class distribution is required for the dynamic margin, which can be specified within DATASET.cls_dist_file of the configuration file. For the UnED training set, this distribution is provided in data/.

Further explanations of changeable parameters can be found in the default_cfg.py.

IV. Training

1. Training settings

The training settings can be changed in the configuration file found in configs/. The most important parameters are:

  • TRAIN.epoch_based: if True, the training is based on the number of epochs, otherwise on the number of iterations.
  • TRAIN.epochs: the number of epochs to train the model.
  • TRAIN.save_epoch: the frequency of saving the model checkpoint.
  • OPTIMIZER.name: the optimizer used for training, supported optimizers are Adam, AdamW and SGD.
  • OPTIMIZER.lr: the learning rate of the optimizer.
  • OPTIMIZER.weight_decay: the weight decay of the optimizer.
  • SCHEDULER.name: the learning rate scheduler used for training, supported schedulers are cosine.
  • SCHEDULER.epoch_based: if True, the scheduler is based on the number of epochs, otherwise on the number of iterations.
  • SCEDULER.min_lr: the minimum learning rate of the scheduler.
  • SCHEDULER.warmup: Whether to use one (linear) epoch of warmup.

Further explanations of changeable parameters can be found in the default_cfg.py.

2. Training run

To start the training, run the following command:

python src/train.py configs/<config_file> <data_dir> \
    --output-dir results/ \
    --data_parallelism \
    --device cuda:0

The <config_file> corresponds to the configuration file in configs/ and <data_dir> to the directory where the datasets are stored. The --output-dir parameter specifies the directory where the training results are stored. The --data_parallelism parameter enables the use of multiple GPUs for training (available GPU IDs must be specified in the configuration file under TRAIN.gpu_ids). The --device parameter specifies the device to use for training.

V. Evaluation

During training, embedding models are evaluated on the UnED evaluation set each time a model checkpoint is saved. The evaluation results are stored in a CSV file in the output directory.

For evaluating the image embedding model on the UnED test set, run the following command:

python src/evaluate.py configs/<config_file> <data_dir> \
    --model_ckpt <path_to_model.pth> \
    --batch_size 512 \
    --metric euclidean \
    --n_trees 300 \
    --data_parallelism \
    --device cuda:0

The <config_file> corresponds to the configuration file used for training and <data_dir> to the directory where the datasets are stored. The --model_ckpt parameter specifies the path to the model checkpoint to evaluate. The --metric parameter specifies the distance metric used for the nearest neighbor search of the annoy index. The --n_trees parameter specifies the number of trees used for the annoy index. The --data_parallelism parameter enables the use of multiple GPUs for evaluation (available GPU IDs must be specified in the configuration file under TRAIN.gpu_ids). The --device parameter specifies the device to use for evaluation.

Annoy Index

The evaluation script uses the Annoy library to create an index from the database embeddings of the evaluation set. This index is used to find the nearest neighbors of the query embeddings and calculate the recall@1 and mMP@5 metrics. Depending on the number of n_trees used for the annoy index, the evaluation time and accuracy can vary. A higher number of trees leads to a more accurate but slower search.

We evaluated the CLIP ViT-B/16 model on the UnED test set in a zero-shot setting using different numbers of trees for the annoy index. We achieved the most similar results to the official UnED benchmark with 300 trees. The results are shown in the table below. Annoy Index Evaluation

mdfe2's People

Contributors

morrisfl avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.