Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
""" | |
Created on Tue May 20 11:00:14 2025 | |
@author: ColinWang | |
""" | |
import streamlit as st | |
import cv2 | |
import time | |
from streamlit_webrtc import VideoTransformerBase, webrtc_streamer | |
from PIL import Image | |
from transformers import pipeline | |
import os | |
import base64 | |
from twilio.rest import Client | |
from collections import Counter | |
import uuid | |
import pandas as pd | |
# ====================== | |
# Model Loading Functions | |
# ====================== | |
def load_smoke_pipeline(): | |
"""Initialize and cache the smoking image classification pipeline.""" | |
return pipeline("image-classification", model="ccclllwww/smoker_cls_base_V9", use_fast=True) | |
def load_gender_pipeline(): | |
"""Initialize and cache the gender image classification pipeline.""" | |
return pipeline("image-classification", model="rizvandwiki/gender-classification-2", use_fast=True) | |
def load_age_pipeline(): | |
"""Initialize and cache the age image classification pipeline.""" | |
return pipeline("image-classification", model="cledoux42/Age_Classify_v001", use_fast=True) | |
# Preload all models | |
smoke_pipeline = load_smoke_pipeline() | |
gender_pipeline = load_gender_pipeline() | |
age_pipeline = load_age_pipeline() | |
# ====================== | |
# Twilio Configuration | |
# ====================== | |
def initialize_twilio_client(): | |
"""Initialize Twilio client using environment variables.""" | |
account_sid = os.environ.get('TWILIO_ACCOUNT_SID') | |
auth_token = os.environ.get('TWILIO_AUTH_TOKEN') | |
if not account_sid or not auth_token: | |
st.error("Twilio credentials not found in environment variables.") | |
st.stop() | |
client = Client(account_sid, auth_token) | |
return client.tokens.create() | |
token = initialize_twilio_client() | |
# ====================== | |
# Audio Loading Function | |
# ====================== | |
def load_audio_files(): | |
"""Load all .wav files from the audio directory into a dictionary.""" | |
audio_dir = "audio" | |
if not os.path.exists(audio_dir): | |
st.error(f"Audio directory '{audio_dir}' not found.") | |
st.stop() | |
audio_files = [f for f in os.listdir(audio_dir) if f.endswith(".wav")] | |
audio_dict = {} | |
for audio_file in audio_files: | |
with open(os.path.join(audio_dir, audio_file), "rb") as file: | |
audio_dict[os.path.splitext(audio_file)[0]] = file.read() | |
return audio_dict | |
# Load audio files at startup | |
audio_data = load_audio_files() | |
# ====================== | |
# Image Processing Functions | |
# ====================== | |
def detect_smoking(image: Image.Image) -> str: | |
"""Classify an image for smoking activity.""" | |
try: | |
output = smoke_pipeline(image) | |
return output[0]["label"] | |
except Exception as e: | |
st.error(f"Image processing error: {str(e)}") | |
st.stop() | |
def detect_gender(image: Image.Image) -> str: | |
"""Classify an image for gender.""" | |
try: | |
output = gender_pipeline(image) | |
return output[0]["label"] | |
except Exception as e: | |
st.error(f"Image processing error: {str(e)}") | |
st.stop() | |
def detect_age(image: Image.Image) -> str: | |
"""Classify an image for age range.""" | |
try: | |
output = age_pipeline(image) | |
return output[0]["label"] | |
except Exception as e: | |
st.error(f"Image processing error: {str(e)}") | |
st.stop() | |
# ====================== | |
# Real-Time Classification Functions | |
# ====================== | |
def classify_smoking(image: Image.Image) -> str: | |
"""Classify an image for smoking and return the label with highest confidence.""" | |
try: | |
output = smoke_pipeline(image) | |
return max(output, key=lambda x: x["score"])["label"] | |
except Exception as e: | |
st.error(f"Image processing error: {str(e)}") | |
st.stop() | |
def classify_gender(image: Image.Image) -> str: | |
"""Classify an image for gender and return the label with highest confidence.""" | |
try: | |
output = gender_pipeline(image) | |
return max(output, key=lambda x: x["score"])["label"] | |
except Exception as e: | |
st.error(f"Image processing error: {str(e)}") | |
st.stop() | |
def classify_age(image: Image.Image) -> str: | |
"""Classify an image for age range and return the label with highest confidence.""" | |
try: | |
output = age_pipeline(image) | |
return max(output, key=lambda x: x["score"])["label"] | |
except Exception as e: | |
st.error(f"Image processing error: {str(e)}") | |
st.stop() | |
# ====================== | |
# Audio Playback Function | |
# ====================== | |
def play_audio(audio_bytes: bytes): | |
"""Play audio using HTML and JavaScript with Base64-encoded audio data.""" | |
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") | |
audio_id = f"audio_player_{uuid.uuid4()}" | |
html_content = f""" | |
<audio id="{audio_id}" controls style="width: 100%;"> | |
<source src="data:audio/wav;base64,{audio_base64}" type="audio/wav"> | |
Your browser does not support the audio element. | |
</audio> | |
<script type="text/javascript"> | |
window.addEventListener('DOMContentLoaded', function() {{ | |
setTimeout(function() {{ | |
var audioElement = document.getElementById("{audio_id}"); | |
if (audioElement) {{ | |
audioElement.play().catch(function(e) {{ | |
console.log("Playback prevented by browser:", e); | |
}}); | |
}} | |
}}, 1000); | |
}}); | |
</script> | |
""" | |
st.components.v1.html(html_content, height=150) | |
# ====================== | |
# Video Transformer Class | |
# ====================== | |
class VideoTransformer(VideoTransformerBase): | |
def __init__(self): | |
self.snapshots = [] | |
self.last_capture_time = time.time() | |
self.capture_interval = 1 # Capture every 1 second | |
self.max_snapshots = 5 | |
def transform(self, frame): | |
"""Process video frame and capture snapshots.""" | |
img = frame.to_ndarray(format="bgr24") | |
current_time = time.time() | |
if (current_time - self.last_capture_time >= self.capture_interval and | |
len(self.snapshots) < self.max_snapshots): | |
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
self.snapshots.append(Image.fromarray(img_rgb)) | |
self.last_capture_time = current_time | |
st.write(f"Captured snapshot {len(self.snapshots)}/{self.max_snapshots}") | |
return img | |
# ====================== | |
# Cover Page | |
# ====================== | |
def cover_page(): | |
"""Display an enhanced cover page with project overview and instructions.""" | |
st.title("Smoking Detection System", anchor=False) | |
st.markdown("### Welcome to the Smoking Detection System") | |
st.markdown(""" | |
This Streamlit-based application harnesses cutting-edge machine learning to detect smoking behavior in images and real-time video streams. By analyzing smoking activity, gender, and age demographics, it provides valuable insights for public health monitoring and policy enforcement. | |
""") | |
st.markdown("#### Project Overview") | |
st.markdown(""" | |
- **Purpose**: Automatically identify smoking behavior in public or controlled environments to support compliance with no-smoking policies and facilitate behavioral studies. | |
- **Significance**: Enhances public health initiatives by enabling real-time monitoring and demographic analysis of smoking activities. | |
- **Features**: | |
- **Photo Detection**: Analyze a single image (uploaded or captured) for smoking, gender, and age. | |
- **Real-Time Video Detection**: Process webcam streams, capturing snapshots to detect smoking and demographics. | |
- **Audio Feedback**: Play alerts based on detected gender and age when smoking is confirmed. | |
""") | |
st.markdown("#### How to Use") | |
st.markdown(""" | |
1. **Navigate**: Use the sidebar to select a page: | |
- **Cover Page**: View this overview. | |
- **Photo Detection**: Upload or capture an image for analysis. | |
- **Real-Time Video Detection**: Monitor live webcam feed. | |
2. **Photo Detection**: | |
- Upload an image or capture one via webcam. | |
- The system detects smoking; if detected, it analyzes gender and age, playing a corresponding audio alert. | |
3. **Real-Time Video Detection**: | |
- Captures 5 snapshots over one minute. | |
- If smoking is detected in more than 2 snapshots, it analyzes gender and age, displays results in a table, and plays an audio alert. | |
4. **Setup Requirements**: | |
- Ensure the 'audio' directory contains .wav files named as '<age_range> <gender>.wav' (e.g., '10-19 male.wav'). | |
- Configure Twilio environment variables (`TWILIO_ACCOUNT_SID` and `TWILIO_AUTH_TOKEN`) for WebRTC functionality. | |
""") | |
st.markdown("#### Get Started") | |
st.markdown("Select a page from the sidebar to begin analyzing images or video streams.") | |
# ====================== | |
# Photo Detection Page | |
# ====================== | |
def photo_detection_page(): | |
"""Handle photo detection page for smoking, gender, and age classification.""" | |
audio_placeholder = st.empty() | |
st.title("Photo Detection", anchor=False) | |
st.markdown("Upload an image or capture a photo to detect smoking behavior. If smoking is detected, gender and age will be analyzed.") | |
# Image input selection | |
option = st.radio("Choose input method", ["Upload Image", "Capture with Camera"], horizontal=True) | |
image = None | |
if option == "Upload Image": | |
uploaded_file = st.file_uploader("Select an image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_container_width=True) | |
else: | |
enable = st.checkbox("Enable Camera") | |
camera_file = st.camera_input("Capture Photo", disabled=not enable) | |
if camera_file: | |
image = Image.open(camera_file) | |
st.image(image, caption="Captured Photo", use_container_width=True) | |
if image: | |
with st.spinner("Detecting smoking..."): | |
smoke_result = detect_smoking(image) | |
st.success(f"Smoking Status: {smoke_result}") | |
if smoke_result.lower() == "smoking": | |
with st.spinner("Detecting gender..."): | |
gender_result = detect_gender(image) | |
st.success(f"Gender: {gender_result}") | |
with st.spinner("Detecting age..."): | |
age_result = detect_age(image) | |
st.success(f"Age Range: {age_result}") | |
audio_placeholder.empty() | |
audio_key = f"{age_result} {gender_result.lower()}" | |
if audio_key in audio_data: | |
play_audio(audio_data[audio_key]) | |
else: | |
st.error(f"Audio file not found: {audio_key}.wav") | |
# ====================== | |
# Real-Time Detection Page | |
# ====================== | |
def real_time_detection_page(): | |
"""Handle real-time video detection with snapshot capture and analysis.""" | |
st.title("Real-Time Video Detection", anchor=False) | |
st.markdown("Captures 5 snapshots over one minute to detect smoking. If smoking is detected in more than 2 snapshots, results include gender, age, and a snapshot in a table.") | |
# Initialize session state for detection results | |
if 'detection_results' not in st.session_state: | |
st.session_state.detection_results = [] | |
# Placeholders for UI elements | |
capture_text = st.empty() | |
capture_progress = st.empty() | |
classification_text = st.empty() | |
classification_progress = st.empty() | |
detection_info = st.empty() | |
status_alert = st.empty() # New placeholder for status alert | |
table = st.empty() | |
image_display = st.empty() | |
audio = st.empty() | |
# Start video stream | |
ctx = webrtc_streamer( | |
key="unique_example", | |
video_transformer_factory=VideoTransformer, | |
rtc_configuration={"iceServers": token.ice_servers} | |
) | |
capture_target = 5 | |
if ctx.video_transformer: | |
detection_info.info("Starting detection...") | |
while True: | |
snapshots = ctx.video_transformer.snapshots | |
if len(snapshots) < capture_target: | |
capture_text.text(f"Capture Progress: {len(snapshots)}/{capture_target} snapshots") | |
capture_progress.progress(int(len(snapshots) / capture_target * 100)) | |
else: | |
capture_text.text("Capture Progress: Completed!") | |
capture_progress.empty() | |
detection_info.empty() | |
classification_text.text("Classification Progress: Analyzing...") | |
classification = classification_progress.progress(0) | |
# Classify snapshots | |
smoke_results = [classify_smoking(img) for img in snapshots] | |
smoking_count = sum(1 for result in smoke_results if result.lower() == "smoking") | |
classification.progress(33) | |
if smoking_count > 2: | |
status_alert.error("Smoking Detected!") # Red alert for smoking | |
gender_results = [classify_gender(img) for img in snapshots] | |
classification.progress(66) | |
age_results = [classify_age(img) for img in snapshots] | |
classification.progress(100) | |
classification_text.text("Classification Progress: Completed!") | |
# Determine most common gender and age | |
most_common_gender = Counter(gender_results).most_common(1)[0][0] | |
most_common_age = Counter(age_results).most_common(1)[0][0] | |
# Select first smoking snapshot | |
smoking_image = next((snapshots[i] for i, label in enumerate(smoke_results) if label.lower() == "smoking"), snapshots[0]) | |
# Store results | |
st.session_state.detection_results.append({ | |
"Timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |
"Snapshot": smoking_image, | |
"Gender": most_common_gender, | |
"Age Range": most_common_age, | |
"Smoking Count": smoking_count | |
}) | |
# Update table | |
df = pd.DataFrame([ | |
{ | |
"Timestamp": result["Timestamp"], | |
"Gender": result["Gender"], | |
"Age Range": result["Age Range"], | |
"Smoking Count": result["Smoking Count"] | |
} for result in st.session_state.detection_results | |
]) | |
table.dataframe(df, use_container_width=True) | |
# Display snapshot | |
image_display.image(smoking_image, caption="Detected Smoking Snapshot", use_container_width=True) | |
# Play audio | |
audio.empty() | |
audio_key = f"{most_common_age} {most_common_gender.lower()}" | |
if audio_key in audio_data: | |
play_audio(audio_data[audio_key]) | |
else: | |
st.error(f"Audio file not found: {audio_key}.wav") | |
else: | |
status_alert.success("No Smoking Detected") # Green alert for no smoking | |
image_display.empty() | |
audio.empty() | |
classification_text.text("Classification Progress: Completed!") | |
classification_progress.progress(100) | |
# Update table if results exist | |
if st.session_state.detection_results: | |
df = pd.DataFrame([ | |
{ | |
"Timestamp": result["Timestamp"], | |
"Gender": result["Gender"], | |
"Age Range": result["Age Range"], | |
"Smoking Count": result["Smoking Count"] | |
} for result in st.session_state.detection_results | |
]) | |
table.dataframe(df, use_container_width=True) | |
# Reset for next cycle | |
time.sleep(5) | |
classification_progress.empty() | |
classification_text.empty() | |
capture_text.empty() | |
status_alert.empty() # Clear the alert for the next cycle | |
detection_info.info("Starting detection...") | |
ctx.video_transformer.snapshots = [] | |
ctx.video_transformer.last_capture_time = time.time() | |
time.sleep(0.1) | |
# ====================== | |
# Main Application | |
# ====================== | |
def main(): | |
"""Main function to handle page navigation.""" | |
st.sidebar.title("Navigation") | |
page = st.sidebar.selectbox("Select Page", ["Cover Page", "Photo Detection", "Real-Time Video Detection"]) | |
if page == "Cover Page": | |
cover_page() | |
elif page == "Photo Detection": | |
photo_detection_page() | |
elif page == "Real-Time Video Detection": | |
real_time_detection_page() | |
if __name__ == "__main__": | |
main() |