Spaces:
Running
Running
import fasterai | |
from fasterai.sparse.all import * | |
from fasterai.prune.all import * | |
import torch | |
import gradio as gr | |
import os | |
from torch.ao.quantization import get_default_qconfig_mapping | |
import torch.ao.quantization.quantize_fx as quantize_fx | |
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx | |
class Quant(): | |
def __init__(self, backend="x86"): | |
self.qconfig = get_default_qconfig_mapping(backend) | |
def quantize(self, model): | |
x = torch.randn(3, 224, 224) | |
model_prepared = prepare_fx(model.eval(), self.qconfig, x) | |
return convert_fx(model_prepared) | |
def optimize_model(input_model, sparsity, context, criteria): | |
model = torch.load(input_model, weights_only=False) | |
model = model.eval() | |
model = model.to('cpu') | |
sp = Sparsifier(model, 'filter', context, criteria=eval(criteria)) | |
sp.sparsify_model(sparsity) | |
sp._clean_buffers() | |
pr = Pruner(model, sparsity, context, criteria=eval(criteria)) | |
pr.prune_model() | |
qu = Quant() | |
qu_model = qu.quantize(model) | |
comp_path = "./comp_model.pth" | |
scripted = torch.jit.script(qu_model) | |
torch.jit.save(scripted, comp_path) | |
#torch.save(qu_model, comp_path) | |
return comp_path | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import io | |
import numpy as np | |
def get_model_size(model_path): | |
"""Get model size in MB""" | |
size_bytes = os.path.getsize(model_path) | |
size_mb = size_bytes / (1024 * 1024) | |
return round(size_mb, 2) | |
def create_size_comparison_plot(original_size, compressed_size): | |
"""Create a bar plot comparing model sizes""" | |
# Set seaborn style | |
sns.set_style("darkgrid") | |
# Create figure with higher DPI for better resolution | |
fig = plt.figure(figsize=(10, 6), dpi=150) | |
# Set transparent background | |
fig.patch.set_alpha(0.0) | |
ax = plt.gca() | |
ax.patch.set_alpha(0.0) | |
# Plot bars with custom colors and alpha | |
bars = plt.bar(['Original', 'Compressed'], | |
[original_size, compressed_size], | |
color=['#FF6B00', '#FF9F1C'], | |
alpha=0.8, | |
width=0.6) | |
# Add size labels on top of bars with improved styling | |
for bar in bars: | |
height = bar.get_height() | |
plt.text(bar.get_x() + bar.get_width()/2., height + (height * 0.01), | |
f'{height:.2f} MB', | |
ha='center', va='bottom', | |
fontsize=11, | |
fontweight='bold', | |
color='white') | |
# Calculate compression percentage | |
compression_ratio = ((original_size - compressed_size) / original_size) * 100 | |
# Customize title and labels with better visibility | |
plt.title(f'Model Size Comparison\nCompression: {compression_ratio:.1f}%', | |
fontsize=14, | |
fontweight='bold', | |
pad=20, | |
color='white') | |
plt.xlabel('Model Version', | |
fontsize=12, | |
fontweight='bold', | |
labelpad=10, | |
color='white') | |
plt.ylabel('Size (MB)', | |
fontsize=12, | |
fontweight='bold', | |
labelpad=10, | |
color='white') | |
# Customize grid | |
ax.grid(alpha=0.2, color='gray') | |
# Remove top and right spines | |
sns.despine() | |
# Set y-axis limits with some padding | |
max_value = max(original_size, compressed_size) | |
plt.ylim(0, max_value * 1.2) | |
# Add more y-axis ticks | |
plt.yticks(np.linspace(0, max_value * 1.2, 10)) | |
# Make tick labels white | |
ax.tick_params(colors='white') | |
for spine in ax.spines.values(): | |
spine.set_color('white') | |
# Format axes with white text | |
ax.xaxis.label.set_color('white') | |
ax.yaxis.label.set_color('white') | |
ax.tick_params(axis='x', colors='white') | |
ax.tick_params(axis='y', colors='white') | |
# Format y-axis tick labels | |
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.1f}')) | |
# Adjust layout to prevent label cutoff | |
plt.tight_layout() | |
return fig | |
def main_interface(model_name, sparsity, action): | |
import torchvision.models as models | |
model_mapping = { | |
'ResNet18': models.resnet18(pretrained=False), | |
'ResNet50': models.resnet50(pretrained=False), | |
'MobileNetV2': models.mobilenet_v2(pretrained=False), | |
'EfficientNet-B0': models.efficientnet_b0(pretrained=False), | |
'VGG16': models.vgg16(pretrained=False), | |
'DenseNet121': models.densenet121(pretrained=False) | |
} | |
model = model_mapping[model_name] | |
# Save model temporarily | |
temp_path = "./temp_model.pth" | |
torch.save(model, temp_path) | |
original_size = get_model_size(temp_path) | |
try: | |
compressed_path = optimize_model(temp_path, sparsity, 'local', "large_final") | |
compressed_size = get_model_size(compressed_path) | |
size_plot = create_size_comparison_plot(original_size, compressed_size) | |
return size_plot | |
finally: | |
# Clean up temporary file | |
if os.path.exists(temp_path): | |
os.remove(temp_path) | |
available_models = ['ResNet18', 'ResNet50', 'MobileNetV2', 'EfficientNet-B0', 'VGG16', 'DenseNet121'] | |
iface = gr.Interface( | |
fn=main_interface, | |
inputs=[ | |
gr.Dropdown(choices=available_models, label="Select Model", value='ResNet18'), | |
gr.Slider(label="Compression Level", minimum=0, maximum=100, value=50), | |
], | |
outputs=[ | |
gr.Plot(label="Size Comparison") # Changed from gr.Image to gr.Plot | |
], | |
) | |
iface.launch() |