import streamlit as st import tensorflow as tf from PIL import Image import numpy as np import cv2 from huggingface_hub import from_pretrained_keras # Use st.cache_resource to load the model only once, preventing memory errors. @st.cache_resource def load_keras_model(): """Load the pre-trained Keras model from Hugging Face Hub and cache it.""" try: # The model will be downloaded from the Hub and cached. model = from_pretrained_keras("SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net") return model except Exception as e: # If model loading fails, show an error and return None. st.error(f"Error loading the model: {e}") return None # --- Helper Functions --- def load_image(image_file): """Loads an image from a file path or uploaded file object.""" img = Image.open(image_file) return img def convert_one_channel(img_array): """Ensure the image is single-channel (grayscale).""" # If image has 3 channels (like BGR or RGB), convert to grayscale. if len(img_array.shape) > 2 and img_array.shape[2] > 1: img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY) return img_array def convert_rgb(img_array): """Ensure the image is 3-channel (RGB) for drawing contours.""" # If image is grayscale, convert to RGB to draw colored contours. if len(img_array.shape) == 2: img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB) return img_array # --- Streamlit App Layout --- st.header("Segmentation of Teeth in Panoramic X-ray Image Using UNet") link = 'Check Out Our Github Repo! [link](https://github.com/SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net)' st.markdown(link, unsafe_allow_html=True) # Load the model and stop the app if it fails model = load_keras_model() if model is None: st.warning("Model could not be loaded. The application cannot proceed.") st.stop() # --- Image Selection Section --- st.subheader("Upload a Dental Panoramic X-ray Image or Select an Example") image_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"]) st.write("---") st.write("Or choose an example:") examples = ["107.png", "108.png", "109.png"] col1, col2, col3 = st.columns(3) # Display example images and buttons to use them with col1: st.image(examples[0], caption='Example 1', use_column_width=True) if st.button('Use Example 1'): image_file = examples[0] with col2: st.image(examples[1], caption='Example 2', use_column_width=True) if st.button('Use Example 2'): image_file = examples[1] with col3: st.image(examples[2], caption='Example 3', use_column_width=True) if st.button('Use Example 3'): image_file = examples[2] # --- Processing and Prediction Section --- if image_file is not None: st.write("---") # Load and display the selected image original_pil_img = load_image(image_file) st.image(original_pil_img, caption="Original Image", use_column_width=True) with st.spinner("Analyzing image and predicting segmentation..."): # Convert PIL image to NumPy array for processing original_np_img = np.array(original_pil_img) # 1. Pre-process for the model img_gray = convert_one_channel(original_np_img.copy()) img_resized = cv2.resize(img_gray, (512, 512), interpolation=cv2.INTER_LANCZOS4) img_normalized = np.float32(img_resized / 255.0) img_input = np.reshape(img_normalized, (1, 512, 512, 1)) # 2. Make prediction prediction = model.predict(img_input) # 3. Post-process the prediction mask predicted_mask = prediction[0] resized_mask = cv2.resize(predicted_mask, (original_np_img.shape[1], original_np_img.shape[0]), interpolation=cv2.INTER_LANCZOS4) # Binarize the mask using Otsu's thresholding mask_8bit = (resized_mask * 255).astype(np.uint8) _, final_mask = cv2.threshold(mask_8bit, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # Clean up mask with morphological operations kernel = np.ones((5, 5), dtype=np.uint8) final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_OPEN, kernel, iterations=1) final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel, iterations=1) # Find contours on the final mask contours, _ = cv2.findContours(final_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # Draw contours on a color version of the original image img_for_drawing = convert_rgb(original_np_img.copy()) output_image = cv2.drawContours(img_for_drawing, contours, -1, (255, 0, 0), 3) # Draw red contours st.subheader("Predicted Segmentation") st.image(output_image, caption="Image with Segmented Teeth", use_column_width=True) st.success("Prediction complete!")