File size: 8,178 Bytes
250a0ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import torch
from pathlib import Path
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision import transforms
from medmnist import INFO
import gradio as gr
import os
import base64
from io import BytesIO
from huggingface_hub import HfApi
from datetime import datetime
import io

from model import resnet18, resnet50

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
AUTH_TOKEN = os.getenv("APP_TOKEN")#to acces the app
DATASET_REPO = os.getenv("Dataset_repo") #"G44mlops/API_received"  
HF_TOKEN = os.getenv("HF_TOKEN")  #to acces dataset repo
MODEL = os.getenv("Model_repo")#"G44mlops/ResNet-medmnist"

#taken from Mikolaj code with closed PR
def load_model_from_hf(
    repo_id: str,
    filename: str,
    model_type: str,
    num_classes: int,
    in_channels: int,
    device: str,
) -> torch.nn.Module:
    """Load trained model from Hugging Face Hub.
    
    Args:
        repo_id: Hugging Face repository ID
        filename: Model checkpoint filename
        model_type: Type of model ('resnet18' or 'resnet50')
        num_classes: Number of output classes
        in_channels: Number of input channels
        device: Device to load model on
        
    Returns:
        Loaded model in eval mode
    """
    print(f"Downloading model from Hugging Face: {repo_id}/{filename}")
    checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)

    # Create model
    if model_type == "resnet18":
        model = resnet18(num_classes=num_classes, in_channels=in_channels)
    else:
        model = resnet50(num_classes=num_classes, in_channels=in_channels)

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device)
    model.eval()

    return model
#taken from Mikolaj code with closed PR


# Image preprocessing pipeline (basic so far, can be improved)
def get_preprocessing_pipeline() -> transforms.Compose:
    """Get preprocessing pipeline for images."""
    #getting information on number of image channels (RGB or Grayscale) for trained model
    info = INFO["organamnist"]  # Using organamnist as reference
    output_channels = info["n_channels"] # RGB or Grayscale
    #chosing 'standard' mean and std values for normalization if dataset statistics are not available
    mean = (0.5,) * output_channels
    std = (0.5,) * output_channels
    #preparing transformation pipeline
    trans = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])
    #returning the transformation pipeline
    return trans
def get_class_labels(data_flag: str = "organamnist") -> list[str]:
    """Get class labels for MedMNIST dataset."""
    #retrieving dataset info
    info = INFO[data_flag]
    labels = info["label"]
    #returning class labels
    return labels

def save_image_to_hf_folder(image_path, prediction_label):
    """Upload image to HF dataset folder."""
    api = HfApi()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Create a text file with metadata
    metadata = f"prediction: {prediction_label}\ntimestamp: {timestamp}"
    metadata_path = f"{Path(image_path).stem}_metadata.txt"
    
    # Upload image
    api.upload_file(
        path_or_fileobj=image_path,
        path_in_repo=f"uploads/{timestamp}_{Path(image_path).name}",
        repo_id=DATASET_REPO,
        repo_type="dataset",
        token=HF_TOKEN
    )
     # Upload metadata as separate file
    api.upload_file(
        path_or_fileobj=io.BytesIO(metadata.encode()),
        path_in_repo=f"uploads/{timestamp}_{Path(image_path).stem}_metadata.txt",
        repo_id=DATASET_REPO,
        repo_type="dataset",
        token=HF_TOKEN
    )

def classify_images(images) -> str:
    """Classify images and return formatted HTML with embedded images."""
    # Handle case with no images
    if images is None:
        return "<p>No images uploaded</p>"
    # Ensure images is a list if(case when only one image is uploaded is problematic without it)
    if isinstance(images, str):
        images = [images]
    #creating HTML structure for results
    html = "<div style='display: flex; flex-wrap: wrap; gap: 30px; padding: 20px; justify-content: center;'>"
    #loop over images and classify them
    for image_path in images:
        #preparing image for classification
        img = Image.open(image_path).convert("L")  # Convert to grayscale (as project uses grayscale images)
        input_tensor = preprocess(img).unsqueeze(0)
        #forward pass + softmax to get probabilities
        with torch.no_grad():
            output = model(input_tensor)
            probs = torch.nn.functional.softmax(output[0], dim=0)
            top_class = probs.argmax().item()
        #getting class label
        label = class_labels[str(top_class)]
        #getting image filename
        filename = Path(image_path).name
        #Preparing image for embedding in HTML (base64 encoding)
        buffered = BytesIO()
        img.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        #adding current image block to HTML
        html += f"""
        <div style='border: 2px solid #ddd; padding: 15px; border-radius: 8px; background: #f9f9f9; width: 280px;'>
            <p style='font-size: 14px; color: #666; margin: 0 0 10px 0; text-align: center; font-weight: bold;'>{filename}</p>
            <img src='data:image/jpeg;base64,{img_str}' style='width: 250px; height: 250px; object-fit: contain; display: block; margin: 0 auto 10px;'>
            <p style='font-size: 18px; color: #0066cc; margin: 10px 0 0 0; text-align: center; font-weight: bold;'>{label}</p>
        </div>
        """
        # Save image and metadata to HF dataset folder
        save_image_to_hf_folder(image_path, label)
    #closing HTML container
    html += "</div>"
    #returning results
    return html

###main code to launch Gradio app###

#prepare model and preprocessing pipeline (kind of backend)
model =  load_model_from_hf(#taken from Mikolaj code with closed PR
            repo_id=MODEL, 
            filename="resnet18_best.pth",
            model_type="resnet18",
            num_classes=11,
            in_channels=1,
            device=DEVICE,
        )
preprocess = get_preprocessing_pipeline()
class_labels = get_class_labels()
#preparing Gradio interface (frontend)
with gr.Blocks() as demo:
    #app "title"
    gr.Markdown("<h1 style='text-align: center;'> MLOps project - MedMNIST dataset Image Classifier</h1>")
    #app message/information )
    gr.Markdown("This is a Gradio web application for MLOps course project. Given images are stored in our dataset. " \
    "By uploading images you agrree that they will be stored by us and insures that they can be stored by us. " \
    "If you somewhat passed the login and are not connected to the project, please do not upload any images. " )
    #app spine layout
    with gr.Column():
        #title of load segment
        gr.Markdown("<h2 style='text-align: center;'> Upload Images</h2>")
        #images loading component
        images_input = gr.File(file_count="multiple", file_types=["image"], label="Upload Images")
        #buttons row for app functionality
        with gr.Row():
            submit_btn = gr.Button("Classify")
            reset_btn = gr.Button("Reset")
        #title of results segment
        gr.Markdown("<h2 style='text-align: center;'> Results</h2>")
        #classification results output component
        output = gr.HTML(label="Results")
    #getting callable reset function
    def reset():
        return None, ""
    #linking buttons to functions
    submit_btn.click(classify_images, inputs=images_input, outputs=output)
    reset_btn.click(reset, outputs=[images_input, output])


    #just launch
    server_name = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
    demo.launch(
        server_name=server_name,
        auth=[("user", AUTH_TOKEN)] if AUTH_TOKEN else None
        )