ccclllwww's picture
Update app.py
a83b95d verified
# -*- coding: utf-8 -*-
"""
Created on Tue May 20 11:00:14 2025
@author: ColinWang
"""
import streamlit as st
import cv2
import time
from streamlit_webrtc import VideoTransformerBase, webrtc_streamer
from PIL import Image
from transformers import pipeline
import os
import base64
from twilio.rest import Client
from collections import Counter
import uuid
import pandas as pd
# ======================
# Model Loading Functions
# ======================
@st.cache_resource
def load_smoke_pipeline():
"""Initialize and cache the smoking image classification pipeline."""
return pipeline("image-classification", model="ccclllwww/smoker_cls_base_V9", use_fast=True)
@st.cache_resource
def load_gender_pipeline():
"""Initialize and cache the gender image classification pipeline."""
return pipeline("image-classification", model="rizvandwiki/gender-classification-2", use_fast=True)
@st.cache_resource
def load_age_pipeline():
"""Initialize and cache the age image classification pipeline."""
return pipeline("image-classification", model="cledoux42/Age_Classify_v001", use_fast=True)
# Preload all models
smoke_pipeline = load_smoke_pipeline()
gender_pipeline = load_gender_pipeline()
age_pipeline = load_age_pipeline()
# ======================
# Twilio Configuration
# ======================
def initialize_twilio_client():
"""Initialize Twilio client using environment variables."""
account_sid = os.environ.get('TWILIO_ACCOUNT_SID')
auth_token = os.environ.get('TWILIO_AUTH_TOKEN')
if not account_sid or not auth_token:
st.error("Twilio credentials not found in environment variables.")
st.stop()
client = Client(account_sid, auth_token)
return client.tokens.create()
token = initialize_twilio_client()
# ======================
# Audio Loading Function
# ======================
@st.cache_resource
def load_audio_files():
"""Load all .wav files from the audio directory into a dictionary."""
audio_dir = "audio"
if not os.path.exists(audio_dir):
st.error(f"Audio directory '{audio_dir}' not found.")
st.stop()
audio_files = [f for f in os.listdir(audio_dir) if f.endswith(".wav")]
audio_dict = {}
for audio_file in audio_files:
with open(os.path.join(audio_dir, audio_file), "rb") as file:
audio_dict[os.path.splitext(audio_file)[0]] = file.read()
return audio_dict
# Load audio files at startup
audio_data = load_audio_files()
# ======================
# Image Processing Functions
# ======================
def detect_smoking(image: Image.Image) -> str:
"""Classify an image for smoking activity."""
try:
output = smoke_pipeline(image)
return output[0]["label"]
except Exception as e:
st.error(f"Image processing error: {str(e)}")
st.stop()
def detect_gender(image: Image.Image) -> str:
"""Classify an image for gender."""
try:
output = gender_pipeline(image)
return output[0]["label"]
except Exception as e:
st.error(f"Image processing error: {str(e)}")
st.stop()
def detect_age(image: Image.Image) -> str:
"""Classify an image for age range."""
try:
output = age_pipeline(image)
return output[0]["label"]
except Exception as e:
st.error(f"Image processing error: {str(e)}")
st.stop()
# ======================
# Real-Time Classification Functions
# ======================
@st.cache_data(show_spinner=False, max_entries=3)
def classify_smoking(image: Image.Image) -> str:
"""Classify an image for smoking and return the label with highest confidence."""
try:
output = smoke_pipeline(image)
return max(output, key=lambda x: x["score"])["label"]
except Exception as e:
st.error(f"Image processing error: {str(e)}")
st.stop()
@st.cache_data(show_spinner=False, max_entries=3)
def classify_gender(image: Image.Image) -> str:
"""Classify an image for gender and return the label with highest confidence."""
try:
output = gender_pipeline(image)
return max(output, key=lambda x: x["score"])["label"]
except Exception as e:
st.error(f"Image processing error: {str(e)}")
st.stop()
@st.cache_data(show_spinner=False, max_entries=3)
def classify_age(image: Image.Image) -> str:
"""Classify an image for age range and return the label with highest confidence."""
try:
output = age_pipeline(image)
return max(output, key=lambda x: x["score"])["label"]
except Exception as e:
st.error(f"Image processing error: {str(e)}")
st.stop()
# ======================
# Audio Playback Function
# ======================
def play_audio(audio_bytes: bytes):
"""Play audio using HTML and JavaScript with Base64-encoded audio data."""
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
audio_id = f"audio_player_{uuid.uuid4()}"
html_content = f"""
<audio id="{audio_id}" controls style="width: 100%;">
<source src="data:audio/wav;base64,{audio_base64}" type="audio/wav">
Your browser does not support the audio element.
</audio>
<script type="text/javascript">
window.addEventListener('DOMContentLoaded', function() {{
setTimeout(function() {{
var audioElement = document.getElementById("{audio_id}");
if (audioElement) {{
audioElement.play().catch(function(e) {{
console.log("Playback prevented by browser:", e);
}});
}}
}}, 1000);
}});
</script>
"""
st.components.v1.html(html_content, height=150)
# ======================
# Video Transformer Class
# ======================
class VideoTransformer(VideoTransformerBase):
def __init__(self):
self.snapshots = []
self.last_capture_time = time.time()
self.capture_interval = 1 # Capture every 1 second
self.max_snapshots = 5
def transform(self, frame):
"""Process video frame and capture snapshots."""
img = frame.to_ndarray(format="bgr24")
current_time = time.time()
if (current_time - self.last_capture_time >= self.capture_interval and
len(self.snapshots) < self.max_snapshots):
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
self.snapshots.append(Image.fromarray(img_rgb))
self.last_capture_time = current_time
st.write(f"Captured snapshot {len(self.snapshots)}/{self.max_snapshots}")
return img
# ======================
# Cover Page
# ======================
def cover_page():
"""Display an enhanced cover page with project overview and instructions."""
st.title("Smoking Detection System", anchor=False)
st.markdown("### Welcome to the Smoking Detection System")
st.markdown("""
This Streamlit-based application harnesses cutting-edge machine learning to detect smoking behavior in images and real-time video streams. By analyzing smoking activity, gender, and age demographics, it provides valuable insights for public health monitoring and policy enforcement.
""")
st.markdown("#### Project Overview")
st.markdown("""
- **Purpose**: Automatically identify smoking behavior in public or controlled environments to support compliance with no-smoking policies and facilitate behavioral studies.
- **Significance**: Enhances public health initiatives by enabling real-time monitoring and demographic analysis of smoking activities.
- **Features**:
- **Photo Detection**: Analyze a single image (uploaded or captured) for smoking, gender, and age.
- **Real-Time Video Detection**: Process webcam streams, capturing snapshots to detect smoking and demographics.
- **Audio Feedback**: Play alerts based on detected gender and age when smoking is confirmed.
""")
st.markdown("#### How to Use")
st.markdown("""
1. **Navigate**: Use the sidebar to select a page:
- **Cover Page**: View this overview.
- **Photo Detection**: Upload or capture an image for analysis.
- **Real-Time Video Detection**: Monitor live webcam feed.
2. **Photo Detection**:
- Upload an image or capture one via webcam.
- The system detects smoking; if detected, it analyzes gender and age, playing a corresponding audio alert.
3. **Real-Time Video Detection**:
- Captures 5 snapshots over one minute.
- If smoking is detected in more than 2 snapshots, it analyzes gender and age, displays results in a table, and plays an audio alert.
4. **Setup Requirements**:
- Ensure the 'audio' directory contains .wav files named as '<age_range> <gender>.wav' (e.g., '10-19 male.wav').
- Configure Twilio environment variables (`TWILIO_ACCOUNT_SID` and `TWILIO_AUTH_TOKEN`) for WebRTC functionality.
""")
st.markdown("#### Get Started")
st.markdown("Select a page from the sidebar to begin analyzing images or video streams.")
# ======================
# Photo Detection Page
# ======================
def photo_detection_page():
"""Handle photo detection page for smoking, gender, and age classification."""
audio_placeholder = st.empty()
st.title("Photo Detection", anchor=False)
st.markdown("Upload an image or capture a photo to detect smoking behavior. If smoking is detected, gender and age will be analyzed.")
# Image input selection
option = st.radio("Choose input method", ["Upload Image", "Capture with Camera"], horizontal=True)
image = None
if option == "Upload Image":
uploaded_file = st.file_uploader("Select an image", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_container_width=True)
else:
enable = st.checkbox("Enable Camera")
camera_file = st.camera_input("Capture Photo", disabled=not enable)
if camera_file:
image = Image.open(camera_file)
st.image(image, caption="Captured Photo", use_container_width=True)
if image:
with st.spinner("Detecting smoking..."):
smoke_result = detect_smoking(image)
st.success(f"Smoking Status: {smoke_result}")
if smoke_result.lower() == "smoking":
with st.spinner("Detecting gender..."):
gender_result = detect_gender(image)
st.success(f"Gender: {gender_result}")
with st.spinner("Detecting age..."):
age_result = detect_age(image)
st.success(f"Age Range: {age_result}")
audio_placeholder.empty()
audio_key = f"{age_result} {gender_result.lower()}"
if audio_key in audio_data:
play_audio(audio_data[audio_key])
else:
st.error(f"Audio file not found: {audio_key}.wav")
# ======================
# Real-Time Detection Page
# ======================
def real_time_detection_page():
"""Handle real-time video detection with snapshot capture and analysis."""
st.title("Real-Time Video Detection", anchor=False)
st.markdown("Captures 5 snapshots over one minute to detect smoking. If smoking is detected in more than 2 snapshots, results include gender, age, and a snapshot in a table.")
# Initialize session state for detection results
if 'detection_results' not in st.session_state:
st.session_state.detection_results = []
# Placeholders for UI elements
capture_text = st.empty()
capture_progress = st.empty()
classification_text = st.empty()
classification_progress = st.empty()
detection_info = st.empty()
status_alert = st.empty() # New placeholder for status alert
table = st.empty()
image_display = st.empty()
audio = st.empty()
# Start video stream
ctx = webrtc_streamer(
key="unique_example",
video_transformer_factory=VideoTransformer,
rtc_configuration={"iceServers": token.ice_servers}
)
capture_target = 5
if ctx.video_transformer:
detection_info.info("Starting detection...")
while True:
snapshots = ctx.video_transformer.snapshots
if len(snapshots) < capture_target:
capture_text.text(f"Capture Progress: {len(snapshots)}/{capture_target} snapshots")
capture_progress.progress(int(len(snapshots) / capture_target * 100))
else:
capture_text.text("Capture Progress: Completed!")
capture_progress.empty()
detection_info.empty()
classification_text.text("Classification Progress: Analyzing...")
classification = classification_progress.progress(0)
# Classify snapshots
smoke_results = [classify_smoking(img) for img in snapshots]
smoking_count = sum(1 for result in smoke_results if result.lower() == "smoking")
classification.progress(33)
if smoking_count > 2:
status_alert.error("Smoking Detected!") # Red alert for smoking
gender_results = [classify_gender(img) for img in snapshots]
classification.progress(66)
age_results = [classify_age(img) for img in snapshots]
classification.progress(100)
classification_text.text("Classification Progress: Completed!")
# Determine most common gender and age
most_common_gender = Counter(gender_results).most_common(1)[0][0]
most_common_age = Counter(age_results).most_common(1)[0][0]
# Select first smoking snapshot
smoking_image = next((snapshots[i] for i, label in enumerate(smoke_results) if label.lower() == "smoking"), snapshots[0])
# Store results
st.session_state.detection_results.append({
"Timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"Snapshot": smoking_image,
"Gender": most_common_gender,
"Age Range": most_common_age,
"Smoking Count": smoking_count
})
# Update table
df = pd.DataFrame([
{
"Timestamp": result["Timestamp"],
"Gender": result["Gender"],
"Age Range": result["Age Range"],
"Smoking Count": result["Smoking Count"]
} for result in st.session_state.detection_results
])
table.dataframe(df, use_container_width=True)
# Display snapshot
image_display.image(smoking_image, caption="Detected Smoking Snapshot", use_container_width=True)
# Play audio
audio.empty()
audio_key = f"{most_common_age} {most_common_gender.lower()}"
if audio_key in audio_data:
play_audio(audio_data[audio_key])
else:
st.error(f"Audio file not found: {audio_key}.wav")
else:
status_alert.success("No Smoking Detected") # Green alert for no smoking
image_display.empty()
audio.empty()
classification_text.text("Classification Progress: Completed!")
classification_progress.progress(100)
# Update table if results exist
if st.session_state.detection_results:
df = pd.DataFrame([
{
"Timestamp": result["Timestamp"],
"Gender": result["Gender"],
"Age Range": result["Age Range"],
"Smoking Count": result["Smoking Count"]
} for result in st.session_state.detection_results
])
table.dataframe(df, use_container_width=True)
# Reset for next cycle
time.sleep(5)
classification_progress.empty()
classification_text.empty()
capture_text.empty()
status_alert.empty() # Clear the alert for the next cycle
detection_info.info("Starting detection...")
ctx.video_transformer.snapshots = []
ctx.video_transformer.last_capture_time = time.time()
time.sleep(0.1)
# ======================
# Main Application
# ======================
def main():
"""Main function to handle page navigation."""
st.sidebar.title("Navigation")
page = st.sidebar.selectbox("Select Page", ["Cover Page", "Photo Detection", "Real-Time Video Detection"])
if page == "Cover Page":
cover_page()
elif page == "Photo Detection":
photo_detection_page()
elif page == "Real-Time Video Detection":
real_time_detection_page()
if __name__ == "__main__":
main()