Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import joblib | |
from huggingface_hub import hf_hub_download | |
import os | |
# Define the model path and file name on Hugging Face Hub. | |
# The filename now includes the 'data/' subdirectory. | |
repo_id = "Web4/LS-W4-Mini-RF_Addiction_Impact" | |
model_file = "data/LS-W4-Mini-RF_Addiction_Impact.joblib" | |
# Get the Hugging Face token from environment variables. | |
# This is required for gated repositories. | |
token = os.environ.get("HF_TOKEN") | |
# Download the model file from the Hugging Face Hub using the token. | |
try: | |
model_path = hf_hub_download(repo_id=repo_id, filename=model_file, token=token) | |
print(f"Model downloaded to: {model_path}") | |
except Exception as e: | |
# This error indicates that the file was not found or access was denied. | |
print(f"Error downloading model: {e}") | |
raise | |
# Load the scikit-learn pipeline from the downloaded joblib file. | |
pipeline = joblib.load(model_path) | |
# Define the prediction function for the Gradio interface. | |
def predict_impact( | |
gender, | |
academic_level, | |
most_used_platform, | |
relationship_status, | |
age, | |
avg_daily_usage_hours, | |
sleep_hours_per_night, | |
mental_health_score, | |
addicted_score, | |
conflicts_over_social_media | |
): | |
""" | |
Takes user inputs, creates a pandas DataFrame, and makes a prediction | |
using the loaded scikit-learn pipeline. | |
""" | |
# Create a pandas DataFrame from the user inputs. | |
input_data = pd.DataFrame({ | |
'Gender': [gender], | |
'Academic_Level': [academic_level], | |
'Most_Used_Platform': [most_used_platform], | |
'Relationship_Status': [relationship_status], | |
'Age': [age], | |
'Avg_Daily_Usage_Hours': [avg_daily_usage_hours], | |
'Sleep_Hours_Per_Night': [sleep_hours_per_night], | |
'Mental_Health_Score': [mental_health_score], | |
'Addicted_Score': [addicted_score], | |
'Conflicts_Over_Social_Media': [conflicts_over_social_media] | |
}) | |
# Make a prediction. The pipeline handles the preprocessing automatically. | |
prediction = pipeline.predict(input_data)[0] | |
# Return a user-friendly result based on the prediction. | |
if prediction == 1: | |
return "Prediction: Yes, social media use is likely to impact academic performance." | |
else: | |
return "Prediction: No, social media use is likely not to impact academic performance." | |
# Define the Gradio interface components. | |
demo = gr.Interface( | |
fn=predict_impact, | |
inputs=[ | |
gr.Dropdown(["Male", "Female"], label="Gender"), | |
gr.Dropdown(["Undergraduate", "Postgraduate", "High School"], label="Academic_Level"), | |
gr.Dropdown(["Instagram", "Facebook", "Twitter", "YouTube", "WhatsApp", "Other"], label="Most_Used_Platform"), | |
gr.Dropdown(["Single", "In a relationship"], label="Relationship_Status"), | |
gr.Slider(16, 25, value=20, label="Age"), | |
gr.Slider(0, 24, value=3.0, label="Avg_Daily_Usage_Hours"), | |
gr.Slider(0, 12, value=7, label="Sleep_Hours_Per_Night"), | |
gr.Slider(0, 10, value=5, label="Mental_Health_Score (0-10)"), | |
gr.Slider(0, 10, value=5, label="Addicted_Score (0-10)"), | |
gr.Dropdown([0, 1], label="Conflicts_Over_Social_Media (0=No, 1=Yes)") | |
], | |
outputs="text", | |
title="Social Media Addiction Impact on Academic Performance", | |
description="A Random Forest model to predict if social media use impacts a student's academic performance. This is not a diagnostic tool." | |
) | |
# Launch the Gradio app. | |
if __name__ == "__main__": | |
demo.launch() | |