File size: 3,940 Bytes
174f191
a80e992
157ff13
1be9e52
aac18fb
1be9e52
 
 
 
aac18fb
1be9e52
aac18fb
1be9e52
174f191
 
 
aac18fb
1be9e52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174f191
1be9e52
 
174f191
 
 
 
 
 
 
 
 
1be9e52
174f191
aac18fb
174f191
 
 
 
 
 
 
 
 
 
 
 
 
 
aac18fb
 
 
174f191
 
 
 
 
 
 
 
 
 
 
aac18fb
 
174f191
aac18fb
174f191
aac18fb
 
1be9e52
 
174f191
 
 
 
1be9e52
174f191
 
aac18fb
174f191
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import timm
import numpy as np
import cv2
from PIL import Image
import warnings

warnings.filterwarnings("ignore")

# Optional: Turn off file watchers in HF Spaces to avoid torch-related warnings
os.environ["STREAMLIT_WATCHER_TYPE"] = "none"

# Define the model class
class MobileViTSegmentation(nn.Module):
    def __init__(self, encoder_name='mobilevit_s', pretrained=False):
        super().__init__()
        self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained)
        self.encoder_channels = self.backbone.feature_info.channels()
        self.decoder = nn.Sequential(
            nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(32, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        feats = self.backbone(x)
        out = self.decoder(feats[-1])
        out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
        return out

# Load model function with spinner and error handling
@st.cache_resource
def load_model():
    try:
        with st.spinner("Loading model..."):
            model = MobileViTSegmentation()
            model.load_state_dict(torch.load("mobilevit_teeth_segmentation.pth", map_location="cpu"))
            model.eval()
            return model
    except Exception as e:
        st.error(f"❌ Failed to load model: {e}")
        st.stop()

# Inference function
def predict_mask(image, model, threshold=0.7):
    try:
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])
        img_tensor = transform(image).unsqueeze(0)
        with torch.no_grad():
            pred = model(img_tensor)
        pred_mask = pred.squeeze().numpy()
        pred_mask = (pred_mask > threshold).astype(np.uint8)
        return pred_mask
    except Exception as e:
        st.error(f"❌ Prediction failed: {e}")
        return None

# Overlay mask on image
def overlay_mask(image, mask, color=(0, 0, 255), alpha=0.4):
    try:
        image_np = np.array(image.convert("RGB"))
        mask_resized = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]))
        color_mask = np.zeros_like(image_np)
        color_mask[:, :] = color
        overlay = np.where(mask_resized[..., None] == 1, color_mask, 0)
        blended = cv2.addWeighted(image_np, 1 - alpha, overlay, alpha, 0)
        return blended
    except Exception as e:
        st.error(f"❌ Mask overlay failed: {e}")
        return np.array(image)

# Streamlit UI
st.set_page_config(page_title="Tooth Segmentation", layout="wide")
st.title("🦷 Tooth Segmentation from Mouth Images")
st.markdown("Upload a **face or mouth image**, and this app will overlay the **predicted tooth segmentation mask**.")

uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

if uploaded_file:
    try:
        image = Image.open(uploaded_file).convert("RGB")
        model = load_model()
        pred_mask = predict_mask(image, model)

        if pred_mask is not None:
            overlayed_img = overlay_mask(image, pred_mask, color=(0, 0, 255), alpha=0.4)

            col1, col2 = st.columns(2)
            with col1:
                st.image(image, caption="Original Image", use_container_width=True)
            with col2:
                st.image(overlayed_img, caption="Tooth Mask Overlay", use_container_width=True)
    except Exception as e:
        st.error(f"❌ Error processing image: {e}")