Comments (10)
I plan to support both cases. The JSON format is only for the case that people don't want FP8 model but FP8 kv-cache only (e.g., on pre-Hopper GPUs). In this case the model checkpoint is still in FP16, so we need a way to pass scaling factors anyways. For FP8 models we definitely load the scaling factors from checkpoints directly.
But we could have an fp16 model with scales for the kv cache saved in the safetensors files:
model.layers.0.attn.k_proj.weight
model.layers.0.attn.k_cache.act_scale
So we would therefore not need to support the JSON case
from vllm.
Thanks a lot for putting together this RFC! This sounds like a solid plan to me.
Some more detailed comments:
- Why not store the KV scaling factors in the safetensors / model checkpoints instead of the JSON format? The scaling factors for the KV store are very similar to the scaling factors for activations, so it is very natural to handle them the same and the safetensors / model checkpoints are the most natural places to store them.
- Have you considered to not have a
vllm::Fp8KVCacheDataType::kAuto
datatype and instead resolve the the kv cache dtype to the right dtype on the python layer? We should have all the needed information there and the C++ kernels shouldn't need to deal with this complexity and just be able to use a concrete type, right? (unless I'm missing something)
from vllm.
Ah I see, in that case, kAuto
is a good name since it is the same as "auto" in python. I didn't realize it required a special code path :)
from vllm.
Why not store the KV scaling factors in the safetensors / model checkpoints instead of the JSON format? The scaling factors for the KV store are very similar to the scaling factors for activations, so it is very natural to handle them the same and the safetensors / model checkpoints are the most natural places to store them.
I plan to support both cases. The JSON format is only for the case that people don't want FP8 model but FP8 kv-cache only (e.g., on pre-Hopper GPUs). In this case the model checkpoint is still in FP16, so we need a way to pass scaling factors anyways. For FP8 models we definitely load the scaling factors from checkpoints directly.
Have you considered to not have a vllm::Fp8KVCacheDataType::kAuto datatype and instead resolve the the kv cache dtype to the right dtype on the python layer? We should have all the needed information there and the C++ kernels shouldn't need to deal with this complexity and just be able to use a concrete type, right? (unless I'm missing something)
I think I've tried several ways but didn't luck at this moment (maybe I also miss something). I suppose the main reason is kv quantization happens during kv-cache write (cache_kernel) and read (attention_kernel), so the quantization kernel has to be dispatched and invoked inside the kernel. That's why we have to pass the data type all the way to these kernels. Additionally, since we now use uint8_t
to store fp8 values, we cannot differentiate whether a function like uint8_t convert(uint16_t& a) { ... }
is for E4M3 or E5M2. This also makes function overloading problematic. If we could directly use c10::Float8_e4m3fn
as the kv-cache data type directly, we could just implement c10::Float8_e4m3fn convert(scalar_t& a)
and c10::Float8_e5m2 convert(scalar_t& a)
and let C++ find the right one.
from vllm.
I'm +1 to supporting activation scales in the FP16 checkpoint and not in JSON. This way less configurations need to be supported and everything is uniform :)
from vllm.
I think I've tried several ways but didn't luck at this moment (maybe I also miss something). I suppose the main reason is kv quantization happens during kv-cache write (cache_kernel) and read (attention_kernel), so the quantization kernel has to be dispatched and invoked inside the kernel. That's why we have to pass the data type all the way to these kernels. Additionally, since we now use
uint8_t
to store fp8 values, we cannot differentiate whether a function likeuint8_t convert(uint16_t& a) { ... }
is for E4M3 or E5M2. This also makes function overloading problematic. If we could directly usec10::Float8_e4m3fn
as the kv-cache data type directly, we could just implementc10::Float8_e4m3fn convert(scalar_t& a)
andc10::Float8_e5m2 convert(scalar_t& a)
and let C++ find the right one.
Sounds good! I get why the data type needs to be passed throught, what I don't really get is why "auto" needs to be handled in C++ -- it seems to me it could be mapped to a concrete type in python and then c++ only needs to handle the cases of concrete types. But I might be wrong, feel free to do what feels most natural while implementing this :)
from vllm.
Thanks for the valuable feedback!
- [Interface] I am actually supportive with kv-cache scaling factor embedded in model checkpoint (and 1.0 will be used if we cannot find it in the checkpoint). After all the kv-cache scaling factors are always tightly coupled with the model itself. Let's try to align with other folks on this interface.
- [kAuto] I see your point. I agree that the term
kAuto
is misleading. In fact, it should be something likekNoFp8
. The reason to keep this flag is mainly for the case when FP8 is not enabled (e.g., CPU or CUDA<8.0). In this case we completely disable FP8 related kernels, so we need a compile-time flag to make a special path. For example:
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
// Make sure we don't call " fp8::scaled_convert" when FP8 is disabled.
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
}
Then we could have:
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
switch (kv_dt) {
#ifdef ENABLE_FP8
case Fp8KVCacheDataType::kFp8E4m3:
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
case Fp8KVCacheDataType::kFp8E5m2:
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
#endif
default: // Only this branch left when FP8 is disabled, so always throw an error.
assert(false);
}
}
from vllm.
w.r.t. We cannot enable e4m3 and e5m2 in the same build.
If we look to have a build with both supported on a same newer hardware, most likely we won't need both formats function simultaneously, as that increases the complexity with no use case, considering only e4m3 would be used in forward and inferencing computations. On contrary, e5m2 would only be practically feasible on older hardware to be a storage type with acceptable performance from mantissa rounding, have e4m3 enabled on older hardware isn't beneficial, neither to cost of cast, nor to computation (no hardware support). Finally, have a build being generic across generations of GPUs seems to be unnecessary.
from vllm.
w.r.t. When running FP8 model, we load kv-cache scaling factor from the model checkpoint.
We shall have serialized checkpoint with various scaling factors defined, to both the stationary scaling factors (for weights, at whatever granularity), and updatable scaling factors (activations, KV caches), and to the later we need to define the update process with quantizer flow included.
from vllm.
A wrapper function convert that converts Tin to Tout with a particular FP8 format. For example, when writing values to kv-cache, Tin=uint16_t, Tout=uint8_t, kv_dt=kFp8E4M3
Over time, it makes more sense to rule out uint8_t
and move to use torch fp8 types, then kv_dt
would be unnecessary.
from vllm.
Related Issues (20)
- [Usage]: how should I do data parallelism using vLLM?
- [Bug]: torch.cuda.OutOfMemoryError: CUDA out of memory when Handle inference requests
- [Misc]: Should inference with temperature 0 generate the same results for a lora adapter and equivalent merged model? HOT 5
- [Bug]: CUDA illegal memory access when calling flash_attn_cuda.fwd_kvcache
- [Bug]: The openai deployment model takes twice as long to deploy as fastapi's approach to offline inference. HOT 1
- [Feature]: Linear adapter support for Mixtral
- [Feature]: VLLM suport for function calling in Mistral-7B-Instruct-v0.3 HOT 1
- [Bug]: Issue with Token Processing Efficiency and Key-Value Cache Utilization in AsyncLLMEngine
- [Bug]: WSL2(Including Docker) 2 GPU problem --tensor-parallel-size 2 HOT 1
- [Bug]: Unable to Use Prefix Caching in AsyncLLMEngine HOT 10
- [Performance]: What can we learn from OctoAI HOT 3
- [Bug]: Model Launch Hangs with 16+ Ranks in vLLM HOT 1
- [Usage]: Prefix caching in VLLM HOT 1
- [Bug]: Incorrect Example for the Inference with Prefix
- [Feature]: BERT models for embeddings HOT 1
- [Bug]: The Offline Inference Embedding Example Fails HOT 5
- [Bug]: Offline Inference with the OpenAI Batch file format yields unnecessary `asyncio.exceptions.CancelledError` HOT 2
- [Feature]: MoE kernels (Mixtral-8x22B-Instruct-v0.1) are not yet supported on CPU only ?
- [Bug]: vLLM api_server.py when using with prompt_token_ids causes error.
- [Bug]: loading squeezellm model
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from vllm.