File size: 5,619 Bytes
7ab4f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e167cea
7ab4f00
 
 
 
 
64cd77c
 
7ab4f00
 
 
 
 
 
 
 
 
 
 
6a77094
 
 
 
7ab4f00
6a77094
 
 
 
 
7ab4f00
6a77094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ab4f00
6a77094
 
 
 
13cf301
 
 
 
 
 
6a77094
 
 
 
 
 
 
 
 
 
 
e167cea
6a77094
 
 
07cb9a1
6a77094
 
 
 
7ab4f00
 
6a77094
7ab4f00
 
 
6a77094
 
 
 
 
 
 
7ab4f00
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import fasterai
from fasterai.sparse.all import *
from fasterai.prune.all import *
import torch
import gradio as gr
import os
from torch.ao.quantization import get_default_qconfig_mapping
import torch.ao.quantization.quantize_fx as quantize_fx
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx

class Quant():
    def __init__(self, backend="x86"):
        self.qconfig = get_default_qconfig_mapping(backend)
    
    def quantize(self, model):
        x = torch.randn(3, 224, 224)
        model_prepared = prepare_fx(model.eval(), self.qconfig, x)
        return convert_fx(model_prepared)


def optimize_model(input_model, sparsity, context, criteria):

    model = torch.load(input_model, weights_only=False)
    model = model.eval()
    model = model.to('cpu')
    sp = Sparsifier(model, 'filter', context, criteria=eval(criteria))
    sp.sparsify_model(sparsity)
    sp._clean_buffers()
    pr = Pruner(model, sparsity, context, criteria=eval(criteria))
    pr.prune_model()
    qu = Quant()
    qu_model = qu.quantize(model)
    
    comp_path = "./comp_model.pth"
    
    scripted = torch.jit.script(qu_model)
    torch.jit.save(scripted, comp_path)
    #torch.save(qu_model, comp_path)
    
    return comp_path

import matplotlib.pyplot as plt
import seaborn as sns
import io
import numpy as np

def get_model_size(model_path):
    """Get model size in MB"""
    size_bytes = os.path.getsize(model_path)
    size_mb = size_bytes / (1024 * 1024)
    return round(size_mb, 2)

def create_size_comparison_plot(original_size, compressed_size):
    """Create a bar plot comparing model sizes"""
    # Set seaborn style
    sns.set_style("darkgrid")
    
    # Create figure with higher DPI for better resolution
    fig = plt.figure(figsize=(10, 6), dpi=150)
    
    # Set transparent background
    fig.patch.set_alpha(0.0)
    ax = plt.gca()
    ax.patch.set_alpha(0.0)
    
    # Plot bars with custom colors and alpha
    bars = plt.bar(['Original', 'Compressed'], 
                  [original_size, compressed_size],
                  color=['#FF6B00', '#FF9F1C'],
                  alpha=0.8,
                  width=0.6)
    
    # Add size labels on top of bars with improved styling
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + (height * 0.01),
                f'{height:.2f} MB',
                ha='center', va='bottom',
                fontsize=11,
                fontweight='bold',
                color='white')
    
    # Calculate compression percentage
    compression_ratio = ((original_size - compressed_size) / original_size) * 100
    
    # Customize title and labels with better visibility
    plt.title(f'Model Size Comparison\nCompression: {compression_ratio:.1f}%', 
              fontsize=14, 
              fontweight='bold', 
              pad=20,
              color='white')
    
    plt.xlabel('Model Version', 
              fontsize=12, 
              fontweight='bold',
              labelpad=10,
              color='white')
    
    plt.ylabel('Size (MB)', 
              fontsize=12, 
              fontweight='bold',
              labelpad=10,
              color='white')
    
    # Customize grid
    ax.grid(alpha=0.2, color='gray')
    
    # Remove top and right spines
    sns.despine()
    
    # Set y-axis limits with some padding
    max_value = max(original_size, compressed_size)
    plt.ylim(0, max_value * 1.2)
    
    # Add more y-axis ticks
    plt.yticks(np.linspace(0, max_value * 1.2, 10))
    
    # Make tick labels white
    ax.tick_params(colors='white')
    for spine in ax.spines.values():
        spine.set_color('white')
    
    # Format axes with white text
    ax.xaxis.label.set_color('white')
    ax.yaxis.label.set_color('white')
    ax.tick_params(axis='x', colors='white')
    ax.tick_params(axis='y', colors='white')
    
    # Format y-axis tick labels
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.1f}'))
    
    # Adjust layout to prevent label cutoff
    plt.tight_layout()
    
    return fig

def main_interface(model_name, sparsity, action):
    import torchvision.models as models
    
    model_mapping = {
        'ResNet18': models.resnet18(pretrained=False),
        'ResNet50': models.resnet50(pretrained=False),
        'MobileNetV2': models.mobilenet_v2(pretrained=False),
        'EfficientNet-B0': models.efficientnet_b0(pretrained=False),
        'VGG16': models.vgg16(pretrained=False),
        'DenseNet121': models.densenet121(pretrained=False)
    }
    
    model = model_mapping[model_name]
    
    # Save model temporarily
    temp_path = "./temp_model.pth"
    torch.save(model, temp_path)
    
    original_size = get_model_size(temp_path)
    
    try:
        compressed_path = optimize_model(temp_path, sparsity, 'local', "large_final")
        compressed_size = get_model_size(compressed_path)
        size_plot = create_size_comparison_plot(original_size, compressed_size)
        
        return size_plot
    finally:
        # Clean up temporary file
        if os.path.exists(temp_path):
            os.remove(temp_path)


available_models = ['ResNet18', 'ResNet50', 'MobileNetV2', 'EfficientNet-B0', 'VGG16', 'DenseNet121']

iface = gr.Interface(
    fn=main_interface,
    inputs=[            
        gr.Dropdown(choices=available_models, label="Select Model", value='ResNet18'),
        gr.Slider(label="Compression Level", minimum=0, maximum=100, value=50),
    ],
    outputs=[
        gr.Plot(label="Size Comparison")  # Changed from gr.Image to gr.Plot
    ],
)

iface.launch()