svsaurav95's picture
Update app.py
174f191 verified
import os
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import timm
import numpy as np
import cv2
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
# Optional: Turn off file watchers in HF Spaces to avoid torch-related warnings
os.environ["STREAMLIT_WATCHER_TYPE"] = "none"
# Define the model class
class MobileViTSegmentation(nn.Module):
def __init__(self, encoder_name='mobilevit_s', pretrained=False):
super().__init__()
self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained)
self.encoder_channels = self.backbone.feature_info.channels()
self.decoder = nn.Sequential(
nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(32, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
feats = self.backbone(x)
out = self.decoder(feats[-1])
out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
return out
# Load model function with spinner and error handling
@st.cache_resource
def load_model():
try:
with st.spinner("Loading model..."):
model = MobileViTSegmentation()
model.load_state_dict(torch.load("mobilevit_teeth_segmentation.pth", map_location="cpu"))
model.eval()
return model
except Exception as e:
st.error(f"❌ Failed to load model: {e}")
st.stop()
# Inference function
def predict_mask(image, model, threshold=0.7):
try:
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
pred = model(img_tensor)
pred_mask = pred.squeeze().numpy()
pred_mask = (pred_mask > threshold).astype(np.uint8)
return pred_mask
except Exception as e:
st.error(f"❌ Prediction failed: {e}")
return None
# Overlay mask on image
def overlay_mask(image, mask, color=(0, 0, 255), alpha=0.4):
try:
image_np = np.array(image.convert("RGB"))
mask_resized = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]))
color_mask = np.zeros_like(image_np)
color_mask[:, :] = color
overlay = np.where(mask_resized[..., None] == 1, color_mask, 0)
blended = cv2.addWeighted(image_np, 1 - alpha, overlay, alpha, 0)
return blended
except Exception as e:
st.error(f"❌ Mask overlay failed: {e}")
return np.array(image)
# Streamlit UI
st.set_page_config(page_title="Tooth Segmentation", layout="wide")
st.title("🦷 Tooth Segmentation from Mouth Images")
st.markdown("Upload a **face or mouth image**, and this app will overlay the **predicted tooth segmentation mask**.")
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file:
try:
image = Image.open(uploaded_file).convert("RGB")
model = load_model()
pred_mask = predict_mask(image, model)
if pred_mask is not None:
overlayed_img = overlay_mask(image, pred_mask, color=(0, 0, 255), alpha=0.4)
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Original Image", use_container_width=True)
with col2:
st.image(overlayed_img, caption="Tooth Mask Overlay", use_container_width=True)
except Exception as e:
st.error(f"❌ Error processing image: {e}")