Spaces:
Running
Running
File size: 3,405 Bytes
338c9b0 9678fdb 338c9b0 0e63702 338c9b0 0e63702 338c9b0 0e63702 9678fdb 0e63702 9678fdb 338c9b0 f3ecc65 9678fdb f3ecc65 55723a4 338c9b0 55723a4 9678fdb 338c9b0 55723a4 338c9b0 55723a4 338c9b0 55723a4 338c9b0 9678fdb 338c9b0 9678fdb 338c9b0 9678fdb 338c9b0 9678fdb 55723a4 338c9b0 55723a4 338c9b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
import joblib
from datetime import datetime
import os
from src.predict.models import (
BaseMLModel,
EloBaselineModel,
LogisticRegressionModel,
XGBoostModel,
SVCModel,
RandomForestModel,
BernoulliNBModel,
LGBMModel
)
from src.config import MODELS_DIR
# --- Model Cache ---
# This global dictionary will store loaded models to avoid reloading them from disk.
MODEL_CACHE = {}
# --- Gradio App Setup ---
if not os.path.exists(MODELS_DIR):
os.makedirs(MODELS_DIR)
print(f"Warning: Models directory not found. Created a dummy directory at '{MODELS_DIR}'.")
# Get a list of available models
available_models = [f for f in os.listdir(MODELS_DIR) if f.endswith(".joblib")]
if not available_models:
print(f"Warning: No models found in '{MODELS_DIR}'. The dropdown will be empty.")
available_models.append("No models found")
# --- Prediction Function ---
def predict_fight(model_name, fighter1_name, fighter2_name):
"""
Loads the selected model and predicts the winner of a fight.
"""
if model_name == "No models found" or not fighter1_name or not fighter2_name:
return "Please select a model and enter both fighter names.", ""
try:
# Load model from cache or from disk if it's the first time
if model_name not in MODEL_CACHE:
print(f"Loading and caching model: {model_name}...")
model_path = os.path.join(MODELS_DIR, model_name)
MODEL_CACHE[model_name] = joblib.load(model_path)
print("...model cached.")
model = MODEL_CACHE[model_name]
fight = {
'fighter_1': fighter1_name,
'fighter_2': fighter2_name,
'event_date': datetime.now().strftime('%B %d, %Y')
}
prediction_result = model.predict(fight)
if prediction_result and prediction_result.get('winner'):
winner = prediction_result['winner']
prob = prediction_result['probability']
return winner, f"{prob:.1%}"
else:
return "Could not make a prediction.", ""
except FileNotFoundError:
return f"Error: Model file '{model_name}' not found.", ""
except Exception as e:
print(f"An error occurred during prediction: {e}")
return f"An error occurred: {e}", ""
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🥋 UFC Fight Predictor 🥊")
gr.Markdown("Select a prediction model and enter two fighter names to predict the outcome.")
with gr.Column():
model_dropdown = gr.Dropdown(
label="Select Model",
choices=available_models,
value=available_models[0] if available_models else None
)
with gr.Row():
fighter1_input = gr.Textbox(label="Fighter 1", placeholder="e.g., Jon Jones")
fighter2_input = gr.Textbox(label="Fighter 2", placeholder="e.g., Stipe Miocic")
predict_button = gr.Button("Predict Winner")
with gr.Column():
winner_output = gr.Textbox(label="Predicted Winner", interactive=False)
prob_output = gr.Textbox(label="Confidence", interactive=False)
predict_button.click(
fn=predict_fight,
inputs=[model_dropdown, fighter1_input, fighter2_input],
outputs=[winner_output, prob_output]
)
# --- Launch the App ---
demo.launch() |