Giter VIP home page Giter VIP logo

saxml's Introduction

Saxml (aka Sax)

Saxml is an experimental system that serves Paxml, JAX, and PyTorch models for inference.

A Sax cell (aka Sax cluster) consists of an admin server and a group of model servers. The admin server keeps track of model servers, assigns published models to model servers to serve, and helps clients locate model servers serving specific published models.

The example below walks through setting up a Sax cell and starting a TPU or GPU model server in the cell.

Install Sax

Install and set up the gcloud tool

Install the gcloud CLI and set the default account and project:

gcloud config set account <your-email-account>
gcloud config set project <your-project>

Create a Cloud Storage bucket to store Sax server states

Create a Cloud Storage bucket:

GSBUCKET=sax-data
gcloud storage buckets create gs://${GSBUCKET}

Create a Compute Engine VM instance for the admin server

Create a Compute Engine VM instance:

gcloud compute instances create sax-admin \
  --zone=us-central1-b \
  --machine-type=e2-standard-8 \
  --boot-disk-size=200GB \
  --scopes=https://www.googleapis.com/auth/cloud-platform

Create a Cloud TPU VM instance for a TPU model server

Use this guide to enable the Cloud TPU API in a Google Cloud project.

Create a Cloud TPU VM instance:


gcloud compute tpus tpu-vm create sax-tpu \
  --zone=us-central2-b \
  --accelerator-type=v4-8 \
  --version=tpu-vm-v4-base \
  --scopes=https://www.googleapis.com/auth/cloud-platform

Create a Compute Engine VM instance for a GPU model server

Alternatively or in addition to the Cloud TPU VM instance, create a Compute Engine VM instance with GPUs:

gcloud compute instances create sax-gpu \
  --zone=us-central1-b \
  --machine-type=n1-standard-32 \
  --accelerator=count=4,type=nvidia-tesla-v100 \
  --maintenance-policy=TERMINATE \
  --boot-disk-size=200GB \
  --scopes=https://www.googleapis.com/auth/cloud-platform

Consider creating a VM instance using the "GPU-optimized Debian 10 with CUDA 11.0" image instead, so the Nvidia CUDA stack doesn't need to be manually installed as described below.

Start the Sax admin server

SSH to the Compute Engine VM instance:

gcloud compute ssh --zone=us-central1-b sax-admin

Inside the VM instance, clone the Sax repo and initialize the environment:

git clone https://github.com/google/saxml.git
cd saxml
saxml/tools/init_cloud_vm.sh

Configure the Sax admin server. This only needs to be done once:

bazel run saxml/bin:admin_config -- \
  --sax_cell=/sax/test \
  --sax_root=gs://${GSBUCKET}/sax-root \
  --fs_root=gs://${GSBUCKET}/sax-fs-root \
  --alsologtostderr

Start the Sax admin server:

bazel run saxml/bin:admin_server -- \
  --sax_cell=/sax/test \
  --sax_root=gs://${GSBUCKET}/sax-root \
  --port=10000 \
  --alsologtostderr

Start the Sax TPU model server

SSH to the Cloud TPU VM instance:

gcloud compute tpus tpu-vm ssh --zone=us-central2-b sax-tpu

Inside the VM instance, clone the Sax repo and initialize the environment:

git clone https://github.com/google/saxml.git
cd saxml
saxml/tools/init_cloud_vm.sh

Start the Sax model server:

SAX_ROOT=gs://${GSBUCKET}/sax-root \
bazel run saxml/server:server -- \
  --sax_cell=/sax/test \
  --port=10001 \
  --platform_chip=tpuv4 \
  --platform_topology=2x2x1 \
  --alsologtostderr

You should see a log message "Joined [admin server IP:port]" from the model server to indicate it has successfully joined the admin server.

Start the Sax GPU model server

SSH to the Compute Engine VM instance:

gcloud compute ssh --zone=us-central1-b sax-gpu

Install the Nvidia GPU driver, CUDA, and cuDNN. Note that Sax by default requires CUDA 11. To switch to CUDA 12, edit requirements-cuda.txt and replace jaxlib==0.4.7+cuda11.cudnn86 with jaxlib==0.4.7+cuda12.cudnn88.

Inside the VM instance, clone the Sax repo and initialize the environment:

git clone https://github.com/google/saxml.git
cd saxml
saxml/tools/init_cloud_vm.sh

Enable the GPU-specific requirements.txt file:

cp requirements-cuda.txt requirements.txt

Start the Sax model server:

SAX_ROOT=gs://${GSBUCKET}/sax-root \
bazel run saxml/server:server -- \
  --sax_cell=/sax/test \
  --port=10001 \
  --platform_chip=v100 \
  --platform_topology=4 \
  --jax_platforms=cuda \
  --alsologtostderr

You should see a log message "Joined [admin server IP:port]" from the model server to indicate it has successfully joined the admin server.

Use Sax

Sax comes with a command-line tool called saxutil for easy usage:

# From the `saxml` repo root directory:
alias saxutil='bazel run saxml/bin:saxutil -- --sax_root=gs://${GSBUCKET}/sax-root'

saxutil supports the following commands:

  • saxutil help: Show general help or help about a particular command.
  • saxutil ls: List all cells, all models in a cell, or a particular model.
  • saxutil publish: Publish a model.
  • saxutil unpublish: Unpublish a model.
  • saxutil update: Update a model.
  • saxutil lm.generate: Use a language model generate suffixes from a prefix.
  • saxutil lm.score: Use a language model to score a prefix and suffix.
  • saxutil lm.embed: Use a language model to embed text into a vector.
  • saxutil vm.generate: Use a vision model to generate images from text.
  • saxutil vm.classify: Use a vision model to classify an image.
  • saxutil vm.embed: Use a vision model to embed an image into a vector.

As an example, Sax comes with a Pax language model servable on a Cloud TPU VM v4-8 instance. You can use it to verify Sax is correctly set up by publishing and using the model with a dummy checkpoint.

saxutil publish \
  /sax/test/lm2b \
  saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2BTest \
  None \
  1

Check if the model is loaded by looking at the "selected replica address" column of this command's output:

saxutil ls /sax/test/lm2b

When the model is loaded, issue a query:

saxutil lm.generate /sax/test/lm2b "Q: Who is Harry Porter's mother? A: "

The result will be printed in the terminal.

To use a real checkpoint with the model, follow the Paxml tutorial to generate a checkpoint. The model can then be published in Sax like this:

saxutil publish \
  /sax/test/lm2b \
  saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2B \
  gs://${GSBUCKET}/checkpoints/checkpoint_00000000 \
  1

Use the same saxutil lm.generate command as above to query the model.

Use Sax to load LLaMA 7B/13B/70B model:

First get LLaMA pytorch_vars from Meta, then run the following script to convert the LLaMA PyTorch checkpoint to SAX format

python3 -m saxml/tools/convert_llama_ckpt --base llama_7b --pax pax_7b

For the 7B model, this script roughly needs 50-60GB memory. For larger models, for example, the 70B model, this script would need 500-600GB memory to run.

The script load and save weights in a single pass. To fit less memory, modify convert() function to load/save weights in multiple passes. In each pass, load and save partial weights (subset of all weight variables).

After converting the checkpoint, the checkpoint folder should have the following structure

checkpoint_00000000
metadata/
	metadata
	state/
		mdl_vars.params.lm*/
		...
		...
		step/

Please create empty files “commit_success.txt” and put one in each folder. This will let SAX know this checkpoint is ready to use when loading the model. So the fully ready checkpoint should be as following:

checkpoint_00000000
	commit_success.txt
metadata/
	commit_success.txt
	metadata
	state/
		commit_success.txt
		mdl_vars.params.lm*/
		...
		...
		step/

Now the checkpoint is fully ready.

Then start the SAX server

GPU server:

SAX_ROOT=gs://${GSBUCKET}/sax-root \
bazel run saxml/server:server -- \
  --sax_cell=/sax/test \
  --port=10001 \
  --platform_chip=a100 \
  --platform_topology=1 \
  --jax_platforms=cuda \
  --alsologtostderr

TPU server:

SAX_ROOT=gs://${GSBUCKET}/sax-root \
bazel run saxml/server:server -- \
  --sax_cell=/sax/test \
  --port=10001 \
  --platform_chip=tpuv4 \
  --platform_topology=2x2x1 \
  --alsologtostderr

Finally move the converted ckpt to your google cloud data bucket and publish the model

7B model

saxutil publish \
  /sax/test/llama-7b \
  saxml.server.pax.lm.params.lm_cloud.LLaMA7BFP16 \
  gs://sax-data/pax-llama/7B \
  1

70B model

saxutil publish \
  /sax/test/llama-7b \
  saxml.server.pax.lm.params.lm_cloud.LLaMA70BFP16TPUv5e \
  gs://sax-data/pax-llama/70B \
  1

saxml's People

Contributors

aaroey avatar andrewluchen avatar andyly avatar ashishenoyp avatar bignamehyp avatar cangermueller avatar changlan avatar descrip avatar dhr avatar dryman avatar edloper avatar faizan-m avatar frederick0329 avatar hawkinsp avatar jianlijianli avatar jiawenhao avatar jihwanlee-alphago avatar junwhanahn avatar laurentes avatar maxwillzq avatar rryan avatar saeta avatar shapor avatar superbobry avatar tfboyd avatar tink-expo avatar ukoxyz avatar voutcn avatar weinan1997 avatar yashk2810 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

saxml's Issues

GPT-J model conversion failed from pytorch to paxml, throwing OOM error for TPUv3-8

Hi, I am trying to do the serving on gpt-j 6B model using TPUv3-8. For which I am using saxml framework,

The error is coming when I am doing the model conversion from pytorch to pax format which is supported by sax. This is the conversion script:

https://github.com/mlcommons/inference_results_v3.1/blob/main/closed/Google/code/gptj-99/convert_gptj_ckpt.py

The admin and model server is running correctly even I have confirmed that they are communicating by running a sample test query.

The model pickle file is just 22.7 GB so it should acomodate into the TPU cluster. Any idea?

The enviornment
pip3 install accelerate
pip3 install torch
pip3 install transformers
pip install paxml==1.1.0)(Although I have build it from its gitrepo)

2024-01-03 05:23:41.411871: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
Loading the base model from EleutherAI/gpt-j-6b
transformer.wte.weight (50400, 4096)
transformer.h.0.ln_1.weight (4096,)
transformer.h.0.ln_1.bias (4096,)
transformer.h.0.attn.k_proj.weight (4096, 4096)
transformer.h.0.attn.v_proj.weight (4096, 4096)
transformer.h.0.attn.q_proj.weight (4096, 4096)
transformer.h.0.attn.out_proj.weight (4096, 4096)
transformer.h.0.mlp.fc_in.weight (16384, 4096)
transformer.h.0.mlp.fc_in.bias (16384,)
transformer.h.0.mlp.fc_out.weight (4096, 16384)
transformer.h.0.mlp.fc_out.bias (4096,)
transformer.h.1.ln_1.weight (4096,)
transformer.h.1.ln_1.bias (4096,)
transformer.h.1.attn.k_proj.weight (4096, 4096)
transformer.h.1.attn.v_proj.weight (4096, 4096)
transformer.h.1.attn.q_proj.weight (4096, 4096)
transformer.h.1.attn.out_proj.weight (4096, 4096)
transformer.h.1.mlp.fc_in.weight (16384, 4096)
transformer.h.1.mlp.fc_in.bias (16384,)
transformer.h.1.mlp.fc_out.weight (4096, 16384)
transformer.h.1.mlp.fc_out.bias (4096,)
transformer.h.2.ln_1.weight (4096,)
transformer.h.2.ln_1.bias (4096,)
transformer.h.2.attn.k_proj.weight (4096, 4096)
transformer.h.2.attn.v_proj.weight (4096, 4096)
transformer.h.2.attn.q_proj.weight (4096, 4096)
transformer.h.2.attn.out_proj.weight (4096, 4096)
transformer.h.2.mlp.fc_in.weight (16384, 4096)
transformer.h.2.mlp.fc_in.bias (16384,)
transformer.h.2.mlp.fc_out.weight (4096, 16384)
transformer.h.2.mlp.fc_out.bias (4096,)
transformer.h.3.ln_1.weight (4096,)
transformer.h.3.ln_1.bias (4096,)
transformer.h.3.attn.k_proj.weight (4096, 4096)
transformer.h.3.attn.v_proj.weight (4096, 4096)
transformer.h.3.attn.q_proj.weight (4096, 4096)
transformer.h.3.attn.out_proj.weight (4096, 4096)
transformer.h.3.mlp.fc_in.weight (16384, 4096)
transformer.h.3.mlp.fc_in.bias (16384,)
transformer.h.3.mlp.fc_out.weight (4096, 16384)
transformer.h.3.mlp.fc_out.bias (4096,)
transformer.h.4.ln_1.weight (4096,)
transformer.h.4.ln_1.bias (4096,)
transformer.h.4.attn.k_proj.weight (4096, 4096)
transformer.h.4.attn.v_proj.weight (4096, 4096)
transformer.h.4.attn.q_proj.weight (4096, 4096)
transformer.h.4.attn.out_proj.weight (4096, 4096)
transformer.h.4.mlp.fc_in.weight (16384, 4096)
transformer.h.4.mlp.fc_in.bias (16384,)
transformer.h.4.mlp.fc_out.weight (4096, 16384)
transformer.h.4.mlp.fc_out.bias (4096,)
transformer.h.5.ln_1.weight (4096,)
transformer.h.5.ln_1.bias (4096,)
transformer.h.5.attn.k_proj.weight (4096, 4096)
transformer.h.5.attn.v_proj.weight (4096, 4096)
transformer.h.5.attn.q_proj.weight (4096, 4096)
transformer.h.5.attn.out_proj.weight (4096, 4096)
transformer.h.5.mlp.fc_in.weight (16384, 4096)
transformer.h.5.mlp.fc_in.bias (16384,)
transformer.h.5.mlp.fc_out.weight (4096, 16384)
transformer.h.5.mlp.fc_out.bias (4096,)
transformer.h.6.ln_1.weight (4096,)
transformer.h.6.ln_1.bias (4096,)
transformer.h.6.attn.k_proj.weight (4096, 4096)
transformer.h.6.attn.v_proj.weight (4096, 4096)
transformer.h.6.attn.q_proj.weight (4096, 4096)
transformer.h.6.attn.out_proj.weight (4096, 4096)
transformer.h.6.mlp.fc_in.weight (16384, 4096)
transformer.h.6.mlp.fc_in.bias (16384,)
transformer.h.6.mlp.fc_out.weight (4096, 16384)
transformer.h.6.mlp.fc_out.bias (4096,)
transformer.h.7.ln_1.weight (4096,)
transformer.h.7.ln_1.bias (4096,)
transformer.h.7.attn.k_proj.weight (4096, 4096)
transformer.h.7.attn.v_proj.weight (4096, 4096)
transformer.h.7.attn.q_proj.weight (4096, 4096)
transformer.h.7.attn.out_proj.weight (4096, 4096)
transformer.h.7.mlp.fc_in.weight (16384, 4096)
transformer.h.7.mlp.fc_in.bias (16384,)
transformer.h.7.mlp.fc_out.weight (4096, 16384)
transformer.h.7.mlp.fc_out.bias (4096,)
transformer.h.8.ln_1.weight (4096,)
transformer.h.8.ln_1.bias (4096,)
transformer.h.8.attn.k_proj.weight (4096, 4096)
transformer.h.8.attn.v_proj.weight (4096, 4096)
transformer.h.8.attn.q_proj.weight (4096, 4096)
transformer.h.8.attn.out_proj.weight (4096, 4096)
transformer.h.8.mlp.fc_in.weight (16384, 4096)
transformer.h.8.mlp.fc_in.bias (16384,)
transformer.h.8.mlp.fc_out.weight (4096, 16384)
transformer.h.8.mlp.fc_out.bias (4096,)
transformer.h.9.ln_1.weight (4096,)
transformer.h.9.ln_1.bias (4096,)
transformer.h.9.attn.k_proj.weight (4096, 4096)
transformer.h.9.attn.v_proj.weight (4096, 4096)
transformer.h.9.attn.q_proj.weight (4096, 4096)
transformer.h.9.attn.out_proj.weight (4096, 4096)
transformer.h.9.mlp.fc_in.weight (16384, 4096)
transformer.h.9.mlp.fc_in.bias (16384,)
transformer.h.9.mlp.fc_out.weight (4096, 16384)
transformer.h.9.mlp.fc_out.bias (4096,)
transformer.h.10.ln_1.weight (4096,)
transformer.h.10.ln_1.bias (4096,)
transformer.h.10.attn.k_proj.weight (4096, 4096)
transformer.h.10.attn.v_proj.weight (4096, 4096)
transformer.h.10.attn.q_proj.weight (4096, 4096)
transformer.h.10.attn.out_proj.weight (4096, 4096)
transformer.h.10.mlp.fc_in.weight (16384, 4096)
transformer.h.10.mlp.fc_in.bias (16384,)
transformer.h.10.mlp.fc_out.weight (4096, 16384)
transformer.h.10.mlp.fc_out.bias (4096,)
transformer.h.11.ln_1.weight (4096,)
transformer.h.11.ln_1.bias (4096,)
transformer.h.11.attn.k_proj.weight (4096, 4096)
transformer.h.11.attn.v_proj.weight (4096, 4096)
transformer.h.11.attn.q_proj.weight (4096, 4096)
transformer.h.11.attn.out_proj.weight (4096, 4096)
transformer.h.11.mlp.fc_in.weight (16384, 4096)
transformer.h.11.mlp.fc_in.bias (16384,)
transformer.h.11.mlp.fc_out.weight (4096, 16384)
transformer.h.11.mlp.fc_out.bias (4096,)
transformer.h.12.ln_1.weight (4096,)
transformer.h.12.ln_1.bias (4096,)
transformer.h.12.attn.k_proj.weight (4096, 4096)
transformer.h.12.attn.v_proj.weight (4096, 4096)
transformer.h.12.attn.q_proj.weight (4096, 4096)
transformer.h.12.attn.out_proj.weight (4096, 4096)
transformer.h.12.mlp.fc_in.weight (16384, 4096)
transformer.h.12.mlp.fc_in.bias (16384,)
transformer.h.12.mlp.fc_out.weight (4096, 16384)
transformer.h.12.mlp.fc_out.bias (4096,)
transformer.h.13.ln_1.weight (4096,)
transformer.h.13.ln_1.bias (4096,)
transformer.h.13.attn.k_proj.weight (4096, 4096)
transformer.h.13.attn.v_proj.weight (4096, 4096)
transformer.h.13.attn.q_proj.weight (4096, 4096)
transformer.h.13.attn.out_proj.weight (4096, 4096)
transformer.h.13.mlp.fc_in.weight (16384, 4096)
transformer.h.13.mlp.fc_in.bias (16384,)
transformer.h.13.mlp.fc_out.weight (4096, 16384)
transformer.h.13.mlp.fc_out.bias (4096,)
transformer.h.14.ln_1.weight (4096,)
transformer.h.14.ln_1.bias (4096,)
transformer.h.14.attn.k_proj.weight (4096, 4096)
transformer.h.14.attn.v_proj.weight (4096, 4096)
transformer.h.14.attn.q_proj.weight (4096, 4096)
transformer.h.14.attn.out_proj.weight (4096, 4096)
transformer.h.14.mlp.fc_in.weight (16384, 4096)
transformer.h.14.mlp.fc_in.bias (16384,)
transformer.h.14.mlp.fc_out.weight (4096, 16384)
transformer.h.14.mlp.fc_out.bias (4096,)
transformer.h.15.ln_1.weight (4096,)
transformer.h.15.ln_1.bias (4096,)
transformer.h.15.attn.k_proj.weight (4096, 4096)
transformer.h.15.attn.v_proj.weight (4096, 4096)
transformer.h.15.attn.q_proj.weight (4096, 4096)
transformer.h.15.attn.out_proj.weight (4096, 4096)
transformer.h.15.mlp.fc_in.weight (16384, 4096)
transformer.h.15.mlp.fc_in.bias (16384,)
transformer.h.15.mlp.fc_out.weight (4096, 16384)
transformer.h.15.mlp.fc_out.bias (4096,)
transformer.h.16.ln_1.weight (4096,)
transformer.h.16.ln_1.bias (4096,)
transformer.h.16.attn.k_proj.weight (4096, 4096)
transformer.h.16.attn.v_proj.weight (4096, 4096)
transformer.h.16.attn.q_proj.weight (4096, 4096)
transformer.h.16.attn.out_proj.weight (4096, 4096)
transformer.h.16.mlp.fc_in.weight (16384, 4096)
transformer.h.16.mlp.fc_in.bias (16384,)
transformer.h.16.mlp.fc_out.weight (4096, 16384)
transformer.h.16.mlp.fc_out.bias (4096,)
transformer.h.17.ln_1.weight (4096,)
transformer.h.17.ln_1.bias (4096,)
transformer.h.17.attn.k_proj.weight (4096, 4096)
transformer.h.17.attn.v_proj.weight (4096, 4096)
transformer.h.17.attn.q_proj.weight (4096, 4096)
transformer.h.17.attn.out_proj.weight (4096, 4096)
transformer.h.17.mlp.fc_in.weight (16384, 4096)
transformer.h.17.mlp.fc_in.bias (16384,)
transformer.h.17.mlp.fc_out.weight (4096, 16384)
transformer.h.17.mlp.fc_out.bias (4096,)
transformer.h.18.ln_1.weight (4096,)
transformer.h.18.ln_1.bias (4096,)
transformer.h.18.attn.k_proj.weight (4096, 4096)
transformer.h.18.attn.v_proj.weight (4096, 4096)
transformer.h.18.attn.q_proj.weight (4096, 4096)
transformer.h.18.attn.out_proj.weight (4096, 4096)
transformer.h.18.mlp.fc_in.weight (16384, 4096)
transformer.h.18.mlp.fc_in.bias (16384,)
transformer.h.18.mlp.fc_out.weight (4096, 16384)
transformer.h.18.mlp.fc_out.bias (4096,)
transformer.h.19.ln_1.weight (4096,)
transformer.h.19.ln_1.bias (4096,)
transformer.h.19.attn.k_proj.weight (4096, 4096)
transformer.h.19.attn.v_proj.weight (4096, 4096)
transformer.h.19.attn.q_proj.weight (4096, 4096)
transformer.h.19.attn.out_proj.weight (4096, 4096)
transformer.h.19.mlp.fc_in.weight (16384, 4096)
transformer.h.19.mlp.fc_in.bias (16384,)
transformer.h.19.mlp.fc_out.weight (4096, 16384)
transformer.h.19.mlp.fc_out.bias (4096,)
transformer.h.20.ln_1.weight (4096,)
transformer.h.20.ln_1.bias (4096,)
transformer.h.20.attn.k_proj.weight (4096, 4096)
transformer.h.20.attn.v_proj.weight (4096, 4096)
transformer.h.20.attn.q_proj.weight (4096, 4096)
transformer.h.20.attn.out_proj.weight (4096, 4096)
transformer.h.20.mlp.fc_in.weight (16384, 4096)
transformer.h.20.mlp.fc_in.bias (16384,)
transformer.h.20.mlp.fc_out.weight (4096, 16384)
transformer.h.20.mlp.fc_out.bias (4096,)
transformer.h.21.ln_1.weight (4096,)
transformer.h.21.ln_1.bias (4096,)
transformer.h.21.attn.k_proj.weight (4096, 4096)
transformer.h.21.attn.v_proj.weight (4096, 4096)
transformer.h.21.attn.q_proj.weight (4096, 4096)
transformer.h.21.attn.out_proj.weight (4096, 4096)
transformer.h.21.mlp.fc_in.weight (16384, 4096)
transformer.h.21.mlp.fc_in.bias (16384,)
transformer.h.21.mlp.fc_out.weight (4096, 16384)
transformer.h.21.mlp.fc_out.bias (4096,)
transformer.h.22.ln_1.weight (4096,)
transformer.h.22.ln_1.bias (4096,)
transformer.h.22.attn.k_proj.weight (4096, 4096)
transformer.h.22.attn.v_proj.weight (4096, 4096)
transformer.h.22.attn.q_proj.weight (4096, 4096)
transformer.h.22.attn.out_proj.weight (4096, 4096)
transformer.h.22.mlp.fc_in.weight (16384, 4096)
transformer.h.22.mlp.fc_in.bias (16384,)
transformer.h.22.mlp.fc_out.weight (4096, 16384)
transformer.h.22.mlp.fc_out.bias (4096,)
transformer.h.23.ln_1.weight (4096,)
transformer.h.23.ln_1.bias (4096,)
transformer.h.23.attn.k_proj.weight (4096, 4096)
transformer.h.23.attn.v_proj.weight (4096, 4096)
transformer.h.23.attn.q_proj.weight (4096, 4096)
transformer.h.23.attn.out_proj.weight (4096, 4096)
transformer.h.23.mlp.fc_in.weight (16384, 4096)
transformer.h.23.mlp.fc_in.bias (16384,)
transformer.h.23.mlp.fc_out.weight (4096, 16384)
transformer.h.23.mlp.fc_out.bias (4096,)
transformer.h.24.ln_1.weight (4096,)
transformer.h.24.ln_1.bias (4096,)
transformer.h.24.attn.k_proj.weight (4096, 4096)
transformer.h.24.attn.v_proj.weight (4096, 4096)
transformer.h.24.attn.q_proj.weight (4096, 4096)
transformer.h.24.attn.out_proj.weight (4096, 4096)
transformer.h.24.mlp.fc_in.weight (16384, 4096)
transformer.h.24.mlp.fc_in.bias (16384,)
transformer.h.24.mlp.fc_out.weight (4096, 16384)
transformer.h.24.mlp.fc_out.bias (4096,)
transformer.h.25.ln_1.weight (4096,)
transformer.h.25.ln_1.bias (4096,)
transformer.h.25.attn.k_proj.weight (4096, 4096)
transformer.h.25.attn.v_proj.weight (4096, 4096)
transformer.h.25.attn.q_proj.weight (4096, 4096)
transformer.h.25.attn.out_proj.weight (4096, 4096)
transformer.h.25.mlp.fc_in.weight (16384, 4096)
transformer.h.25.mlp.fc_in.bias (16384,)
transformer.h.25.mlp.fc_out.weight (4096, 16384)
transformer.h.25.mlp.fc_out.bias (4096,)
transformer.h.26.ln_1.weight (4096,)
transformer.h.26.ln_1.bias (4096,)
transformer.h.26.attn.k_proj.weight (4096, 4096)
transformer.h.26.attn.v_proj.weight (4096, 4096)
transformer.h.26.attn.q_proj.weight (4096, 4096)
transformer.h.26.attn.out_proj.weight (4096, 4096)
transformer.h.26.mlp.fc_in.weight (16384, 4096)
transformer.h.26.mlp.fc_in.bias (16384,)
transformer.h.26.mlp.fc_out.weight (4096, 16384)
transformer.h.26.mlp.fc_out.bias (4096,)
transformer.h.27.ln_1.weight (4096,)
transformer.h.27.ln_1.bias (4096,)
transformer.h.27.attn.k_proj.weight (4096, 4096)
transformer.h.27.attn.v_proj.weight (4096, 4096)
transformer.h.27.attn.q_proj.weight (4096, 4096)
transformer.h.27.attn.out_proj.weight (4096, 4096)
transformer.h.27.mlp.fc_in.weight (16384, 4096)
transformer.h.27.mlp.fc_in.bias (16384,)
transformer.h.27.mlp.fc_out.weight (4096, 16384)
transformer.h.27.mlp.fc_out.bias (4096,)
transformer.ln_f.weight (4096,)
transformer.ln_f.bias (4096,)
lm_head.weight (50400, 4096)
lm_head.bias (50400,)
Saving the pax model to pax_6b
Traceback (most recent call last):
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 192, in
convert(args.base_model_path, args.pax_model_path)
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 176, in convert
jax_states_gda = pjitted_identity(jax_states)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 248, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 195, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/core.py", line 2591, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/core.py", line 362, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/core.py", line 816, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 1246, in _pjit_call_impl
compiled = _pjit_lower(
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2836, in compile
self._executable = UnloadedMeshExecutable.from_hlo(
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 3048, in from_hlo
xla_executable = dispatch.compile_or_get_cached(
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
return backend_compile(backend, serialized_computation, compile_options,
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 22.54G of 15.48G hbm. Exceeded hbm capacity by 7.06G.

Total hbm usage >= 23.06G:
reserved 530.00M
program 4.0K
arguments 22.54G

Output size 22.54G; shares 0B with arguments.

Program hbm requirement 4.0K:
global 4.0K

Largest program allocations in hbm:

  1. Size: 4.0K
    Shape: u32[8,128]{1,0}
    Unpadded size: 4.0K
    XLA label: constant literal
    Allocation type: global

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 192, in
convert(args.base_model_path, args.pax_model_path)
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 176, in convert
jax_states_gda = pjitted_identity(jax_states)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 22.54G of 15.48G hbm. Exceeded hbm capacity by 7.06G.

Total hbm usage >= 23.06G:
reserved 530.00M
program 4.0K
arguments 22.54G

Output size 22.54G; shares 0B with arguments.

Program hbm requirement 4.0K:
global 4.0K

Largest program allocations in hbm:

  1. Size: 4.0K
    Shape: u32[8,128]{1,0}
    Unpadded size: 4.0K
    XLA label: constant literal
    Allocation type: global

    @zhihaoshan-google

Loading model and build time

Hi,

I tried following the README to run the LmCloudSpmd2BTest example on TPUv4 but couldn't load the model; this is the output of saxutil ls /sax/test/lm2b on admin:

INFO: Running command line: bazel-bin/saxml/bin/saxutil_/saxutil '--sax_root=gs://saxml-data/sax-root' ls /sax/test/lm2b
+-------+-------------------------------------------------------+-----------------+---------------+---------------------------+
| MODEL | MODEL PATH | CHECKPOINT PATH | # OF REPLICAS | (SELECTED) REPLICAADDRESS |
+-------+-------------------------------------------------------+-----------------+---------------+---------------------------+
| lm2b | saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2BTest | None | 0 | |
+-------+-------------------------------------------------------+-----------------+---------------+---------------------------+
+--------+-----+
| METHOD | ACL |
+--------+-----+
+--------+-----+

Here are the commands I used to start the admin and model server.
On admin:
bazel run saxml/bin:admin_config -- --sax_cell=/sax/test --sax_root=gs://saxml-data/sax-root --fs_root=gs://saxml-data/sax-fs-root --alsologtostderr
bazel run saxml/bin:admin_server -- --sax_cell=/sax/test --sax_root=gs://saxml-data/sax-root --port=10000 --alsologtostderr
saxutil publish /sax/test/lm2b saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2BTest None 1

I0630 04:24:08.036908 19996 ipaddr.go:56] IPNet address 10.128.0.71
I0630 04:24:08.212039 19996 admin.go:305] Loaded config: fs_root: "gs://saxml-data/sax-fs-root"
I0630 04:24:08.248588 19996 addr.go:105] SetAddr /gcs/saxml-data/sax-root/sax/test/location.prot o "10.128.0.71:10000"
I0630 04:24:08.298355 19996 admin.go:325] Updated config: fs_root: "gs://saxml-data/sax-fs-root "
I0630 04:24:08.455680 19996 mgr.go:781] Loaded manager state
I0630 04:24:08.455819 19996 mgr.go:784] Refreshing manager state every 10s
I0630 04:24:08.455895 19996 admin.go:350] Starting the server on port 10000
I0630 04:24:08.455957 19996 cloud.go:480] Starting the HTTP server on port 8080
I0630 14:22:11.800066 19996 state.go:456] Starting a queue that drains pending model server acti ons
I0630 14:22:11.800149 19996 state.go:473] Initializing state from model server 10.130.0.4:10001
I0630 14:22:11.810371 19996 state.go:479] Refreshing model server state every 10s
I0630 14:29:54.329640 19996 mgr.go:134] Published with overrides: map[]

On model server:
bazel run saxml/server:server -- --sax_cell=/sax/test --port=10001 --platform_chip=tpuv4 --platform_topology=2x2x1 --alsologtostderr

I0630 14:22:09.754312 139843449665280 model_service_base.py:852] Started joining SAX cell /sax/test
ERROR: logging before flag.Parse: I0630 14:22:11.754970 223228 location.go:141] Calling Join due to address update
ERROR: logging before flag.Parse: I0630 14:22:11.814963 223228 location.go:155] Joined 10.128.0.71 :10000
ERROR: logging before flag.Parse: I0630 14:37:11.758835 223228 location.go:162] Calling Join at fixed interval
ERROR: logging before flag.Parse: I0630 14:37:11.814902 223228 addr.go:72] FetchAddr /gcs/saxml-data/sax-root/sax/test/location.proto "10.128.0.71:10000"
ERROR: logging before flag.Parse: I0630 14:37:11.843650 223228 location.go:172] Joined 10.128.0.71 :10000

I've also waited a while to try saxutil ls /sax/test/lm2b again but still nothing in the "selected replica address" column. Any ideas of what might went wrong?

One thing I also noticed is the build time on model server is very long. The first time of running bazel run saxml/server:server -- --sax_cell=/sax/test --port=10001 --platform_chip=tpuv4 --platform_topology=2x2x1 --alsologtostderr took ~5 hrs to finish:

Target //saxml/server:server up-to-date:
bazel-bin/saxml/server/server.py
bazel-bin/saxml/server/server
INFO: Elapsed time: 16268.138s, Critical Path: 16222.45s
INFO: 5113 processes: 19 internal, 5091 linux-sandbox, 3 local.
INFO: Build completed successfully, 5113 total actions
INFO: Running command line: bazel-bin/saxml/server/server '--sax_cell=/sax/test' '--port=10001' '-- platform_chip=tpuv4' '--platform_topology=2x2x1' --alsologtostderr

Succeeding ones only took a few seconds to complete. Is this expected behavior?

Thanks!

LLama2-7b model conversion fails

$ python3 -m convert_llama_ckpt --base-model-path /llama2-7b-hf/ --pax-model-path pax_7B/ --model-size 7b
Loading the base model from /llama2-7b-hf/
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/shivajid/convert_llama_ckpt.py", line 210, in
convert(args.base_model_path, args.pax_model_path, args.model_size)
File "/home/shivajid/convert_llama_ckpt.py", line 96, in convert
'emb_var': np.concatenate([var['tok_embeddings.weight'].type(torch.float16).numpy() for var in pytorch_vars], axis=1)[:vocab,:]
ValueError: need at least one array to concatenate

Can you please help? I am using the LLama2 weights in --base-model-path

Can't build SAX using bazel 7.0 (installed by default)

When following default setup steps in README file, saxml/tools/init_cloud_vm.sh installs latest version of bazel, which is 7.0.0 at this point.

Looks like due to some change and bazel, building saxml/server:server fails with error: Error: 'apple_common' value has no field or method 'multi_arch_split'.
This is probably due to: bazelbuild/bazel@a76763c

Confirmed that I can build saxml/server:server after downgrading bazel to 6.4.0.

Access to gs:sax-data bucket.

Several of the models in the lm_cloud params reference files in GCP buckets.
Link
I don't have permission to access these files. Should they be publicly available?
Link

Failing to start the admin server

Getting the following error

1110 03:19:42.630464 1622013 admin_server.go:54] Failed to start server: config.Load error: Get "https://storage.googleapis.com/storage/v1/b/sax-data2/o/sax-root%2Fsax%2Ftest%2FMETADATA?alt=json&prettyPrint=false&projection=full": context canceled
goroutine 1 [running]:
github.com/golang/glog.stacks(0x0)
external/com_github_golang_glog/glog.go:769 +0x89
github.com/golang/glog.(*loggingT).output(0x101ba60, 0x3, 0xc0003038f0, {0xcf565f?, 0xc0002dfef8?}, 0x1?, 0x0)
external/com_github_golang_glog/glog.go:720 +0x46d
github.com/golang/glog.(*loggingT).printf(0xc000081920?, 0xbfa6d8?, {0xb206b6, 0x1a}, {0xc0002dfef8, 0x1, 0x1})
external/com_github_golang_glog/glog.go:655 +0x10f
github.com/golang/glog.Fatalf(...)
external/com_github_golang_glog/glog.go:1148
main.main()
saxml/bin/admin_server.go:54 +0x273

Python server performance

Thank you for making saxml open source! I'm one of the many people who think this project is very useful.

I am interested in how well the Python server works. I know that GPU kernels can be launched well with Python to make good use of GPUs. Could Python listen to network requests and handle them efficiently?

If not, should this Python server be used as a prototype for a C++ server that loads and runs AOT-compiled JAX programs (https://jax.readthedocs.io/en/latest/aot.html)? Thanks!

Can't run admin_config: build broken?

Hi, I was following the instructions in the readme and got blocked at the admin_config step.

gcastle@sax-admin:~/saxml$ git log -1
commit 7122f1e502fa5c6b012fa2dfa9ed4fa63c939f81 (HEAD -> main, origin/main, origin/HEAD)
Author: Daniel Freeman <[email protected]>
Date:   Mon Aug 14 16:16:37 2023 -0700

    adds support for persistent compilation cache
    
    PiperOrigin-RevId: 556942452
    Change-Id: I7342061514e57fdea41b2bcb1eb82cfecd0389c3

gcastle@sax-admin:~/saxml$ bazel run saxml/bin:admin_config --   --sax_cell=/sax/test   --sax_root=gs://${GSBUCKET}/sax-root   --fs_root=gs://${GSBUCKET}/sax-fs-root   --alsologtostderr
ERROR: /home/gcastle_google_com/saxml/saxml/protobuf/BUILD:387:14: //saxml/protobuf:multimodal_proto: no such attribute 'use_java_stubby_library' in 'proto_library' rule
ERROR: /home/gcastle_google_com/saxml/saxml/common/BUILD:136:11: Target '//saxml/protobuf:admin_go_proto_grpc' contains an error and its package is in error and referenced by '//saxml/common:config'
ERROR: Analysis of target '//saxml/bin:admin_config' failed; build aborted: 
INFO: Elapsed time: 0.185s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (0 packages loaded, 0 targets configured)
ERROR: Build failed. Not running target

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.