Mohaddz's picture
Update app.py
6b60724 verified
import gradio as gr
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer, util
import numpy as np
from typing import Dict, List, Tuple, Optional
import io
import plotly.express as px
import plotly.graph_objects as go
from collections import defaultdict
import json
import traceback
import spaces # Import the spaces library
import tempfile
from dotenv import load_dotenv
import os
token_hf = os.getenv('HF_TOKEN')
load_dotenv()
class MultiClientThemeClassifier:
def __init__(self):
self.model = None
self.client_themes = {}
self.model_loaded = False
self.default_model = 'google/embeddinggemma-300m'
self.current_model_name = self.default_model
def load_model(self, model_name: str):
"""Load the embedding model onto the GPU, remembering the choice."""
try:
# Prevent reloading the same model
if self.model_loaded and self.current_model_name == model_name:
return f"βœ… Model '{model_name}' is already loaded."
self.model = None
self.client_themes = {}
self.model_loaded = False
print(f"Loading model: {model_name} onto CUDA device")
self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True,token=token_hf)
self.model_loaded = True
self.current_model_name = model_name
return f"βœ… Model '{model_name}' loaded successfully onto GPU!"
except Exception as e:
self.model_loaded = False
error_details = traceback.format_exc()
return f"❌ Error loading model '{model_name}': {str(e)}\n\nDetails:\n{error_details}"
def _ensure_model_is_loaded(self) -> Optional[str]:
"""Internal helper to load the correct model if it's not already loaded."""
if not self.model_loaded:
print(f"Model not loaded. Automatically loading last selected model: {self.current_model_name}...")
status = self.load_model(self.current_model_name)
if "Error" in status:
return status
return None
def add_client_themes(self, client_id: str, themes: List[str]):
"""Add themes for a specific client"""
error_status = self._ensure_model_is_loaded()
if error_status: return error_status
try:
self.client_themes[client_id] = {}
for theme in themes:
prototype = self.model.encode(theme, convert_to_tensor=True)
self.client_themes[client_id][theme] = prototype
return f"βœ… Added {len(themes)} themes for client '{client_id}'"
except Exception as e:
return f"❌ Error adding themes: {str(e)}"
def classify_text(self, text: str, client_id: str, confidence_threshold: float = 0.3) -> Tuple[str, float, Dict[str, float]]:
"""Classify a single text for a specific client"""
error_status = self._ensure_model_is_loaded()
if error_status: return f"Error: {error_status}", 0.0, {}
if client_id not in self.client_themes:
return "Client not found", 0.0, {}
try:
text_embedding = self.model.encode(text, convert_to_tensor=True)
similarities = {theme: util.cos_sim(text_embedding, prototype).item()
for theme, prototype in self.client_themes[client_id].items()}
if not similarities: return "No themes for client", 0.0, {}
best_theme = max(similarities, key=similarities.get)
best_score = similarities[best_theme]
if best_score < confidence_threshold:
return "UNKNOWN_THEME", best_score, similarities
return best_theme, best_score, similarities
except Exception as e:
return f"Error: {str(e)}", 0.0, {}
# CORRECTED: The benchmark function now takes the model_name as an argument
def benchmark_csv(self, csv_filepath: str, client_id: str, model_name: str) -> Tuple[str, Optional[str], Optional[str]]:
"""Benchmark a specific model on a CSV file."""
# Step 1: Explicitly load the model requested by the user for this benchmark run.
load_status = self.load_model(model_name)
# We allow the function to proceed if the model is "already loaded", but stop for any other error.
if "❌" in load_status:
return f"❌ Model '{model_name}' could not be loaded for benchmarking.\n\nError: {load_status}", None, None
# Step 2: Proceed with the benchmark logic as before.
encodings_to_try = ['utf-8-sig', 'utf-8', 'cp1256', 'latin1']
df = None
for encoding in encodings_to_try:
try:
df = pd.read_csv(csv_filepath, encoding=encoding)
print(f"Successfully read CSV with encoding: {encoding}")
break
except (UnicodeDecodeError, pd.errors.ParserError):
continue
if df is None:
return "❌ Could not decode the CSV. Please save it as 'UTF-8' and try again.", None, None
try:
if 'text' not in df.columns or 'real_tag' not in df.columns:
return f"❌ CSV must have 'text' and 'real_tag' columns! Found: {df.columns.to_list()}", None, None
df.dropna(subset=['text', 'real_tag'], inplace=True)
df['text'] = df['text'].astype(str)
df['real_tag'] = df['real_tag'].astype(str)
unique_themes = df['real_tag'].unique().tolist()
self.add_client_themes(client_id, unique_themes)
texts = df['text'].str.slice(0, 500).tolist()
results = [self.classify_text(text, client_id) for text in texts]
df['predicted_tag'] = [res[0] for res in results]
df['confidence'] = [res[1] for res in results]
correct = (df['real_tag'] == df['predicted_tag']).sum()
total = len(df)
accuracy = correct / total if total > 0 else 0
results_summary = f"πŸ“Š **Benchmarking Results for `{self.current_model_name}`**\n\n**Accuracy: {accuracy:.2%}** ({correct}/{total})"
fig = px.bar(df['real_tag'].value_counts(), title="Theme Distribution", labels={'index': 'Theme', 'value': 'Count'})
visualization_html = fig.to_html()
temp_file_path = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8-sig').name
df.to_csv(temp_file_path, index=False)
return results_summary, temp_file_path, visualization_html
except Exception as e:
error_details = traceback.format_exc()
return f"❌ Error during benchmarking: {str(e)}\n\n{error_details}", None, None
# Initialize the classifier
classifier = MultiClientThemeClassifier()
@spaces.GPU
def load_model_interface(model_name: str):
return classifier.load_model(model_name.strip())
@spaces.GPU
def add_themes_interface(client_id: str, themes_text: str):
if not themes_text.strip(): return "❌ Please enter themes!"
themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()]
return classifier.add_client_themes(client_id, themes)
@spaces.GPU
def classify_interface(text: str, client_id: str, confidence_threshold: float):
if not text.strip(): return "Please enter text to classify!", ""
pred_theme, confidence, similarities = classifier.classify_text(text, client_id, confidence_threshold)
sim_display = "**Similarity Scores:**\n" + "\n".join([f"- {theme}: {sim:.3f}" for theme, sim in sorted(similarities.items(), key=lambda x: x[1], reverse=True)])
result = f"🎯 **Predicted Theme:** {pred_theme}\nπŸ”₯ **Confidence:** {confidence:.3f}\n\n{sim_display}"
return result, ""
# CORRECTED: The interface now accepts model_name
@spaces.GPU(duration=300)
def benchmark_interface(csv_file_obj, client_id: str, model_name: str):
if csv_file_obj is None:
return "Please upload a CSV file!", None, None
if not model_name.strip():
return "Please enter a model name for the benchmark!", None, None
try:
csv_filepath = csv_file_obj.name
# Pass the model name from the UI down to the classifier method
return classifier.benchmark_csv(csv_filepath, client_id, model_name.strip())
except Exception as e:
error_details = traceback.format_exc()
return f"❌ Error processing CSV file object: {str(e)}\n\nDetails:\n{error_details}", None, None
# --- Gradio Interface ---
with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎯 Custom Themes Classification - MVP")
with gr.Tab("πŸš€ Setup & Model"):
gr.Markdown("### Step 1: Load the Embedding Model (Optional)")
gr.Markdown("A default model (`google/embeddinggemma-300m`) will load automatically on first use. You can specify a different model here to use it in other tabs.")
with gr.Row():
# This input is now used by the benchmark tab as well
model_input = gr.Textbox(label="HuggingFace Model Name", value="google/embeddinggemma-300m")
load_btn = gr.Button("Load Model", variant="primary")
load_status = gr.Textbox(label="Status", interactive=False)
load_btn.click(load_model_interface, inputs=[model_input], outputs=load_status)
gr.Markdown("### Step 2: Add Themes for a Client")
with gr.Row():
client_input = gr.Textbox(label="Client ID", placeholder="e.g., client_1")
themes_input = gr.Textbox(label="Themes (one per line)", lines=5)
add_themes_btn = gr.Button("Add Themes", variant="secondary")
themes_status = gr.Textbox(label="Status", interactive=False)
add_themes_btn.click(add_themes_interface, inputs=[client_input, themes_input], outputs=themes_status)
with gr.Tab("πŸ” Single Text Classification"):
gr.Markdown("### Classify Individual Posts")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Text to Classify", lines=3)
client_select = gr.Textbox(label="Client ID", placeholder="e.g., client_1")
confidence_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.1, label="Confidence Threshold")
classify_btn = gr.Button("Classify", variant="primary")
with gr.Column():
classification_result = gr.Markdown(label="Results")
classify_btn.click(classify_interface, inputs=[text_input, client_select, confidence_slider], outputs=[classification_result, gr.Textbox(visible=False)])
with gr.Tab("πŸ“Š CSV Benchmarking"):
gr.Markdown("### Benchmark on Your Dataset\nUpload a CSV with `text` and `real_tag` columns. The model from the 'Setup & Model' tab will be loaded and used for the benchmark.")
with gr.Row():
with gr.Column():
csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
benchmark_client = gr.Textbox(label="Client ID for Benchmark", placeholder="e.g., benchmark_client")
benchmark_btn = gr.Button("Run Benchmark", variant="primary")
with gr.Column():
benchmark_results = gr.Markdown(label="Benchmark Results")
with gr.Row():
results_csv = gr.File(label="Download Detailed Results", interactive=False)
visualization = gr.HTML(label="Visualization")
# CORRECTED: The button now sends the model_input value to the benchmark function
benchmark_btn.click(
benchmark_interface,
inputs=[csv_upload, benchmark_client, model_input],
outputs=[benchmark_results, results_csv, visualization]
)
# Launch the app
if __name__ == "__main__":
demo.launch(share=True)