import gradio as gr import torch import joblib import pandas as pd import os import json from safetensors.torch import load_file from typing import List, Tuple from network import PricePredictor MODEL_DIR = "model" DATA_DIR = "data" SCALER_PATH = os.path.join(DATA_DIR, "scaler.pkl") DATA_PATH = os.path.join(DATA_DIR, "pokemon_final_with_labels.csv") TARGET_COLUMN = 'price_will_rise_30_in_6m' def load_model_and_config(model_dir: str) -> Tuple[torch.nn.Module, List[str]]: config_path = os.path.join(model_dir, "config.json") with open(config_path, "r") as f: model_config = json.load(f) model = PricePredictor(input_size=model_config["input_size"]) weights_path = os.path.join(model_dir, "model.safetensors") model.load_state_dict(load_file(weights_path)) model.eval() return model, model_config["feature_columns"] def perform_prediction(model: torch.nn.Module, scaler, input_features: pd.Series) -> Tuple[bool, float]: features_np = input_features.to_numpy(dtype="float32").reshape(1, -1) features_scaled = scaler.transform(features_np) features_tensor = torch.tensor(features_scaled, dtype=torch.float32) with torch.no_grad(): logit = model(features_tensor) probability = torch.sigmoid(logit).item() predicted_class = bool(round(probability)) return predicted_class, probability try: model, feature_columns = load_model_and_config(MODEL_DIR) scaler = joblib.load(SCALER_PATH) full_data = pd.read_csv(DATA_PATH) ASSETS_LOADED = True except FileNotFoundError as e: print(f"Error loading necessary files: {e}") print("Please make sure you have uploaded the 'model' and 'data' directories to your Hugging Face Space.") ASSETS_LOADED = False def predict_price_trend(card_identifier: str) -> str: if not ASSETS_LOADED: return "## Application Error\nAssets could not be loaded. Please check the logs on Hugging Face Spaces for details. You may need to upload your `model` and `data` directories." if not card_identifier or not card_identifier.strip().isdigit(): return "## Input Error\nPlease enter a valid, numeric TCGPlayer ID." card_id = int(card_identifier.strip()) card_data = full_data[full_data['tcgplayer_id'] == card_id] if card_data.empty: return f"## Card Not Found\nCould not find a card with TCGPlayer ID '{card_id}'. Please check the ID and try again." card_sample = card_data.iloc[0] sample_features = card_sample[feature_columns] predicted_class, probability = perform_prediction(model, scaler, sample_features) prediction_text = "**RISE**" if predicted_class else "**NOT RISE**" confidence = probability if predicted_class else 1 - probability tcgplayer_id = card_sample['tcgplayer_id'] tcgplayer_link = f"https://www.tcgplayer.com/product/{tcgplayer_id}?Language=English" true_label_text = "" try: if TARGET_COLUMN in card_sample and pd.notna(card_sample[TARGET_COLUMN]): true_label = bool(card_sample[TARGET_COLUMN]) true_label_text = f"\n- **Actual Result in Dataset:** The price did **{'RISE' if true_label else 'NOT RISE'}**." except (KeyError, TypeError): pass output = f""" ## 🔮 Prediction Report for {card_sample['name']} - **Prediction:** The model predicts the card's price will {prediction_text} by 30% in the next 6 months. - **Confidence:** {confidence:.2%} - **View on TCGPlayer:** [Check Current Price]({tcgplayer_link}) {true_label_text} """ return output with gr.Blocks(theme=gr.themes.Soft(), title="PricePoke Predictor") as demo: gr.Markdown( """ # 📈 PricePoke: Pokémon Card Price Trend Predictor Enter a Pokémon card's TCGPlayer ID to predict whether its market price will increase by 30% or more over the next 6 months. This model was trained on historical TCGPlayer market data. """ ) with gr.Row(): with gr.Column(scale=1): card_input = gr.Textbox( label="TCGPlayer ID", placeholder="e.g., '84198'", info="Find the ID in the card's URL on TCGPlayer's website (e.g., tcgplayer.com/product/84198/... has ID 84198)." ) predict_button = gr.Button("Predict Trend", variant="primary") gr.Markdown("---") gr.Markdown("### Example Cards") if ASSETS_LOADED: example_df = full_data.sample(5, random_state=42)[['name', 'tcgplayer_id']] gr.Markdown(example_df.to_markdown(index=False)) else: gr.Markdown("Could not load examples.") with gr.Column(scale=2): output_markdown = gr.Markdown() predict_button.click(fn=predict_price_trend, inputs=[card_input], outputs=[output_markdown]) card_input.submit(fn=predict_price_trend, inputs=[card_input], outputs=[output_markdown]) if __name__ == "__main__": demo.launch()