File size: 4,913 Bytes
a3149ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import streamlit as st
import tensorflow as tf
from PIL import Image
import numpy as np
import cv2
from huggingface_hub import from_pretrained_keras

# Use st.cache_resource to load the model only once, preventing memory errors.
@st.cache_resource
def load_keras_model():
    """Load the pre-trained Keras model from Hugging Face Hub and cache it."""
    try:
        # The model will be downloaded from the Hub and cached.
        model = from_pretrained_keras("SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net")
        return model
    except Exception as e:
        # If model loading fails, show an error and return None.
        st.error(f"Error loading the model: {e}")
        return None

# --- Helper Functions ---
def load_image(image_file):
    """Loads an image from a file path or uploaded file object."""
    img = Image.open(image_file)
    return img

def convert_one_channel(img_array):
    """Ensure the image is single-channel (grayscale)."""
    # If image has 3 channels (like BGR or RGB), convert to grayscale.
    if len(img_array.shape) > 2 and img_array.shape[2] > 1:
        img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
    return img_array

def convert_rgb(img_array):
    """Ensure the image is 3-channel (RGB) for drawing contours."""
    # If image is grayscale, convert to RGB to draw colored contours.
    if len(img_array.shape) == 2:
        img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
    return img_array

# --- Streamlit App Layout ---
st.header("Segmentation of Teeth in Panoramic X-ray Image Using UNet")

link = 'Check Out Our Github Repo! [link](https://github.com/SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net)'
st.markdown(link, unsafe_allow_html=True)

# Load the model and stop the app if it fails
model = load_keras_model()
if model is None:
    st.warning("Model could not be loaded. The application cannot proceed.")
    st.stop()

# --- Image Selection Section ---
st.subheader("Upload a Dental Panoramic X-ray Image or Select an Example")
image_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])

st.write("---")
st.write("Or choose an example:")
examples = ["107.png", "108.png", "109.png"]
col1, col2, col3 = st.columns(3)

# Display example images and buttons to use them
with col1:
    st.image(examples[0], caption='Example 1', use_column_width=True)
    if st.button('Use Example 1'):
        image_file = examples[0]

with col2:
    st.image(examples[1], caption='Example 2', use_column_width=True)
    if st.button('Use Example 2'):
        image_file = examples[1]

with col3:
    st.image(examples[2], caption='Example 3', use_column_width=True)
    if st.button('Use Example 3'):
        image_file = examples[2]

# --- Processing and Prediction Section ---
if image_file is not None:
    st.write("---")
    
    # Load and display the selected image
    original_pil_img = load_image(image_file)
    st.image(original_pil_img, caption="Original Image", use_column_width=True)
    
    with st.spinner("Analyzing image and predicting segmentation..."):
        # Convert PIL image to NumPy array for processing
        original_np_img = np.array(original_pil_img)

        # 1. Pre-process for the model
        img_gray = convert_one_channel(original_np_img.copy())
        img_resized = cv2.resize(img_gray, (512, 512), interpolation=cv2.INTER_LANCZOS4)
        img_normalized = np.float32(img_resized / 255.0)
        img_input = np.reshape(img_normalized, (1, 512, 512, 1))

        # 2. Make prediction
        prediction = model.predict(img_input)
        
        # 3. Post-process the prediction mask
        predicted_mask = prediction[0]
        resized_mask = cv2.resize(predicted_mask, (original_np_img.shape[1], original_np_img.shape[0]), interpolation=cv2.INTER_LANCZOS4)
        
        # Binarize the mask using Otsu's thresholding
        mask_8bit = (resized_mask * 255).astype(np.uint8)
        _, final_mask = cv2.threshold(mask_8bit, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        
        # Clean up mask with morphological operations
        kernel = np.ones((5, 5), dtype=np.uint8)
        final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_OPEN, kernel, iterations=1)
        final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel, iterations=1)
        
        # Find contours on the final mask
        contours, _ = cv2.findContours(final_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        
        # Draw contours on a color version of the original image
        img_for_drawing = convert_rgb(original_np_img.copy())
        output_image = cv2.drawContours(img_for_drawing, contours, -1, (255, 0, 0), 3) # Draw red contours

        st.subheader("Predicted Segmentation")
        st.image(output_image, caption="Image with Segmented Teeth", use_column_width=True)
        
    st.success("Prediction complete!")