Chroma-fp8-scaled / convert_fp8_scaled_stochastic.py
Clybius's picture
Upload convert_fp8_scaled_stochastic.py
043e8b2 verified
import argparse
import os
import torch
import numpy as np
from safetensors import safe_open
from safetensors.torch import save_file
from typing import Dict, Tuple
# --- Configuration ---
# Keys containing these substrings will not be quantized if --t5xxl is set
AVOID_KEY_NAMES = ["norm", "bias", "embed_tokens", "shared"] #T5XXL, may need to be changed for other TEs.
# Target FP8 format
TARGET_FP8_DTYPE = torch.float8_e4m3fn
# Intermediate dtype for calculations
COMPUTE_DTYPE = torch.float64 # Don't think more hurts here since we're working tensor by tensor.
# Dtype for storing scale factors
SCALE_DTYPE = torch.float64 # Might be overkill, float32 should do just fine, but since these are so tiny may as well :3
# --- End Configuration ---
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
#Not 100% sure about this
def manual_stochastic_round_to_float8(x, dtype, generator=None):
if dtype == torch.float8_e4m3fn:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
elif dtype == torch.float8_e5m2:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
else:
raise ValueError("Unsupported dtype")
x = x.half()
sign = torch.sign(x)
abs_x = x.abs()
sign = torch.where(abs_x == 0, 0, sign)
# Combine exponent calculation and clamping
exponent = torch.clamp(
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
0, 2**EXPONENT_BITS - 1
)
# Combine mantissa calculation and rounding
normal_mask = ~(exponent == 0)
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
sign *= torch.where(
normal_mask,
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
inf = torch.finfo(dtype)
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
return sign
def stochastic_rounding(value, dtype=TARGET_FP8_DTYPE, seed=0):
if dtype == torch.float32:
return value.to(dtype=torch.float32)
if dtype == torch.float16:
return value.to(dtype=torch.float16)
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (1536 * 1536)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
#output.copy_(manual_stochastic_round_to_float8(value, dtype, generator=generator))
return output
return value.to(dtype=dtype)
def get_fp8_constants(fp8_dtype: torch.dtype) -> Tuple[float, float, float]:
"""Gets the min, max, and smallest positive normal value for a given FP8 dtype."""
finfo = torch.finfo(fp8_dtype)
# Smallest positive normal value approximation (may vary based on exact FP8 spec interpretation)
# For E4M3FN: exponent bias 7, smallest normal exp is -6. 1.0 * 2^-6 = 1/64
# Smallest subnormal is 2^-9 for E4M3FN from the paper. Let's use subnormal min.
# Find the smallest positive value representable (subnormal)
# This is tricky as finfo.tiny is often the smallest *normal*.
# Let's hardcode based on E4M3FN spec (S=0, E=0000, M=001) -> 2^-9
if fp8_dtype == torch.float8_e4m3fn:
fp8_min_pos = 2**-9 # Smallest subnormal for E4M3FN
elif fp8_dtype == torch.float8_e5m2:
# E5M2: exponent bias 15, smallest normal exp -14. Smallest subnormal 2^-16
fp8_min_pos = 2**-16 # Smallest subnormal for E5M2
else:
# Fallback using finfo.tiny (likely smallest normal)
fp8_min_pos = finfo.tiny * finfo.eps # A guess if unknown type
# Ensure min_pos is a Python float for consistency
fp8_min_pos = float(fp8_min_pos)
return float(finfo.min), float(finfo.max), fp8_min_pos
# Global FP8 constants
FP8_MIN, FP8_MAX, FP8_MIN_POS = get_fp8_constants(TARGET_FP8_DTYPE)
def convert_to_fp8_scaled(input_file: str, output_file: str, t5xxl: bool):
"""
Converts a safetensors file to a version with FP8 scaled weights using stochastic rounding.
For each tensor ending with '.weight' (unless excluded):
1. Calculates a scale factor based on the tensor's max absolute value.
2. Scales the tensor to fit within the FP8 range [-FP8_MAX, FP8_MAX].
3. Clamps the scaled tensor.
4. Applies stochastic rounding during quantization to TARGET_FP8_DTYPE.
5. Stores the quantized tensor.
6. Stores '.scale_weight' tensor: the factor to dequantize the weight (1.0 / scale_factor).
7. Stores '.scale_input' tensor: the factor to dequantize the input (using 1.0 / scale_factor as proxy).
"""
print(f"Processing: {input_file}")
print(f"Output will be saved to: {output_file}")
print(f"Using FP8 format: {TARGET_FP8_DTYPE}")
print(f"FP8 Range: [{FP8_MIN}, {FP8_MAX}], Min Pos Subnormal: {FP8_MIN_POS:.2e}")
print(f"Using Stochastic Rounding: True")
# Load the original model
tensors: Dict[str, torch.Tensor] = {}
try:
with safe_open(input_file, framework="pt", device="cpu") as f:
for key in f.keys():
# Load directly to CPU to avoid potential GPU OOM for large models
tensors[key] = f.get_tensor(key).cpu()
except Exception as e:
print(f"Error loading '{input_file}': {e}")
return
# Keep track of new/modified tensors
new_tensors: Dict[str, torch.Tensor] = {}
# Process each tensor ending with '.weight'
weight_keys = sorted([key for key in tensors.keys() if key.endswith('.weight')])
total_weights = len(weight_keys)
skipped_count = 0
processed_count = 0
print(f"Found {total_weights} weight tensors to potentially process.")
for i, key in enumerate(weight_keys):
process_this_key = True
if t5xxl:
for avoid_name in AVOID_KEY_NAMES:
if avoid_name in key:
print(f"({i+1}/{total_weights}) Skipping excluded tensor: {key}")
# Keep original tensor
new_tensors[key] = tensors[key]
process_this_key = False
skipped_count += 1
break # Stop checking avoid names for this key
if not process_this_key:
continue
print(f"({i+1}/{total_weights}) Processing tensor: {key}")
processed_count += 1
# Get the original tensor and convert to high precision for calculations
original_tensor = tensors[key].to(COMPUTE_DTYPE)
if original_tensor.numel() == 0:
print(f" - Skipping empty tensor: {key}")
new_tensors[key] = tensors[key].to(TARGET_FP8_DTYPE) # Store as empty FP8
# Add dummy scales
base_name = key[:-len('.weight')]
scale_weight_key = f"{base_name}.scale_weight"
dequant_scale = torch.tensor([1.0], dtype=SCALE_DTYPE)
new_tensors[scale_weight_key] = dequant_scale.detach().clone()
continue
# Calculate the scaling factor needed to map the max absolute value to FP8_MAX
abs_max = torch.max(torch.abs(original_tensor))
# Handle all-zero tensors or edge cases
if abs_max < 1e-12: # Use a small threshold instead of exact zero
print(f" - Tensor has near-zero max value ({abs_max.item():.2e}). Using scale factor 1.0.")
scale_factor = torch.tensor(1.0, dtype=COMPUTE_DTYPE)
scaled_tensor = original_tensor # No scaling needed
else:
# Ensure abs_max is positive before division
abs_max = abs_max.clamp(min=FP8_MIN_POS) # Clamp to smallest positive FP8 value
scale_factor = (FP8_MAX - FP8_MIN_POS) / abs_max
# Scale the tensor
scaled_tensor = original_tensor.mul(scale_factor)
# Clamp the scaled tensor to the representable FP8 range
#print(scale_factor)
clamped_tensor = torch.clamp(scaled_tensor, FP8_MIN, FP8_MAX)
# Perform stochastic rounding and quantization to FP8
quantized_fp8_tensor = stochastic_rounding(clamped_tensor)
# Store the quantized tensor
new_tensors[key] = quantized_fp8_tensor
# Calculate dequantization scale factor (inverse of the scaling factor)
dequant_scale = scale_factor.reciprocal()
# Create scale tensor keys
base_name = key[:-len('.weight')]
scale_weight_key = f"{base_name}.scale_weight"
# scale_input_key = f"{base_name}.scale_input" # scale_input Is not necessary, I think? Leaving this here as a cookie trail or smth if necessary in the future.
# Store scale tensors
new_tensors[scale_weight_key] = dequant_scale.detach().clone()
# --- Debug/Info Printing ---
print(f" - Abs Max : {abs_max.item():.5}")
print(f" - Scale Factor : {scale_factor.item():.5}")
print(f" - Dequant Scale : {dequant_scale.item():.5}")
# Combine original non-weight tensors with new/modified ones
added_scale_keys = set()
for key in new_tensors:
if key.endswith(".scale_weight") or key.endswith(".scale_input"):
added_scale_keys.add(key)
original_keys = set(tensors.keys())
processed_weight_keys = set(k for k, v in new_tensors.items() if k.endswith(".weight"))
for key, tensor in tensors.items():
# Add if it's not a weight tensor OR if it's a weight tensor that was skipped
is_weight = key.endswith(".weight")
if key not in new_tensors:
if not is_weight:
# Non-weight tensor, just copy it over
new_tensors[key] = tensor
print(f"(+) Adding original non-weight tensor: {key}")
# Add FP8 marker key for compatibility (e.g., ComfyUI)
new_tensors["scaled_fp8"] = torch.empty((2), dtype=TARGET_FP8_DTYPE) if not t5xxl else torch.empty((0), dtype=TARGET_FP8_DTYPE)
# Save the modified model
print("-" * 40)
print(f"Saving {len(new_tensors)} tensors to {output_file}")
try:
# Ensure parent directory exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
# Metadata can be useful
#metadata = {'format': f'pt_scaled_{TARGET_FP8_DTYPE.__str__().split(".")[-1]}'}
save_file(new_tensors, output_file)
print("Conversion complete!")
except Exception as e:
print(f"Error saving file '{output_file}': {e}")
return
# Print summary
final_tensor_count = len(new_tensors)
original_tensor_count = len(tensors)
added_tensors_count = final_tensor_count - original_tensor_count
added_scales_count = len(added_scale_keys)
print("-" * 40)
print(f"Summary:")
print(f" - Original tensor count : {original_tensor_count}")
print(f" - Weight tensors found : {total_weights}")
print(f" - Weights processed : {processed_count}")
print(f" - Weights skipped : {skipped_count}")
print(f" - Added scale tensors : {added_scales_count}") # Should be processed_count * 2 + skipped_count * 2
print(f" - Added marker tensor : 1")
print(f" - Final tensor count : {final_tensor_count}")
print("-" * 40)
def main():
parser = argparse.ArgumentParser(
description=f"Convert safetensors weights to Scaled {TARGET_FP8_DTYPE} format using stochastic rounding.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--input",
type=str,
required=True,
help="Input safetensors file path."
)
parser.add_argument(
"--output",
type=str,
help="Output safetensors file path. If not provided, generated based on input name."
)
parser.add_argument(
"--t5xxl",
action='store_true', # Use action='store_true' for boolean flags
help=f"Exclude certain layers from quantization while quantizing T5XXL."
)
args = parser.parse_args()
input_file = args.input
output_file = args.output
t5xxl = args.t5xxl
if not os.path.exists(input_file):
print(f"Error: Input file not found: {input_file}")
return
fp8_type_str = TARGET_FP8_DTYPE.__str__().split('.')[-1] # e.g., float8_e4m3fn
if not output_file:
# Generate output file name based on input file
base_name = os.path.splitext(input_file)[0]
output_file = f"{base_name}_{fp8_type_str}_scaled_stochastic.safetensors"
# Prevent overwriting input file
if os.path.abspath(input_file) == os.path.abspath(output_file):
print("Error: Output file cannot be the same as the input file.")
# Suggest a modified name
base, ext = os.path.splitext(output_file)
output_file = f"{base}_converted{ext}"
print(f"Suggestion: Use --output {output_file}")
return
convert_to_fp8_scaled(input_file, output_file, t5xxl)
if __name__ == "__main__":
main()