ufc-predictor / app.py
AlvaroMros's picture
Refactor imports to use absolute paths and clean up scripts
9678fdb
import gradio as gr
import joblib
from datetime import datetime
import os
from src.predict.models import (
BaseMLModel,
EloBaselineModel,
LogisticRegressionModel,
XGBoostModel,
SVCModel,
RandomForestModel,
BernoulliNBModel,
LGBMModel
)
from src.config import MODELS_DIR
# --- Model Cache ---
# This global dictionary will store loaded models to avoid reloading them from disk.
MODEL_CACHE = {}
# --- Gradio App Setup ---
if not os.path.exists(MODELS_DIR):
os.makedirs(MODELS_DIR)
print(f"Warning: Models directory not found. Created a dummy directory at '{MODELS_DIR}'.")
# Get a list of available models
available_models = [f for f in os.listdir(MODELS_DIR) if f.endswith(".joblib")]
if not available_models:
print(f"Warning: No models found in '{MODELS_DIR}'. The dropdown will be empty.")
available_models.append("No models found")
# --- Prediction Function ---
def predict_fight(model_name, fighter1_name, fighter2_name):
"""
Loads the selected model and predicts the winner of a fight.
"""
if model_name == "No models found" or not fighter1_name or not fighter2_name:
return "Please select a model and enter both fighter names.", ""
try:
# Load model from cache or from disk if it's the first time
if model_name not in MODEL_CACHE:
print(f"Loading and caching model: {model_name}...")
model_path = os.path.join(MODELS_DIR, model_name)
MODEL_CACHE[model_name] = joblib.load(model_path)
print("...model cached.")
model = MODEL_CACHE[model_name]
fight = {
'fighter_1': fighter1_name,
'fighter_2': fighter2_name,
'event_date': datetime.now().strftime('%B %d, %Y')
}
prediction_result = model.predict(fight)
if prediction_result and prediction_result.get('winner'):
winner = prediction_result['winner']
prob = prediction_result['probability']
return winner, f"{prob:.1%}"
else:
return "Could not make a prediction.", ""
except FileNotFoundError:
return f"Error: Model file '{model_name}' not found.", ""
except Exception as e:
print(f"An error occurred during prediction: {e}")
return f"An error occurred: {e}", ""
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ₯‹ UFC Fight Predictor πŸ₯Š")
gr.Markdown("Select a prediction model and enter two fighter names to predict the outcome.")
with gr.Column():
model_dropdown = gr.Dropdown(
label="Select Model",
choices=available_models,
value=available_models[0] if available_models else None
)
with gr.Row():
fighter1_input = gr.Textbox(label="Fighter 1", placeholder="e.g., Jon Jones")
fighter2_input = gr.Textbox(label="Fighter 2", placeholder="e.g., Stipe Miocic")
predict_button = gr.Button("Predict Winner")
with gr.Column():
winner_output = gr.Textbox(label="Predicted Winner", interactive=False)
prob_output = gr.Textbox(label="Confidence", interactive=False)
predict_button.click(
fn=predict_fight,
inputs=[model_dropdown, fighter1_input, fighter2_input],
outputs=[winner_output, prob_output]
)
# --- Launch the App ---
demo.launch()