Web4's picture
Update app.py
4277bd8 verified
import gradio as gr
import pandas as pd
import joblib
from huggingface_hub import hf_hub_download
import os
# Define the model path and file name on Hugging Face Hub.
# The filename now includes the 'data/' subdirectory.
repo_id = "Web4/LS-W4-Mini-RF_Addiction_Impact"
model_file = "data/LS-W4-Mini-RF_Addiction_Impact.joblib"
# Get the Hugging Face token from environment variables.
# This is required for gated repositories.
token = os.environ.get("HF_TOKEN")
# Download the model file from the Hugging Face Hub using the token.
try:
model_path = hf_hub_download(repo_id=repo_id, filename=model_file, token=token)
print(f"Model downloaded to: {model_path}")
except Exception as e:
# This error indicates that the file was not found or access was denied.
print(f"Error downloading model: {e}")
raise
# Load the scikit-learn pipeline from the downloaded joblib file.
pipeline = joblib.load(model_path)
# Define the prediction function for the Gradio interface.
def predict_impact(
gender,
academic_level,
most_used_platform,
relationship_status,
age,
avg_daily_usage_hours,
sleep_hours_per_night,
mental_health_score,
addicted_score,
conflicts_over_social_media
):
"""
Takes user inputs, creates a pandas DataFrame, and makes a prediction
using the loaded scikit-learn pipeline.
"""
# Create a pandas DataFrame from the user inputs.
input_data = pd.DataFrame({
'Gender': [gender],
'Academic_Level': [academic_level],
'Most_Used_Platform': [most_used_platform],
'Relationship_Status': [relationship_status],
'Age': [age],
'Avg_Daily_Usage_Hours': [avg_daily_usage_hours],
'Sleep_Hours_Per_Night': [sleep_hours_per_night],
'Mental_Health_Score': [mental_health_score],
'Addicted_Score': [addicted_score],
'Conflicts_Over_Social_Media': [conflicts_over_social_media]
})
# Make a prediction. The pipeline handles the preprocessing automatically.
prediction = pipeline.predict(input_data)[0]
# Return a user-friendly result based on the prediction.
if prediction == 1:
return "Prediction: Yes, social media use is likely to impact academic performance."
else:
return "Prediction: No, social media use is likely not to impact academic performance."
# Define the Gradio interface components.
demo = gr.Interface(
fn=predict_impact,
inputs=[
gr.Dropdown(["Male", "Female"], label="Gender"),
gr.Dropdown(["Undergraduate", "Postgraduate", "High School"], label="Academic_Level"),
gr.Dropdown(["Instagram", "Facebook", "Twitter", "YouTube", "WhatsApp", "Other"], label="Most_Used_Platform"),
gr.Dropdown(["Single", "In a relationship"], label="Relationship_Status"),
gr.Slider(16, 25, value=20, label="Age"),
gr.Slider(0, 24, value=3.0, label="Avg_Daily_Usage_Hours"),
gr.Slider(0, 12, value=7, label="Sleep_Hours_Per_Night"),
gr.Slider(0, 10, value=5, label="Mental_Health_Score (0-10)"),
gr.Slider(0, 10, value=5, label="Addicted_Score (0-10)"),
gr.Dropdown([0, 1], label="Conflicts_Over_Social_Media (0=No, 1=Yes)")
],
outputs="text",
title="Social Media Addiction Impact on Academic Performance",
description="A Random Forest model to predict if social media use impacts a student's academic performance. This is not a diagnostic tool."
)
# Launch the Gradio app.
if __name__ == "__main__":
demo.launch()