Giter VIP home page Giter VIP logo

m2's Introduction

Monarch Mixer

Update January 11, 2024: We are excited to release new long-context M2-BERT models, with versions fine-tuned for embeddings! See the blog post for more details, and check out bert/EMBEDDINGS.md for more details!

Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture
Daniel Y. Fu, Simran Arora*, Jessica Grogan*, Isys Johnson*, Sabri Eyuboglu*, Armin W. Thomas*, Benjamin F. Spector, Michael Poli, Atri Rudra, and Christopher Ré.
arXiv | M2-BERT blog post

Long-Context Retrieval Models with Monarch Mixer
Jon Saad-Falcon, Dan Fu, Simran Arora. Blog post, Jan 11 2024.
Blog post.

Updates:

  • January 11, 2024: M2-BERT retrieval models are now available on Together API! Check out instructions below for running them!
  • January 11, 2024: New long-context M2-BERT models available (2k, 8k, and 32k), as well as retrieval versions for embeddings. Also releasing a preview of LoCo, a new benchmark for long-context retrieval! Check out our blog to read more, and try the models out here!
  • October 21, 2023: M2-BERT-large checkpoints are now up on HuggingFace (260M, 341M). The 260M model matches BERT-large in GLUE fine-tuning with 24% fewer parameters, and the 341M model outperforms BERT-large.
  • October 18, 2023: M2 paper is now up on arXiv, and will be presented at NeurIPS as an oral!

Base M2-BERT Checkpoints:

Long-Context and Retrieval M2-BERT Checkpoints:

Transformers have taken the world by a storm! The architecture is composed of two core operations: Attention for mixing information across the input sequence and MLPs for mixing information across the model dimension. Each operator scales quadratically -- the complexity of Attention is quadratic in sequence length and the complexity of an MLP is quadratic in model dimension. Ideally, we can have alternatives that scale more efficiently, while preserving Transformer-level quality. Towards this goal, we've been developing Monarch Mixer (M2), a framework for training models that are sub-quadratic in both sequence length and model dimension.

M2 diagram

Our basic idea is to replace the major elements of a Transformer with Monarch matrices — which are a class of structured matrices that generalize the FFT and are sub-quadratic, hardware-efficient, and expressive. In Monarch Mixer, we use layers built up from Monarch matrices to do both mixing across the sequence (replacing the Attention operation) and mixing across the model dimension (replacing the dense MLP). This repo includes code and models for training Monarch Mixer architectures!

Getting Started with Embeddings

M2-BERT embedding models are now available on the Together API. You can run them by signing up for an account and querying the API as follows (you can find your API key here):

import os
import requests

def generate_together_embeddings(text: str, model_api_string: str, api_key: str):
    url = "https://api.together.xyz/api/v1/embeddings"
    headers = {
        "accept": "application/json",
        "content-type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }
    session = requests.Session()
    response = session.post(
        url,
        headers=headers,
        json={
            "input": text,
            "model": model_api_string
        }
    )
    if response.status_code != 200:
        raise ValueError(f"Request failed with status code {response.status_code}: {response.text}")
    return response.json()['data'][0]['embedding']

print(generate_together_embeddings('Hello world', 'togethercomputer/m2-bert-80M-32k-retrieval', os.environ['TOGETHER_API_KEY'])[:10])

Check out bert/EMBEDDINGS.md for more on how to evaluate these models and run them locally!

Current Contents

January 9, 2024: long-context M2-BERT-base checkpoints are up on HuggingFace, as well as retrieval versions for embedding models!

October 21, 2023: M2-BERT-large checkpoints are now up on HuggingFace, and the paper is on arXiv!

July 24, 2023: We are excited to release Monarch Mixer BERT (M2-BERT), which has 25% fewer parameters/FLOPs than BERT, and matches in average quality on the GLUE benchmark. The BERT folder includes code for pretraining and finetuning BERT baselines and M2-BERT. We also release pretrained checkpoints at 128 sequence length for an 80M parameter BERT, which matches the average GLUE benchmark score of the BERT-base-uncased 110M parameter model, and a parameter matched M2-BERT model.

Citation

If you use this codebase, or otherwise found our work valuable, you can cite us as follows:

@inproceedings{fu2023monarch,
  title={Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture},
  author={Fu, Daniel Y and Arora, Simran and Grogan, Jessica and Johnson, Isys and Eyuboglu, Sabri and Thomas, Armin W and Spector, Benjamin and Poli, Michael and Rudra, Atri and R{\'e}, Christopher},
  booktitle={Advances in Neural Information Processing Systems},
  year={2023}
}

You can also cite our previous work that this repository builds on:

@article{poli2023hyena,
  title={Hyena Hierarchy: Towards Larger Convolutional Language Models},
  author={Poli, Michael and Massaroli, Stefano and Nguyen, Eric and Fu, Daniel Y and Dao, Tri and Baccus, Stephen and Bengio, Yoshua and Ermon, Stefano and R{\'e}, Christopher},
  journal={arXiv preprint arXiv:2302.10866},
  year={2023}
}
@article{fu2023simple,
  title={Simple Hardware-Efficient Long Convolutions for Sequence Modeling},
  author={Fu, Daniel Y. and Epstein, Elliot L. and Nguyen, Eric and Thomas, Armin W. and Zhang, Michael and Dao, Tri and Rudra, Atri and R{\'e}, Christopher},
  journal={International Conference on Machine Learning},
  year={2023}
}
@inproceedings{fu2023hungry,
  title={Hungry {H}ungry {H}ippos: Towards Language Modeling with State Space Models},
  author={Fu, Daniel Y. and Dao, Tri and Saab, Khaled K. and Thomas, Armin W.
  and Rudra, Atri and R{\'e}, Christopher},
  booktitle={International Conference on Learning Representations},
  year={2023}
}

m2's People

Contributors

danfu09 avatar eltociear avatar jonsaadfalcon avatar mocuto avatar simran-arora avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

m2's Issues

Multilingual?

Are you planning to release a multilingual version of this model? Could I finetune the current m2_bert model on german data?

`python3 test_flash_mm.py` got error

ERROR: CUDA RT call "cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size )" in line 695 of file mm/csrc/flashmm/mm_block_fwd_cuda.cu failed with invalid device function (98).
max diff for mm block: tensor(2.0590e-05, device='cuda:0', grad_fn=<SelectBackward0>)
average diff for mm block: tensor(2.9658e-06, device='cuda:0', grad_fn=<MeanBackward0>)
max diff: tensor(0.0003, device='cuda:0')
avg diff: tensor(7.4159e-05, device='cuda:0')

I still can run the trainer and the loss go down,

MonarchMixerLayer

Hello,

I've come across an algorithm in the paper that appears to be designed for the M2 layer, with the intention of replacing both the Attention and MLP layers (specifically the nn.Linear part of the latter).

However, upon examining the monarch_mixer_sequence_mixer.py script, I noticed that it uses Hyena filters, and I couldn't find any implementation of this M2 layer algorithm in the code.

I might be missing something, but I wanted to clarify if it's necessary to substitute the Hyena filters with the M2 layer.

Thank you for your assistance with this project!

P.S.: I'm currently working with image data.

Getting Cuda error when trying to train for 8k context

Hi,

I am trying to do pretraining for 8k sequence length on a custom dataset. However, I am getting the following error -

`Traceback (most recent call last):
File "/disk1/sandeep/m2bert/m2/bert/main.py", line 280, in
main(cfg)
File "/disk1/sandeep/m2bert/m2/bert/main.py", line 187, in main
train_loader = build_dataloader(
File "/disk1/sandeep/m2bert/m2/bert/main.py", line 144, in build_dataloader
return text_data_module.build_text_dataloader(cfg, tokenizer,
File "/disk1/sandeep/m2bert/m2/bert/src/text_data.py", line 274, in build_text_dataloader
dataset = StreamingTextDataset(
File "/disk1/sandeep/m2bert/m2/bert/src/text_data.py", line 134, in init
super().init(
File "/disk1/sandeep/miniconda3/envs/m2_bert/lib/python3.10/site-packages/streaming/base/dataset.py", line 325, in init
self._shm_prefix, self._locals_shm = get_shm_prefix(my_locals, world)
File "/disk1/sandeep/miniconda3/envs/m2_bert/lib/python3.10/site-packages/streaming/base/shared.py", line 357, in get_shm_prefix
dist.barrier()
File "/disk1/sandeep/miniconda3/envs/m2_bert/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3145, in barrier
work = default_pg.barrier(opts=opts)
RuntimeError: NCCL Error 1: unhandled cuda error

`

my batch sizes are -

global_train_batch_size: 7

System

seed: 17
device_eval_batch_size: 1
#device_train_microbatch_size: 8
device_train_microbatch_size: auto
precision: amp_bf16

please let me know what i am doing wrong

Using (Absolute) Positional Embeddings with Hyena Operators

Hi @DanFu09
Hope you're well,

I was reading the source code and the config files, and I realized that use_positional_encodings is True (link). So, the M2-BERT model is using an absolute positional embeddings (link) before feeding the tokens to Hyena operators.

I checked the original Hyena and HyenaDNA source codes, and they haven't used any positional embeddings for their models.
My question is why have you used the positional embeddings? Have you tried not using them? Did it worsen the performance?

Unable to use 'convert_dataset.py' to load data

I am getting server disconnected error when I am using convert_dataset.py', even for bookcorpus or wikipedia dataset.
If I do, stream=False in the code, then i get the following error -

Downloading data: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [03:06<00:00, 4.55s/files] Generating train split: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6458670/6458670 [01:06<00:00, 96995.20 examples/s] Loading dataset shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:00<00:00, 1805.01it/s] Traceback (most recent call last): File "/home/sandeep.pandey/m2/bert/src/convert_dataset.py", line 524, in <module> main(parse_args()) File "/home/sandeep.pandey/m2/bert/src/convert_dataset.py", line 489, in main loader = build_dataloader(dataset=dataset, batch_size=512) File "/home/sandeep.pandey/m2/bert/src/convert_dataset.py", line 397, in build_dataloader num_workers = min(64, dataset.hf_dataset.n_shards) # type: ignore AttributeError: 'Dataset' object has no attribute 'n_shards'

Please help to resolve this as I am stucked on reproducing the training pipeline.

torch.bmm kernel fusion

@DanFu09 Thanks for open-sourcing the code!
I see that in your previous fly repo(https://github.com/HazyResearch/fly), you used cast_inputs=torch.float16 for BlockdiagButterflyMultiply, but changed it to bf16 here. I wonder if there's a specific reason (e.g. fp16 training not converging due to range issues)?
Also, I wonder if there are opportunities for fusing the two bmm operations into one kernel? It seems hard to find the exact kernel torch is calling though.

MNLI yaml config

Hi, thank you for the awesome work!

I was wondering if the yaml configs used for fine-tuning on the MNLI task could also be shared.

Thank you so much :)

Code for projecting pre-trained BERT weights into Monarch matrices

Hello, I would like to know if you have published the code to project the pre-trained weights of the BERT model into Monarch matrices. I cannot locate the code for this (I have also looked in the fly repo).
I can see the projection functions here, but I am interested in knowing how you use them specifically for BERT (or other transformers for NLP) to go from pre-trained weights to Monarch matrices. Thank you very much.

Embedding speed seems slow

Hello there.
I tried to use m2_bert_80M_2k to generate embeddings for text strings with lengths around 500 tokens following your example on Huggingface](https://huggingface.co/togethercomputer/m2-bert-80M-2k-retrieval). However, the outputs = model(**input_ids) line tooks over 15s on average, slower than expected. Could you please help me find the issue here?

I also tested your example for 12 tokens. The model forwarding process is still slow (over 5s for 12tokens & padding="longest", over 16s for 12tokens & padding="max_length"(=2048).
12tok_longest
12tok_max_length
Thanks in advance!

Why is there such a big difference in cosine similarity between embeddings of the same pair when using padding=max_length versus padding=true?

When I was embedding a relevant text pair using the m2-bert-80M-32k-retrieval model, the cosine similarity obtained with padding=max_length was 0.7, while with padding=true (to save memory) it was close to 0. This resulted in semantic retrieval being completely impossible with padding=true. The same situation occurred with the 2k and 8k models as well.Why is this the case? And is padding=true completely unusable?

Missing Licence

Hey there,

just wanted to try out your code. Would you be so kind to add any licence.

Thanks!

Best

  • MichaelFeil

can i use MonarchMixer replace cross attention lay

The Sequence Mixer in the paper doesn't seem to be able to mix unequal lengths of sequences in the same way as corss attention.because it uses elementwise multiplication.Is this a misunderstanding on my part or is Monarch Mixer not a replacement for cross attention?

What category does the M2 model belong to

Hello, thank you for your great work! M2bert paper mentioned that "Monarch Mixer is part of a new class of architectures called state-space models (SSMs), which include S4, Mamba, and BiGS".
Is Monarch Mixer and M2BERT a part of SSMs?
I consider M2BERT to be:
(1) replace attention with bidirectional gated convolutions with a residual convolution, and set the Monarch matrices to DFT and inverse DFT matrices to speed up DFT for conv;
(2)In the dimension mixer, replace the two dense matrices in MLPs with learned block-diagonal matrices to speed up MLP computation.

I wonder which part of it is related to SSM? I would be very grateful if you could help me with the answer : )

does M2 work with ONNX?

i know there are special long convolutions, not sure if ONNX supports them / what would happen if I tried to export to ONNX

precision on imagenet experiment

Hi,

For imagenet, you mentioned in the paper the Hyena code is used for the experimentation by replacing MLP blocks in Hyena ViT-b with block-diagonal matrices, similarly to M2-BERT. Based on the config file: trainer: precision: 16 is used in Hyena, so I wonder if you use mixed precision bf16 here for imagenet (similar to M2-bert) to train it on A100 gpus or used simple 16-bit precision.

Bert-like implementation

Hello,

Amazing work!!!

I have a couple of questions regarding the bidirectional implementation of the model.

  1. Does the MonarchMixerSequenceMixing have by default all the recommending settings used in the training of the bert-like models (obviously with bidirectional = True). If not, is it possible to share the settings used for the M2 large?
  2. It seems like the input u is after a token embedding layer. Do you add positional embeddings?
  3. Is any sort of attention mask required?
  4. Is it really okay to say M2 outperforms BERT when trained on different data? I think C4 improves BERT base considerably if I remember correctly.

Best,
Logan

LoCo Benchmark - BM25 & Insights

Hey, thanks for sharing this very interesting work!

I was interested in the recent LoCo benchmark composed for long-context retrieval and found it useful to have results for a very simple lexical baseline method first to put the scores in the blog post into context. As this was not yet done in the blog post, I ran BM25 (via ElasticSearch) on all benchmark tasks based on your eval script. Full results, in comparison to the best-performing M2-BERT-32768 (80M), below (NDCG@10 for all).

BM25

Retrieval Encoders Tau Scrolls Summ. Screen Tau Scrolls Gov. Report Tau Scrolls QMSUM QASPER - Title to Article QASPER - Abstract to Article Average
BM25 97.4 98.7 59.4 94.0 99.4 89.8
M2-BERT-32768 (80M) 98.6 98.5 69.5 97.4 98.7 92.5

BM25 seems to be very competitive on LoCo, coming close to the best model tested in the post's evaluation and outperforming all other tested embedding models. Thus, lexical overlap between queries and correct documents seems to be very high on the benchmark tasks.

QMSum Analysis

Looking a bit closer at the results, we can see that for 4 of 5 tasks, NDCG is well above 90, meaning that BM25 is nearly perfectly able to retrieve the correct documents. The only exception is QMSum, so I looked into its data a bit closer:

Originally, QMSum is a summarization dataset consisting of three text fragments: a corpus of 232 long meeting transcript, a set of 272 questions and 272 query-based summarizations of the transcripts. In the tau/scrolls format, queries and transcripts are joined together in the "input" field whereas summaries are given in the "output" field. This gives 272 pairs of inputs-outputs. LoCo now simply uses "output" as query and "input" as document, giving 272 queries and 272 documents.

This means that in the LoCo doc corpus of QmSum multiple documents are based off the same long meeting transcript, paired with different questions. E.g. for the first 4 documents are:

Passage_0 -> What was agreed upon on sample transcripts? Professor E: So . OK . Doesn't look like it crashed . That's great ...
Passage_1 -> What was said on speech overlap? Professor E: So . OK . Doesn't look like it crashed . That's great ...
Passage_2 -> What’s the current status of recordings and transcriptions? Professor E: So . OK . Doesn't look like it crased. That's great ...
Passage_3 -> What was the future of data collection? Professor E: So . OK . Doesn't look like it crashed . That 's great ...

The truncated part is identical in all four, meaning that the overwhelming part of the documents (with 9748 words on average) is identical apart from the question stated in the first few words. For distinguishing between these groups of documents, only the first few words are therefore relevant.

As an ablation, I removed the questions at the start of all documents and "merged" the resulting identical documents into one and then ran BM25 again. This improves NDCG@10 to 78.7.


Just wanted to share these quick insights into the LoCo benchmark, maybe this is useful to someone!

A question on square matrices

Hello there, I want to clarify whether the need for square matrices is strictly enforced. From the paper, I note that

"We turn to an expressive class of sub-quadratic structured matrices called Monarch matrices [12] (Figure 1 left) to propose Monarch Mixer (M2). Monarch matrices are a family of structured matrices that generalize the fast Fourier transform (FFT) and have been shown to capture a wide class of linear transforms including Hadamard transforms, Toeplitz matrices [30], AFDF matrices [55], and convolutions. They are parameterized as the products of block-diagonal matrices, called monarch factors, interleaved with permutation. Their compute scales sub-quadratically: setting the number of factors to p results in computational complexity of $O(pN^{(p+1)/p})$ in input length $N$ , allowing the complexity to interpolate between $O(N \log N )$ at $p = \log N$ and $O(N 3/2)$ at $p = 2$.", as well as:

"The convolution case with Monarch matrices fixed to DFT and inverse DFT matrices also admits implementations based on FFT algorithms [9]."

Furthermore, Proposition 3.2 in the original Monarch paper asserts that $\mathcal{MM}^*$ which can represent a convolution.

I thus want to find out whether the Monarch mixer operation enforces the requirements for having block-diagonal matrices, since (residual gated) convolution(s) intuitively does not usually output a block-diagonal matrix.

Thank You!

training data

hi, thank you for your nice work! Do you train your M2-BERT-128 ~ 32K (shown in the paper) on LOCO V0 or LOCO V1 training set?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.