hasanah10105's picture
add author info html before launch
7544e95
import torch
import os
import gradio as gr
import numpy as np
from torchvision.transforms import transforms
from typing import Optional
import torch.nn as nn
from utils import page_utils
class BasicBlock(nn.Module):
"""
ResNet Basic Block.
This class defines a basic building block for ResNet architectures. It consists of two convolutional
layers with batch normalization and a ReLU activation function. Optionally, it can include an
identity downsample layer to match the dimensions of the input and output when the stride is not 1.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
stride : int, optional
Convolution stride size, by default 1.
identity_downsample : Optional[torch.nn.Module], optional
Downsampling layer, by default None.
Methods
-------
forward(x: torch.Tensor) -> torch.Tensor:
Apply forward computation.
"""
def __init__(self,
in_channels: int,
out_channels: int,
stride: int = 1,
identity_downsample: Optional[torch.nn.Module] = None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels,
out_channels,
kernel_size = 3,
stride = stride,
padding = 1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels,
out_channels,
kernel_size = 3,
stride = 1,
padding = 1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.identity_downsample = identity_downsample
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward computation."""
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
# Apply an operation to the identity output.
# Useful to reduce the layer size and match from conv2 output
if self.identity_downsample is not None:
identity = self.identity_downsample(identity)
x += identity
x = self.relu(x)
return x
class ResNet18(nn.Module):
"""
Construct ResNet-18 Model.
This class defines the ResNet-18 architecture, including convolutional layers, basic blocks, and
fully connected layers for classification.
Parameters
----------
input_channels : int
Number of input channels.
num_classes : int
Number of class outputs.
Methods
-------
forward(x: torch.Tensor) -> torch.Tensor:
Apply forward computation.
"""
def __init__(self, input_channels, num_classes):
super(ResNet18, self).__init__()
self.conv1 = nn.Conv2d(input_channels,
64, kernel_size = 7,
stride = 2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size = 3,
stride = 2,
padding = 1)
self.layer1 = self._make_layer(64, 64, stride = 1)
self.layer2 = self._make_layer(64, 128, stride = 2)
self.layer3 = self._make_layer(128, 256, stride = 2)
self.layer4 = self._make_layer(256, 512, stride = 2)
# Last layers
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
def identity_downsample(self, in_channels: int, out_channels: int) -> nn.Module:
"""Downsampling block to reduce the feature sizes."""
return nn.Sequential(
nn.Conv2d(in_channels,
out_channels,
kernel_size = 3,
stride = 2,
padding = 1),
nn.BatchNorm2d(out_channels)
)
def _make_layer(self, in_channels: int, out_channels: int, stride: int) -> nn.Module:
"""Create sequential basic block."""
identity_downsample = None
# Add downsampling function
if stride != 1:
identity_downsample = self.identity_downsample(in_channels, out_channels)
return nn.Sequential(
BasicBlock(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride),
BasicBlock(out_channels, out_channels)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
model = ResNet18(3, 3)
checkpoint = torch.load('epoch=49-step=1750.ckpt', map_location=torch.device('cpu'))
state_dict = checkpoint['state_dict']
for key in list(state_dict.keys()):
if 'net.' in key:
state_dict[key.replace('net.', '')] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
model.eval()
class_names = ['benign', 'malignant', 'normal']
class_names.sort()
example_dir = "SAMPLES"
transformation_pipeline = transforms.Compose([
transforms.ToPILImage(),
transforms.Grayscale(num_output_channels=3),
transforms.Resize((256, 256)),
transforms.RandomRotation(20),
transforms.ToTensor(),
transforms.Normalize(mean=[0.233827, 0.2338219, 0.23378967], std=[0.2016421162328173, 0.20164345656093885, 0.20160390432148026])
])
def preprocess_image(image: np.ndarray):
"""Preprocess the input image.
Note that the input image is in RGB mode.
Parameters
----------
image: np.ndarray
Input image from callback.
"""
image = transformation_pipeline(image)
image = torch.unsqueeze(image, 0)
return image
def image_classifier(inp):
"""Image Classifier Function.
Parameters
----------
inp: Optional[np.ndarray] = None
Input image from callback
Returns
-------
Dict
A dictionary class names and its probability
"""
# If input not valid, return dummy data or raise error
if inp is None:
return {gr.Error()}
# preprocess
image = preprocess_image(inp)
image = image.to(dtype=torch.float32)
# inference
result = model(image)
# postprocess
result = torch.nn.functional.softmax(result, dim=1) # apply softmax
result = result[0].detach().numpy().tolist() # take the first batch
labeled_result = {name:score for name, score in zip(class_names, result)}
return labeled_result
with open('index.html', encoding="utf-8") as f:
description = f.read()
with open('author.html', encoding="utf-8") as author:
author_info = author.read()
with gr.Blocks(
theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
button_primary_background_fill="*primary_600",
button_primary_background_fill_hover="*primary_500",
button_primary_text_color="white",)
) as demo:
with gr.Column():
gr.HTML(description)
with gr.Row(): # Build a row
with gr.Column(): # build a column section as the first item
inp = gr.Image(label="image", image_mode="RGB") # build an image as the first column item
with gr.Row(): # build a row section as the second item
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit")
# build a label as the second item
out = gr.Label(label="prediction", num_top_classes=3)
# Define buttons functionalities
submit_btn.click(fn=image_classifier, inputs=inp, outputs=out)
clear_btn.click(
lambda: (
gr.update(value=None),
gr.update(value=None),
),
inputs=None,
outputs=[inp, out]
)
# Add examples
gr.Markdown("## Image Examples")
gr.Examples(
example_dir,
inputs=[inp],
label="Image Examples",
cache_examples=False
)
with gr.Column():
gr.HTML(author_info)
demo.launch(share=True)