Compressor / app.py
Nathan12's picture
update compressor
07cb9a1
import fasterai
from fasterai.sparse.all import *
from fasterai.prune.all import *
import torch
import gradio as gr
import os
from torch.ao.quantization import get_default_qconfig_mapping
import torch.ao.quantization.quantize_fx as quantize_fx
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
class Quant():
def __init__(self, backend="x86"):
self.qconfig = get_default_qconfig_mapping(backend)
def quantize(self, model):
x = torch.randn(3, 224, 224)
model_prepared = prepare_fx(model.eval(), self.qconfig, x)
return convert_fx(model_prepared)
def optimize_model(input_model, sparsity, context, criteria):
model = torch.load(input_model, weights_only=False)
model = model.eval()
model = model.to('cpu')
sp = Sparsifier(model, 'filter', context, criteria=eval(criteria))
sp.sparsify_model(sparsity)
sp._clean_buffers()
pr = Pruner(model, sparsity, context, criteria=eval(criteria))
pr.prune_model()
qu = Quant()
qu_model = qu.quantize(model)
comp_path = "./comp_model.pth"
scripted = torch.jit.script(qu_model)
torch.jit.save(scripted, comp_path)
#torch.save(qu_model, comp_path)
return comp_path
import matplotlib.pyplot as plt
import seaborn as sns
import io
import numpy as np
def get_model_size(model_path):
"""Get model size in MB"""
size_bytes = os.path.getsize(model_path)
size_mb = size_bytes / (1024 * 1024)
return round(size_mb, 2)
def create_size_comparison_plot(original_size, compressed_size):
"""Create a bar plot comparing model sizes"""
# Set seaborn style
sns.set_style("darkgrid")
# Create figure with higher DPI for better resolution
fig = plt.figure(figsize=(10, 6), dpi=150)
# Set transparent background
fig.patch.set_alpha(0.0)
ax = plt.gca()
ax.patch.set_alpha(0.0)
# Plot bars with custom colors and alpha
bars = plt.bar(['Original', 'Compressed'],
[original_size, compressed_size],
color=['#FF6B00', '#FF9F1C'],
alpha=0.8,
width=0.6)
# Add size labels on top of bars with improved styling
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height + (height * 0.01),
f'{height:.2f} MB',
ha='center', va='bottom',
fontsize=11,
fontweight='bold',
color='white')
# Calculate compression percentage
compression_ratio = ((original_size - compressed_size) / original_size) * 100
# Customize title and labels with better visibility
plt.title(f'Model Size Comparison\nCompression: {compression_ratio:.1f}%',
fontsize=14,
fontweight='bold',
pad=20,
color='white')
plt.xlabel('Model Version',
fontsize=12,
fontweight='bold',
labelpad=10,
color='white')
plt.ylabel('Size (MB)',
fontsize=12,
fontweight='bold',
labelpad=10,
color='white')
# Customize grid
ax.grid(alpha=0.2, color='gray')
# Remove top and right spines
sns.despine()
# Set y-axis limits with some padding
max_value = max(original_size, compressed_size)
plt.ylim(0, max_value * 1.2)
# Add more y-axis ticks
plt.yticks(np.linspace(0, max_value * 1.2, 10))
# Make tick labels white
ax.tick_params(colors='white')
for spine in ax.spines.values():
spine.set_color('white')
# Format axes with white text
ax.xaxis.label.set_color('white')
ax.yaxis.label.set_color('white')
ax.tick_params(axis='x', colors='white')
ax.tick_params(axis='y', colors='white')
# Format y-axis tick labels
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.1f}'))
# Adjust layout to prevent label cutoff
plt.tight_layout()
return fig
def main_interface(model_name, sparsity, action):
import torchvision.models as models
model_mapping = {
'ResNet18': models.resnet18(pretrained=False),
'ResNet50': models.resnet50(pretrained=False),
'MobileNetV2': models.mobilenet_v2(pretrained=False),
'EfficientNet-B0': models.efficientnet_b0(pretrained=False),
'VGG16': models.vgg16(pretrained=False),
'DenseNet121': models.densenet121(pretrained=False)
}
model = model_mapping[model_name]
# Save model temporarily
temp_path = "./temp_model.pth"
torch.save(model, temp_path)
original_size = get_model_size(temp_path)
try:
compressed_path = optimize_model(temp_path, sparsity, 'local', "large_final")
compressed_size = get_model_size(compressed_path)
size_plot = create_size_comparison_plot(original_size, compressed_size)
return size_plot
finally:
# Clean up temporary file
if os.path.exists(temp_path):
os.remove(temp_path)
available_models = ['ResNet18', 'ResNet50', 'MobileNetV2', 'EfficientNet-B0', 'VGG16', 'DenseNet121']
iface = gr.Interface(
fn=main_interface,
inputs=[
gr.Dropdown(choices=available_models, label="Select Model", value='ResNet18'),
gr.Slider(label="Compression Level", minimum=0, maximum=100, value=50),
],
outputs=[
gr.Plot(label="Size Comparison") # Changed from gr.Image to gr.Plot
],
)
iface.launch()