import fasterai from fasterai.sparse.all import * from fasterai.prune.all import * import torch import gradio as gr import os from torch.ao.quantization import get_default_qconfig_mapping import torch.ao.quantization.quantize_fx as quantize_fx from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx class Quant(): def __init__(self, backend="x86"): self.qconfig = get_default_qconfig_mapping(backend) def quantize(self, model): x = torch.randn(3, 224, 224) model_prepared = prepare_fx(model.eval(), self.qconfig, x) return convert_fx(model_prepared) def optimize_model(input_model, sparsity, context, criteria): model = torch.load(input_model, weights_only=False) model = model.eval() model = model.to('cpu') sp = Sparsifier(model, 'filter', context, criteria=eval(criteria)) sp.sparsify_model(sparsity) sp._clean_buffers() pr = Pruner(model, sparsity, context, criteria=eval(criteria)) pr.prune_model() qu = Quant() qu_model = qu.quantize(model) comp_path = "./comp_model.pth" scripted = torch.jit.script(qu_model) torch.jit.save(scripted, comp_path) #torch.save(qu_model, comp_path) return comp_path import matplotlib.pyplot as plt import seaborn as sns import io import numpy as np def get_model_size(model_path): """Get model size in MB""" size_bytes = os.path.getsize(model_path) size_mb = size_bytes / (1024 * 1024) return round(size_mb, 2) def create_size_comparison_plot(original_size, compressed_size): """Create a bar plot comparing model sizes""" # Set seaborn style sns.set_style("darkgrid") # Create figure with higher DPI for better resolution fig = plt.figure(figsize=(10, 6), dpi=150) # Set transparent background fig.patch.set_alpha(0.0) ax = plt.gca() ax.patch.set_alpha(0.0) # Plot bars with custom colors and alpha bars = plt.bar(['Original', 'Compressed'], [original_size, compressed_size], color=['#FF6B00', '#FF9F1C'], alpha=0.8, width=0.6) # Add size labels on top of bars with improved styling for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height + (height * 0.01), f'{height:.2f} MB', ha='center', va='bottom', fontsize=11, fontweight='bold', color='white') # Calculate compression percentage compression_ratio = ((original_size - compressed_size) / original_size) * 100 # Customize title and labels with better visibility plt.title(f'Model Size Comparison\nCompression: {compression_ratio:.1f}%', fontsize=14, fontweight='bold', pad=20, color='white') plt.xlabel('Model Version', fontsize=12, fontweight='bold', labelpad=10, color='white') plt.ylabel('Size (MB)', fontsize=12, fontweight='bold', labelpad=10, color='white') # Customize grid ax.grid(alpha=0.2, color='gray') # Remove top and right spines sns.despine() # Set y-axis limits with some padding max_value = max(original_size, compressed_size) plt.ylim(0, max_value * 1.2) # Add more y-axis ticks plt.yticks(np.linspace(0, max_value * 1.2, 10)) # Make tick labels white ax.tick_params(colors='white') for spine in ax.spines.values(): spine.set_color('white') # Format axes with white text ax.xaxis.label.set_color('white') ax.yaxis.label.set_color('white') ax.tick_params(axis='x', colors='white') ax.tick_params(axis='y', colors='white') # Format y-axis tick labels ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.1f}')) # Adjust layout to prevent label cutoff plt.tight_layout() return fig def main_interface(model_name, sparsity, action): import torchvision.models as models model_mapping = { 'ResNet18': models.resnet18(pretrained=False), 'ResNet50': models.resnet50(pretrained=False), 'MobileNetV2': models.mobilenet_v2(pretrained=False), 'EfficientNet-B0': models.efficientnet_b0(pretrained=False), 'VGG16': models.vgg16(pretrained=False), 'DenseNet121': models.densenet121(pretrained=False) } model = model_mapping[model_name] # Save model temporarily temp_path = "./temp_model.pth" torch.save(model, temp_path) original_size = get_model_size(temp_path) try: compressed_path = optimize_model(temp_path, sparsity, 'local', "large_final") compressed_size = get_model_size(compressed_path) size_plot = create_size_comparison_plot(original_size, compressed_size) return size_plot finally: # Clean up temporary file if os.path.exists(temp_path): os.remove(temp_path) available_models = ['ResNet18', 'ResNet50', 'MobileNetV2', 'EfficientNet-B0', 'VGG16', 'DenseNet121'] iface = gr.Interface( fn=main_interface, inputs=[ gr.Dropdown(choices=available_models, label="Select Model", value='ResNet18'), gr.Slider(label="Compression Level", minimum=0, maximum=100, value=50), ], outputs=[ gr.Plot(label="Size Comparison") # Changed from gr.Image to gr.Plot ], ) iface.launch()