Giter VIP home page Giter VIP logo

gemma-2b-10m's Introduction

Gemma 2B - 10M Context

Gemma 2B with recurrent local attention with context length of up to 10M. Our implementation uses <32GB of memory!

Graphic of our implementation context

Features:

  • 10M sequence length on Gemma 2B.
  • Runs on less than 32GB of memory.
  • Native inference optimized for cuda.
  • Recurrent local attention for O(N) memory.

Quick Start

Note: This is a very early checkpoint of the model. Only 200 steps. We plan on training for a lot more tokens!

Install the requirements:

pip install -r requirements.txt

Install the model from huggingface - Huggingface Model.

python main.py

Change the main.py inference code to the specific prompt you desire.

model_path = "./models/gemma-2b-10m"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = GemmaForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16
)

prompt_text = "Summarize this harry potter book..."

with torch.no_grad():
    generated_text = generate(
        model, tokenizer, prompt_text, max_length=512, temperature=0.8
    )

    print(generated_text)

How does this work?

The largest bottleneck (in terms of memory) for LLMs is the KV cache. It grows quadratically in vanilla multi-head attention, thus limiting the size of your sequence length.

Our approach splits the attention in local attention blocks as outlined by InfiniAttention. We take those local attention blocks and apply recurrance to the local attention blocks for the final result of 10M context global atention.

A lot of the inspiration for our ideas comes from the Transformer-XL paper.

More details

For more context about our motivations, implementation details, and the theory behind the work, check out our technical overview on medium.

Credits

This was built by:

gemma-2b-10m's People

Contributors

mustafaaljadery avatar akshgarg7 avatar eltociear avatar

Stargazers

CyrilEnjalbert avatar Shashank Shekhar avatar Steve Seguin avatar varunsaagar avatar linroger avatar Ai avatar SuSuSoo avatar Wei Zhang avatar  avatar Giorgio Crivellari avatar  avatar Selim Furkan Tekin avatar Eddy avatar  avatar  avatar Daxiong avatar wangpu avatar phalanx avatar curtis.abcd avatar  avatar Animikh Aich avatar Qian Zeng avatar  avatar Stark avatar jinyuan sun avatar Allen Peng avatar obito avatar  avatar Juan Lopez avatar lithium avatar  avatar Aleksandr M avatar duke79 avatar hengyanchen avatar Kevin S Kreger avatar  avatar Asad Dhamani avatar Lesly Arun Franco avatar  avatar Thahir Kareem avatar AbdoulRZ avatar Dzmitry Tselabionak avatar snowmerak avatar John avatar Quang Dương Quỳnh avatar martintomov avatar  avatar Pongsaky avatar _HYX_ avatar  avatar Halit Ziya KARTAL avatar  avatar  avatar Andranik avatar  avatar Stefan Huber avatar Code. avatar BlueSharkPartners avatar  avatar Tarik Sghiouri idrissi avatar Zhichun Wu avatar  avatar Adam Twardoch avatar Sofija Sternad avatar 麻由 avatar ViIl avatar felix-wang avatar kw avatar MaybeMonad avatar  avatar allen.hu avatar Lële avatar Sulabh avatar Somesh Khandelwal avatar  avatar  avatar  avatar  avatar Matvei Shestakov avatar  avatar Jemshad AI avatar  avatar Hongcheng Zhu avatar Les Vogel avatar aubrey avatar  avatar Suresh Veeragoni avatar Jeffrey Paul avatar  avatar  avatar Jerry Zhao avatar SpongeBob avatar  avatar Víðópnir avatar Ziyú Ye avatar Charles avatar Rohith avatar kebin liu avatar WangXian avatar Amr Kayid avatar

Watchers

 avatar  avatar  avatar  avatar Sofija Sternad avatar  avatar J. Rotich avatar  avatar  avatar AbSomeone avatar Blue Screen avatar

gemma-2b-10m's Issues

can't install flash_attn

My notebook:
Windows 11 Pro 23H2
Intel i7-8750H
GeForce GTX 1050Ti (Mobile)
32GB RAM (2666GHz)

pip install -r .\requirements.txt
Requirement already satisfied: torch in c:\users\anime\appdata\local\programs\python\python310\lib\site-packages (from -r .\requirements.txt (line 1)) (2.3.0)
Requirement already satisfied: transformers in c:\users\anime\appdata\local\programs\python\python310\lib\site-packages (from -r .\requirements.txt (line 2)) (4.40.2)
Requirement already satisfied: datasets in c:\users\anime\appdata\local\programs\python\python310\lib\site-packages (from -r .\requirements.txt (line 3)) (2.19.1)
Collecting flash_attn (from -r .\requirements.txt (line 4))
  Using cached flash_attn-2.5.8.tar.gz (2.5 MB)
  Preparing metadata (setup.py) ... error
  error: subprocess-exited-with-error
  
  × python setup.py egg_info did not run successfully.
  │ exit code: 1
  ╰─> [20 lines of output]
      fatal: not a git repository (or any of the parent directories): .git
      C:\Users\Anime\AppData\Local\Temp\pip-install-zqspt8qf\flash-attn_c20c0c86c12c4a6083c44ea61c202e13\setup.py:78: UserWarning: flash_attn was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.
        warnings.warn(
      Traceback (most recent call last):
        File "<string>", line 2, in <module>
        File "<pip-setuptools-caller>", line 34, in <module>
        File "C:\Users\Anime\AppData\Local\Temp\pip-install-zqspt8qf\flash-attn_c20c0c86c12c4a6083c44ea61c202e13\setup.py", line 134, in <module>
          CUDAExtension(
        File "C:\Users\Anime\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\cpp_extension.py", line 1077, in CUDAExtension
          library_dirs += library_paths(cuda=True)
        File "C:\Users\Anime\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\cpp_extension.py", line 1211, in library_paths
          paths.append(_join_cuda_home(lib_dir))
        File "C:\Users\Anime\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\cpp_extension.py", line 2419, in _join_cuda_home
          raise OSError('CUDA_HOME environment variable is not set. '
      OSError: CUDA_HOME environment variable is not set. Please set it to your CUDA install root.


      torch.__version__  = 2.3.0+cpu


      [end of output]

  note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed

× Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.

Host a demo on Huggingface Spaces with free A100s (ZeroGPU)

Congratulations on this super-exciting project! It would be awesome to top it up with a live Gradio demo on Huggingface Spaces. I think this could help with more community engagement and drive more visibility to the project. We at Huggingface also provide free GPU grants through the ZeroGPU program, which includes access to free A100s for developers. We would be glad to extend the grant to your application.

Some useful links to help you get started on Spaces:

Please let us know if you need further assistance or support in integrating your project with Spaces or other relevant Huggingface offerings.

Some Errors...

My notebook:
Windows 11 Pro 23H2
Intel i7-8750H
GeForce GTX 1050Ti (Mobile)
32GB RAM (2666GHz)

After I removed the mention of flash_atn in gemma.py, I got the following errors:
TypeError: GemmaModel.forward() got an unexpected keyword argument 'cache_position'
(and with other models also)

after adding *args and **kwargs to all forwards, another error appeared:
RuntimeError: The size of tensor a (5) must match the size of tensor b (6) at non-singleton dimension 3

Traceback (most recent call last):
   File "d:\Programming\Python\MyGemma2B\1.py", line 42, in <module>
     generated_text = generate(
   File "d:\Programming\Python\MyGemma2B\1.py", line 17, in generate
     outputs = model(input_ids=input_segment.to(model.device), memory=memory, norm_term=norm_term)
   File "C:\Users\Anime\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
     return self._call_impl(*args, **kwargs)
   File "C:\Users\Anime\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
     return forward_call(*args, **kwargs)
   File "d:\Programming\Python\MyGemma2B\gemma_modified.py", line 960, in forward
     outputs = self.model(
   File "C:\Users\Anime\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
     return self._call_impl(*args, **kwargs)
   File "C:\Users\Anime\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
     return forward_call(*args, **kwargs)
   File "d:\Programming\Python\MyGemma2B\gemma_modified.py", line 783, in forward
     layer_outputs = decoder_layer(
   File "C:\Users\Anime\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
     return self._call_impl(*args, **kwargs)
   File "C:\Users\Anime\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
     return forward_call(*args, **kwargs)
   File "d:\Programming\Python\MyGemma2B\gemma_modified.py", line 617, in forward
     _attended = self.self_attn(
   File "C:\Users\Anime\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
     return self._call_impl(*args, **kwargs)
   File "C:\Users\Anime\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
     return forward_call(*args, **kwargs)
   File "d:\Programming\Python\MyGemma2B\gemma_modified.py", line 532, in forward
     attn_output = torch.nn.functional.scaled_dot_product_attention(
RuntimeError: The size of tensor a (5) must match the size of tensor b (6) at non-singleton dimension 3

All errors occurred after Loading checkpoint shards

TypeError: GemmaModel.forward() got an unexpected keyword argument 'cache_position'

I made a colab( https://colab.research.google.com/drive/1Z3NdoT0WS8KXnSUS3_xxT39NBZD6eGcN?usp=sharing ) to test and I ran into some issue. GemmaModel.forward() got an unexpected keyword argument 'cache_position'. I had to change some of the main.py to get the model to load correctly. The model loads into system ram not onto the gpu, I don't know if that is the issue for the GemmaModel.forward() error.

I have some other question, is the content length set in the def generate function? Is the memory ballooning as the context and hidden state grows? In the config.json "torch_dtype" is "float32" is there a reason for this, in google gemma2b its "torch_dtype" is "bfloat16".


TypeError Traceback (most recent call last)
in <cell line: 3>()
2
3 with torch.no_grad():
----> 4 generated_text = generate(
5 model, tokenizer, prompt_text, max_length=512, temperature=0.8
6 )

5 frames
in generate(model, tokenizer, prompt_text, max_length, temperature)
15 while generated_sequence.size(1) < original_length + max_length:
16 input_segment = generated_sequence[:, -2048:]
---> 17 outputs = model(input_ids=input_segment.to(model.device), memory=memory, norm_term=norm_term)
18 memory, norm_term = outputs.memory, outputs.norm_term
19 next_token_logits = outputs.logits[:, -1, :]

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
1512
1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1521
1522 try:

/content/gemma-2B-10M/src/gemma.py in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, memory, norm_term, no_memory_update)
947 )
948
--> 949 outputs = self.model(
950 input_ids=input_ids,
951 attention_mask=attention_mask,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
1512
1513 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1521
1522 try:

TypeError: GemmaModel.forward() got an unexpected keyword argument 'cache_position'

How run it?

I don't quite understand how to install and run it.
I downloaded this folder from github, and downloaded all 13 files from hugging face. What's next, in which folder should I put them and how to run them?

Can I limit the context window to say like 100k?

Hi, really exciting to see 10M context window. But I don't have 32G memory. Can I limit the context window to 100k to reduce the required memory to be fit in 16G? I don't want to crash and restart my laptop since so many work going on. Thanks.

LoRA fine tuning code ?

Hi

Can this be finetuned with LoRA without any additional script. Also, during finetuning, if we take sequence length of 512 or 1k, will it affect the inference for higher context length of say 16k or 32k ?

implementation for pytorch gemma InifiniTransformer is copied without attribution

The code for the model provided in this repository seems to be a copy of the repository linked below:

https://github.com/Beomi/InfiniTransformer/blob/main/infini_gemma/modeling_infini_gemma.py

Specifically, the GemmaInifniAttention and GemmaModel seem to be a direct copy, with comments removed, gradient checkpointing removed, and specific sections of the code slightly altered (especially in the aforementioned classes).

The only actual difference other than slight code re-writings (e.g., flipping conditional statements, shuffling variable definition positions, turning multi-line statements into a single long line) seems to be that you forgot to add the rotary embeddings in the GemmaInfiniAttention class, replaced instances of segment variables with hidden_states. All of the variable names are identical.

What do the authors of this repository have to say in response? There doesn't seem to be anything new, and there is no mention of the original authors of the unofficial implementation. Not a very good look considering the recent llama3v incident...

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.