#!/usr/bin/env python3 | |
"""Test the saved compressed models""" | |
import torch | |
import torch.nn as nn | |
import os | |
print("="*70) | |
print(" "*10 + "TESTING SAVED COMPRESSED MODELS") | |
print("="*70) | |
# Test MLP model | |
print("\n1. Testing MLP models:") | |
print("-"*40) | |
# Load original | |
original_mlp = torch.load("compressed_models/mlp_original_fp32.pth") | |
print(f"✅ Loaded original MLP: {os.path.getsize('compressed_models/mlp_original_fp32.pth')/1024:.1f} KB") | |
# Load compressed | |
compressed_mlp = torch.load("compressed_models/mlp_compressed_int8.pth") | |
print(f"✅ Loaded compressed MLP: {os.path.getsize('compressed_models/mlp_compressed_int8.pth')/1024:.1f} KB") | |
# Recreate model and test | |
model = nn.Sequential( | |
nn.Linear(784, 256), | |
nn.ReLU(), | |
nn.Linear(256, 128), | |
nn.ReLU(), | |
nn.Linear(128, 10) | |
) | |
model.load_state_dict(original_mlp['model_state_dict']) | |
# Test inference | |
test_input = torch.randn(1, 784) | |
with torch.no_grad(): | |
output = model(test_input) | |
print(f" Original output shape: {output.shape}") | |
print(f" Prediction: {torch.argmax(output).item()}") | |
# For quantized model, we need to recreate and quantize | |
model_quant = nn.Sequential( | |
nn.Linear(784, 256), | |
nn.ReLU(), | |
nn.Linear(256, 128), | |
nn.ReLU(), | |
nn.Linear(128, 10) | |
) | |
model_quant.eval() | |
model_quant = torch.quantization.quantize_dynamic(model_quant, {nn.Linear}, dtype=torch.qint8) | |
model_quant.load_state_dict(compressed_mlp['model_state_dict']) | |
with torch.no_grad(): | |
output_quant = model_quant(test_input) | |
print(f" Compressed output shape: {output_quant.shape}") | |
print(f" Prediction: {torch.argmax(output_quant).item()}") | |
print("\n✅ Both models work and produce valid outputs!") | |