Giter VIP home page Giter VIP logo

bagformer's Introduction

BagFormer: Better Cross-Modal Retrieval via bag-wise interaction

This is the PyTorch code of the BagFormer paper. The code has been tested on Python 3.8 and PyTorch 1.13. To install the dependencies, please create a virtual environment and run

pip install -r requirements.txt

Pre-trained checkpoints:

num of image-text pairs BagFormer
108M Download

Finetuned checkpoints:

Task BagFormer
Image-Text Retrieval (MUGE) Download

Image-Text Retrieval:

  1. Download MUGE Multimodal Retrieval dataset from the original website, and unzip file to data directory, or modify the path in configs/config_muge.yaml.
  2. To evaluate the finetuned BagFormer model on MUGE, run:
python3 train_muge.py \
--checkpoint path-to-finetuned-checkpoint \
--interaction bagwise \
--output_dir path-to-output \
--evaluate
  1. To finetune the pre-trained checkpoint. Then run:
python3 train_muge.py \
--checkpoint path-to-pretrain-checkpoint \
--interaction bagwise \
--output_dir path-to-output 
  1. To compare bagwise interaction with cls_token or tokenwise interaction, run baseline:
 # cls_token baseline, which is the BagFormer w/o late interaction model in the paper
python3 train_muge.py \
--checkpoint path-to-pretrain-checkpoint \
--interaction cls_token \
--output_dir path-to-output

# tokenwise baseline, which is the BagFormer w/o bagging layer model in the paper
python3 train_muge.py \
--checkpoint path-to-pretrain-checkpoint \
--interaction tokenwise \
--output_dir path-to-output

Calculate bag-wise similarity

import torch
import torch.nn.functional as F
from PIL import Image
from ruamel import yaml
from transformers import BertTokenizer

from models.loss import tokenwise_similarity_martix
from models.model_helper import EmbeddingBagHelperAutomaton
from models.model_retrieval_bagwise import BagFormer
from MUGE_helper.dataset import get_test_transform

device = "cuda" if torch.cuda.is_available() else "cpu"
text_encoder = "bert-base-chinese"
max_seq_len = 25
config = yaml.load(open("configs/config_muge.yaml", "r"), Loader=yaml.Loader)
test_transform = get_test_transform(config)

tokenizer = BertTokenizer.from_pretrained(text_encoder)

model = BagFormer(
    config=config, 
    text_encoder=toext_encoder,
    tokenizer=tokenizer
    )

checkpoint = torch.load(
    "path-to-checkpoint", map_location="cpu"
)
model.load_state_dict(checkpoint["model"], strict=False)
model = model.to(device)

embedding_bag_helper = EmbeddingBagHelperAutomaton(
    tokenizer, config["entity_dict_path"], masked_token=["[CLS]", "[PAD]"]
)

product_image = test_transform(Image.open("rumble_roller.jpeg"))
image = product_image.unsqueeze(0).to(device)

product_title = ["rumble roller", "nike zoomx vista"]
text = tokenizer(product_title, padding="max_length", max_length=max_seq_len)

embed_bag_offset, attn_mask = embedding_bag_helper.process(text, return_mask=True)
embed_bag_offset = torch.LongTensor(embed_bag_offset).to(device)
embed_bag_attn_mask = torch.LongTensor(attn_mask).to(device)
text = text.convert_to_tensors("pt").to(device)

with torch.no_grad():
    # encode image and text
    image_features = model.visual_encoder(image)
    text_features = model.text_encoder(
        text.input_ids, attention_mask=text.attention_mask, mode="text"
    ).last_hidden_state
    # get text bag feature
    batch_size, seq_len, text_width = text_features.shape
    embedding_input = torch.arange(batch_size * seq_len, device=device)
    embedbag_feats = F.embedding_bag(
        embedding_input,
        text_features.view(-1, text_width),
        embed_bag_offset,
        mode="sum",
    ).view(batch_size, -1, text_width)
    embedbag_feats = F.normalize(embedbag_feats, dim=-1)
    # pad to same length
    embedbag_seq_len = embedbag_feats.shape[1]
    embedbag_feats = F.pad(
        embedbag_feats,
        pad=(0, 0, 0, max_seq_len - embedbag_seq_len, 0, 0),
        mode="constant",
        value=0,
    )
    # calc bagwise similarity matrix
    sim_i2t, sim_t2i = tokenwise_similarity_martix(embedbag_feats, image_features)

print("image feature shape:", image_features.shape)  
# prints: torch.Size([1, 257, 768])
print("text feature shape:", embedbag_feats.shape)  
# prints: torch.Size([2, 25, 768])
print("img2text sim:", sim_i2t)  # prints: [[132.4761, 50.0424]
print("text2img sim:", sim_t2i)  # prints: [[33.4206], [19.6727]]

Citation

If you find our work useful, please consider citing BagFormer:

@article{hou2022bagformer,
  title={BagFormer: Better Cross-Modal Retrieval via bag-wise interaction},
  author={Hou, Haowen and Yan, Xiaopeng and Zhang, Yigeng and Lian, Fengzong and Kang, Zhanhui},
  journal={arXiv preprint arXiv:2212.14322},
  year={2022}
}

bagformer's People

Contributors

howard-hou avatar huohuade-blog avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.