| | #pragma once |
| |
|
| | |
| | * Quantization utilities including: |
| | * Adjusted maximum values for qtypes. |
| | * Minimum scaling factors for qtypes. |
| | */ |
| |
|
| | #include <cmath> |
| | #include <torch/types.h> |
| |
|
| | #ifndef USE_ROCM |
| | #include <c10/util/Float8_e4m3fn.h> |
| | #define MAYBE_HOST_DEVICE C10_HOST_DEVICE |
| | #else |
| | #include <ATen/hip/HIPContext.h> |
| | #include <c10/util/Float8_e4m3fn.h> |
| | #include <c10/util/Float8_e4m3fnuz.h> |
| | |
| | #define MAYBE_HOST_DEVICE |
| | #endif |
| |
|
| | template <typename T, |
| | typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> || |
| | std::is_same_v<T, c10::Float8_e4m3fnuz> || |
| | std::is_same_v<T, int8_t>>> |
| | struct quant_type_max { |
| | static constexpr T val() { return std::numeric_limits<T>::max(); } |
| | }; |
| |
|
| | |
| | |
| | template <> |
| | struct quant_type_max<c10::Float8_e4m3fnuz> { |
| | static constexpr c10::Float8_e4m3fnuz val() { |
| | return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits()); |
| | } |
| | }; |
| |
|
| | template <typename T> |
| | MAYBE_HOST_DEVICE static constexpr T quant_type_max_v = |
| | quant_type_max<T>::val(); |
| |
|
| | template <typename T, |
| | typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> || |
| | std::is_same_v<T, c10::Float8_e4m3fnuz> || |
| | std::is_same_v<T, int8_t>>> |
| | struct min_scaling_factor { |
| | C10_DEVICE C10_ALWAYS_INLINE static float val() { |
| | return 1.0f / (quant_type_max_v<T> * 512.0f); |
| | } |
| | }; |
| |
|
| | template <> |
| | struct min_scaling_factor<int8_t> { |
| | C10_DEVICE C10_ALWAYS_INLINE static float val() { |
| | return std::numeric_limits<float>::epsilon(); |
| | } |
| | }; |