Giter VIP home page Giter VIP logo

bitnet's Introduction

Multi-Modality

BitNet

bitnet PyTorch Implementation of the linear methods and model from the paper "BitNet: Scaling 1-bit Transformers for Large Language Models"

Paper link:

BitLinear = tensor -> layernorm -> Binarize -> abs max quantization -> dequant

"The implementation of the BitNet architecture is quite simple, requiring only the replacement of linear projections (i.e., nn.Linear in PyTorch) in the Transformer. " -- BitNet is really easy to implement just swap out the linears with the BitLinear modules!

NEWS

  • New Iteration ๐Ÿ”ฅ There is an all-new iteration from the paper "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits", we're implementing it now. Join the Agora discord and contribute! Join Here
  • New Optimizations The first BitLinear has been optimized and we now have a Bit Attention BitMGQA That implements BitLinear into the attention mechanism. Multi Grouped Query Attention is also widely recognized as the best attention for its fast decoding and long context handling, thanks to Frank for his easy to use implementation!
  • BitLinear 1.5 Launch ๐Ÿ”ฅ: The new BitLinear 1.5 is still in progress ๐Ÿ”ฅ Here is the file There are still some bugs like with the dequantization algorithm and we still need to replace the multiplication with elementwisw addition, if you could help with this, this would be amazing.
  • NOTICE: A model obviously needs to be finetuned from scratch to use BitLinear, just changing the linear methods in an already trained model isn't going to work. Finetune or train from scratch.

Appreciation

  • Dimitry, Nullonix for analysis and code review and revision
  • Vyom, for providing 4080 to train!

Installation

pip install bitnet

Usage:

BitLinear

  • Example of the BitLinear layer which is the main innovation of the paper!
import torch

from bitnet import BitLinear

# Input
x = torch.randn(10, 1000, 512)

# BitLinear layer
layer = BitLinear(512, 400)

# Output
y = layer(x)

print(y)

BitLinearNew

import torch
from bitnet import BitLinearNew

# Create a random tensor of shape (16, 10)
x = torch.randn(16, 1000, 512)

# Create an instance of the BitLinearNew class with input size 10, output size 20, and 2 groups
layer = BitLinearNew(
    512,
    20,
)

# Perform a forward pass through the BitLinearNew layer with input x
output = layer(x)

# Print the output tensor
print(output)
print(output.shape)

BitNetTransformer

  • Fully implemented Transformer as described in the diagram with MHA, and BitFeedforwards
  • Can be utilized not just for text but for images and maybe even video or audio processing
  • Complete with residuals and skip connections for gradient flow
# Import the necessary libraries
import torch
from bitnet import BitNetTransformer

# Create a random tensor of integers
x = torch.randint(0, 20000, (1, 1024))

# Initialize the BitNetTransformer model
bitnet = BitNetTransformer(
    num_tokens=20000,  # Number of unique tokens in the input
    dim=1024,  # Dimension of the input and output embeddings
    depth=6,  # Number of transformer layers
    heads=8,  # Number of attention heads
    ff_mult=4,  # Multiplier for the hidden dimension in the feed-forward network
)

# Pass the tensor through the transformer model
logits = bitnet(x)

# Print the shape of the output
print(logits)

BitAttention

This Attention has been modified to use BitLinear instead of the default linear projection. It's also using Multi-Grouped Query Attention instead of regular multi-head attention for faster decoding and longer context handling.

import torch
from bitnet import BitMGQA

# Create a random tensor of shape (1, 10, 512)
x = torch.randn(1, 10, 512)

# Create an instance of the BitMGQA model with input size 512, 8 attention heads, and 4 layers
gqa = BitMGQA(512, 8, 4)

# Pass the input tensor through the BitMGQA model and get the output and attention weights
out, _ = gqa(x, x, x, need_weights=True)

# Print the shapes of the output tensor and attention tensor
print(out)

BitFeedForward

  • Feedforward as shown in the diagram with BitLinear and a GELU:
  • Linear -> GELU -> Linear
  • You can add dropouts, or layernorms, or other layers for a better ffn
import torch
from bitnet import BitFeedForward

# Create a random input tensor of shape (10, 512)
x = torch.randn(10, 512)

# Create an instance of the BitFeedForward class with the following parameters:
# - input_dim: 512
# - hidden_dim: 512
# - num_layers: 4
# - swish: True (use Swish activation function)
# - post_act_ln: True (apply Layer Normalization after each activation)
# - dropout: 0.1 (apply dropout with a probability of 0.1)
ff = BitFeedForward(512, 512, 4, swish=True, post_act_ln=True, dropout=0.1)

# Apply the BitFeedForward network to the input tensor x
y = ff(x)

# Print the shape of the output tensor y
print(y)  # torch.Size([10, 512])

Inference

from bitnet import BitNetInference

bitnet = BitNetInference()
bitnet.load_model("../model_checkpoint.pth")  # Download model
output_str = bitnet.generate("The dog jumped over the ", 512)
print(output_str)

Huggingface Usage

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from bitnet import replace_linears_in_hf

# Load a model from Hugging Face's Transformers
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Replace Linear layers with BitLinear
replace_linears_in_hf(model)

# Example text to classify
text = "Replace this with your text"
inputs = tokenizer(
    text, return_tensors="pt", padding=True, truncation=True, max_length=512
)

# Perform inference
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
    print(predictions)

# Process predictions
predicted_class_id = predictions.argmax().item()
print(f"Predicted class ID: {predicted_class_id}")

# Optionally, map the predicted class ID to a label, if you know the classification labels
# labels = ["Label 1", "Label 2", ...]  # Define your labels corresponding to the model's classes
# print(f"Predicted label: {labels[predicted_class_id]}")

Drop in Replacement for Pytorch Models

import torch
from torch import nn
from bitnet import replace_linears_in_pytorch_model

# Define a simple model
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 30),
)

print("Before replacement:")
print(model)

# Replace nn.Linear with BitLinear
replace_linears_in_pytorch_model(model)

print("After replacement:")
print(model)

# Now you can use the model for training or inference
# For example, pass a random input through the model
input = torch.randn(1, 10)
output = model(input)

Optimized Cuda Kernel

python setup.py build_ext --inplace

import torch
import gemm_lowbit_ext  # This imports the compiled module

# Example usage
a = torch.randn(10, 20, dtype=torch.half, device='cuda')  # Example tensor
b = torch.randn(20, 30, dtype=torch.half, device='cuda')  # Example tensor
c = torch.empty(10, 30, dtype=torch.half, device='cuda')  # Output tensor

w_scale = 1.0  # Example scale factor
x_scale = 1.0  # Example scale factor

# Call the custom CUDA GEMM operation
gemm_lowbit_ext.gemm_lowbit(a, b, c, w_scale, x_scale)

print(c)  # View the result

BitLora

Implementation of BitLora!

import torch
from bitnet import BitLora

# Random text tensor
x = torch.randn(1, 12, 200)

# Create an instance of the BitLora model
model = BitLora(in_features=200, out_features=200, rank=4, lora_alpha=1)

# Perform the forward pass
out = model(x)

# Print the shape of the output tensor
print(out.shape)

BitMamba

import torch
from bitnet import BitMamba

# Create a tensor of size (2, 10) with random values between 0 and 100
x = torch.randint(0, 100, (2, 10))

# Create an instance of the BitMamba model with input size 512, hidden size 100, output size 10, and depth size 6
model = BitMamba(512, 100, 10, 6, return_tokens=True)

# Pass the input tensor through the model and get the output
output = model(x)

# Print the output tensor
print(output)

# Print the shape of the output tensor
print(output.shape)

BitMoE

import torch
from bitnet.bit_moe import BitMoE

# Create input tensor
x = torch.randn(2, 4, 8)

# Create BitMoE model with specified input and output dimensions
model = BitMoE(8, 4, 2)

# Forward pass through the model
output = model(x)

# Print the output
print(output)

License

MIT

Citation

@misc{2310.11453,
Author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Huaijie Wang and Lingxiao Ma and Fan Yang and Ruiping Wang and Yi Wu and Furu Wei},
Title = {BitNet: Scaling 1-bit Transformers for Large Language Models},
Year = {2023},
Eprint = {arXiv:2310.11453},
}

Todo

  • Double check BitLinear implementation and make sure it works exactly as in paper
  • Implement training script for BitNetTransformer
  • Train on Enwiki8, copy and past code and data from Lucidrains repos
  • Benchmark performance
  • Look into Straight Through Estimator for non-differentiable backprop
  • Implement BitFeedForward
  • Clean up codebase
  • Add unit tests for each module
  • Implement the new BitNet1.5b from the paper
  • Implement the BitNet15b in Cuda

bitnet's People

Contributors

dependabot[bot] avatar erjanmx avatar jiangxg avatar kyegomez avatar nullonesix avatar ramonpeter avatar sunwood-ai-labs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

bitnet's Issues

Expected BitLinear weight to be 1 or -1

Hello, I presume according to BitNet paper the weight should be -1 or 1. But

import torch
from bitnet import BitLinearNew

# Create a random tensor of shape (16, 10)
x = torch.randn(2, 10, 10)

# Create an instance of the BitLinearNew class with input size 10, output size 20, and 2 groups
layer = BitLinearNew(
    10,
    20,
)

# Perform a forward pass through the BitLinearNew layer with input x
output = layer(x)

print(layer.weight.dtype)
print(layer.weight)

Output

torch.float32
Parameter containing:
tensor([[ 0.1634,  0.2419, -0.0605,  0.1592,  0.2348, -0.1431, -0.1634,  0.0171,
         -0.1672, -0.1526],
        [-0.0848,  0.0079, -0.2014, -0.0492,  0.2833,  0.1290, -0.2156, -0.1515,
         -0.0473, -0.0839],
        [ 0.2230,  0.1434, -0.1410, -0.0626,  0.1189, -0.1652, -0.2978, -0.0287,
          0.1025,  0.2458],
        [-0.1670, -0.0222, -0.0272, -0.2312,  0.1880, -0.2040, -0.0305,  0.1009,
         -0.2247,  0.0124],
        [ 0.1351, -0.2926,  0.1891, -0.1614,  0.2894, -0.2931,  0.0802,  0.2884,
          0.0454, -0.1398],
        [-0.2954,  0.2651, -0.0062, -0.1592,  0.2138, -0.2038,  0.2965, -0.2545,
          0.0505, -0.0811],
        [-0.3062, -0.1191, -0.1521,  0.1021, -0.1865, -0.1102,  0.2120, -0.2865,
          0.1754,  0.1763],
        [ 0.1375, -0.2975,  0.0399, -0.1723, -0.0526, -0.2694,  0.1838, -0.1826,
          0.2806, -0.1438],
        [-0.3150,  0.2163,  0.1946, -0.0244,  0.0657, -0.1531, -0.0310,  0.0071,
          0.2590,  0.0985],
        [ 0.0402,  0.0704, -0.1441, -0.1929, -0.2450,  0.2408, -0.0750,  0.0238,
          0.3030,  0.0516],
        [ 0.1537, -0.2231, -0.0092, -0.1068,  0.3131,  0.0859, -0.1692, -0.2364,
          0.2257,  0.2601],
        [-0.0478, -0.2978, -0.2025, -0.2411, -0.3061, -0.2566,  0.0564, -0.0906,
          0.2113,  0.3118],
        [-0.1048,  0.2073, -0.2126, -0.1883,  0.0463, -0.1716, -0.3052,  0.0548,
          0.2079,  0.2587],
        [-0.1387,  0.1778, -0.1886,  0.1239,  0.0265, -0.0421, -0.1020,  0.2481,
         -0.0840,  0.1879],
        [ 0.0707, -0.0534,  0.0623,  0.0803,  0.3135,  0.2192, -0.1202,  0.3139,
          0.0781, -0.0512],
        [ 0.2812,  0.2515, -0.0371,  0.0248,  0.0231, -0.0437,  0.0875,  0.3085,
         -0.0482, -0.0092],
        [ 0.1735,  0.2584, -0.0900, -0.1616,  0.1253,  0.1352,  0.1841,  0.1416,
         -0.0686, -0.0269],
        [-0.3121, -0.1050,  0.0265,  0.0242,  0.1973,  0.1816, -0.0084,  0.2866,
          0.2559, -0.2523],
        [ 0.1272, -0.2361,  0.0847, -0.0724,  0.2531,  0.0948, -0.0765, -0.1252,
         -0.0459, -0.0133],
        [-0.0660,  0.0650,  0.2529, -0.1763, -0.1248, -0.1073, -0.2926,  0.1837,
          0.1265, -0.0090]], requires_grad=True)

Am I missing something?

[BUG] Bitnet Example Bug

Describe the bug
When running the example Example of the BitLinear layer from https://github.com/kyegomez/BitNet as of commit 171f4e5 (committed Sun Mar 24 19:48:59 2024 -0700), quoted here for reference

import torch

from bitnet import BitLinear

# Input
x = torch.randn(10, 512)

# BitLinear layer
layer = BitLinear(512, 400)

# Output
y = layer(x)

print(y)

I get the following error

In [1]: import torch
   ...:
   ...: from bitnet import BitLinear
   ...:
   ...: # Input
   ...: x = torch.randn(10, 512)
   ...:
   ...: # BitLinear layer
   ...: layer = BitLinear(512, 400)
   ...:
   ...: # Output
   ...: y = layer(x)
   ...:
   ...: print(y)
2024-03-29 20:06:13.544245: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-29 20:06:13.564836: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-29 20:06:13.939526: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2024-03-29 20:06:14,366 - numexpr.utils - INFO - Note: NumExpr detected 32 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-03-29 20:06:14,366 - numexpr.utils - INFO - NumExpr defaulting to 8 threads.
/home/sneilan/.gp/scratch/.venv/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/sneilan/.gp/scratch/.venv/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 12
      9 layer = BitLinear(512, 400)
     11 # Output
---> 12 y = layer(x)
     14 print(y)

File ~/.gp/scratch/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.gp/scratch/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.gp/scratch/BitNet/bitnet/bitlinear.py:53, in BitLinear.forward(self, x)
     42 def forward(self, x: Tensor) -> Tensor:
     43     """
     44     Forward pass of the BitLinear layer.
     45
   (...)
     51
     52     """
---> 53     b, s, d = x.shape
     54     w = self.weight
     55     x_norm = RMSNorm(d)(x)

ValueError: not enough values to unpack (expected 3, got 2)

To Reproduce

mkdir scratch
cd scratch
python3 -m venv .venv
source .venv/bin/activate
pip install bitnet
pip uninstall bitnet # to be able to clone repo but leave dependencies
git clone https://github.com/kyegomez/BitNet
cd BitNet
git checkout 171f4e5
ipython
(paste in following code)
import torch
from bitnet import BitLinear
x = torch.randn(10, 512)
layer = BitLinear(512, 400)
y = layer(x)

Expected behavior
I expect y to be printed.

Screenshots
n/a

Additional context
Running Python 3.10.12

Cuda Version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Nov_30_19:08:53_PST_2020
Cuda compilation tools, release 11.2, V11.2.67
Build cuda_11.2.r11.2/compiler.29373293_0

Question about weight quantization methodology memory savings

Thanks for your quick implementation! I was reading through bitnet/bitbnet_b158.py and just had a short question.

In your implementation of quantize_weights you use the same procedure as outlined in the paper "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits", but it looks like the quantized weights are stored in float32 while the activation quantization is explicitly casted to int8. I could be missing something, but how are you saving on memory (other than 8bit activations just like the paper) when the quantized weights are kept as float32s ?

   def quantize_weights(self, W):
        """
        Quantizes the weights using the absmean quantization function.

        Args:
            W (Tensor): The weight tensor to be quantized.

        Returns:
            Tensor: Quantized weight tensor.
        """
        gamma = torch.mean(torch.abs(W)) + self.eps
        W_scaled = W / gamma
        W_quantized = torch.sign(W_scaled) * torch.clamp(
            torch.abs(W_scaled).round(), max=1.0 # torch.float32 
        )
        return W_quantized

forward method in Class BitLinear

hello, thanks for your Implementation.
I was a bit confused while reading the bitnet/bitlinear.py forward()
as the paper shown:
ไผไธšๅพฎไฟกๆˆชๅ›พ_09dcb834-cb93-4cd5-8d02-84d33f63a955
i think the forward method should be:
image
did i misunderstand the process?

[BUG] RuntimeError: einsum(): the number of subscripts in the equation (4) does not match the number of dimensions (2) for operand 0 and no ellipsis was given

Describe the bug
A clear and concise description of what the bug is and what the main root cause error is. Test very thoroughly before submitting.



Traceback (most recent call last):
  File "/Users/defalt/Desktop/Athena/research/BitNet/bitnet/main.py", line 189, in <module>
    y = layer(x)
        ^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/defalt/Desktop/Athena/research/BitNet/bitnet/main.py", line 177, in forward
    x = attn(q, k, v, mask=mask) + x
        ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/zeta/nn/attention/attend.py", line 287, in forward
    dots = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/functional.py", line 378, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: einsum(): the number of subscripts in the equation (4) does not match the number of dimensions (2) for operand 0 and no ellipsis was given

To Reproduce
Steps to reproduce the behavior:

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Expected behavior
A clear and concise description of what you expected to happen.

Screenshots
If applicable, add screenshots to help explain your problem.

Additional context
Add any other context about the problem here.

The output of BitLinear is quite abnormal

Describe the bug
I print the mean and variance of the tensor y in example.py.
Its mean and variance are abnormal, as follows:

mean and var of BitLinear output:
-0.567935049533844
1149.9969482421875

To make sure, I print the mean and variance of outputs from Linear and BitLinear, simutaneously.

mean and var of Linear output:
0.012186492793262005
0.33256232738494873
mean and var of BitLinear output:
0.9070871472358704
992.69384765625

I believe there are mistakes in the implementation of BitLinear in bitnet/bitlinear.py.

To Reproduce
Steps to reproduce the behavior:

  1. print the mean and variance of y in example.py
  2. insert output_linear = torch.nn.functional.linear(x, self.weight, self.bias) in bitnet/bitlinear.py line 129. Then print the mean and variance of output_linear

[BUG]multi-head attention is noop for BITLINEAR

Describe the bug
A clear and concise description of what the bug is and what the main root cause error is. Test very thoroughly before submitting.

To Reproduce
Steps to reproduce the behavior:

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Expected behavior
A clear and concise description of what you expected to happen.

Screenshots
If applicable, add screenshots to help explain your problem.

Additional context
Add any other context about the problem here.

need a distributed training example

Thank you for your innovative work๏ผŒ can you provide a distributed training example?
then can quickly reproduct and verify thesis workใ€‚

About 'replace_hf.py'

Hello @kyegomez

In the inference code of huggingface_example.py, it appears that replace_hf is executed, followed immediately by inference.
However, upon examining replace_hf.py, I noticed it converts linear layers to bitlinear layers and seems to declare new weights.
I'm curious if there's a need for additional code to transfer the original weights to the bitlinear layers.

maybe ... like this?

def replace_linears_in_hf(
    model,
):
    """
    Replaces all instances of nn.Linear in the given model with BitLinear15b.

    Args:
        model (nn.Module): The model to modify.

    Returns:
        None
    """
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            # Replace the nn.Linear with BitLinear matching in features and and out_features, and add it to the model
            new_module = BitLinear(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None)

            with torch.no_grad():
                new_module.weight = module.weight
                if module.bias is not None:
                    new_module.bias = module.bias
            setattr(model, name, new_module)
        else:
            # Recursively apply to child modules
            replace_linears_in_hf(module)

Thanks.

Parts of the BitLinear code doesn't match paper (before bit1.58)

Referencing this paper: https://arxiv.org/pdf/2310.11453.pdf
Code part: https://github.com/kyegomez/BitNet/blob/984ec72c2a45a88b739c85668690fe1abbdf3152/bitnet/bitlinear.py

In general, it seems that the code does not match the paper, mainly Equation (1), (4) and (11). It also seems to be missing the straight-through estimator? (edit: the code also didn't replace bitlinear within the multihead attention)

I also found this other reference implementation which seems to follow the equations from the paper a bit more. https://github.com/Beomi/BitNet-Transformers

[Question] How did you implement 1-bit tensor?

Hi, I found this repository when I'm plan to impelment BitNet.

BitLinear use 1-bit, but, since pytorch native dtype does not supports 1bit tensor,
So i thought I need to implement via custom cuda kernel.(bit packing and unpacking)

How did u implement 1bit tensor implementation?
(I can't find cuda files yet)

[BUG] I tried using BitLinear in nanoRWKV but got the error.

https://github.com/BlinkDL/nanoRWKV/blob/master/model.py

I tried using BitLinear in nanoRWKV but got the error:
PS D:\nanoRWKV-master> python train.py
tokens per iteration will be: 491,520
Initializing a new model from scratch
defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)
Traceback (most recent call last):
File "D:\nanoRWKV-master\train.py", line 154, in
model = GPT(gptconf)
^^^^^^^^^^^^
File "D:\nanoRWKV-master\model.py", line 210, in init
h = nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\nanoRWKV-master\model.py", line 210, in
h = nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
^^^^^^^^^^^^^^^^
File "D:\nanoRWKV-master\model.py", line 179, in init
self.tmix = RWKV_TimeMix_x051a(config, layer_id)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\nanoRWKV-master\model.py", line 69, in init
self.receptance = BitLinear(config.n_embd, config.n_embd, bias=config.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: BitLinear.init() got an unexpected keyword argument 'bias'

[BUG] Loss drops, model still produces gibberish?

Describe the bug

After 5300 iteraitons loss near 2.7, is it still supposed to spit out near giberish?

To Reproduce

Running on CPU, macbookkair M2, omitting the model.cuda() line

Expected behaviour

Some kind of convergence on sentences that are at least english-ish?

Screenshots
image

Additional context

Maybe my expectations are just off and I should train way way more?

1.58bit algorithm implement recommend

def ste(self, x):

I noticed that you're attempting to implement 1.58-bit quantization, but it seems you only quantize the values during the forward pass, then proceed with the model inference, using the original values for the backward pass. In 4-bit quantization, we store two quantized values in one byte for representation, and the computation and gradients of the new data type are implemented with CUDA. You should consider this approach as well. Keep it up, I'm rooting for you.

where to download bitnet model ?

Describe the bug
Where to download bitnet model? Which is shown in this example code:

bitnet.load_model("../model_checkpoint.pth")  # Download model

To Reproduce
Execute example. No model loaded.

Expected behavior
Link to download suitable model

[BUG] NoneType in sequential module in bit_ffn

self.ff sequential modules could have None, which is not callable, if post_act_ln is False.

[suggenstion]

    ff_layers = [project_in]
    if post_act_ln:
        ff_layers.append(nn.LayerNorm(inner_dim))
    ff_layers.append(nn.Dropout(dropout))
    ff_layers.append(BitLinear(inner_dim, dim_out, bias=not no_bias, *args, **kwargs))
    self.ff=nn.Sequential(*ff_layers)

[BUG] residual connection wrong?

In bit_transformer.py:

class Transformer(nn.Module):
    def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
        for attn, ffn in zip(self.layers, self.ffn_layers):
            # print(x.shape)
            x, _ = attn(x, x, x, is_causal=True, *args, **kwargs)
            x = x + x
            x = ffn(x) + x
        return x

Is the line x = x + x wrong? This seems not a residual connection.

[BUG] Tensor size mismatch from train.py

Thank you for sharing this incredible work!

I speculate that it's an issue of library versions, but I'm receiving the following error when attempting to run unmodified train.py:
RuntimeError: The size of tensor a (1024) must match the size of tensor b (512) at non-singleton dimension 1

Changing the default SEQ_LEN = 1024 to 512 gives the following:
RuntimeError: The size of tensor a (513) must match the size of tensor b (512) at non-singleton dimension 1

While a sequence length of 511 says:
RuntimeError: The size of tensor a (511) must match the size of tensor b (512) at non-singleton dimension 1

Full error log:

Traceback (most recent call last):
  File "Data/Development/BitNet/train.py", line 86, in <module>
    loss = model(next(train_loader))
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "Data/Development/BitNet/bitnet/at.py", line 82, in forward
    logits = self.net(x_inp, **kwargs)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "Data/Development/BitNet/bitnet/transformer.py", line 52, in forward
    return self.to_logits(x)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File ".local/lib/python3.10/site-packages/zeta/nn/modules/rms_norm.py", line 35, in forward
    return normed * self.scale * self.gamma
RuntimeError: The size of tensor a (1024) must match the size of tensor b (512) at non-singleton dimension 1

Any help would be appreciated!

[BUG] Can't install with pipenv, pip

Describe the bug
Installation does not work with pipenv install bitnet, pipenv run pip install bitnet.

To Reproduce
Apple M2, MacOS: 14.3.1, Python(in pipenv environment): 3.12.1
Run pipenv install bitnet, pipenv run pip.

Expected behavior
The error occurs.

Screenshots

Courtesy Notice: Pipenv found itself running within a virtual environment, so it will automatically use that environment, instead of creating its own for any project. You can set PIPENV_IGNORE_VIRTUALENVS=1 to force pipenv to ignore that environment and create its own instead. You can set PIPENV_VERBOSITY=-1 to suppress this warning.
Collecting bitnet
  Downloading bitnet-0.0.8-py3-none-any.whl.metadata (4.3 kB)
Collecting einops (from bitnet)
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Collecting torch (from bitnet)
  Downloading torch-2.2.0-cp312-none-macosx_11_0_arm64.whl.metadata (25 kB)
Collecting zetascale (from bitnet)
  Downloading zetascale-2.1.1-py3-none-any.whl.metadata (20 kB)
Collecting filelock (from torch->bitnet)
  Downloading filelock-3.13.1-py3-none-any.whl.metadata (2.8 kB)
Collecting typing-extensions>=4.8.0 (from torch->bitnet)
  Downloading typing_extensions-4.9.0-py3-none-any.whl.metadata (3.0 kB)
Collecting sympy (from torch->bitnet)
  Downloading sympy-1.12-py3-none-any.whl (5.7 MB)
     โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ” 5.7/5.7 MB 9.6 MB/s eta 0:00:00
Collecting networkx (from torch->bitnet)
  Downloading networkx-3.2.1-py3-none-any.whl.metadata (5.2 kB)
Collecting jinja2 (from torch->bitnet)
  Downloading Jinja2-3.1.3-py3-none-any.whl.metadata (3.3 kB)
Collecting fsspec (from torch->bitnet)
  Downloading fsspec-2024.2.0-py3-none-any.whl.metadata (6.8 kB)
Collecting accelerate==0.26.1 (from zetascale->bitnet)
  Downloading accelerate-0.26.1-py3-none-any.whl.metadata (18 kB)
Collecting argparse<2.0.0,>=1.4.0 (from zetascale->bitnet)
  Downloading argparse-1.4.0-py2.py3-none-any.whl (23 kB)
Collecting beartype==0.17.0 (from zetascale->bitnet)
  Downloading beartype-0.17.0-py3-none-any.whl.metadata (29 kB)
Collecting bitsandbytes==0.42.0 (from zetascale->bitnet)
  Downloading bitsandbytes-0.42.0-py3-none-any.whl.metadata (9.9 kB)
Collecting colt5-attention==0.10.19 (from zetascale->bitnet)
  Downloading CoLT5_attention-0.10.19-py3-none-any.whl.metadata (738 bytes)
Collecting datasets (from zetascale->bitnet)
  Downloading datasets-2.17.0-py3-none-any.whl.metadata (20 kB)
Collecting einops-exts==0.0.4 (from zetascale->bitnet)
  Downloading einops_exts-0.0.4-py3-none-any.whl (3.9 kB)
Collecting jax (from zetascale->bitnet)
  Downloading jax-0.4.24-py3-none-any.whl.metadata (24 kB)
Collecting jaxlib (from zetascale->bitnet)
  Downloading jaxlib-0.4.24-cp312-cp312-macosx_11_0_arm64.whl.metadata (2.1 kB)
Collecting lion-pytorch==0.0.7 (from zetascale->bitnet)
  Downloading lion_pytorch-0.0.7-py3-none-any.whl (4.3 kB)
Collecting numexpr (from zetascale->bitnet)
  Downloading numexpr-2.9.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (7.9 kB)
Collecting pytest==7.4.2 (from zetascale->bitnet)
  Downloading pytest-7.4.2-py3-none-any.whl.metadata (7.9 kB)
Collecting rich==13.7.0 (from zetascale->bitnet)
  Downloading rich-13.7.0-py3-none-any.whl.metadata (18 kB)
Collecting scipy==1.9.3 (from zetascale->bitnet)
  Downloading scipy-1.9.3.tar.gz (42.1 MB)
     โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ” 42.1/42.1 MB 8.8 MB/s eta 0:00:00
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Installing backend dependencies ... done
  Preparing metadata (pyproject.toml) ... error
  error: subprocess-exited-with-error
  
  ร— Preparing metadata (pyproject.toml) did not run successfully.
  โ”‚ exit code: 1
  โ•ฐโ”€> [17 lines of output]
      + meson setup /private/var/folders/sm/jc08s6093j5648np4b1t51zm0000gn/T/pip-install-6nkf0lnf/scipy_e72ed8b1dedc47308309bba88221c617 /private/var/folders/sm/jc08s6093j5648np4b1t51zm0000gn/T/pip-install-6nkf0lnf/scipy_e72ed8b1dedc47308309bba88221c617/.mesonpy-fnlud_2t -Dbuildtype=release -Db_ndebug=if-release -Db_vscrt=md --native-file=/private/var/folders/sm/jc08s6093j5648np4b1t51zm0000gn/T/pip-install-6nkf0lnf/scipy_e72ed8b1dedc47308309bba88221c617/.mesonpy-fnlud_2t/meson-python-native-file.ini
      The Meson build system
      Version: 1.3.1
      Source dir: /private/var/folders/sm/jc08s6093j5648np4b1t51zm0000gn/T/pip-install-6nkf0lnf/scipy_e72ed8b1dedc47308309bba88221c617
      Build dir: /private/var/folders/sm/jc08s6093j5648np4b1t51zm0000gn/T/pip-install-6nkf0lnf/scipy_e72ed8b1dedc47308309bba88221c617/.mesonpy-fnlud_2t
      Build type: native build
      Project name: SciPy
      Project version: 1.9.3
      
      ../meson.build:1:0: ERROR: Unknown compiler(s): [['cc'], ['gcc'], ['clang'], ['nvc'], ['pgcc'], ['icc'], ['icx']]
      The following exception(s) were encountered:
      Running `nvc --version` gave "[Errno 2] No such file or directory: 'nvc'"
      Running `pgcc --version` gave "[Errno 2] No such file or directory: 'pgcc'"
      Running `icc --version` gave "[Errno 2] No such file or directory: 'icc'"
      Running `icx --version` gave "[Errno 2] No such file or directory: 'icx'"
      
      A full log can be found at /private/var/folders/sm/jc08s6093j5648np4b1t51zm0000gn/T/pip-install-6nkf0lnf/scipy_e72ed8b1dedc47308309bba88221c617/.mesonpy-fnlud_2t/meson-logs/meson-log.txt
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed

ร— Encountered error while generating package metadata.
โ•ฐโ”€> See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.

Additional context
Add any other context about the problem here.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.