File size: 17,341 Bytes
8aa70c8
 
 
 
 
8223da7
 
 
 
 
 
 
 
9583fd9
8aa70c8
 
b99112d
8223da7
 
af773bf
8223da7
 
 
 
af773bf
8223da7
 
 
 
af773bf
8223da7
 
 
 
af773bf
a83b95d
8223da7
af773bf
8aa70c8
 
 
 
8223da7
af773bf
8223da7
8aa70c8
af773bf
 
 
 
 
 
 
 
 
8223da7
af773bf
8223da7
 
af773bf
8223da7
 
 
af773bf
 
8223da7
af773bf
 
 
8223da7
 
 
af773bf
 
8223da7
 
af773bf
 
8223da7
 
af773bf
8aa70c8
 
af773bf
 
8aa70c8
 
af773bf
8aa70c8
af773bf
8aa70c8
af773bf
 
 
8aa70c8
 
af773bf
8aa70c8
af773bf
8aa70c8
af773bf
 
 
8aa70c8
 
af773bf
8aa70c8
af773bf
8aa70c8
af773bf
8aa70c8
af773bf
8223da7
 
 
af773bf
 
8223da7
 
af773bf
8223da7
af773bf
8223da7
 
 
af773bf
 
8223da7
 
af773bf
8223da7
af773bf
8223da7
 
 
af773bf
 
8223da7
 
af773bf
8223da7
af773bf
8223da7
 
 
af773bf
8223da7
 
af773bf
 
8223da7
af773bf
8223da7
af773bf
8223da7
 
 
 
 
 
af773bf
8223da7
 
af773bf
8223da7
 
 
 
 
 
 
 
 
af773bf
8223da7
 
 
 
af773bf
 
 
 
8223da7
 
af773bf
8223da7
 
af773bf
 
8223da7
 
 
af773bf
 
8223da7
fbac2cf
 
 
 
 
af773bf
 
fbac2cf
af773bf
 
 
 
fbac2cf
af773bf
 
 
 
 
 
 
 
fbac2cf
 
af773bf
 
 
 
 
 
fbac2cf
af773bf
 
fbac2cf
af773bf
 
 
 
 
fbac2cf
af773bf
 
 
fbac2cf
8223da7
af773bf
8223da7
 
8aa70c8
af773bf
fbac2cf
af773bf
 
8aa70c8
af773bf
 
8aa70c8
af773bf
 
 
 
8aa70c8
af773bf
8aa70c8
af773bf
 
 
8aa70c8
af773bf
 
 
 
 
 
 
8aa70c8
af773bf
 
 
 
 
 
 
 
fbac2cf
 
 
af773bf
fbac2cf
af773bf
8aa70c8
 
af773bf
8aa70c8
 
 
af773bf
 
 
b99112d
af773bf
b99112d
 
8223da7
af773bf
 
 
 
 
 
3008e25
af773bf
 
 
 
 
 
 
 
 
 
8223da7
1011620
8223da7
af773bf
 
8223da7
 
 
 
 
af773bf
 
8223da7
af773bf
 
 
8223da7
af773bf
 
8223da7
af773bf
 
8223da7
af773bf
8223da7
 
3008e25
af773bf
 
 
 
 
8223da7
af773bf
8223da7
 
 
af773bf
 
8223da7
af773bf
b99112d
 
 
 
 
 
 
 
af773bf
b99112d
 
 
 
 
 
 
 
af773bf
b99112d
af773bf
 
b99112d
af773bf
 
8223da7
 
af773bf
8223da7
af773bf
8223da7
3008e25
af773bf
 
 
8223da7
 
af773bf
b99112d
 
 
 
 
 
 
 
 
af773bf
b99112d
af773bf
1011620
af773bf
 
 
3008e25
af773bf
8223da7
 
af773bf
8223da7
5b47b5d
8aa70c8
af773bf
8aa70c8
 
 
af773bf
 
 
8aa70c8
af773bf
b99112d
af773bf
8aa70c8
af773bf
8aa70c8
 
8223da7
8aa70c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
# -*- coding: utf-8 -*-
"""
Created on Tue May 20 11:00:14 2025
@author: ColinWang
"""
import streamlit as st
import cv2
import time
from streamlit_webrtc import VideoTransformerBase, webrtc_streamer
from PIL import Image
from transformers import pipeline
import os
import base64
from twilio.rest import Client
from collections import Counter
import uuid
import pandas as pd

# ======================
# Model Loading Functions
# ======================

@st.cache_resource
def load_smoke_pipeline():
    """Initialize and cache the smoking image classification pipeline."""
    return pipeline("image-classification", model="ccclllwww/smoker_cls_base_V9", use_fast=True)

@st.cache_resource
def load_gender_pipeline():
    """Initialize and cache the gender image classification pipeline."""
    return pipeline("image-classification", model="rizvandwiki/gender-classification-2", use_fast=True)

@st.cache_resource
def load_age_pipeline():
    """Initialize and cache the age image classification pipeline."""
    return pipeline("image-classification", model="cledoux42/Age_Classify_v001", use_fast=True)

# Preload all models
smoke_pipeline = load_smoke_pipeline()
gender_pipeline = load_gender_pipeline()
age_pipeline = load_age_pipeline()

# ======================
# Twilio Configuration
# ======================

def initialize_twilio_client():
    """Initialize Twilio client using environment variables."""
    account_sid = os.environ.get('TWILIO_ACCOUNT_SID')
    auth_token = os.environ.get('TWILIO_AUTH_TOKEN')
    if not account_sid or not auth_token:
        st.error("Twilio credentials not found in environment variables.")
        st.stop()
    client = Client(account_sid, auth_token)
    return client.tokens.create()

token = initialize_twilio_client()

# ======================
# Audio Loading Function
# ======================

@st.cache_resource
def load_audio_files():
    """Load all .wav files from the audio directory into a dictionary."""
    audio_dir = "audio"
    if not os.path.exists(audio_dir):
        st.error(f"Audio directory '{audio_dir}' not found.")
        st.stop()
    audio_files = [f for f in os.listdir(audio_dir) if f.endswith(".wav")]
    audio_dict = {}
    for audio_file in audio_files:
        with open(os.path.join(audio_dir, audio_file), "rb") as file:
            audio_dict[os.path.splitext(audio_file)[0]] = file.read()
    return audio_dict

# Load audio files at startup
audio_data = load_audio_files()

# ======================
# Image Processing Functions
# ======================

def detect_smoking(image: Image.Image) -> str:
    """Classify an image for smoking activity."""
    try:
        output = smoke_pipeline(image)
        return output[0]["label"]
    except Exception as e:
        st.error(f"Image processing error: {str(e)}")
        st.stop()

def detect_gender(image: Image.Image) -> str:
    """Classify an image for gender."""
    try:
        output = gender_pipeline(image)
        return output[0]["label"]
    except Exception as e:
        st.error(f"Image processing error: {str(e)}")
        st.stop()

def detect_age(image: Image.Image) -> str:
    """Classify an image for age range."""
    try:
        output = age_pipeline(image)
        return output[0]["label"]
    except Exception as e:
        st.error(f"Image processing error: {str(e)}")
        st.stop()

# ======================
# Real-Time Classification Functions
# ======================

@st.cache_data(show_spinner=False, max_entries=3)
def classify_smoking(image: Image.Image) -> str:
    """Classify an image for smoking and return the label with highest confidence."""
    try:
        output = smoke_pipeline(image)
        return max(output, key=lambda x: x["score"])["label"]
    except Exception as e:
        st.error(f"Image processing error: {str(e)}")
        st.stop()

@st.cache_data(show_spinner=False, max_entries=3)
def classify_gender(image: Image.Image) -> str:
    """Classify an image for gender and return the label with highest confidence."""
    try:
        output = gender_pipeline(image)
        return max(output, key=lambda x: x["score"])["label"]
    except Exception as e:
        st.error(f"Image processing error: {str(e)}")
        st.stop()

@st.cache_data(show_spinner=False, max_entries=3)
def classify_age(image: Image.Image) -> str:
    """Classify an image for age range and return the label with highest confidence."""
    try:
        output = age_pipeline(image)
        return max(output, key=lambda x: x["score"])["label"]
    except Exception as e:
        st.error(f"Image processing error: {str(e)}")
        st.stop()

# ======================
# Audio Playback Function
# ======================

def play_audio(audio_bytes: bytes):
    """Play audio using HTML and JavaScript with Base64-encoded audio data."""
    audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
    audio_id = f"audio_player_{uuid.uuid4()}"
    html_content = f"""
    <audio id="{audio_id}" controls style="width: 100%;">
        <source src="data:audio/wav;base64,{audio_base64}" type="audio/wav">
        Your browser does not support the audio element.
    </audio>
    <script type="text/javascript">
        window.addEventListener('DOMContentLoaded', function() {{
            setTimeout(function() {{
                var audioElement = document.getElementById("{audio_id}");
                if (audioElement) {{
                    audioElement.play().catch(function(e) {{
                        console.log("Playback prevented by browser:", e);
                    }});
                }}
            }}, 1000);
        }});
    </script>
    """
    st.components.v1.html(html_content, height=150)

# ======================
# Video Transformer Class
# ======================

class VideoTransformer(VideoTransformerBase):
    def __init__(self):
        self.snapshots = []
        self.last_capture_time = time.time()
        self.capture_interval = 1  # Capture every 1 second
        self.max_snapshots = 5

    def transform(self, frame):
        """Process video frame and capture snapshots."""
        img = frame.to_ndarray(format="bgr24")
        current_time = time.time()
        if (current_time - self.last_capture_time >= self.capture_interval and 
            len(self.snapshots) < self.max_snapshots):
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            self.snapshots.append(Image.fromarray(img_rgb))
            self.last_capture_time = current_time
            st.write(f"Captured snapshot {len(self.snapshots)}/{self.max_snapshots}")
        return img

# ======================
# Cover Page
# ======================

def cover_page():
    """Display an enhanced cover page with project overview and instructions."""
    st.title("Smoking Detection System", anchor=False)
    
    st.markdown("### Welcome to the Smoking Detection System")
    st.markdown("""
    This Streamlit-based application harnesses cutting-edge machine learning to detect smoking behavior in images and real-time video streams. By analyzing smoking activity, gender, and age demographics, it provides valuable insights for public health monitoring and policy enforcement.
    """)
    
    st.markdown("#### Project Overview")
    st.markdown("""
    - **Purpose**: Automatically identify smoking behavior in public or controlled environments to support compliance with no-smoking policies and facilitate behavioral studies.
    - **Significance**: Enhances public health initiatives by enabling real-time monitoring and demographic analysis of smoking activities.
    - **Features**:
      - **Photo Detection**: Analyze a single image (uploaded or captured) for smoking, gender, and age.
      - **Real-Time Video Detection**: Process webcam streams, capturing snapshots to detect smoking and demographics.
      - **Audio Feedback**: Play alerts based on detected gender and age when smoking is confirmed.
    """)
    
    st.markdown("#### How to Use")
    st.markdown("""
    1. **Navigate**: Use the sidebar to select a page:
       - **Cover Page**: View this overview.
       - **Photo Detection**: Upload or capture an image for analysis.
       - **Real-Time Video Detection**: Monitor live webcam feed.
    2. **Photo Detection**:
       - Upload an image or capture one via webcam.
       - The system detects smoking; if detected, it analyzes gender and age, playing a corresponding audio alert.
    3. **Real-Time Video Detection**:
       - Captures 5 snapshots over one minute.
       - If smoking is detected in more than 2 snapshots, it analyzes gender and age, displays results in a table, and plays an audio alert.
    4. **Setup Requirements**:
       - Ensure the 'audio' directory contains .wav files named as '<age_range> <gender>.wav' (e.g., '10-19 male.wav').
       - Configure Twilio environment variables (`TWILIO_ACCOUNT_SID` and `TWILIO_AUTH_TOKEN`) for WebRTC functionality.
    """)
    
    st.markdown("#### Get Started")
    st.markdown("Select a page from the sidebar to begin analyzing images or video streams.")

# ======================
# Photo Detection Page
# ======================

def photo_detection_page():
    """Handle photo detection page for smoking, gender, and age classification."""
    audio_placeholder = st.empty()
    st.title("Photo Detection", anchor=False)
    st.markdown("Upload an image or capture a photo to detect smoking behavior. If smoking is detected, gender and age will be analyzed.")

    # Image input selection
    option = st.radio("Choose input method", ["Upload Image", "Capture with Camera"], horizontal=True)
    image = None

    if option == "Upload Image":
        uploaded_file = st.file_uploader("Select an image", type=["jpg", "jpeg", "png"])
        if uploaded_file:
            image = Image.open(uploaded_file)
            st.image(image, caption="Uploaded Image", use_container_width=True)
    else:
        enable = st.checkbox("Enable Camera")
        camera_file = st.camera_input("Capture Photo", disabled=not enable)
        if camera_file:
            image = Image.open(camera_file)
            st.image(image, caption="Captured Photo", use_container_width=True)

    if image:
        with st.spinner("Detecting smoking..."):
            smoke_result = detect_smoking(image)
        st.success(f"Smoking Status: {smoke_result}")

        if smoke_result.lower() == "smoking":
            with st.spinner("Detecting gender..."):
                gender_result = detect_gender(image)
            st.success(f"Gender: {gender_result}")

            with st.spinner("Detecting age..."):
                age_result = detect_age(image)
            st.success(f"Age Range: {age_result}")

            audio_placeholder.empty()
            audio_key = f"{age_result} {gender_result.lower()}"
            if audio_key in audio_data:
                play_audio(audio_data[audio_key])
            else:
                st.error(f"Audio file not found: {audio_key}.wav")

# ======================
# Real-Time Detection Page
# ======================

def real_time_detection_page():
    """Handle real-time video detection with snapshot capture and analysis."""
    st.title("Real-Time Video Detection", anchor=False)
    st.markdown("Captures 5 snapshots over one minute to detect smoking. If smoking is detected in more than 2 snapshots, results include gender, age, and a snapshot in a table.")

    # Initialize session state for detection results
    if 'detection_results' not in st.session_state:
        st.session_state.detection_results = []

    # Placeholders for UI elements
    capture_text = st.empty()
    capture_progress = st.empty()
    classification_text = st.empty()
    classification_progress = st.empty()
    detection_info = st.empty()
    status_alert = st.empty()  # New placeholder for status alert
    table = st.empty()
    image_display = st.empty()
    audio = st.empty()

    # Start video stream
    ctx = webrtc_streamer(
        key="unique_example",
        video_transformer_factory=VideoTransformer,
        rtc_configuration={"iceServers": token.ice_servers}
    )

    capture_target = 5

    if ctx.video_transformer:
        detection_info.info("Starting detection...")

        while True:
            snapshots = ctx.video_transformer.snapshots

            if len(snapshots) < capture_target:
                capture_text.text(f"Capture Progress: {len(snapshots)}/{capture_target} snapshots")
                capture_progress.progress(int(len(snapshots) / capture_target * 100))
            else:
                capture_text.text("Capture Progress: Completed!")
                capture_progress.empty()
                detection_info.empty()

                classification_text.text("Classification Progress: Analyzing...")
                classification = classification_progress.progress(0)

                # Classify snapshots
                smoke_results = [classify_smoking(img) for img in snapshots]
                smoking_count = sum(1 for result in smoke_results if result.lower() == "smoking")
                classification.progress(33)

                if smoking_count > 2:
                    status_alert.error("Smoking Detected!")  # Red alert for smoking
                    gender_results = [classify_gender(img) for img in snapshots]
                    classification.progress(66)
                    age_results = [classify_age(img) for img in snapshots]
                    classification.progress(100)
                    classification_text.text("Classification Progress: Completed!")

                    # Determine most common gender and age
                    most_common_gender = Counter(gender_results).most_common(1)[0][0]
                    most_common_age = Counter(age_results).most_common(1)[0][0]

                    # Select first smoking snapshot
                    smoking_image = next((snapshots[i] for i, label in enumerate(smoke_results) if label.lower() == "smoking"), snapshots[0])

                    # Store results
                    st.session_state.detection_results.append({
                        "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                        "Snapshot": smoking_image,
                        "Gender": most_common_gender,
                        "Age Range": most_common_age,
                        "Smoking Count": smoking_count
                    })

                    # Update table
                    df = pd.DataFrame([
                        {
                            "Timestamp": result["Timestamp"],
                            "Gender": result["Gender"],
                            "Age Range": result["Age Range"],
                            "Smoking Count": result["Smoking Count"]
                        } for result in st.session_state.detection_results
                    ])
                    table.dataframe(df, use_container_width=True)

                    # Display snapshot
                    image_display.image(smoking_image, caption="Detected Smoking Snapshot", use_container_width=True)

                    # Play audio
                    audio.empty()
                    audio_key = f"{most_common_age} {most_common_gender.lower()}"
                    if audio_key in audio_data:
                        play_audio(audio_data[audio_key])
                    else:
                        st.error(f"Audio file not found: {audio_key}.wav")
                else:
                    status_alert.success("No Smoking Detected")  # Green alert for no smoking
                    image_display.empty()
                    audio.empty()
                    classification_text.text("Classification Progress: Completed!")
                    classification_progress.progress(100)

                # Update table if results exist
                if st.session_state.detection_results:
                    df = pd.DataFrame([
                        {
                            "Timestamp": result["Timestamp"],
                            "Gender": result["Gender"],
                            "Age Range": result["Age Range"],
                            "Smoking Count": result["Smoking Count"]
                        } for result in st.session_state.detection_results
                    ])
                    table.dataframe(df, use_container_width=True)

                # Reset for next cycle
                time.sleep(5)
                classification_progress.empty()
                classification_text.empty()
                capture_text.empty()
                status_alert.empty()  # Clear the alert for the next cycle
                detection_info.info("Starting detection...")
                ctx.video_transformer.snapshots = []
                ctx.video_transformer.last_capture_time = time.time()

            time.sleep(0.1)

# ======================
# Main Application
# ======================

def main():
    """Main function to handle page navigation."""
    st.sidebar.title("Navigation")
    page = st.sidebar.selectbox("Select Page", ["Cover Page", "Photo Detection", "Real-Time Video Detection"])

    if page == "Cover Page":
        cover_page()
    elif page == "Photo Detection":
        photo_detection_page()
    elif page == "Real-Time Video Detection":
        real_time_detection_page()

if __name__ == "__main__":
    main()