|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
from matplotlib.colors import TwoSlopeNorm |
|
import io |
|
from PIL import Image |
|
|
|
|
|
class W8A16LinearLayer(nn.Module): |
|
def __init__(self, in_features, out_features, bias=True, dtype=torch.float32): |
|
super().__init__() |
|
|
|
self.register_buffer( |
|
"int8_weights", |
|
torch.randint( |
|
-128, 127, (out_features, in_features), dtype=torch.int8 |
|
) |
|
) |
|
|
|
self.register_buffer("scales", |
|
torch.randn((out_features), dtype=dtype)) |
|
|
|
if bias: |
|
self.register_buffer("bias", |
|
torch.randn((1, out_features), |
|
dtype=dtype)) |
|
else: |
|
self.bias = None |
|
|
|
def quantize(self, weights): |
|
""" |
|
Quantize floating point weights to int8 precision |
|
|
|
Args: |
|
weights: Tensor of weights to quantize (shape: out_features x in_features) |
|
|
|
Returns: |
|
None (updates the int8_weights and scales directly) |
|
""" |
|
w_fp32 = weights.clone().to(torch.float32) |
|
|
|
|
|
|
|
scales = w_fp32.abs().max(dim=-1).values / 127 |
|
scales = scales.to(weights.dtype) |
|
|
|
|
|
int8_weights = torch.round(weights / scales.unsqueeze(1)).to(torch.int8) |
|
|
|
|
|
self.int8_weights = int8_weights |
|
self.scales = scales |
|
|
|
return int8_weights, scales |
|
|
|
def forward(self, input): |
|
""" |
|
Forward pass through the quantized linear layer |
|
|
|
Args: |
|
input: Input tensor (shape: batch_size x seq_len x in_features) |
|
|
|
Returns: |
|
output: Output tensor after the linear transformation |
|
""" |
|
|
|
casted_weights = self.int8_weights.to(input.dtype) |
|
|
|
|
|
output = F.linear(input, casted_weights) * self.scales |
|
|
|
|
|
if self.bias is not None: |
|
output = output + self.bias |
|
|
|
return output |
|
|
|
|
|
|
|
def plot_weight_matrix(weights, title="Weight Matrix"): |
|
"""Create a heatmap visualization of weight matrices""" |
|
plt.figure(figsize=(10, 8)) |
|
|
|
|
|
vmax = max(abs(weights.min().item()), abs(weights.max().item())) |
|
vmin = -vmax |
|
norm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax) |
|
|
|
plt.imshow(weights.detach().numpy(), cmap='RdBu_r', norm=norm) |
|
plt.colorbar(label='Weight Value') |
|
plt.title(title) |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
buf.seek(0) |
|
|
|
return Image.open(buf) |
|
|
|
def plot_weight_distribution(weights, title="Weight Distribution"): |
|
"""Create a histogram visualization of weight distributions""" |
|
plt.figure(figsize=(10, 6)) |
|
|
|
|
|
flat_weights = weights.flatten().detach().numpy() |
|
|
|
plt.hist(flat_weights, bins=50, alpha=0.7, color='blue') |
|
plt.xlabel('Weight Value') |
|
plt.ylabel('Frequency') |
|
plt.title(title) |
|
plt.grid(alpha=0.3) |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
buf.seek(0) |
|
|
|
return Image.open(buf) |
|
|
|
def calculate_quantization_error(original, quantized, scales): |
|
"""Calculate error metrics between original and dequantized weights""" |
|
|
|
dequantized = quantized.float() * scales.unsqueeze(1) |
|
|
|
|
|
abs_error = (original - dequantized).abs() |
|
max_error = abs_error.max().item() |
|
mean_error = abs_error.mean().item() |
|
|
|
return max_error, mean_error, dequantized |
|
|
|
|
|
|
|
def initialize_model(in_features, out_features, with_bias, dtype_str): |
|
"""Initialize a new quantized linear layer model""" |
|
|
|
dtype_map = { |
|
"float32": torch.float32, |
|
"float16": torch.float16, |
|
"bfloat16": torch.bfloat16 |
|
} |
|
dtype = dtype_map[dtype_str] |
|
|
|
|
|
model = W8A16LinearLayer(in_features, out_features, bias=with_bias, dtype=dtype) |
|
|
|
|
|
random_weights = torch.randn((out_features, in_features), dtype=dtype) |
|
|
|
|
|
weights_vis = plot_weight_matrix(random_weights, "Original Weights") |
|
dist_vis = plot_weight_distribution(random_weights, "Original Weight Distribution") |
|
|
|
|
|
int8_weights, scales = model.quantize(random_weights) |
|
|
|
|
|
q_weights_vis = plot_weight_matrix(int8_weights, "Quantized Weights (INT8)") |
|
q_dist_vis = plot_weight_distribution(int8_weights, "Quantized Weight Distribution") |
|
|
|
|
|
max_error, mean_error, dequantized = calculate_quantization_error( |
|
random_weights, int8_weights, scales |
|
) |
|
|
|
|
|
deq_weights_vis = plot_weight_matrix(dequantized, "Dequantized Weights") |
|
|
|
|
|
error = (random_weights - dequantized).abs() |
|
error_vis = plot_weight_matrix(error, "Quantization Error (Absolute)") |
|
|
|
|
|
model_info = f""" |
|
## Model Configuration |
|
- Input Features: {in_features} |
|
- Output Features: {out_features} |
|
- Bias: {"Yes" if with_bias else "No"} |
|
- Data Type: {dtype_str} |
|
|
|
## Quantization Stats |
|
- Original Weights Shape: {random_weights.shape} |
|
- Quantized Weights Shape: {int8_weights.shape} |
|
- Scales Shape: {scales.shape} |
|
- Maximum Quantization Error: {max_error:.6f} |
|
- Mean Quantization Error: {mean_error:.6f} |
|
- Memory Savings: {100 * (1 - (int8_weights.element_size() + scales.element_size() * scales.numel()/int8_weights.numel()) / random_weights.element_size()):.2f}% |
|
""" |
|
|
|
|
|
sample_input = torch.randn(1, in_features, dtype=dtype) |
|
sample_output = model(sample_input) |
|
|
|
io_info = f""" |
|
## Sample Input/Output |
|
- Input Shape: {sample_input.shape} |
|
- Output Shape: {sample_output.shape} |
|
- Output Range: [{sample_output.min().item():.4f}, {sample_output.max().item():.4f}] |
|
""" |
|
|
|
return model_info, io_info, weights_vis, q_weights_vis, deq_weights_vis, dist_vis, q_dist_vis, error_vis |
|
|
|
def quantize_custom_weights(in_features, out_features, with_bias, dtype_str, weight_pattern): |
|
"""Quantize custom weights based on the selected pattern""" |
|
|
|
dtype_map = { |
|
"float32": torch.float32, |
|
"float16": torch.float16, |
|
"bfloat16": torch.bfloat16 |
|
} |
|
dtype = dtype_map[dtype_str] |
|
|
|
|
|
model = W8A16LinearLayer(in_features, out_features, bias=with_bias, dtype=dtype) |
|
|
|
|
|
if weight_pattern == "random": |
|
custom_weights = torch.randn((out_features, in_features), dtype=dtype) |
|
elif weight_pattern == "eye": |
|
|
|
custom_weights = torch.zeros((out_features, in_features), dtype=dtype) |
|
min_dim = min(out_features, in_features) |
|
custom_weights[:min_dim, :min_dim] = torch.eye(min_dim, dtype=dtype) |
|
elif weight_pattern == "ones": |
|
custom_weights = torch.ones((out_features, in_features), dtype=dtype) |
|
elif weight_pattern == "alternating": |
|
custom_weights = torch.ones((out_features, in_features), dtype=dtype) |
|
|
|
for i in range(out_features): |
|
for j in range(in_features): |
|
if (i + j) % 2 == 1: |
|
custom_weights[i, j] = -1.0 |
|
elif weight_pattern == "gradient": |
|
|
|
x = torch.linspace(-1, 1, in_features) |
|
y = torch.linspace(-1, 1, out_features) |
|
xx, yy = torch.meshgrid(x, y, indexing='ij') |
|
custom_weights = (xx + yy).t().to(dtype) |
|
|
|
|
|
weights_vis = plot_weight_matrix(custom_weights, f"Original Weights ({weight_pattern})") |
|
dist_vis = plot_weight_distribution(custom_weights, "Original Weight Distribution") |
|
|
|
|
|
int8_weights, scales = model.quantize(custom_weights) |
|
|
|
|
|
q_weights_vis = plot_weight_matrix(int8_weights, "Quantized Weights (INT8)") |
|
q_dist_vis = plot_weight_distribution(int8_weights, "Quantized Weight Distribution") |
|
|
|
|
|
max_error, mean_error, dequantized = calculate_quantization_error( |
|
custom_weights, int8_weights, scales |
|
) |
|
|
|
|
|
deq_weights_vis = plot_weight_matrix(dequantized, "Dequantized Weights") |
|
|
|
|
|
error = (custom_weights - dequantized).abs() |
|
error_vis = plot_weight_matrix(error, "Quantization Error (Absolute)") |
|
|
|
|
|
quant_info = f""" |
|
## Quantization Details |
|
- Original Data Type: {dtype_str} |
|
- Quantized Data Type: int8 (8-bit) |
|
- Weight Pattern: {weight_pattern} |
|
|
|
## Error Analysis |
|
- Maximum Quantization Error: {max_error:.6f} |
|
- Mean Quantization Error: {mean_error:.6f} |
|
- Memory Savings: {100 * (1 - (int8_weights.element_size() + scales.element_size() * scales.numel()/int8_weights.numel()) / custom_weights.element_size()):.2f}% |
|
|
|
## Tensor Shapes |
|
- Original Weights: {custom_weights.shape} |
|
- Quantized Weights: {int8_weights.shape} |
|
- Quantization Scales: {scales.shape} |
|
""" |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
plt.hist(scales.detach().cpu().numpy(), bins=30, alpha=0.7, color='green') |
|
plt.xlabel('Scale Value') |
|
plt.ylabel('Frequency') |
|
plt.title('Distribution of Quantization Scales') |
|
plt.grid(alpha=0.3) |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
buf.seek(0) |
|
scales_vis = Image.open(buf) |
|
|
|
return quant_info, weights_vis, q_weights_vis, deq_weights_vis, dist_vis, q_dist_vis, error_vis, scales_vis |
|
|
|
|
|
with gr.Blocks(title="8-Bit Weight Quantizer") as demo: |
|
gr.Markdown("# PyTorch 8-Bit Weight Quantizer") |
|
gr.Markdown(""" |
|
This tool demonstrates quantization of neural network weights to INT8 precision. |
|
It implements a custom `W8A16LinearLayer` that uses 8-bit weights with 16-bit activations. |
|
""") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Initialize Model"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
in_feat = gr.Slider(minimum=1, maximum=512, value=16, step=1, label="Input Features") |
|
out_feat = gr.Slider(minimum=1, maximum=512, value=32, step=1, label="Output Features") |
|
with_bias = gr.Checkbox(value=True, label="Include Bias") |
|
dtype = gr.Dropdown(choices=["float32", "float16", "bfloat16"], value="float32", label="Data Type") |
|
init_btn = gr.Button("Initialize Model") |
|
|
|
with gr.Column(): |
|
model_info = gr.Markdown() |
|
io_info = gr.Markdown() |
|
|
|
with gr.Row(): |
|
orig_weights = gr.Image(label="Original Weights") |
|
quant_weights = gr.Image(label="Quantized Weights (INT8)") |
|
dequant_weights = gr.Image(label="Dequantized Weights") |
|
|
|
with gr.Row(): |
|
orig_dist = gr.Image(label="Original Weight Distribution") |
|
quant_dist = gr.Image(label="Quantized Weight Distribution") |
|
error_vis = gr.Image(label="Quantization Error") |
|
|
|
with gr.TabItem("Custom Quantization"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
c_in_feat = gr.Slider(minimum=1, maximum=512, value=16, step=1, label="Input Features") |
|
c_out_feat = gr.Slider(minimum=1, maximum=512, value=32, step=1, label="Output Features") |
|
c_with_bias = gr.Checkbox(value=True, label="Include Bias") |
|
c_dtype = gr.Dropdown(choices=["float32", "float16", "bfloat16"], value="float32", label="Data Type") |
|
weight_pattern = gr.Dropdown( |
|
choices=["random", "eye", "ones", "alternating", "gradient"], |
|
value="random", |
|
label="Weight Pattern" |
|
) |
|
quantize_btn = gr.Button("Quantize Weights") |
|
|
|
with gr.Column(): |
|
quant_details = gr.Markdown() |
|
|
|
with gr.Row(): |
|
c_orig_weights = gr.Image(label="Original Weights") |
|
c_quant_weights = gr.Image(label="Quantized Weights (INT8)") |
|
c_dequant_weights = gr.Image(label="Dequantized Weights") |
|
|
|
with gr.Row(): |
|
c_orig_dist = gr.Image(label="Original Weight Distribution") |
|
c_quant_dist = gr.Image(label="Quantized Weight Distribution") |
|
c_error_vis = gr.Image(label="Quantization Error") |
|
|
|
with gr.Row(): |
|
scales_dist = gr.Image(label="Quantization Scales Distribution") |
|
|
|
with gr.TabItem("About"): |
|
gr.Markdown(""" |
|
## 8-bit Quantizer Implementation |
|
|
|
This implementation includes: |
|
|
|
1. **W8A16LinearLayer** - A PyTorch module that uses INT8 weights and FP16/BF16/FP32 activations |
|
2. **Quantization** - Converts FP32/FP16/BF16 weights to INT8 using per-output-channel scaling |
|
3. **Visualization** - Shows the impact of quantization on weight distributions and errors |
|
|
|
### How It Works: |
|
|
|
1. For each output channel, find the maximum absolute weight value |
|
2. Scale all weights in that channel so the maximum fits in INT8 range (-128 to 127) |
|
3. Round scaled weights to integers and store as INT8 |
|
4. During inference, multiply INT8 weights by scaling factors to recover approximate FP values |
|
|
|
The quantization process reduces memory usage by up to 75% compared to FP32 weights. |
|
|
|
### References: |
|
|
|
- This implementation is based on modern techniques used in LLM quantization |
|
- Similar methods are used in libraries like bitsandbytes, AutoGPTQ, and GPTQ-for-LLaMa |
|
""") |
|
|
|
|
|
init_btn.click( |
|
initialize_model, |
|
inputs=[in_feat, out_feat, with_bias, dtype], |
|
outputs=[model_info, io_info, orig_weights, quant_weights, dequant_weights, orig_dist, quant_dist, error_vis] |
|
) |
|
|
|
quantize_btn.click( |
|
quantize_custom_weights, |
|
inputs=[c_in_feat, c_out_feat, c_with_bias, c_dtype, weight_pattern], |
|
outputs=[quant_details, c_orig_weights, c_quant_weights, c_dequant_weights, c_orig_dist, c_quant_dist, c_error_vis, scales_dist] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |