|
import gradio as gr |
|
import torch |
|
from torchvision.transforms import transforms |
|
import numpy as np |
|
from typing import Optional |
|
import torch.nn as nn |
|
import os |
|
import shutil |
|
from utils import page_utils |
|
|
|
class BasicBlock(nn.Module): |
|
"""ResNet Basic Block. |
|
Parameters |
|
---------- |
|
in_channels : int |
|
Number of input channels |
|
out_channels : int |
|
Number of output channels |
|
stride : int, optional |
|
Convolution stride size, by default 1 |
|
identity_downsample : Optional[torch.nn.Module], optional |
|
Downsampling layer, by default None |
|
""" |
|
|
|
def __init__(self, |
|
in_channels: int, |
|
out_channels: int, |
|
stride: int = 1, |
|
identity_downsample: Optional[torch.nn.Module] = None): |
|
super(BasicBlock, self).__init__() |
|
self.conv1 = nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size = 3, |
|
stride = stride, |
|
padding = 1) |
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
self.relu = nn.ReLU() |
|
self.conv2 = nn.Conv2d(out_channels, |
|
out_channels, |
|
kernel_size = 3, |
|
stride = 1, |
|
padding = 1) |
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
self.identity_downsample = identity_downsample |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Apply forward computation.""" |
|
identity = x |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
x = self.conv2(x) |
|
x = self.bn2(x) |
|
|
|
|
|
|
|
if self.identity_downsample is not None: |
|
identity = self.identity_downsample(identity) |
|
x += identity |
|
x = self.relu(x) |
|
return x |
|
|
|
class ResNet18(nn.Module): |
|
"""Construct ResNet-18 Model. |
|
Parameters |
|
---------- |
|
input_channels : int |
|
Number of input channels |
|
num_classes : int |
|
Number of class outputs |
|
""" |
|
|
|
def __init__(self, input_channels, num_classes): |
|
|
|
super(ResNet18, self).__init__() |
|
self.conv1 = nn.Conv2d(input_channels, |
|
64, kernel_size = 7, |
|
stride = 2, padding=3) |
|
self.bn1 = nn.BatchNorm2d(64) |
|
self.relu = nn.ReLU() |
|
self.maxpool = nn.MaxPool2d(kernel_size = 3, |
|
stride = 2, |
|
padding = 1) |
|
|
|
self.layer1 = self._make_layer(64, 64, stride = 1) |
|
self.layer2 = self._make_layer(64, 128, stride = 2) |
|
self.layer3 = self._make_layer(128, 256, stride = 2) |
|
self.layer4 = self._make_layer(256, 512, stride = 2) |
|
|
|
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
|
self.fc = nn.Linear(512, num_classes) |
|
|
|
def identity_downsample(self, in_channels: int, out_channels: int) -> nn.Module: |
|
"""Downsampling block to reduce the feature sizes.""" |
|
return nn.Sequential( |
|
nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size = 3, |
|
stride = 2, |
|
padding = 1), |
|
nn.BatchNorm2d(out_channels) |
|
) |
|
|
|
def _make_layer(self, in_channels: int, out_channels: int, stride: int) -> nn.Module: |
|
"""Create sequential basic block.""" |
|
identity_downsample = None |
|
|
|
|
|
if stride != 1: |
|
identity_downsample = self.identity_downsample(in_channels, out_channels) |
|
|
|
return nn.Sequential( |
|
BasicBlock(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride), |
|
BasicBlock(out_channels, out_channels) |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
x = self.layer3(x) |
|
x = self.layer4(x) |
|
|
|
x = self.avgpool(x) |
|
x = x.view(x.shape[0], -1) |
|
x = self.fc(x) |
|
return x |
|
|
|
model = ResNet18(3, 7) |
|
|
|
checkpoint = torch.load('ham10000.ckpt', map_location=torch.device('cpu')) |
|
|
|
|
|
|
|
state_dict = checkpoint['state_dict'] |
|
for key in list(state_dict.keys()): |
|
if 'net.' in key: |
|
state_dict[key.replace('net.', '')] = state_dict[key] |
|
del state_dict[key] |
|
|
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
|
|
class_names = { |
|
'akk': 'Actinic Keratosis', |
|
'bcc': 'Basal Cell Carcinoma', |
|
'bkl': 'Benign Keratosis', |
|
'df': 'Dermatofibroma', |
|
'mel': 'Melanoma', |
|
'nv': 'Melanocytic Nevi', |
|
'vasc': 'Vascular Lesion' |
|
} |
|
|
|
examples_dir = "sample" |
|
|
|
transformation_pipeline = transforms.Compose([ |
|
transforms.ToPILImage(), |
|
transforms.Grayscale(num_output_channels=3), |
|
transforms.CenterCrop((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def preprocess_image(image: np.ndarray): |
|
"""Preprocess the input image. |
|
Note that the input image is in RGB mode. |
|
Parameters |
|
---------- |
|
image: np.ndarray |
|
Input image from callback. |
|
""" |
|
image = transformation_pipeline(image) |
|
image = torch.unsqueeze(image, 0) |
|
return image |
|
|
|
def image_classifier(inp): |
|
"""Image Classifier Function. |
|
Parameters |
|
---------- |
|
inp: Optional[np.ndarray] = None |
|
Input image from callback |
|
Returns |
|
------- |
|
Dict |
|
A dictionary class names and its probability |
|
""" |
|
|
|
if inp is None: |
|
return { |
|
'Actinic Keratosis': 0.0, |
|
'Basal Cell Carcinoma': 0.0, |
|
'Benign Keratosis': 0.0, |
|
'Dermatofibroma': 0.0, |
|
'Melanoma': 0.0, |
|
'Melanocytic Nevi': 0.0, |
|
'Vascular Lesion': 0.0 |
|
} |
|
|
|
image = preprocess_image(inp) |
|
image = image.to(dtype=torch.float32) |
|
|
|
|
|
result = model(image) |
|
|
|
|
|
result = torch.nn.functional.softmax(result, dim=1) |
|
result = result[0].detach().numpy().tolist() |
|
labeled_result = {class_names[name]: score for name, score in zip(class_names, result)} |
|
|
|
return labeled_result |
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("# Skin Cancer Classification") |
|
|
|
with open('index.html', encoding="utf-8") as f: |
|
description = f.read() |
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set( |
|
button_primary_background_fill="*primary_600", |
|
button_primary_background_fill_hover="*primary_500", |
|
button_primary_text_color="white", |
|
)) as app: |
|
with gr.Column(): |
|
gr.HTML(description) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
inp_img = gr.Image() |
|
with gr.Row(): |
|
clear_btn = gr.Button(value="Clear") |
|
process_btn = gr.Button(value="Process", variant="primary") |
|
with gr.Column(): |
|
out_txt = gr.Label(label="Probabilities", num_top_classes=3) |
|
|
|
process_btn.click(image_classifier, inputs=inp_img, outputs=out_txt) |
|
clear_btn.click(lambda: ( |
|
gr.update(value=None), |
|
gr.update(value=None) |
|
), |
|
inputs=None, |
|
outputs=[inp_img, out_txt]) |
|
|
|
gr.Markdown("## Image Examples") |
|
gr.Examples( |
|
examples=[os.path.join(examples_dir, "nv.jpeg"), |
|
os.path.join(examples_dir, "bcc.jpeg"), |
|
os.path.join(examples_dir, "bkl_1.jpeg"), |
|
os.path.join(examples_dir, "akk.jpeg"), |
|
os.path.join(examples_dir, "mel-_3_.jpeg"), |
|
], |
|
inputs=inp_img, |
|
outputs=out_txt, |
|
fn=image_classifier, |
|
cache_examples=False, |
|
) |
|
gr.Markdown(line_breaks=True, value='Author: M HAIKAL FEBRIAN P (haikalphona23@gmail.com) <div class="row"><a href="https://github.com/HAikalfebrianp96?tab=repositories"><img alt="GitHub" src="https://img.shields.io/badge/haikal%20phona-000000?logo=github"> </div>') |
|
|
|
|
|
app.launch(share=True) |
|
|
|
|