Giter VIP home page Giter VIP logo

vldeformer's Introduction

Pytorch implement of the paper "VLDeformer: Vision Language Decomposed Transformer for Fast Cross-modal Retrieval"

This is a pytorch implementation of the VLDeformer paper. Please remember to give a citation if this paper and codes benefits your research!

@article{zhang2022vldeformer,
  title={VLDeformer: Vision--Language Decomposed Transformer for fast cross-modal retrieval},
  author={Zhang, Lisai and Wu, Hongfa and Chen, Qingcai and Deng, Yimeng and Siebert, Joanna and Li, Zhonghua and Han, Yunpeng and Kong, Dejiang and Cao, Zhao},
  journal={Knowledge-Based Systems},
  volume={252},
  pages={109316},
  year={2022},
  publisher={Elsevier}
}

Environment Setup

Install the python packages following VinVL. To achieve the performance in the paper, at least 6 V100 GPU is suggested.

Dataset and Pre-processed Files

You need to download the dataset COCO and Flickr30k to reproduce the exprements, and also SBU for the large version.

Besides, we use the features extracted by VinVL, which are given in their download page. You can directly download these features from coco and flikckr30k.

If you want to run the model on your customed data, please refer to Scene Graph to extract the features, which is specified by the VinVL repo.

Pre-trained Model Checkpoints

The decomposition is applied on the pre-trained one stream VinVL model, so you need to download it first.

path/to/azcopy copy 'https://biglmdiag.blob.core.windows.net/vinvl/model_ckpts/TASK_NAME' coco_ir --recursive

Afterwards, you can run our code to perform decomposition.

You can also directly use our pre-trained checkpoints for Flickr30k and COCO.

Running

Run contrastive_learn.py using following args:

"program": "${workspaceFolder}/contrastive_learn.py",
"args": [
    "--model_name_or_path",
    "vinvl/coco_ir/base/checkpoint-1340000",    // Your path to Vinvl checkpoint
    "--data_dir",
    "/raid/data_modal/coco_vinvL/coco_ir/",     // Your path to Vinvl data
    "--img_feat_file",
    "/raid/data_modal/coco_vinvL/model_0060000/features.tsv", // Your path to Vinvl image feature tsv
    "--eval_img_keys_file",         
    "test_img_keys_1k.tsv",         // select the test file in ${data_dir}  
    "--do_train",
    "--do_lower_case",
    "--evaluate_during_training",
    "--num_captions_per_img_val", "5",
    "--per_gpu_train_batch_size", "300",
    "--per_gpu_eval_batch_size", "300",
    "--learning_rate", "7.5e-06",
    "--warmup_steps", "200",
    "--scheduler", "cos",
    "--num_train_epochs", "200",
    "--save_steps", "400",
    "--add_od_labels",
    "--od_label_type",
    "vg",
    "--max_seq_length",
    "35",
    "--max_img_seq_length",
    "70",
    "--output_dir",
    "contrastive_checkpoint/output_infoNCE_coswarmup",
    "--logit_gpu", "2,3,4,5,6,7",      // Contrative learning requires large batch size
    "--contrastive_gpu", "0",
    "--temperature_t", "0.005",
    "--temperature_i", "0.005",
]

For test, you can use the following args:

"program": "${workspaceFolder}/contrastive_learn.py",
"args": [
    "--do_test",
    "--do_eval",
    "--test_split",
    "test",
    "--num_captions_per_img_val",
    "5",
    "--eval_img_keys_file",
    "test_img_keys_1k.tsv",
    "--per_gpu_eval_batch_size",
    "400",
    "--img_feat_file",
    "/raid/data_modal/coco_vinvL/model_0060000/features.tsv",
    "--eval_model_dir",
    "contrastive_checkpoint/twodataset_flickr/checkpoint-93-7600",
    "--logit_gpu", "0",
    "--contrastive_gpu", "0",
    "--max_seq_length", "70",
]

Acknowledge

This repo is modified based on the VinVL, we thank the authors for sharing their project.

vldeformer's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.