Giter VIP home page Giter VIP logo

Comments (10)

robertgshaw2-neuralmagic avatar robertgshaw2-neuralmagic commented on June 8, 2024 2

I plan to support both cases. The JSON format is only for the case that people don't want FP8 model but FP8 kv-cache only (e.g., on pre-Hopper GPUs). In this case the model checkpoint is still in FP16, so we need a way to pass scaling factors anyways. For FP8 models we definitely load the scaling factors from checkpoints directly.

But we could have an fp16 model with scales for the kv cache saved in the safetensors files:

model.layers.0.attn.k_proj.weight
model.layers.0.attn.k_cache.act_scale

So we would therefore not need to support the JSON case

from vllm.

pcmoritz avatar pcmoritz commented on June 8, 2024 1

Thanks a lot for putting together this RFC! This sounds like a solid plan to me.

Some more detailed comments:

  • Why not store the KV scaling factors in the safetensors / model checkpoints instead of the JSON format? The scaling factors for the KV store are very similar to the scaling factors for activations, so it is very natural to handle them the same and the safetensors / model checkpoints are the most natural places to store them.
  • Have you considered to not have a vllm::Fp8KVCacheDataType::kAuto datatype and instead resolve the the kv cache dtype to the right dtype on the python layer? We should have all the needed information there and the C++ kernels shouldn't need to deal with this complexity and just be able to use a concrete type, right? (unless I'm missing something)

from vllm.

pcmoritz avatar pcmoritz commented on June 8, 2024 1

Ah I see, in that case, kAuto is a good name since it is the same as "auto" in python. I didn't realize it required a special code path :)

from vllm.

comaniac avatar comaniac commented on June 8, 2024

Why not store the KV scaling factors in the safetensors / model checkpoints instead of the JSON format? The scaling factors for the KV store are very similar to the scaling factors for activations, so it is very natural to handle them the same and the safetensors / model checkpoints are the most natural places to store them.

I plan to support both cases. The JSON format is only for the case that people don't want FP8 model but FP8 kv-cache only (e.g., on pre-Hopper GPUs). In this case the model checkpoint is still in FP16, so we need a way to pass scaling factors anyways. For FP8 models we definitely load the scaling factors from checkpoints directly.

Have you considered to not have a vllm::Fp8KVCacheDataType::kAuto datatype and instead resolve the the kv cache dtype to the right dtype on the python layer? We should have all the needed information there and the C++ kernels shouldn't need to deal with this complexity and just be able to use a concrete type, right? (unless I'm missing something)

I think I've tried several ways but didn't luck at this moment (maybe I also miss something). I suppose the main reason is kv quantization happens during kv-cache write (cache_kernel) and read (attention_kernel), so the quantization kernel has to be dispatched and invoked inside the kernel. That's why we have to pass the data type all the way to these kernels. Additionally, since we now use uint8_t to store fp8 values, we cannot differentiate whether a function like uint8_t convert(uint16_t& a) { ... } is for E4M3 or E5M2. This also makes function overloading problematic. If we could directly use c10::Float8_e4m3fn as the kv-cache data type directly, we could just implement c10::Float8_e4m3fn convert(scalar_t& a) and c10::Float8_e5m2 convert(scalar_t& a) and let C++ find the right one.

from vllm.

pcmoritz avatar pcmoritz commented on June 8, 2024

I'm +1 to supporting activation scales in the FP16 checkpoint and not in JSON. This way less configurations need to be supported and everything is uniform :)

from vllm.

pcmoritz avatar pcmoritz commented on June 8, 2024

I think I've tried several ways but didn't luck at this moment (maybe I also miss something). I suppose the main reason is kv quantization happens during kv-cache write (cache_kernel) and read (attention_kernel), so the quantization kernel has to be dispatched and invoked inside the kernel. That's why we have to pass the data type all the way to these kernels. Additionally, since we now use uint8_t to store fp8 values, we cannot differentiate whether a function like uint8_t convert(uint16_t& a) { ... } is for E4M3 or E5M2. This also makes function overloading problematic. If we could directly use c10::Float8_e4m3fn as the kv-cache data type directly, we could just implement c10::Float8_e4m3fn convert(scalar_t& a) and c10::Float8_e5m2 convert(scalar_t& a) and let C++ find the right one.

Sounds good! I get why the data type needs to be passed throught, what I don't really get is why "auto" needs to be handled in C++ -- it seems to me it could be mapped to a concrete type in python and then c++ only needs to handle the cases of concrete types. But I might be wrong, feel free to do what feels most natural while implementing this :)

from vllm.

comaniac avatar comaniac commented on June 8, 2024

Thanks for the valuable feedback!

  • [Interface] I am actually supportive with kv-cache scaling factor embedded in model checkpoint (and 1.0 will be used if we cannot find it in the checkpoint). After all the kv-cache scaling factors are always tightly coupled with the model itself. Let's try to align with other folks on this interface.
  • [kAuto] I see your point. I agree that the term kAuto is misleading. In fact, it should be something like kNoFp8. The reason to keep this flag is mainly for the case when FP8 is not enabled (e.g., CPU or CUDA<8.0). In this case we completely disable FP8 related kernels, so we need a compile-time flag to make a special path. For example:
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
  // Make sure we don't call " fp8::scaled_convert" when FP8 is disabled.
  key_cache[tgt_key_idx] = tgt_key;
  value_cache[tgt_value_idx] = tgt_value;
} else {
  key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
  value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
}

Then we could have:

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
  switch (kv_dt) {
#ifdef ENABLE_FP8
  case Fp8KVCacheDataType::kFp8E4m3:
    return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
  case Fp8KVCacheDataType::kFp8E5m2:
    return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
#endif
  default: // Only this branch left when FP8 is disabled, so always throw an error.
    assert(false);
  }
}

from vllm.

HaiShaw avatar HaiShaw commented on June 8, 2024

w.r.t. We cannot enable e4m3 and e5m2 in the same build.
If we look to have a build with both supported on a same newer hardware, most likely we won't need both formats function simultaneously, as that increases the complexity with no use case, considering only e4m3 would be used in forward and inferencing computations. On contrary, e5m2 would only be practically feasible on older hardware to be a storage type with acceptable performance from mantissa rounding, have e4m3 enabled on older hardware isn't beneficial, neither to cost of cast, nor to computation (no hardware support). Finally, have a build being generic across generations of GPUs seems to be unnecessary.

from vllm.

HaiShaw avatar HaiShaw commented on June 8, 2024

w.r.t. When running FP8 model, we load kv-cache scaling factor from the model checkpoint.
We shall have serialized checkpoint with various scaling factors defined, to both the stationary scaling factors (for weights, at whatever granularity), and updatable scaling factors (activations, KV caches), and to the later we need to define the update process with quantizer flow included.

from vllm.

HaiShaw avatar HaiShaw commented on June 8, 2024
A wrapper function convert that converts Tin to Tout with a particular FP8 format. For example, when writing values to kv-cache, Tin=uint16_t, Tout=uint8_t, kv_dt=kFp8E4M3

Over time, it makes more sense to rule out uint8_t and move to use torch fp8 types, then kv_dt would be unnecessary.

from vllm.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.