# -*- coding: utf-8 -*- """ Created on Tue May 20 11:00:14 2025 @author: ColinWang """ import streamlit as st import cv2 import time from streamlit_webrtc import VideoTransformerBase, webrtc_streamer from PIL import Image from transformers import pipeline import os import base64 from twilio.rest import Client from collections import Counter import uuid import pandas as pd # ====================== # Model Loading Functions # ====================== @st.cache_resource def load_smoke_pipeline(): """Initialize and cache the smoking image classification pipeline.""" return pipeline("image-classification", model="ccclllwww/smoker_cls_base_V9", use_fast=True) @st.cache_resource def load_gender_pipeline(): """Initialize and cache the gender image classification pipeline.""" return pipeline("image-classification", model="rizvandwiki/gender-classification-2", use_fast=True) @st.cache_resource def load_age_pipeline(): """Initialize and cache the age image classification pipeline.""" return pipeline("image-classification", model="cledoux42/Age_Classify_v001", use_fast=True) # Preload all models smoke_pipeline = load_smoke_pipeline() gender_pipeline = load_gender_pipeline() age_pipeline = load_age_pipeline() # ====================== # Twilio Configuration # ====================== def initialize_twilio_client(): """Initialize Twilio client using environment variables.""" account_sid = os.environ.get('TWILIO_ACCOUNT_SID') auth_token = os.environ.get('TWILIO_AUTH_TOKEN') if not account_sid or not auth_token: st.error("Twilio credentials not found in environment variables.") st.stop() client = Client(account_sid, auth_token) return client.tokens.create() token = initialize_twilio_client() # ====================== # Audio Loading Function # ====================== @st.cache_resource def load_audio_files(): """Load all .wav files from the audio directory into a dictionary.""" audio_dir = "audio" if not os.path.exists(audio_dir): st.error(f"Audio directory '{audio_dir}' not found.") st.stop() audio_files = [f for f in os.listdir(audio_dir) if f.endswith(".wav")] audio_dict = {} for audio_file in audio_files: with open(os.path.join(audio_dir, audio_file), "rb") as file: audio_dict[os.path.splitext(audio_file)[0]] = file.read() return audio_dict # Load audio files at startup audio_data = load_audio_files() # ====================== # Image Processing Functions # ====================== def detect_smoking(image: Image.Image) -> str: """Classify an image for smoking activity.""" try: output = smoke_pipeline(image) return output[0]["label"] except Exception as e: st.error(f"Image processing error: {str(e)}") st.stop() def detect_gender(image: Image.Image) -> str: """Classify an image for gender.""" try: output = gender_pipeline(image) return output[0]["label"] except Exception as e: st.error(f"Image processing error: {str(e)}") st.stop() def detect_age(image: Image.Image) -> str: """Classify an image for age range.""" try: output = age_pipeline(image) return output[0]["label"] except Exception as e: st.error(f"Image processing error: {str(e)}") st.stop() # ====================== # Real-Time Classification Functions # ====================== @st.cache_data(show_spinner=False, max_entries=3) def classify_smoking(image: Image.Image) -> str: """Classify an image for smoking and return the label with highest confidence.""" try: output = smoke_pipeline(image) return max(output, key=lambda x: x["score"])["label"] except Exception as e: st.error(f"Image processing error: {str(e)}") st.stop() @st.cache_data(show_spinner=False, max_entries=3) def classify_gender(image: Image.Image) -> str: """Classify an image for gender and return the label with highest confidence.""" try: output = gender_pipeline(image) return max(output, key=lambda x: x["score"])["label"] except Exception as e: st.error(f"Image processing error: {str(e)}") st.stop() @st.cache_data(show_spinner=False, max_entries=3) def classify_age(image: Image.Image) -> str: """Classify an image for age range and return the label with highest confidence.""" try: output = age_pipeline(image) return max(output, key=lambda x: x["score"])["label"] except Exception as e: st.error(f"Image processing error: {str(e)}") st.stop() # ====================== # Audio Playback Function # ====================== def play_audio(audio_bytes: bytes): """Play audio using HTML and JavaScript with Base64-encoded audio data.""" audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") audio_id = f"audio_player_{uuid.uuid4()}" html_content = f""" """ st.components.v1.html(html_content, height=150) # ====================== # Video Transformer Class # ====================== class VideoTransformer(VideoTransformerBase): def __init__(self): self.snapshots = [] self.last_capture_time = time.time() self.capture_interval = 1 # Capture every 1 second self.max_snapshots = 5 def transform(self, frame): """Process video frame and capture snapshots.""" img = frame.to_ndarray(format="bgr24") current_time = time.time() if (current_time - self.last_capture_time >= self.capture_interval and len(self.snapshots) < self.max_snapshots): img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) self.snapshots.append(Image.fromarray(img_rgb)) self.last_capture_time = current_time st.write(f"Captured snapshot {len(self.snapshots)}/{self.max_snapshots}") return img # ====================== # Cover Page # ====================== def cover_page(): """Display an enhanced cover page with project overview and instructions.""" st.title("Smoking Detection System", anchor=False) st.markdown("### Welcome to the Smoking Detection System") st.markdown(""" This Streamlit-based application harnesses cutting-edge machine learning to detect smoking behavior in images and real-time video streams. By analyzing smoking activity, gender, and age demographics, it provides valuable insights for public health monitoring and policy enforcement. """) st.markdown("#### Project Overview") st.markdown(""" - **Purpose**: Automatically identify smoking behavior in public or controlled environments to support compliance with no-smoking policies and facilitate behavioral studies. - **Significance**: Enhances public health initiatives by enabling real-time monitoring and demographic analysis of smoking activities. - **Features**: - **Photo Detection**: Analyze a single image (uploaded or captured) for smoking, gender, and age. - **Real-Time Video Detection**: Process webcam streams, capturing snapshots to detect smoking and demographics. - **Audio Feedback**: Play alerts based on detected gender and age when smoking is confirmed. """) st.markdown("#### How to Use") st.markdown(""" 1. **Navigate**: Use the sidebar to select a page: - **Cover Page**: View this overview. - **Photo Detection**: Upload or capture an image for analysis. - **Real-Time Video Detection**: Monitor live webcam feed. 2. **Photo Detection**: - Upload an image or capture one via webcam. - The system detects smoking; if detected, it analyzes gender and age, playing a corresponding audio alert. 3. **Real-Time Video Detection**: - Captures 5 snapshots over one minute. - If smoking is detected in more than 2 snapshots, it analyzes gender and age, displays results in a table, and plays an audio alert. 4. **Setup Requirements**: - Ensure the 'audio' directory contains .wav files named as ' .wav' (e.g., '10-19 male.wav'). - Configure Twilio environment variables (`TWILIO_ACCOUNT_SID` and `TWILIO_AUTH_TOKEN`) for WebRTC functionality. """) st.markdown("#### Get Started") st.markdown("Select a page from the sidebar to begin analyzing images or video streams.") # ====================== # Photo Detection Page # ====================== def photo_detection_page(): """Handle photo detection page for smoking, gender, and age classification.""" audio_placeholder = st.empty() st.title("Photo Detection", anchor=False) st.markdown("Upload an image or capture a photo to detect smoking behavior. If smoking is detected, gender and age will be analyzed.") # Image input selection option = st.radio("Choose input method", ["Upload Image", "Capture with Camera"], horizontal=True) image = None if option == "Upload Image": uploaded_file = st.file_uploader("Select an image", type=["jpg", "jpeg", "png"]) if uploaded_file: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_container_width=True) else: enable = st.checkbox("Enable Camera") camera_file = st.camera_input("Capture Photo", disabled=not enable) if camera_file: image = Image.open(camera_file) st.image(image, caption="Captured Photo", use_container_width=True) if image: with st.spinner("Detecting smoking..."): smoke_result = detect_smoking(image) st.success(f"Smoking Status: {smoke_result}") if smoke_result.lower() == "smoking": with st.spinner("Detecting gender..."): gender_result = detect_gender(image) st.success(f"Gender: {gender_result}") with st.spinner("Detecting age..."): age_result = detect_age(image) st.success(f"Age Range: {age_result}") audio_placeholder.empty() audio_key = f"{age_result} {gender_result.lower()}" if audio_key in audio_data: play_audio(audio_data[audio_key]) else: st.error(f"Audio file not found: {audio_key}.wav") # ====================== # Real-Time Detection Page # ====================== def real_time_detection_page(): """Handle real-time video detection with snapshot capture and analysis.""" st.title("Real-Time Video Detection", anchor=False) st.markdown("Captures 5 snapshots over one minute to detect smoking. If smoking is detected in more than 2 snapshots, results include gender, age, and a snapshot in a table.") # Initialize session state for detection results if 'detection_results' not in st.session_state: st.session_state.detection_results = [] # Placeholders for UI elements capture_text = st.empty() capture_progress = st.empty() classification_text = st.empty() classification_progress = st.empty() detection_info = st.empty() status_alert = st.empty() # New placeholder for status alert table = st.empty() image_display = st.empty() audio = st.empty() # Start video stream ctx = webrtc_streamer( key="unique_example", video_transformer_factory=VideoTransformer, rtc_configuration={"iceServers": token.ice_servers} ) capture_target = 5 if ctx.video_transformer: detection_info.info("Starting detection...") while True: snapshots = ctx.video_transformer.snapshots if len(snapshots) < capture_target: capture_text.text(f"Capture Progress: {len(snapshots)}/{capture_target} snapshots") capture_progress.progress(int(len(snapshots) / capture_target * 100)) else: capture_text.text("Capture Progress: Completed!") capture_progress.empty() detection_info.empty() classification_text.text("Classification Progress: Analyzing...") classification = classification_progress.progress(0) # Classify snapshots smoke_results = [classify_smoking(img) for img in snapshots] smoking_count = sum(1 for result in smoke_results if result.lower() == "smoking") classification.progress(33) if smoking_count > 2: status_alert.error("Smoking Detected!") # Red alert for smoking gender_results = [classify_gender(img) for img in snapshots] classification.progress(66) age_results = [classify_age(img) for img in snapshots] classification.progress(100) classification_text.text("Classification Progress: Completed!") # Determine most common gender and age most_common_gender = Counter(gender_results).most_common(1)[0][0] most_common_age = Counter(age_results).most_common(1)[0][0] # Select first smoking snapshot smoking_image = next((snapshots[i] for i, label in enumerate(smoke_results) if label.lower() == "smoking"), snapshots[0]) # Store results st.session_state.detection_results.append({ "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "Snapshot": smoking_image, "Gender": most_common_gender, "Age Range": most_common_age, "Smoking Count": smoking_count }) # Update table df = pd.DataFrame([ { "Timestamp": result["Timestamp"], "Gender": result["Gender"], "Age Range": result["Age Range"], "Smoking Count": result["Smoking Count"] } for result in st.session_state.detection_results ]) table.dataframe(df, use_container_width=True) # Display snapshot image_display.image(smoking_image, caption="Detected Smoking Snapshot", use_container_width=True) # Play audio audio.empty() audio_key = f"{most_common_age} {most_common_gender.lower()}" if audio_key in audio_data: play_audio(audio_data[audio_key]) else: st.error(f"Audio file not found: {audio_key}.wav") else: status_alert.success("No Smoking Detected") # Green alert for no smoking image_display.empty() audio.empty() classification_text.text("Classification Progress: Completed!") classification_progress.progress(100) # Update table if results exist if st.session_state.detection_results: df = pd.DataFrame([ { "Timestamp": result["Timestamp"], "Gender": result["Gender"], "Age Range": result["Age Range"], "Smoking Count": result["Smoking Count"] } for result in st.session_state.detection_results ]) table.dataframe(df, use_container_width=True) # Reset for next cycle time.sleep(5) classification_progress.empty() classification_text.empty() capture_text.empty() status_alert.empty() # Clear the alert for the next cycle detection_info.info("Starting detection...") ctx.video_transformer.snapshots = [] ctx.video_transformer.last_capture_time = time.time() time.sleep(0.1) # ====================== # Main Application # ====================== def main(): """Main function to handle page navigation.""" st.sidebar.title("Navigation") page = st.sidebar.selectbox("Select Page", ["Cover Page", "Photo Detection", "Real-Time Video Detection"]) if page == "Cover Page": cover_page() elif page == "Photo Detection": photo_detection_page() elif page == "Real-Time Video Detection": real_time_detection_page() if __name__ == "__main__": main()