|
import argparse |
|
import os |
|
import torch |
|
import numpy as np |
|
from safetensors import safe_open |
|
from safetensors.torch import save_file |
|
from typing import Dict, Tuple |
|
|
|
|
|
|
|
AVOID_KEY_NAMES = ["norm", "bias", "embed_tokens", "shared"] |
|
|
|
TARGET_FP8_DTYPE = torch.float8_e4m3fn |
|
|
|
COMPUTE_DTYPE = torch.float64 |
|
|
|
SCALE_DTYPE = torch.float64 |
|
|
|
|
|
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): |
|
mantissa_scaled = torch.where( |
|
normal_mask, |
|
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS), |
|
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) |
|
) |
|
|
|
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator) |
|
return mantissa_scaled.floor() / (2**MANTISSA_BITS) |
|
|
|
|
|
def manual_stochastic_round_to_float8(x, dtype, generator=None): |
|
if dtype == torch.float8_e4m3fn: |
|
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7 |
|
elif dtype == torch.float8_e5m2: |
|
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15 |
|
else: |
|
raise ValueError("Unsupported dtype") |
|
|
|
x = x.half() |
|
sign = torch.sign(x) |
|
abs_x = x.abs() |
|
sign = torch.where(abs_x == 0, 0, sign) |
|
|
|
|
|
exponent = torch.clamp( |
|
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS, |
|
0, 2**EXPONENT_BITS - 1 |
|
) |
|
|
|
|
|
normal_mask = ~(exponent == 0) |
|
|
|
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator) |
|
|
|
sign *= torch.where( |
|
normal_mask, |
|
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x), |
|
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x |
|
) |
|
|
|
inf = torch.finfo(dtype) |
|
torch.clamp(sign, min=inf.min, max=inf.max, out=sign) |
|
return sign |
|
|
|
|
|
|
|
def stochastic_rounding(value, dtype=TARGET_FP8_DTYPE, seed=0): |
|
if dtype == torch.float32: |
|
return value.to(dtype=torch.float32) |
|
if dtype == torch.float16: |
|
return value.to(dtype=torch.float16) |
|
if dtype == torch.bfloat16: |
|
return value.to(dtype=torch.bfloat16) |
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: |
|
generator = torch.Generator(device=value.device) |
|
generator.manual_seed(seed) |
|
output = torch.empty_like(value, dtype=dtype) |
|
num_slices = max(1, (value.numel() / (1536 * 1536))) |
|
slice_size = max(1, round(value.shape[0] / num_slices)) |
|
for i in range(0, value.shape[0], slice_size): |
|
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)) |
|
|
|
return output |
|
|
|
return value.to(dtype=dtype) |
|
|
|
def get_fp8_constants(fp8_dtype: torch.dtype) -> Tuple[float, float, float]: |
|
"""Gets the min, max, and smallest positive normal value for a given FP8 dtype.""" |
|
finfo = torch.finfo(fp8_dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if fp8_dtype == torch.float8_e4m3fn: |
|
fp8_min_pos = 2**-9 |
|
elif fp8_dtype == torch.float8_e5m2: |
|
|
|
fp8_min_pos = 2**-16 |
|
else: |
|
|
|
fp8_min_pos = finfo.tiny * finfo.eps |
|
|
|
|
|
fp8_min_pos = float(fp8_min_pos) |
|
|
|
return float(finfo.min), float(finfo.max), fp8_min_pos |
|
|
|
|
|
FP8_MIN, FP8_MAX, FP8_MIN_POS = get_fp8_constants(TARGET_FP8_DTYPE) |
|
|
|
def convert_to_fp8_scaled(input_file: str, output_file: str, t5xxl: bool): |
|
""" |
|
Converts a safetensors file to a version with FP8 scaled weights using stochastic rounding. |
|
|
|
For each tensor ending with '.weight' (unless excluded): |
|
1. Calculates a scale factor based on the tensor's max absolute value. |
|
2. Scales the tensor to fit within the FP8 range [-FP8_MAX, FP8_MAX]. |
|
3. Clamps the scaled tensor. |
|
4. Applies stochastic rounding during quantization to TARGET_FP8_DTYPE. |
|
5. Stores the quantized tensor. |
|
6. Stores '.scale_weight' tensor: the factor to dequantize the weight (1.0 / scale_factor). |
|
7. Stores '.scale_input' tensor: the factor to dequantize the input (using 1.0 / scale_factor as proxy). |
|
""" |
|
print(f"Processing: {input_file}") |
|
print(f"Output will be saved to: {output_file}") |
|
print(f"Using FP8 format: {TARGET_FP8_DTYPE}") |
|
print(f"FP8 Range: [{FP8_MIN}, {FP8_MAX}], Min Pos Subnormal: {FP8_MIN_POS:.2e}") |
|
print(f"Using Stochastic Rounding: True") |
|
|
|
|
|
tensors: Dict[str, torch.Tensor] = {} |
|
try: |
|
with safe_open(input_file, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
|
|
tensors[key] = f.get_tensor(key).cpu() |
|
except Exception as e: |
|
print(f"Error loading '{input_file}': {e}") |
|
return |
|
|
|
|
|
new_tensors: Dict[str, torch.Tensor] = {} |
|
|
|
|
|
weight_keys = sorted([key for key in tensors.keys() if key.endswith('.weight')]) |
|
total_weights = len(weight_keys) |
|
skipped_count = 0 |
|
processed_count = 0 |
|
|
|
print(f"Found {total_weights} weight tensors to potentially process.") |
|
|
|
for i, key in enumerate(weight_keys): |
|
process_this_key = True |
|
if t5xxl: |
|
for avoid_name in AVOID_KEY_NAMES: |
|
if avoid_name in key: |
|
print(f"({i+1}/{total_weights}) Skipping excluded tensor: {key}") |
|
|
|
new_tensors[key] = tensors[key] |
|
process_this_key = False |
|
skipped_count += 1 |
|
break |
|
|
|
if not process_this_key: |
|
continue |
|
|
|
print(f"({i+1}/{total_weights}) Processing tensor: {key}") |
|
processed_count += 1 |
|
|
|
|
|
original_tensor = tensors[key].to(COMPUTE_DTYPE) |
|
|
|
if original_tensor.numel() == 0: |
|
print(f" - Skipping empty tensor: {key}") |
|
new_tensors[key] = tensors[key].to(TARGET_FP8_DTYPE) |
|
|
|
base_name = key[:-len('.weight')] |
|
scale_weight_key = f"{base_name}.scale_weight" |
|
dequant_scale = torch.tensor([1.0], dtype=SCALE_DTYPE) |
|
new_tensors[scale_weight_key] = dequant_scale.detach().clone() |
|
continue |
|
|
|
|
|
abs_max = torch.max(torch.abs(original_tensor)) |
|
|
|
if abs_max < 1e-12: |
|
print(f" - Tensor has near-zero max value ({abs_max.item():.2e}). Using scale factor 1.0.") |
|
scale_factor = torch.tensor(1.0, dtype=COMPUTE_DTYPE) |
|
scaled_tensor = original_tensor |
|
else: |
|
|
|
abs_max = abs_max.clamp(min=FP8_MIN_POS) |
|
scale_factor = (FP8_MAX - FP8_MIN_POS) / abs_max |
|
|
|
scaled_tensor = original_tensor.mul(scale_factor) |
|
|
|
|
|
|
|
clamped_tensor = torch.clamp(scaled_tensor, FP8_MIN, FP8_MAX) |
|
|
|
|
|
quantized_fp8_tensor = stochastic_rounding(clamped_tensor) |
|
|
|
|
|
new_tensors[key] = quantized_fp8_tensor |
|
|
|
|
|
dequant_scale = scale_factor.reciprocal() |
|
|
|
|
|
base_name = key[:-len('.weight')] |
|
scale_weight_key = f"{base_name}.scale_weight" |
|
|
|
|
|
|
|
new_tensors[scale_weight_key] = dequant_scale.detach().clone() |
|
|
|
|
|
print(f" - Abs Max : {abs_max.item():.5}") |
|
print(f" - Scale Factor : {scale_factor.item():.5}") |
|
print(f" - Dequant Scale : {dequant_scale.item():.5}") |
|
|
|
|
|
added_scale_keys = set() |
|
for key in new_tensors: |
|
if key.endswith(".scale_weight") or key.endswith(".scale_input"): |
|
added_scale_keys.add(key) |
|
|
|
original_keys = set(tensors.keys()) |
|
processed_weight_keys = set(k for k, v in new_tensors.items() if k.endswith(".weight")) |
|
|
|
for key, tensor in tensors.items(): |
|
|
|
is_weight = key.endswith(".weight") |
|
if key not in new_tensors: |
|
if not is_weight: |
|
|
|
new_tensors[key] = tensor |
|
print(f"(+) Adding original non-weight tensor: {key}") |
|
|
|
|
|
new_tensors["scaled_fp8"] = torch.empty((2), dtype=TARGET_FP8_DTYPE) if not t5xxl else torch.empty((0), dtype=TARGET_FP8_DTYPE) |
|
|
|
|
|
print("-" * 40) |
|
print(f"Saving {len(new_tensors)} tensors to {output_file}") |
|
try: |
|
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True) |
|
|
|
|
|
save_file(new_tensors, output_file) |
|
print("Conversion complete!") |
|
except Exception as e: |
|
print(f"Error saving file '{output_file}': {e}") |
|
return |
|
|
|
|
|
final_tensor_count = len(new_tensors) |
|
original_tensor_count = len(tensors) |
|
added_tensors_count = final_tensor_count - original_tensor_count |
|
added_scales_count = len(added_scale_keys) |
|
|
|
print("-" * 40) |
|
print(f"Summary:") |
|
print(f" - Original tensor count : {original_tensor_count}") |
|
print(f" - Weight tensors found : {total_weights}") |
|
print(f" - Weights processed : {processed_count}") |
|
print(f" - Weights skipped : {skipped_count}") |
|
print(f" - Added scale tensors : {added_scales_count}") |
|
print(f" - Added marker tensor : 1") |
|
print(f" - Final tensor count : {final_tensor_count}") |
|
print("-" * 40) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description=f"Convert safetensors weights to Scaled {TARGET_FP8_DTYPE} format using stochastic rounding.", |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
parser.add_argument( |
|
"--input", |
|
type=str, |
|
required=True, |
|
help="Input safetensors file path." |
|
) |
|
parser.add_argument( |
|
"--output", |
|
type=str, |
|
help="Output safetensors file path. If not provided, generated based on input name." |
|
) |
|
parser.add_argument( |
|
"--t5xxl", |
|
action='store_true', |
|
help=f"Exclude certain layers from quantization while quantizing T5XXL." |
|
) |
|
args = parser.parse_args() |
|
|
|
input_file = args.input |
|
output_file = args.output |
|
t5xxl = args.t5xxl |
|
|
|
if not os.path.exists(input_file): |
|
print(f"Error: Input file not found: {input_file}") |
|
return |
|
|
|
fp8_type_str = TARGET_FP8_DTYPE.__str__().split('.')[-1] |
|
|
|
if not output_file: |
|
|
|
base_name = os.path.splitext(input_file)[0] |
|
output_file = f"{base_name}_{fp8_type_str}_scaled_stochastic.safetensors" |
|
|
|
|
|
if os.path.abspath(input_file) == os.path.abspath(output_file): |
|
print("Error: Output file cannot be the same as the input file.") |
|
|
|
base, ext = os.path.splitext(output_file) |
|
output_file = f"{base}_converted{ext}" |
|
print(f"Suggestion: Use --output {output_file}") |
|
return |
|
|
|
convert_to_fp8_scaled(input_file, output_file, t5xxl) |
|
|
|
if __name__ == "__main__": |
|
main() |