Giter VIP home page Giter VIP logo

ml-4m's Introduction

4M: Massively Multimodal Masked Modeling

A framework for training any-to-any multimodal foundation models.
Scalable. Open-sourced. Across tens of modalities and tasks.

EPFL - Apple

Website | BibTeX | 🤗 Demo

Official implementation and pre-trained models for :

4M: Massively Multimodal Masked Modeling, NeurIPS 2023 (Spotlight)
David Mizrahi*, Roman Bachmann*, Oğuzhan Fatih Kar, Teresa Yeo, Mingfei Gao, Afshin Dehghan, Amir Zamir

4M-21: An Any-to-Any Vision Model for Tens of Tasks and Modalities, arXiv 2024
Roman Bachmann*, Oğuzhan Fatih Kar*, David Mizrahi*, Ali Garjani, Mingfei Gao, David Griffiths, Jiaming Hu, Afshin Dehghan, Amir Zamir


4M main figure 4M main figure

4M is a framework for training "any-to-any" foundation models, using tokenization and masking to scale to many diverse modalities. Models trained using 4M can perform a wide range of vision tasks, transfer well to unseen tasks and modalities, and are flexible and steerable multimodal generative models. We are releasing code and models for "4M: Massively Multimodal Masked Modeling" (here denoted 4M-7), as well as "4M-21: An Any-to-Any Vision Model for Tens of Tasks and Modalities" (here denoted 4M-21).

Table of contents

Usage

Installation

  1. Clone this repository and navigate to the root directory:
git clone https://github.com/apple/ml-4m
cd ml-4m
  1. Create a new conda environment, then install the package and its dependencies:
conda create -n fourm python=3.9 -y
conda activate fourm
pip install --upgrade pip  # enable PEP 660 support
pip install -e .
  1. Verify that CUDA is available in PyTorch by running the following in a Python shell:
# Run in Python shell
import torch
print(torch.cuda.is_available())  # Should return True

If CUDA is not available, consider re-installing PyTorch following the official installation instructions. Likewise, if you want to install xFormers (optional, for faster tokenizers), follow their README to ensure that the CUDA version is correct.

Getting started

We provide a demo wrapper to quickly get started with using 4M models for RGB-to-all or {caption, bounding boxes}-to-all generation tasks. For example, to generate all modalities from a given RGB input, call:

from fourm.demo_4M_sampler import Demo4MSampler, img_from_url
sampler = Demo4MSampler(fm='EPFL-VILAB/4M-21_XL').cuda()
img = img_from_url('https://storage.googleapis.com/four_m_site/images/demo_rgb.png') # 1x3x224x224 ImageNet-standardized PyTorch Tensor
preds = sampler({'rgb@224': img.cuda()}, seed=None) 
sampler.plot_modalities(preds, save_path=None)

You should expect to see an output like the following:

4M demo sampler output 4M demo sampler output

For performing caption-to-all generation, you can replace the sampler input by: preds = sampler({'caption': 'A lake house with a boat in front [S_1]'}). For a list of available 4M models, please see the model zoo below, and see README_GENERATION.md for more instructions on generation.

Data

See README_DATA.md for instructions on how to prepare aligned multimodal datasets.

Tokenization

See README_TOKENIZATION.md for instructions on how to train modality-specific tokenizers.

4M Training

See README_TRAINING.md for instructions on how to train 4M models.

Generation

See README_GENERATION.md for instructions on how to use 4M models for inference / generation. We also provide a generation notebook that contains examples for 4M inference, specifically performing conditional image generation and common vision tasks (i.e. RGB-to-All).

Model Zoo

We provide 4M and tokenizer checkpoints as safetensors, and also offer easy loading via Hugging Face Hub.

4M models

Model # Mod. Datasets # Params Config Weights
4M-B 7 CC12M 198M Config Checkpoint / HF Hub
4M-B 7 COYO700M 198M Config Checkpoint / HF Hub
4M-B 21 CC12M+COYO700M+C4 198M Config Checkpoint / HF Hub
4M-L 7 CC12M 705M Config Checkpoint / HF Hub
4M-L 7 COYO700M 705M Config Checkpoint / HF Hub
4M-L 21 CC12M+COYO700M+C4 705M Config Checkpoint / HF Hub
4M-XL 7 CC12M 2.8B Config Checkpoint / HF Hub
4M-XL 7 COYO700M 2.8B Config Checkpoint / HF Hub
4M-XL 21 CC12M+COYO700M+C4 2.8B Config Checkpoint / HF Hub

To load models from Hugging Face Hub:

from fourm.models.fm import FM

fm7b_cc12m  = FM.from_pretrained('EPFL-VILAB/4M-7_B_CC12M')
fm7b_coyo   = FM.from_pretrained('EPFL-VILAB/4M-7_B_COYO700M')
fm21b       = FM.from_pretrained('EPFL-VILAB/4M-21_B')

fm7l_cc12m  = FM.from_pretrained('EPFL-VILAB/4M-7_L_CC12M')
fm7l_coyo   = FM.from_pretrained('EPFL-VILAB/4M-7_L_COYO700M')
fm21l       = FM.from_pretrained('EPFL-VILAB/4M-21_L')

fm7xl_cc12m = FM.from_pretrained('EPFL-VILAB/4M-7_XL_CC12M')
fm7xl_coyo  = FM.from_pretrained('EPFL-VILAB/4M-7_XL_COYO700M')
fm21xl      = FM.from_pretrained('EPFL-VILAB/4M-21_XL')

To load the checkpoints manually, first download the safetensors files from the above links and call:

from fourm.utils import load_safetensors
from fourm.models.fm import FM

ckpt, config = load_safetensors('/path/to/checkpoint.safetensors')
fm = FM(config=config)
fm.load_state_dict(ckpt)

4M text-to-image specialist models

These models were initialized with the standard 4M-7 CC12M models, but continued training with a modality mixture heavily biased towards text inputs. They are still able to perform all other tasks, but perform better at text-to-image generation compared to the non-finetuned models.

Model # Mod. Datasets # Params Config Weights
4M-T2I-B 7 CC12M 198M Config Checkpoint / HF Hub
4M-T2I-L 7 CC12M 705M Config Checkpoint / HF Hub
4M-T2I-XL 7 CC12M 2.8B Config Checkpoint / HF Hub

To load models from Hugging Face Hub:

from fourm.models.fm import FM

fm7b_t2i_cc12m  = FM.from_pretrained('EPFL-VILAB/4M-7-T2I_B_CC12M')
fm7l_t2i_cc12m  = FM.from_pretrained('EPFL-VILAB/4M-7-T2I_L_CC12M')
fm7xl_t2i_cc12m  = FM.from_pretrained('EPFL-VILAB/4M-7-T2I_XL_CC12M')

Loading manually from checkpoints is performed in the same way as above for the base 4M models.

4M super-resolution models

Model # Mod. Datasets # Params Config Weights
4M-SR-L 7 CC12M 198M Config Checkpoint / HF Hub

To load models from Hugging Face Hub:

from fourm.models.fm import FM

fm7l_sr_cc12m  = FM.from_pretrained('EPFL-VILAB/4M-7-SR_L_CC12M')

Loading manually from checkpoints is performed in the same way as above for the base 4M models.

Tokenizers

Modality Resolution Number of tokens Codebook size Diffusion decoder Weights
RGB 224-448 196-784 16k Checkpoint / HF Hub
Depth 224-448 196-784 8k Checkpoint / HF Hub
Normals 224-448 196-784 8k Checkpoint / HF Hub
Edges (Canny, SAM) 224-512 196-1024 8k Checkpoint / HF Hub
COCO semantic segmentation 224-448 196-784 4k Checkpoint / HF Hub
CLIP-B/16 224-448 196-784 8k Checkpoint / HF Hub
DINOv2-B/14 224-448 256-1024 8k Checkpoint / HF Hub
DINOv2-B/14 (global) 224 16 8k Checkpoint / HF Hub
ImageBind-H/14 224-448 256-1024 8k Checkpoint / HF Hub
ImageBind-H/14 (global) 224 16 8k Checkpoint / HF Hub
SAM instances - 64 1k Checkpoint / HF Hub
3D Human poses - 8 1k Checkpoint / HF Hub

To load models from Hugging Face Hub:

from fourm.vq.vqvae import VQVAE, DiVAE

# 4M-7 modalities
tok_rgb = DiVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_rgb_16k_224-448')
tok_depth = DiVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_depth_8k_224-448')
tok_normal = DiVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_normal_8k_224-448')
tok_semseg = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_semseg_4k_224-448')
tok_clip = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_CLIP-B16_8k_224-448')

# 4M-21 modalities
tok_edge = DiVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_edge_8k_224-512')
tok_dinov2 = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_DINOv2-B14_8k_224-448')
tok_dinov2_global = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224')
tok_imagebind = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_ImageBind-H14_8k_224-448')
tok_imagebind_global = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_ImageBind-H14-global_8k_16_224')
sam_instance = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_sam-instance_1k_64')
human_poses = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_human-poses_1k_8')

To load the checkpoints manually, first download the safetensors files from the above links and call:

from fourm.utils import load_safetensors
from fourm.vq.vqvae import VQVAE, DiVAE

ckpt, config = load_safetensors('/path/to/checkpoint.safetensors')
tok = VQVAE(config=config) # Or DiVAE for models with a diffusion decoder
tok.load_state_dict(ckpt)

License

The code in this repository is released under the Apache 2.0 license as found in the LICENSE file.

The model weights in this repository are released under the Sample Code license as found in the LICENSE_WEIGHTS file.

Citation

If you find this repository helpful, please consider citing our work:

@inproceedings{4m,
    title={{4M}: Massively Multimodal Masked Modeling},
    author={David Mizrahi and Roman Bachmann and O{\u{g}}uzhan Fatih Kar and Teresa Yeo and Mingfei Gao and Afshin Dehghan and Amir Zamir},
    booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
    year={2023},
}

@article{4m21,
    title={{4M-21}: An Any-to-Any Vision Model for Tens of Tasks and Modalities},
    author={Roman Bachmann and O{\u{g}}uzhan Fatih Kar and David Mizrahi and Ali Garjani and Mingfei Gao and David Griffiths and Jiaming Hu and Afshin Dehghan and Amir Zamir},
    journal={arXiv 2024},
    year={2024},
}

ml-4m's People

Contributors

amir32002 avatar dmizr avatar eltociear avatar epoyraz avatar garjania avatar ofkar avatar roman-bachmann avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ml-4m's Issues

What are the minimum requirements to run an inference?

Hi,

I am attempting to run the model on my machine however the code keeps dying due to lack of memory even though my machine has enough memory to load all of the files. Is there any way I can know the minimum requirements needed to use the model? Thanks!

Training details of RGB tokenizer

Thanks for your great work!

As you mentioned in your paper 4M, you used DiVAE to train the RGB tokenizer, first on 100 epochs of ImageNet-1K, and then for an additional 15 epochs on the CC12M dataset, I followed your training settings indicated by your paper, and used the ckpt-100 training on ImageNet-1K as full_ckpt, but I encoutered NaN loss problem when I continued training on the CC12M dataset, could you please provide some suggestions to resolve the problem?

Thanks in advance!

Is it possible to prompt 4m

May thanks for making this work publicly available.

My question is on whether it is possible to prompt the available models, and if so, where might I find some examples on how to do this?

If not, do you plan to make this possible at some point?

Thanks in advance.

Depth tokenizer

Hi everyone, thanks for the nice work. I am considering using your pretrained depth tokenizer to extract precomputed (features) tokens for further training. I have some questions.

  1. I cloned the ml-4m, and installed the diffusers library. However, get error: AttributeError: module diffusers.models has no attribute unet_2d_blocks. Could you please specify the requisites for using your repo and which diffuser version you have used?

  2. Also, how many tokens do we get from your pretrained checkpoint model?

  3. Is your uploaded pretrained depth tokenizer an encoder-decoder or encoder only model that would just give me the required tokens?

  4. What normalization did you use for the depth data?

Thanks a lot!

Example of generating image pixels from ImageBind modality

Thanks for your excellent work!

I would like to inquire if you could provide some examples or documentation on how to use 4m to generate images from ImageBind feature or tokens. Your guidance on this matter would be greatly appreciated.

Thank you for your time and assistance.

how to convert the trained FM pth model file to safetensors format?

Thanks for providing the traning source for FM model. I notice that there is the scripts fourm/vq/init.py to pause the pre-trained tokens. However, there is no scripts that can parse the trained FM pth model file (PyTorch model file) to the safetensors format.
How could we deal with this situation?

VRAM Requirements and Multi-GPU Inference Support

Hello, thank you for your impressive research.
Could you please provide information on the amount of GPU VRAM required for each model? Additionally, if a single GPU does not have sufficient VRAM, is it possible to distribute the inference across multiple GPUs?

Input masks for generation - Potential small bug.

Looks like there may be a small bug in the generation:

eos_idx = torch.where(mod_dict[domain]['tensor'] == eos_id)[1][0].item()

The input masks for text are determined by the position of the first batch eos only but subsequently applied to all batches. Is this intentional? Looks like it's commonly used with single batch generation (in the examples) so this may have fallen through the cracks? If not I'd be curious about the intention here, otherwise happy to make a PR.

Great stuff btw, thanks for open sourcing this!

Fine-tune using LoRA

Hi,

I would like to know if it is possible to fine-tune the model for the specific downstream task using LoRA?

I noticed that there is a file related to LoRA: fourm/models/lora_utils.py but could not find how it is utilized. It would be highly appreciated if you can provide a tutorial of how we can use LoRA for fine-tuning? Thank you!

[Errno 2] No such file or directory: './fourm/utils/hmr2_utils/model_cfg.pkl'

Human pose dependencies are not installed, hence poses will not be visualized. To visualize them (optional), you can do the following:

  1. Install via pip install timm yacs smplx pyrender pyopengl==3.1.4
    You may need to follow the pyrender install instructions: https://pyrender.readthedocs.io/en/latest/install/index.html
  2. Download SMPL data from https://smpl.is.tue.mpg.de/. See https://github.com/shubham-goel/4D-Humans/ for an example.
  3. Copy the required SMPL files (smpl_mean_params.npz, SMPL_to_J19.pkl, smpl/SMPL_NEUTRAL.pkl) to fourm/utils/hmr2_utils/data .

I followed all the steps above but still got the error in the title.
Where is model_cfg.pkl?

Typo for tokenizer_path arg

It seems like run_training_4m.py uses the arg text_tokenizer_path to define the path of the text tokenizer, however the config files have this same variable called tokenizer_path. I believe they were supposed to be the same

Question on Token Masking in 4M Implementation

Thank you for open sourcing your amazing work.

I have a question regarding the token masking implementation: https://github.com/apple/ml-4m/blob/main/fourm/models/fm.py#L429

While I understand setting the tokens to 0, I'm curious about masking the positional embeddings. If we mask both tokens and positional embeddings to 0, how does the model distinguish between different tokens? Wouldn't this cause the model to treat these tokens identically? Would it make sense to add position embeddings after masking?

We can use causal attention to remedy this, but I'm wondering if I've misinterpreted the token masking process. Could you clarify this approach? Thank you!

How to use RGB DiVAE tokenizer?

I am trying to encoe and decode RGB images using the trained DiVAE checkpoint:

from fourm.vq.vqvae import DiVAE
from fourm.utils import denormalize, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize


tok = DiVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_rgb_16k_224-448')
normalize = Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
# encode
_, _, tokens = tok.encode(normalize(rgb_b3hw))
# decode
rgb_b3hw = denormalize(tok.decode_tokens(tokens))

For these input images:
image

I get these decoded images:
image

I tried to use with and without RGB normalization - it did not make any significant change tot the quality of the reconstruction I get.

What am I doing wrong? How one should use the tokenizer?

Thank you,

CLIPScore moved in latest torchmetrics v1.4.0.post0

Firstly, thank you for releasing an amazing repo.

I see in pyproject.toml you have torchmetrics>=1.3.1. Given this is not pinned to an exact version, if a new env is created the latest v1.4.0.post0 is used and this breaks run_generation.py.

Changing the import from torchmetrics.multimodal import CLIPScore to be from torchmetrics.multimodal.clip_score import CLIPScore solves this issue.

I'm happy to raise a PR if that's helpful.

Object Detection with Caption

First, thank you all for open sourcing this fantastic work.

I want to ask whether the object detection with caption feasible with this model and if yes how can I use it?

Thank you in advance!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.