File size: 8,178 Bytes
250a0ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import torch
from pathlib import Path
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision import transforms
from medmnist import INFO
import gradio as gr
import os
import base64
from io import BytesIO
from huggingface_hub import HfApi
from datetime import datetime
import io
from model import resnet18, resnet50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
AUTH_TOKEN = os.getenv("APP_TOKEN")#to acces the app
DATASET_REPO = os.getenv("Dataset_repo") #"G44mlops/API_received"
HF_TOKEN = os.getenv("HF_TOKEN") #to acces dataset repo
MODEL = os.getenv("Model_repo")#"G44mlops/ResNet-medmnist"
#taken from Mikolaj code with closed PR
def load_model_from_hf(
repo_id: str,
filename: str,
model_type: str,
num_classes: int,
in_channels: int,
device: str,
) -> torch.nn.Module:
"""Load trained model from Hugging Face Hub.
Args:
repo_id: Hugging Face repository ID
filename: Model checkpoint filename
model_type: Type of model ('resnet18' or 'resnet50')
num_classes: Number of output classes
in_channels: Number of input channels
device: Device to load model on
Returns:
Loaded model in eval mode
"""
print(f"Downloading model from Hugging Face: {repo_id}/{filename}")
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
# Create model
if model_type == "resnet18":
model = resnet18(num_classes=num_classes, in_channels=in_channels)
else:
model = resnet50(num_classes=num_classes, in_channels=in_channels)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
return model
#taken from Mikolaj code with closed PR
# Image preprocessing pipeline (basic so far, can be improved)
def get_preprocessing_pipeline() -> transforms.Compose:
"""Get preprocessing pipeline for images."""
#getting information on number of image channels (RGB or Grayscale) for trained model
info = INFO["organamnist"] # Using organamnist as reference
output_channels = info["n_channels"] # RGB or Grayscale
#chosing 'standard' mean and std values for normalization if dataset statistics are not available
mean = (0.5,) * output_channels
std = (0.5,) * output_channels
#preparing transformation pipeline
trans = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
#returning the transformation pipeline
return trans
def get_class_labels(data_flag: str = "organamnist") -> list[str]:
"""Get class labels for MedMNIST dataset."""
#retrieving dataset info
info = INFO[data_flag]
labels = info["label"]
#returning class labels
return labels
def save_image_to_hf_folder(image_path, prediction_label):
"""Upload image to HF dataset folder."""
api = HfApi()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create a text file with metadata
metadata = f"prediction: {prediction_label}\ntimestamp: {timestamp}"
metadata_path = f"{Path(image_path).stem}_metadata.txt"
# Upload image
api.upload_file(
path_or_fileobj=image_path,
path_in_repo=f"uploads/{timestamp}_{Path(image_path).name}",
repo_id=DATASET_REPO,
repo_type="dataset",
token=HF_TOKEN
)
# Upload metadata as separate file
api.upload_file(
path_or_fileobj=io.BytesIO(metadata.encode()),
path_in_repo=f"uploads/{timestamp}_{Path(image_path).stem}_metadata.txt",
repo_id=DATASET_REPO,
repo_type="dataset",
token=HF_TOKEN
)
def classify_images(images) -> str:
"""Classify images and return formatted HTML with embedded images."""
# Handle case with no images
if images is None:
return "<p>No images uploaded</p>"
# Ensure images is a list if(case when only one image is uploaded is problematic without it)
if isinstance(images, str):
images = [images]
#creating HTML structure for results
html = "<div style='display: flex; flex-wrap: wrap; gap: 30px; padding: 20px; justify-content: center;'>"
#loop over images and classify them
for image_path in images:
#preparing image for classification
img = Image.open(image_path).convert("L") # Convert to grayscale (as project uses grayscale images)
input_tensor = preprocess(img).unsqueeze(0)
#forward pass + softmax to get probabilities
with torch.no_grad():
output = model(input_tensor)
probs = torch.nn.functional.softmax(output[0], dim=0)
top_class = probs.argmax().item()
#getting class label
label = class_labels[str(top_class)]
#getting image filename
filename = Path(image_path).name
#Preparing image for embedding in HTML (base64 encoding)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
#adding current image block to HTML
html += f"""
<div style='border: 2px solid #ddd; padding: 15px; border-radius: 8px; background: #f9f9f9; width: 280px;'>
<p style='font-size: 14px; color: #666; margin: 0 0 10px 0; text-align: center; font-weight: bold;'>{filename}</p>
<img src='data:image/jpeg;base64,{img_str}' style='width: 250px; height: 250px; object-fit: contain; display: block; margin: 0 auto 10px;'>
<p style='font-size: 18px; color: #0066cc; margin: 10px 0 0 0; text-align: center; font-weight: bold;'>{label}</p>
</div>
"""
# Save image and metadata to HF dataset folder
save_image_to_hf_folder(image_path, label)
#closing HTML container
html += "</div>"
#returning results
return html
###main code to launch Gradio app###
#prepare model and preprocessing pipeline (kind of backend)
model = load_model_from_hf(#taken from Mikolaj code with closed PR
repo_id=MODEL,
filename="resnet18_best.pth",
model_type="resnet18",
num_classes=11,
in_channels=1,
device=DEVICE,
)
preprocess = get_preprocessing_pipeline()
class_labels = get_class_labels()
#preparing Gradio interface (frontend)
with gr.Blocks() as demo:
#app "title"
gr.Markdown("<h1 style='text-align: center;'> MLOps project - MedMNIST dataset Image Classifier</h1>")
#app message/information )
gr.Markdown("This is a Gradio web application for MLOps course project. Given images are stored in our dataset. " \
"By uploading images you agrree that they will be stored by us and insures that they can be stored by us. " \
"If you somewhat passed the login and are not connected to the project, please do not upload any images. " )
#app spine layout
with gr.Column():
#title of load segment
gr.Markdown("<h2 style='text-align: center;'> Upload Images</h2>")
#images loading component
images_input = gr.File(file_count="multiple", file_types=["image"], label="Upload Images")
#buttons row for app functionality
with gr.Row():
submit_btn = gr.Button("Classify")
reset_btn = gr.Button("Reset")
#title of results segment
gr.Markdown("<h2 style='text-align: center;'> Results</h2>")
#classification results output component
output = gr.HTML(label="Results")
#getting callable reset function
def reset():
return None, ""
#linking buttons to functions
submit_btn.click(classify_images, inputs=images_input, outputs=output)
reset_btn.click(reset, outputs=[images_input, output])
#just launch
server_name = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
demo.launch(
server_name=server_name,
auth=[("user", AUTH_TOKEN)] if AUTH_TOKEN else None
) |