import os import streamlit as st import torch import torch.nn as nn import torchvision.transforms as transforms import timm import numpy as np import cv2 from PIL import Image import warnings warnings.filterwarnings("ignore") # Optional: Turn off file watchers in HF Spaces to avoid torch-related warnings os.environ["STREAMLIT_WATCHER_TYPE"] = "none" # Define the model class class MobileViTSegmentation(nn.Module): def __init__(self, encoder_name='mobilevit_s', pretrained=False): super().__init__() self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained) self.encoder_channels = self.backbone.feature_info.channels() self.decoder = nn.Sequential( nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1), nn.Upsample(scale_factor=2, mode='bilinear'), nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.Upsample(scale_factor=2, mode='bilinear'), nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.Upsample(scale_factor=2, mode='bilinear'), nn.Conv2d(32, 1, kernel_size=1), nn.Sigmoid() ) def forward(self, x): feats = self.backbone(x) out = self.decoder(feats[-1]) out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False) return out # Load model function with spinner and error handling @st.cache_resource def load_model(): try: with st.spinner("Loading model..."): model = MobileViTSegmentation() model.load_state_dict(torch.load("mobilevit_teeth_segmentation.pth", map_location="cpu")) model.eval() return model except Exception as e: st.error(f"❌ Failed to load model: {e}") st.stop() # Inference function def predict_mask(image, model, threshold=0.7): try: transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor() ]) img_tensor = transform(image).unsqueeze(0) with torch.no_grad(): pred = model(img_tensor) pred_mask = pred.squeeze().numpy() pred_mask = (pred_mask > threshold).astype(np.uint8) return pred_mask except Exception as e: st.error(f"❌ Prediction failed: {e}") return None # Overlay mask on image def overlay_mask(image, mask, color=(0, 0, 255), alpha=0.4): try: image_np = np.array(image.convert("RGB")) mask_resized = cv2.resize(mask, (image_np.shape[1], image_np.shape[0])) color_mask = np.zeros_like(image_np) color_mask[:, :] = color overlay = np.where(mask_resized[..., None] == 1, color_mask, 0) blended = cv2.addWeighted(image_np, 1 - alpha, overlay, alpha, 0) return blended except Exception as e: st.error(f"❌ Mask overlay failed: {e}") return np.array(image) # Streamlit UI st.set_page_config(page_title="Tooth Segmentation", layout="wide") st.title("🦷 Tooth Segmentation from Mouth Images") st.markdown("Upload a **face or mouth image**, and this app will overlay the **predicted tooth segmentation mask**.") uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if uploaded_file: try: image = Image.open(uploaded_file).convert("RGB") model = load_model() pred_mask = predict_mask(image, model) if pred_mask is not None: overlayed_img = overlay_mask(image, pred_mask, color=(0, 0, 255), alpha=0.4) col1, col2 = st.columns(2) with col1: st.image(image, caption="Original Image", use_container_width=True) with col2: st.image(overlayed_img, caption="Tooth Mask Overlay", use_container_width=True) except Exception as e: st.error(f"❌ Error processing image: {e}")