Dddixyy's picture
Update app.py
a3149ce verified
raw
history blame
4.91 kB
import streamlit as st
import tensorflow as tf
from PIL import Image
import numpy as np
import cv2
from huggingface_hub import from_pretrained_keras
# Use st.cache_resource to load the model only once, preventing memory errors.
@st.cache_resource
def load_keras_model():
"""Load the pre-trained Keras model from Hugging Face Hub and cache it."""
try:
# The model will be downloaded from the Hub and cached.
model = from_pretrained_keras("SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net")
return model
except Exception as e:
# If model loading fails, show an error and return None.
st.error(f"Error loading the model: {e}")
return None
# --- Helper Functions ---
def load_image(image_file):
"""Loads an image from a file path or uploaded file object."""
img = Image.open(image_file)
return img
def convert_one_channel(img_array):
"""Ensure the image is single-channel (grayscale)."""
# If image has 3 channels (like BGR or RGB), convert to grayscale.
if len(img_array.shape) > 2 and img_array.shape[2] > 1:
img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
return img_array
def convert_rgb(img_array):
"""Ensure the image is 3-channel (RGB) for drawing contours."""
# If image is grayscale, convert to RGB to draw colored contours.
if len(img_array.shape) == 2:
img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
return img_array
# --- Streamlit App Layout ---
st.header("Segmentation of Teeth in Panoramic X-ray Image Using UNet")
link = 'Check Out Our Github Repo! [link](https://github.com/SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net)'
st.markdown(link, unsafe_allow_html=True)
# Load the model and stop the app if it fails
model = load_keras_model()
if model is None:
st.warning("Model could not be loaded. The application cannot proceed.")
st.stop()
# --- Image Selection Section ---
st.subheader("Upload a Dental Panoramic X-ray Image or Select an Example")
image_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
st.write("---")
st.write("Or choose an example:")
examples = ["107.png", "108.png", "109.png"]
col1, col2, col3 = st.columns(3)
# Display example images and buttons to use them
with col1:
st.image(examples[0], caption='Example 1', use_column_width=True)
if st.button('Use Example 1'):
image_file = examples[0]
with col2:
st.image(examples[1], caption='Example 2', use_column_width=True)
if st.button('Use Example 2'):
image_file = examples[1]
with col3:
st.image(examples[2], caption='Example 3', use_column_width=True)
if st.button('Use Example 3'):
image_file = examples[2]
# --- Processing and Prediction Section ---
if image_file is not None:
st.write("---")
# Load and display the selected image
original_pil_img = load_image(image_file)
st.image(original_pil_img, caption="Original Image", use_column_width=True)
with st.spinner("Analyzing image and predicting segmentation..."):
# Convert PIL image to NumPy array for processing
original_np_img = np.array(original_pil_img)
# 1. Pre-process for the model
img_gray = convert_one_channel(original_np_img.copy())
img_resized = cv2.resize(img_gray, (512, 512), interpolation=cv2.INTER_LANCZOS4)
img_normalized = np.float32(img_resized / 255.0)
img_input = np.reshape(img_normalized, (1, 512, 512, 1))
# 2. Make prediction
prediction = model.predict(img_input)
# 3. Post-process the prediction mask
predicted_mask = prediction[0]
resized_mask = cv2.resize(predicted_mask, (original_np_img.shape[1], original_np_img.shape[0]), interpolation=cv2.INTER_LANCZOS4)
# Binarize the mask using Otsu's thresholding
mask_8bit = (resized_mask * 255).astype(np.uint8)
_, final_mask = cv2.threshold(mask_8bit, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Clean up mask with morphological operations
kernel = np.ones((5, 5), dtype=np.uint8)
final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_OPEN, kernel, iterations=1)
final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel, iterations=1)
# Find contours on the final mask
contours, _ = cv2.findContours(final_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# Draw contours on a color version of the original image
img_for_drawing = convert_rgb(original_np_img.copy())
output_image = cv2.drawContours(img_for_drawing, contours, -1, (255, 0, 0), 3) # Draw red contours
st.subheader("Predicted Segmentation")
st.image(output_image, caption="Image with Segmented Teeth", use_column_width=True)
st.success("Prediction complete!")