Giter VIP home page Giter VIP logo

multimodal-meta-learn's Introduction

Meta Learning to Bridge Vision and Language Models for Multimodal Few-Shot Learning

This is the official code repository for "Meta Learning to Bridge Vision and Language Models for Multimodal Few-Shot Learning", published at ICLR 2023.

[arXiv] [OpenReview]

Intro

Multimodal few-shot learning is challenging due to the large domain gap between vision and language modalities. Existing methods are trying to communicate visual concepts as prompts to frozen language models, but rely on hand-engineered task induction to reduce the hypothesis space. To address these limitations and enable a learnable process, we propose a multimodal meta-learning approach.

overview

Approach Overview

Our approach breaks down the model training into observing a collection of multimodal few-shot tasks. We introduce a meta-mapper network, which serves as a meta-learner, effectively bridging the gap between frozen large-scale vision and language models and leveraging their pre-existing learned capacity. By updating only the learnable parameters of the meta-mapper, it learns to accumulate shared meta-knowledge across these tasks.

model

Getting Started

First clone the project, create the environment and install dependencies:

git clone https://github.com/ivonajdenkoska/multimodal-meta-learn.git
conda env create -f environment.yml
conda activate multimodal_meta_learn

Download the multimodal few-shot datasets from here and place them in your data folder which will be assigned to --data_path. Also, download the COCO image captioning dataset from here.

Usage

To perform meta-training with COCO captioning dataset, first run parse_coco.py to obtain the preprocessed COCO pickle file. To perform the training of the full model, run python main.py. You can choose the episodic method to perform the meta-training or the non_episodic one to perform standard mini-batched training from this script. To perform inference with trained models, run python main_inference.py.

Reference

If you find this code or the paper useful for your work, please cite:

@inproceedings{
    najdenkoska2023meta,
    title={Meta Learning to Bridge Vision and Language Models for Multimodal Few-Shot Learning},
    author={Ivona Najdenkoska and Xiantong Zhen and Marcel Worring},
    booktitle={The Eleventh International Conference on Learning Representations },
    year={2023},
    url={https://openreview.net/forum?id=3oWo92cQyxL}
    }
}

Acknowledgments

This repository uses HuggingFace and is based on ClipCap and MAML code repositories.

multimodal-meta-learn's People

Contributors

ivonajdenkoska avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Forkers

hu-my sev777

multimodal-meta-learn's Issues

Epoch time

Hello, I have another question about your project.

How long does it take to train one epoch with the default settings? And how long does the whole training process take?

I am curious about the time efficiency of your model and I would like to know more details.

Thank you for your reply and your great work.

arguments

Dear author:
What are the specific arguments of main_train.py

Release trained checkpoint

Thanks for the author's impressive work! And I have a few questions.

  1. Do you have any plans to release the trained checkpoint for your model?
  2. I tried to reproduce your results on the COCO dataset, but the accuracy was very low. Here is the log of my training process:
Total num of params: 1777664
shuffle DB: train, b:10000, 2-way, 1-shot, 5-query, 0-repeats, resize:224
shuffle DB: val, b:100, 2-way, 1-shot, 1-query, 0-repeats, resize:224
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 0         Losses: [27.3474, 26.4712, 25.9338, 25.349, 25.0009, 24.364]
Step: 0         Training acc: [0.005  0.0025 0.0025 0.005  0.01   0.015 ]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 100       Losses: [16.2629, 16.4309, 16.4842, 16.2964, 16.4591, 16.5697]
Step: 100       Training acc: [0.1175 0.12   0.1075 0.13   0.1325 0.115 ]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 200       Losses: [14.6574, 14.5848, 14.5446, 14.5029, 14.5325, 14.4986]
Step: 200       Training acc: [0.11   0.11   0.1075 0.115  0.1125 0.1   ]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 300       Losses: [13.8562, 13.9223, 13.823, 13.672, 13.8453, 13.8192]
Step: 300       Training acc: [0.1    0.09   0.105  0.105  0.0925 0.095 ]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 400       Losses: [13.3061, 13.2029, 13.3253, 13.2227, 13.2487, 13.1655]
Step: 400       Training acc: [0.1125 0.1075 0.1    0.09   0.11   0.1175]

Model saved on path /multimodal-few-shot/models/
------ Meta-test 2-way, 1-shot (5-query) ------
Step: 400       Test acc: [0.118  0.122  0.1195 0.121  0.116  0.1135]

------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 500       Losses: [13.6204, 13.3988, 13.5718, 13.4468, 13.5434, 13.55]
Step: 500       Training acc: [0.12   0.1125 0.1325 0.11   0.115  0.135 ]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 600       Losses: [13.0284, 13.0951, 13.1156, 13.1911, 13.0417, 13.1139]                                                                                                                 
Step: 600       Training acc: [0.1425 0.14   0.125  0.1375 0.1275 0.135 ]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 700       Losses: [13.0562, 13.078, 13.227, 13.125, 12.9114, 13.0547]
Step: 700       Training acc: [0.16   0.1625 0.1625 0.145  0.1875 0.155 ]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 800       Losses: [13.4406, 13.5702, 13.4736, 13.4705, 13.4559, 13.3458]
Step: 800       Training acc: [0.1325 0.135  0.135  0.13   0.13   0.13  ]

Model saved on path /multimodal-few-shot/models/
------ Meta-test 2-way, 1-shot (5-query) ------
Step: 800       Test acc: [0.1495 0.141  0.1445 0.143  0.1475 0.14  ]

------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 900       Losses: [12.3796, 12.4293, 12.3599, 12.3973, 12.3356, 12.4367]
Step: 900       Training acc: [0.125  0.155  0.1275 0.14   0.1525 0.14  ]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 1000      Losses: [11.1597, 11.1914, 11.4077, 11.0616, 11.1694, 11.2945]
Step: 1000      Training acc: [0.1675 0.155  0.1475 0.17   0.16   0.1575]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 1100      Losses: [12.062, 12.0198, 12.0565, 12.0145, 12.168, 12.0419]
Step: 1100      Training acc: [0.1575 0.1725 0.155  0.1325 0.1675 0.1375]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 1200      Losses: [12.0617, 11.8894, 12.03, 12.0184, 12.064, 12.1158]
Step: 1200      Training acc: [0.1225 0.125  0.1325 0.13   0.1375 0.13  ]

Model saved on path /multimodal-few-shot/models/
------ Meta-test 2-way, 1-shot (5-query) ------
Step: 1200      Test acc: [0.1395 0.1505 0.154  0.134  0.1545 0.1575]

------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 1300      Losses: [12.2494, 12.2317, 12.2031, 12.0318, 12.146, 12.1562]
Step: 1300      Training acc: [0.115  0.115  0.1375 0.1375 0.14   0.135 ]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 1400      Losses: [11.9626, 11.9661, 11.8697, 11.9094, 11.8787, 11.8339]
Step: 1400      Training acc: [0.115  0.13   0.12   0.1225 0.135  0.13  ]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 1500      Losses: [13.2522, 12.9408, 12.9902, 12.9345, 12.8652, 12.9913]
Step: 1500      Training acc: [0.13   0.1525 0.125  0.1375 0.1275 0.1225]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 1600      Losses: [12.1111, 12.2421, 12.2779, 12.2388, 12.1456, 12.2824]
Step: 1600      Training acc: [0.1325 0.14   0.16   0.1375 0.15   0.1425]

Model saved on path /multimodal-few-shot/models/
------ Meta-test 2-way, 1-shot (5-query) ------
Step: 1600      Test acc: [0.148  0.1615 0.1495 0.1495 0.141  0.1436]

------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 1700      Losses: [11.8007, 11.9436, 11.9044, 11.8383, 11.9157, 11.9359]
Step: 1700      Training acc: [0.155  0.1625 0.17   0.165  0.1525 0.1525]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 1800      Losses: [11.7353, 11.7092, 11.663, 11.8356, 11.6716, 11.6708]
Step: 1800      Training acc: [0.125  0.12   0.1275 0.12   0.115  0.1075]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 1900      Losses: [12.8347, 12.818, 12.811, 12.9872, 12.7438, 12.7335]
Step: 1900      Training acc: [0.1525 0.145  0.1725 0.1525 0.15   0.1525]

Model saved on path /multimodal-few-shot/models/
Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 2000      Losses: [11.3431, 11.4844, 11.6313, 11.4486, 11.308, 11.3376]
Step: 2000      Training acc: [0.1925 0.1525 0.1725 0.155  0.1725 0.165 ]

Model saved on path /multimodal-few-shot/models/
------ Meta-test 2-way, 1-shot (5-query) ------
Step: 2000      Test acc: [0.156  0.143  0.1455 0.153  0.1515 0.1436]

------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 2100      Losses: [12.6283, 12.5109, 12.5109, 12.4421, 12.537, 12.7002]
Step: 2100      Training acc: [0.1475 0.13   0.1425 0.16   0.1225 0.125 ]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 2200      Losses: [11.5735, 11.8118, 11.6213, 11.793, 11.7382, 11.7658]
Step: 2200      Training acc: [0.1525 0.16   0.16   0.135  0.1175 0.1375]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 2300      Losses: [12.602, 12.7457, 12.5895, 12.7572, 12.605, 12.6782]
Step: 2300      Training acc: [0.16   0.155  0.16   0.1575 0.165  0.15  ]

Model saved on path /multimodal-few-shot/models/
------ Meta-training 2-way, 1-shot (5-query) ATT-mapper 4-prefix tokens------
Step: 2400      Losses: [11.4712, 11.6158, 11.4204, 11.4886, 11.5966, 11.5541]
Step: 2400      Training acc: [0.1525 0.14   0.1475 0.1575 0.1525 0.13  ]

Model saved on path /multimodal-few-shot/models/
------ Meta-test 2-way, 1-shot (5-query) ------
Step: 2400      Test acc: [0.163  0.1555 0.1545 0.155  0.168  0.1495]

Is this normal or did I miss something?

Errors during the finetuning process

Hi, thanks for this interesting work. But when I try to train the multimodal meta-learner with the provided codes, I have encountered an error at step 400 during finetuning process:
error
This error happens at Line 164 in meta_trainer.py:
logits_q, pred_tokens = model(x_qry, y_qry, y_qry_mask, list(model.mapper_net.parameters()), is_finetuning=True)
because the forward function (at Line 47 in meta_learner.py) does not have is_finetuning parameter:
def forward(self, image, tokens: torch.Tensor, mask: Optional[torch.Tensor] = None, fast_weights=None, labels: Optional[torch.Tensor] = None, get_pred_tokens=True):
Besides, I find that the passed parameters of meta.finetuning() (at Line 114 in main_train.py) does not correspond to the defined arguments of finetuning() (at Line 145 in meta_trainer.py), which causes another error after I simply remove is_finetuning parameter:
error2
So, how can I fix these errors?

weight of trained

Thank you for the authors' good work.
Will you public your trained model weight soon?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.