Comments (4)
You should print(model), then you'll see:
GemmaForCausalLM(
(model): GemmaModel(
(embed_tokens): Embedding(256000, 2048, padding_idx=0)
(layers): ModuleList(
(0-17): 18 x GemmaDecoderLayer(
(self_attn): GemmaInfiniAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=False)
(k_proj): Linear(in_features=2048, out_features=256, bias=False)
(v_proj): Linear(in_features=2048, out_features=256, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): GemmaRotaryEmbedding()
)
(mlp): GemmaMLP(
(gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
(up_proj): Linear(in_features=2048, out_features=16384, bias=False)
(down_proj): Linear(in_features=16384, out_features=2048, bias=False)
(act_fn): PytorchGELUTanh()
)
(input_layernorm): GemmaRMSNorm()
(post_attention_layernorm): GemmaRMSNorm()
)
)
(norm): GemmaRMSNorm()
)
(lm_head): Linear(in_features=2048, out_features=256000, bias=False)
)
from infinitransformer.
Since using attn_implementation='eager' and overriding GEMMA_ATTENTION_CLASSES
like this, is not optimal way and confusing -- but since the HF does not allow attention classes so currently I overrode eager class, from original GemmaAttention
into GemmaInfiniAttention
.
GEMMA_ATTENTION_CLASSES = {
"eager": GemmaInfiniAttention, # GemmaAttention,
"flash_attention_2": GemmaFlashAttention2,
"sdpa": GemmaSdpaAttention,
}
from infinitransformer.
oh BTW, the default value of attn_implementation
is "spda" for HF Gemma.
from infinitransformer.
thanks, I mistakenly used the transformer library in the miniconda environment, and now the problem has been solved
from infinitransformer.
Related Issues (20)
- Suggest to use the constant memory gradient computation in Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
- Model generating random sequence HOT 8
- Limitations of the method HOT 2
- Memory should be per layer
- Memory does not use PE
- Inference code (with Segments)
- Are there any trained InfinityTransformer weights available?
- Segment and block size error HOT 1
- mem and norm_term is nanοΌ HOT 15
- What is the min GPU memory required to fine-tune the model?
- About memory missing location information HOT 5
- BitLinear
- Model loses information very quickly HOT 2
- Issue while runing test_train.small.gemma.infini.py HOT 2
- question about activation function HOT 2
- Discord server for this?
- Code not running on GPU HOT 6
- question about norm_term_broadcastable HOT 5
- load model failed HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. πππ
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google β€οΈ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from infinitransformer.