import gradio as gr import pandas as pd import torch from sentence_transformers import SentenceTransformer, util import numpy as np from typing import Dict, List, Tuple, Optional import io import plotly.express as px import plotly.graph_objects as go from collections import defaultdict import json import traceback import spaces # Import the spaces library import tempfile from dotenv import load_dotenv import os token_hf = os.getenv('HF_TOKEN') load_dotenv() class MultiClientThemeClassifier: def __init__(self): self.model = None self.client_themes = {} self.model_loaded = False self.default_model = 'google/embeddinggemma-300m' self.current_model_name = self.default_model def load_model(self, model_name: str): """Load the embedding model onto the GPU, remembering the choice.""" try: # Prevent reloading the same model if self.model_loaded and self.current_model_name == model_name: return f"✅ Model '{model_name}' is already loaded." self.model = None self.client_themes = {} self.model_loaded = False print(f"Loading model: {model_name} onto CUDA device") self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True,token=token_hf) self.model_loaded = True self.current_model_name = model_name return f"✅ Model '{model_name}' loaded successfully onto GPU!" except Exception as e: self.model_loaded = False error_details = traceback.format_exc() return f"❌ Error loading model '{model_name}': {str(e)}\n\nDetails:\n{error_details}" def _ensure_model_is_loaded(self) -> Optional[str]: """Internal helper to load the correct model if it's not already loaded.""" if not self.model_loaded: print(f"Model not loaded. Automatically loading last selected model: {self.current_model_name}...") status = self.load_model(self.current_model_name) if "Error" in status: return status return None def add_client_themes(self, client_id: str, themes: List[str]): """Add themes for a specific client""" error_status = self._ensure_model_is_loaded() if error_status: return error_status try: self.client_themes[client_id] = {} for theme in themes: prototype = self.model.encode(theme, convert_to_tensor=True) self.client_themes[client_id][theme] = prototype return f"✅ Added {len(themes)} themes for client '{client_id}'" except Exception as e: return f"❌ Error adding themes: {str(e)}" def classify_text(self, text: str, client_id: str, confidence_threshold: float = 0.3) -> Tuple[str, float, Dict[str, float]]: """Classify a single text for a specific client""" error_status = self._ensure_model_is_loaded() if error_status: return f"Error: {error_status}", 0.0, {} if client_id not in self.client_themes: return "Client not found", 0.0, {} try: text_embedding = self.model.encode(text, convert_to_tensor=True) similarities = {theme: util.cos_sim(text_embedding, prototype).item() for theme, prototype in self.client_themes[client_id].items()} if not similarities: return "No themes for client", 0.0, {} best_theme = max(similarities, key=similarities.get) best_score = similarities[best_theme] if best_score < confidence_threshold: return "UNKNOWN_THEME", best_score, similarities return best_theme, best_score, similarities except Exception as e: return f"Error: {str(e)}", 0.0, {} # CORRECTED: The benchmark function now takes the model_name as an argument def benchmark_csv(self, csv_filepath: str, client_id: str, model_name: str) -> Tuple[str, Optional[str], Optional[str]]: """Benchmark a specific model on a CSV file.""" # Step 1: Explicitly load the model requested by the user for this benchmark run. load_status = self.load_model(model_name) # We allow the function to proceed if the model is "already loaded", but stop for any other error. if "❌" in load_status: return f"❌ Model '{model_name}' could not be loaded for benchmarking.\n\nError: {load_status}", None, None # Step 2: Proceed with the benchmark logic as before. encodings_to_try = ['utf-8-sig', 'utf-8', 'cp1256', 'latin1'] df = None for encoding in encodings_to_try: try: df = pd.read_csv(csv_filepath, encoding=encoding) print(f"Successfully read CSV with encoding: {encoding}") break except (UnicodeDecodeError, pd.errors.ParserError): continue if df is None: return "❌ Could not decode the CSV. Please save it as 'UTF-8' and try again.", None, None try: if 'text' not in df.columns or 'real_tag' not in df.columns: return f"❌ CSV must have 'text' and 'real_tag' columns! Found: {df.columns.to_list()}", None, None df.dropna(subset=['text', 'real_tag'], inplace=True) df['text'] = df['text'].astype(str) df['real_tag'] = df['real_tag'].astype(str) unique_themes = df['real_tag'].unique().tolist() self.add_client_themes(client_id, unique_themes) texts = df['text'].str.slice(0, 500).tolist() results = [self.classify_text(text, client_id) for text in texts] df['predicted_tag'] = [res[0] for res in results] df['confidence'] = [res[1] for res in results] correct = (df['real_tag'] == df['predicted_tag']).sum() total = len(df) accuracy = correct / total if total > 0 else 0 results_summary = f"📊 **Benchmarking Results for `{self.current_model_name}`**\n\n**Accuracy: {accuracy:.2%}** ({correct}/{total})" fig = px.bar(df['real_tag'].value_counts(), title="Theme Distribution", labels={'index': 'Theme', 'value': 'Count'}) visualization_html = fig.to_html() temp_file_path = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8-sig').name df.to_csv(temp_file_path, index=False) return results_summary, temp_file_path, visualization_html except Exception as e: error_details = traceback.format_exc() return f"❌ Error during benchmarking: {str(e)}\n\n{error_details}", None, None # Initialize the classifier classifier = MultiClientThemeClassifier() @spaces.GPU def load_model_interface(model_name: str): return classifier.load_model(model_name.strip()) @spaces.GPU def add_themes_interface(client_id: str, themes_text: str): if not themes_text.strip(): return "❌ Please enter themes!" themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()] return classifier.add_client_themes(client_id, themes) @spaces.GPU def classify_interface(text: str, client_id: str, confidence_threshold: float): if not text.strip(): return "Please enter text to classify!", "" pred_theme, confidence, similarities = classifier.classify_text(text, client_id, confidence_threshold) sim_display = "**Similarity Scores:**\n" + "\n".join([f"- {theme}: {sim:.3f}" for theme, sim in sorted(similarities.items(), key=lambda x: x[1], reverse=True)]) result = f"🎯 **Predicted Theme:** {pred_theme}\n🔥 **Confidence:** {confidence:.3f}\n\n{sim_display}" return result, "" # CORRECTED: The interface now accepts model_name @spaces.GPU(duration=300) def benchmark_interface(csv_file_obj, client_id: str, model_name: str): if csv_file_obj is None: return "Please upload a CSV file!", None, None if not model_name.strip(): return "Please enter a model name for the benchmark!", None, None try: csv_filepath = csv_file_obj.name # Pass the model name from the UI down to the classifier method return classifier.benchmark_csv(csv_filepath, client_id, model_name.strip()) except Exception as e: error_details = traceback.format_exc() return f"❌ Error processing CSV file object: {str(e)}\n\nDetails:\n{error_details}", None, None # --- Gradio Interface --- with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎯 Custom Themes Classification - MVP") with gr.Tab("🚀 Setup & Model"): gr.Markdown("### Step 1: Load the Embedding Model (Optional)") gr.Markdown("A default model (`google/embeddinggemma-300m`) will load automatically on first use. You can specify a different model here to use it in other tabs.") with gr.Row(): # This input is now used by the benchmark tab as well model_input = gr.Textbox(label="HuggingFace Model Name", value="google/embeddinggemma-300m") load_btn = gr.Button("Load Model", variant="primary") load_status = gr.Textbox(label="Status", interactive=False) load_btn.click(load_model_interface, inputs=[model_input], outputs=load_status) gr.Markdown("### Step 2: Add Themes for a Client") with gr.Row(): client_input = gr.Textbox(label="Client ID", placeholder="e.g., client_1") themes_input = gr.Textbox(label="Themes (one per line)", lines=5) add_themes_btn = gr.Button("Add Themes", variant="secondary") themes_status = gr.Textbox(label="Status", interactive=False) add_themes_btn.click(add_themes_interface, inputs=[client_input, themes_input], outputs=themes_status) with gr.Tab("🔍 Single Text Classification"): gr.Markdown("### Classify Individual Posts") with gr.Row(): with gr.Column(): text_input = gr.Textbox(label="Text to Classify", lines=3) client_select = gr.Textbox(label="Client ID", placeholder="e.g., client_1") confidence_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.1, label="Confidence Threshold") classify_btn = gr.Button("Classify", variant="primary") with gr.Column(): classification_result = gr.Markdown(label="Results") classify_btn.click(classify_interface, inputs=[text_input, client_select, confidence_slider], outputs=[classification_result, gr.Textbox(visible=False)]) with gr.Tab("📊 CSV Benchmarking"): gr.Markdown("### Benchmark on Your Dataset\nUpload a CSV with `text` and `real_tag` columns. The model from the 'Setup & Model' tab will be loaded and used for the benchmark.") with gr.Row(): with gr.Column(): csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"]) benchmark_client = gr.Textbox(label="Client ID for Benchmark", placeholder="e.g., benchmark_client") benchmark_btn = gr.Button("Run Benchmark", variant="primary") with gr.Column(): benchmark_results = gr.Markdown(label="Benchmark Results") with gr.Row(): results_csv = gr.File(label="Download Detailed Results", interactive=False) visualization = gr.HTML(label="Visualization") # CORRECTED: The button now sends the model_input value to the benchmark function benchmark_btn.click( benchmark_interface, inputs=[csv_upload, benchmark_client, model_input], outputs=[benchmark_results, results_csv, visualization] ) # Launch the app if __name__ == "__main__": demo.launch(share=True)