Giter VIP home page Giter VIP logo

llama2.java's Introduction

A Java port of Andrej Karpathy's llama2.c

**Check the successor of this project: Llama3.java: Practical Llama (3) inference in a single Java file, with additional features, including a --chat mode.

This is a pure Java port of Andrej Karpathy's awesome llama2.c, a very simple implementation to run inference of models with a Llama2-like transformer-based LLM architecture.

Currently, there isn't anything really original here, but I'll continue polishing it while keeping it in sync with the original.
Besides the educational value, this project will be used to test and tune compiler optimizations on the JVM, particularly for the Graal compiler. This port used llama2.scala initially as a reference.

Build

Java 21+ is required, in particular the MemorySegment mmap-ing feature.

The code expects tokenizer.bin in the current directory. You can use TinyStories checkpoints or get LLama2 models by following instructions.

wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin

To build and run manually:

javac --enable-preview -source 21 --add-modules=jdk.incubator.vector Llama2.java
java --enable-preview --add-modules=jdk.incubator.vector Llama2 stories15M.bin

Or run it directly with JBang:

jbang Llama2.java stories15M.bin
# With additional -D options and custom Java home.
JAVA_HOME=/path/to/java/home jbang -Djava.util.concurrent.ForkJoinPool.common.parallelism=0 -Dllama2.VectorAPI=false Llama2.java stories15M.bin

A Makefile and a run.sh script are also provided:

make # optional, run.sh already runs make

JAVA_HOME=$GRAALVM_HOME \
JAVA_RUNTIME_OPTIONS=-Djava.util.concurrent.ForkJoinPool.common.parallelism=8 \
./run.sh stories15M.bin

Native image

A standalone native image can be created with GraalVM

JAVA_HOME=$GRAALVM_HOME NATIVE_IMAGE_OPTIONS="-march=native" make native-image
./llama2 stories15M.bin

Or can also be built with Profile-Guided Optimizations (PGO), on Oracle GaaalVM:

JAVA_HOME=$GRAALVM_HOME \
NATIVE_IMAGE_OPTIONS="--pgo-instrument -march=native --initialize-at-build-time=Llama2 -Dllama2.VectorAPI=false" \
make native-image

# Profile run to generate default.iprof, with no parallelism to speedup profiling.
./llama2 -Djava.util.concurrent.ForkJoinPool.common.parallelism=0 stories15M.bin

# Build optimized image
JAVA_HOME=$GRAALVM_HOME \
NATIVE_IMAGE_OPTIONS="--pgo -march=native --initialize-at-build-time=Llama2 -Dllama2.VectorAPI=false" \
make native-image

# Should run ~2X faster than regular image.
./llama2 stories15M.bin

Performance

Quick numbers on an AMD Ryzen 3950X 64GB, Arch Linux.
llama2.java executed on OpenJDK 20.0.2+9.
To make things fair w.r.t. to vectorization, the Java version has a matmul implementation using the Vector API.
In these measurements the JVM is warmed up enough to reach peak tokens/s.
On GraalVM, please note that the Graal compiler doesn't support the Vector API yet, to avoid unexpected performance degradation, run with -Dllama2.VectorAPI=false.

**Notes
The numbers below were collected using aggressive (gcc) compiler flags e.g. regular gcc -O2 ... wouldn't be as fast.

Single-threaded

llama2.c compiled with gcc -Ofast -march=native run.c -lm -o run -march=native
llama2.java executed with -Djava.util.concurrent.ForkJoinPool.common.parallelism=0

Model Tokens per second Speedup vs. llama2.c Implementation
stories15M.bin 363 1.0 llama2.c
stories15M.bin 237 0.65 llama2.java
stories110M.bin 51.71 1.0 llama2.c
stories110M.bin 42.20 0.81 llama2.java
llama2_7B.bin 0.92 1.0 llama2.c
llama2_7B.bin 0.88 0.95 llama2.java

Multi-threaded

llama2.c compiled with gcc -Ofast -fopenmp -march=native run.c -lm -o run -march=native
llama2.c executed with OMP_NUM_THREADS=8
llama2.java executed with -Djava.util.concurrent.ForkJoinPool.common.parallelism=8

Model Tokens per second Speedup vs. llama2.c Implementation
stories15M.bin 1233 1.0 llama2.c
stories15M.bin 438 0.35 llama2.java
stories110M.bin 90 1.0 llama2.c
stories110M.bin 80 0.88 llama2.java
llama2_7B.bin 1.68 1.0 llama2.c
llama2_7B.bin 1.65 0.98 llama2.java

**Notes
In stories15M.bin, the C version shows a huge speedup, very likely a cache effect, this is considered an outlier. Running with 16/32 threads may actually cause a slowdown; the performance is, in most cases, U-shaped w.r.t to the # of threads. With that many threads, vectorization does not give any advantage, since throughput is limited by memory bandwidth.

Performance is already comparable to the original C code, bar vectorization, even if the Java code has not been optimized yet.

License

MIT

llama2.java's People

Contributors

dudiao avatar mukel avatar n-saw avatar the-alchemist avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

llama2.java's Issues

Llamafile comparison?

@mukel thank you for creating this project! I would like to discuss the following topics:

  1. Please enable the Discussions tab for posts like this, which are not real "issues"

  2. Do you plan on releasing Llama3 code?

  3. Do you plan on quantized llama models with Java vector api?

  4. Can you run a benchmark against llamafile, the vector version of which (AVX, neon) claims to be the performance king for inference.
    (I am deciding between using your project or wrapping around the llamafile c code with Java 22 foreign function apis)

  5. Do you plan to implant model training as well? If so, take a look at Andrey's LLM.c repo

Stories always output "<0x0A>"

% java --enable-preview --add-modules=jdk.incubator.vector Llama2 stories15M.bin
WARNING: Using incubator modules: jdk.incubator.vector
Config{dim=288, hidden_dim=768, n_layers=6, n_heads=6, n_kv_heads=6, vocab_size=32000, seq_len=256, shared_weights=true, head_size=48}
Once upon a time, there was a little girl named Lily. She loved to play in her backyard, but there was a big fence that surrounded her home. One day, Lily's mom asked her to clean the fence. Lily didn't want to, but she knew it was important to help her mom.<0x0A>As she was cleaning, Lily found an old key hidden in the grass. She picked it up and showed it to her mom. "Look, Mommy! I found this key! Can we use it to open this fence?" Lily asked. <0x0A>Her mom looked at the fence and said, "That's a fake fence, Lily. It's not used to playing in our backyard. But we can use it to build a fence around it."<0x0A>Lily was happy to hear that they could use the fake fence instead. She ran to get her toys and started playing with them. From that day on, Lily made sure to always check that the fake fence was the right thing to do.

achieved tok/s: 63.073076

Using llama2 7b but it's super slow

The performance tab says that it runs on almost 1 token per second, mine takes a few minutes only for one word, and it's not consistent. How do I make it faster?

NoSuchFileException: tokenizer.bin

I just checked out your project and downloaded the stories15M ( wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin). I compiled the llama java class

javac --enable-preview -source 20 --add-modules=jdk.incubator.vector Llama2.java
warning: using incubating module(s): jdk.incubator.vector
Note: Llama2.java uses preview features of Java SE 20.
Note: Recompile with -Xlint:preview for details.
1 warning

When I execute the java class I get the following exception:

java --enable-preview --add-modules=jdk.incubator.vector Llama2 stories15M.bin  
WARNING: Using incubator modules: jdk.incubator.vector
Config{dim=288, hidden_dim=768, n_layers=6, n_heads=6, n_kv_heads=6, vocab_size=32000, seq_len=256, shared_weights=true, head_size=48}
Exception in thread "main" java.nio.file.NoSuchFileException: tokenizer.bin
        at java.base/sun.nio.fs.UnixException.translateToIOException(UnixException.java:92)
        at java.base/sun.nio.fs.UnixException.rethrowAsIOException(UnixException.java:106)
        at java.base/sun.nio.fs.UnixException.rethrowAsIOException(UnixException.java:111)
        at java.base/sun.nio.fs.UnixFileSystemProvider.newFileChannel(UnixFileSystemProvider.java:224)
        at java.base/java.nio.channels.FileChannel.open(FileChannel.java:308)
        at java.base/java.nio.channels.FileChannel.open(FileChannel.java:367)
        at Tokenizer.<init>(Llama2.java:208)
        at Llama2.main(Llama2.java:1006)

I'm using openjdk version 20:

openjdk version "20.0.2" 2023-07-18
OpenJDK Runtime Environment (build 20.0.2+9-78)
OpenJDK 64-Bit Server VM (build 20.0.2+9-78, mixed mode, sharing)

Possibly inefficient code

I'm trying to read through the code to understand it and noticed something that might be a mistake or just it's my misunderstanding of the code:

    static void matmul(float[] xout, float[] x, FloatBuffer w, int n, int d) {
        // W (d,n) @ x (n,) -> xout (d,)
        // by far the most amount of time is spent inside this little function
        MemorySegment wSegment = MemorySegment.ofBuffer(w);
        IntStream.range(0, d).parallel().forEach(i -> {
            float val = 0f;
            int j = 0;
            if (USE_VECTOR_API) {
                VectorSpecies<Float> species = FloatVector.SPECIES_256;
                FloatVector sum0 = FloatVector.zero(species);
                FloatVector sum1 = FloatVector.zero(species);
                FloatVector sum2 = FloatVector.zero(species);
                FloatVector sum3 = FloatVector.zero(species);
                int width = species.length();
                int upperBound = n - n % (4 * width);
                for (; j < upperBound; j += 4 * width) {
                    var wj0 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 0 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                    var wj1 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 1 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                    var wj2 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 2 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                    var wj3 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 3 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                    var xj0 = FloatVector.fromArray(species, x, j + 0 * width);
                    var xj1 = FloatVector.fromArray(species, x, j + 1 * width);
                    var xj2 = FloatVector.fromArray(species, x, j + 2 * width);
                    var xj3 = FloatVector.fromArray(species, x, j + 3 * width);
                    sum0 = wj0.fma(xj0, sum0);
                    sum1 = wj1.fma(xj1, sum1);
                    sum2 = wj2.fma(xj2, sum2);
                    sum3 = wj3.fma(xj3, sum3);
                }
                val = sum0.add(sum1).add(sum2).add(sum3).reduceLanes(VectorOperators.ADD);
            }

            // Graal's auto-vectorization.
            int upperBound = n & ~3;
            float[] sum = new float[4];
            for (; j < upperBound; j += sum.length) {
                sum[0] += w.get(i * n + j + 0) * x[j + 0];
                sum[1] += w.get(i * n + j + 1) * x[j + 1];
                sum[2] += w.get(i * n + j + 2) * x[j + 2];
                sum[3] += w.get(i * n + j + 3) * x[j + 3];
            }
            val += sum[0] + sum[1] + sum[2] + sum[3];

            for (; j < n; j++) {
                val += w.get(i * n + j) * x[j];
            }
            xout[i] = val;
        });
    }

First, there's a small inefficiency in the if (USE_VECTOR_API) { line. Since that is a constant, having the if statement for every forEach call is inefficient. The JIT might optimize it away eventually but I would still have that outside of the block to keep the code efficient as the JIT isn't magic.

The main thing that isn't clear to me. It seems the code does the operation twice when running under a regular JIT. Shouldn't the rest of the code be under an else statement?

This is how I think it should be if I'm reading the code correctly. I haven't tested it though so I might be completely off here:

static void matmul(float[] xout, float[] x, FloatBuffer w, int n, int d) {
    // W (d,n) @ x (n,) -> xout (d,)
    // by far the most amount of time is spent inside this little function
    MemorySegment wSegment = MemorySegment.ofBuffer(w);
    if (USE_VECTOR_API) {
        IntStream.range(0, d).parallel().forEach(i -> {
            int j = 0;
            VectorSpecies<Float> species = FloatVector.SPECIES_256;
            FloatVector sum0 = FloatVector.zero(species);
            FloatVector sum1 = FloatVector.zero(species);
            FloatVector sum2 = FloatVector.zero(species);
            FloatVector sum3 = FloatVector.zero(species);
            int width = species.length();
            int upperBound = n - n % (4 * width);
            for (; j < upperBound; j += 4 * width) {
                var wj0 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 0 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                var wj1 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 1 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                var wj2 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 2 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                var wj3 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 3 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                var xj0 = FloatVector.fromArray(species, x, j + 0 * width);
                var xj1 = FloatVector.fromArray(species, x, j + 1 * width);
                var xj2 = FloatVector.fromArray(species, x, j + 2 * width);
                var xj3 = FloatVector.fromArray(species, x, j + 3 * width);
                sum0 = wj0.fma(xj0, sum0);
                sum1 = wj1.fma(xj1, sum1);
                sum2 = wj2.fma(xj2, sum2);
                sum3 = wj3.fma(xj3, sum3);
            }
            xout[i] = sum0.add(sum1).add(sum2).add(sum3).reduceLanes(VectorOperators.ADD);
        });
    } else {
        // Graal's auto-vectorization.
        IntStream.range(0, d).parallel().forEach(i -> {
            int j = 0;
            float val = 0;
            int upperBound = n & ~3;
            float[] sum = new float[4];
            for (; j < upperBound; j += sum.length) {
                sum[0] += w.get(i * n + j) * x[j];
                sum[1] += w.get(i * n + j + 1) * x[j + 1];
                sum[2] += w.get(i * n + j + 2) * x[j + 2];
                sum[3] += w.get(i * n + j + 3) * x[j + 3];
            }
            val += sum[0] + sum[1] + sum[2] + sum[3];

            for (; j < n; j++) {
                val += w.get(i * n + j) * x[j];
            }
            xout[i] = val;
        });
    }
}

Thanks for the project. It's very interesting!

Exception in thread "main" java.lang.ArithmeticException: / by zero

I have downloaded the llama 7B version of the model and prepared it as described in llama.cpp

# install Python dependencies
python3 -m pip install -r requirements.txt

# convert the 7B model to ggml FP16 format
python3 convert.py models/7B/

  # [Optional] for models using BPE tokenizers
  python convert.py models/7B/ --vocabtype bpe

# quantize the model to 4-bits (using q4_0 method)
./quantize ./models/7B/ggml-model-f16.gguf ./models/7B/ggml-model-q4_0.gguf q4_0

# update the gguf filetype to current if older version is unsupported by another application
./quantize ./models/7B/ggml-model-q4_0.gguf ./models/7B/ggml-model-q4_0-v2.gguf COPY

I then attempted to use the model by executing:

java --enable-preview --add-modules=jdk.incubator.vector Llama2 ./models/7B/ggml-model-q4_0-v2.gguf -n 128

but the execution fails with:

WARNING: Using incubator modules: jdk.incubator.vector
Exception in thread "main" java.lang.ArithmeticException: / by zero
        at Config.<init>(Llama2.java:52)
        at Transformer.<init>(Llama2.java:185)
        at Llama2.main(Llama2.java:1000)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.