Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import pandas as pd | |
import torch | |
from sentence_transformers import SentenceTransformer, util | |
import numpy as np | |
from typing import Dict, List, Tuple, Optional | |
import io | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from collections import defaultdict | |
import json | |
import traceback | |
import spaces # Import the spaces library | |
import tempfile | |
from dotenv import load_dotenv | |
import os | |
token_hf = os.getenv('HF_TOKEN') | |
load_dotenv() | |
class MultiClientThemeClassifier: | |
def __init__(self): | |
self.model = None | |
self.client_themes = {} | |
self.model_loaded = False | |
self.default_model = 'google/embeddinggemma-300m' | |
self.current_model_name = self.default_model | |
def load_model(self, model_name: str): | |
"""Load the embedding model onto the GPU, remembering the choice.""" | |
try: | |
# Prevent reloading the same model | |
if self.model_loaded and self.current_model_name == model_name: | |
return f"β Model '{model_name}' is already loaded." | |
self.model = None | |
self.client_themes = {} | |
self.model_loaded = False | |
print(f"Loading model: {model_name} onto CUDA device") | |
self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True,token=token_hf) | |
self.model_loaded = True | |
self.current_model_name = model_name | |
return f"β Model '{model_name}' loaded successfully onto GPU!" | |
except Exception as e: | |
self.model_loaded = False | |
error_details = traceback.format_exc() | |
return f"β Error loading model '{model_name}': {str(e)}\n\nDetails:\n{error_details}" | |
def _ensure_model_is_loaded(self) -> Optional[str]: | |
"""Internal helper to load the correct model if it's not already loaded.""" | |
if not self.model_loaded: | |
print(f"Model not loaded. Automatically loading last selected model: {self.current_model_name}...") | |
status = self.load_model(self.current_model_name) | |
if "Error" in status: | |
return status | |
return None | |
def add_client_themes(self, client_id: str, themes: List[str]): | |
"""Add themes for a specific client""" | |
error_status = self._ensure_model_is_loaded() | |
if error_status: return error_status | |
try: | |
self.client_themes[client_id] = {} | |
for theme in themes: | |
prototype = self.model.encode(theme, convert_to_tensor=True) | |
self.client_themes[client_id][theme] = prototype | |
return f"β Added {len(themes)} themes for client '{client_id}'" | |
except Exception as e: | |
return f"β Error adding themes: {str(e)}" | |
def classify_text(self, text: str, client_id: str, confidence_threshold: float = 0.3) -> Tuple[str, float, Dict[str, float]]: | |
"""Classify a single text for a specific client""" | |
error_status = self._ensure_model_is_loaded() | |
if error_status: return f"Error: {error_status}", 0.0, {} | |
if client_id not in self.client_themes: | |
return "Client not found", 0.0, {} | |
try: | |
text_embedding = self.model.encode(text, convert_to_tensor=True) | |
similarities = {theme: util.cos_sim(text_embedding, prototype).item() | |
for theme, prototype in self.client_themes[client_id].items()} | |
if not similarities: return "No themes for client", 0.0, {} | |
best_theme = max(similarities, key=similarities.get) | |
best_score = similarities[best_theme] | |
if best_score < confidence_threshold: | |
return "UNKNOWN_THEME", best_score, similarities | |
return best_theme, best_score, similarities | |
except Exception as e: | |
return f"Error: {str(e)}", 0.0, {} | |
# CORRECTED: The benchmark function now takes the model_name as an argument | |
def benchmark_csv(self, csv_filepath: str, client_id: str, model_name: str) -> Tuple[str, Optional[str], Optional[str]]: | |
"""Benchmark a specific model on a CSV file.""" | |
# Step 1: Explicitly load the model requested by the user for this benchmark run. | |
load_status = self.load_model(model_name) | |
# We allow the function to proceed if the model is "already loaded", but stop for any other error. | |
if "β" in load_status: | |
return f"β Model '{model_name}' could not be loaded for benchmarking.\n\nError: {load_status}", None, None | |
# Step 2: Proceed with the benchmark logic as before. | |
encodings_to_try = ['utf-8-sig', 'utf-8', 'cp1256', 'latin1'] | |
df = None | |
for encoding in encodings_to_try: | |
try: | |
df = pd.read_csv(csv_filepath, encoding=encoding) | |
print(f"Successfully read CSV with encoding: {encoding}") | |
break | |
except (UnicodeDecodeError, pd.errors.ParserError): | |
continue | |
if df is None: | |
return "β Could not decode the CSV. Please save it as 'UTF-8' and try again.", None, None | |
try: | |
if 'text' not in df.columns or 'real_tag' not in df.columns: | |
return f"β CSV must have 'text' and 'real_tag' columns! Found: {df.columns.to_list()}", None, None | |
df.dropna(subset=['text', 'real_tag'], inplace=True) | |
df['text'] = df['text'].astype(str) | |
df['real_tag'] = df['real_tag'].astype(str) | |
unique_themes = df['real_tag'].unique().tolist() | |
self.add_client_themes(client_id, unique_themes) | |
texts = df['text'].str.slice(0, 500).tolist() | |
results = [self.classify_text(text, client_id) for text in texts] | |
df['predicted_tag'] = [res[0] for res in results] | |
df['confidence'] = [res[1] for res in results] | |
correct = (df['real_tag'] == df['predicted_tag']).sum() | |
total = len(df) | |
accuracy = correct / total if total > 0 else 0 | |
results_summary = f"π **Benchmarking Results for `{self.current_model_name}`**\n\n**Accuracy: {accuracy:.2%}** ({correct}/{total})" | |
fig = px.bar(df['real_tag'].value_counts(), title="Theme Distribution", labels={'index': 'Theme', 'value': 'Count'}) | |
visualization_html = fig.to_html() | |
temp_file_path = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8-sig').name | |
df.to_csv(temp_file_path, index=False) | |
return results_summary, temp_file_path, visualization_html | |
except Exception as e: | |
error_details = traceback.format_exc() | |
return f"β Error during benchmarking: {str(e)}\n\n{error_details}", None, None | |
# Initialize the classifier | |
classifier = MultiClientThemeClassifier() | |
def load_model_interface(model_name: str): | |
return classifier.load_model(model_name.strip()) | |
def add_themes_interface(client_id: str, themes_text: str): | |
if not themes_text.strip(): return "β Please enter themes!" | |
themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()] | |
return classifier.add_client_themes(client_id, themes) | |
def classify_interface(text: str, client_id: str, confidence_threshold: float): | |
if not text.strip(): return "Please enter text to classify!", "" | |
pred_theme, confidence, similarities = classifier.classify_text(text, client_id, confidence_threshold) | |
sim_display = "**Similarity Scores:**\n" + "\n".join([f"- {theme}: {sim:.3f}" for theme, sim in sorted(similarities.items(), key=lambda x: x[1], reverse=True)]) | |
result = f"π― **Predicted Theme:** {pred_theme}\nπ₯ **Confidence:** {confidence:.3f}\n\n{sim_display}" | |
return result, "" | |
# CORRECTED: The interface now accepts model_name | |
def benchmark_interface(csv_file_obj, client_id: str, model_name: str): | |
if csv_file_obj is None: | |
return "Please upload a CSV file!", None, None | |
if not model_name.strip(): | |
return "Please enter a model name for the benchmark!", None, None | |
try: | |
csv_filepath = csv_file_obj.name | |
# Pass the model name from the UI down to the classifier method | |
return classifier.benchmark_csv(csv_filepath, client_id, model_name.strip()) | |
except Exception as e: | |
error_details = traceback.format_exc() | |
return f"β Error processing CSV file object: {str(e)}\n\nDetails:\n{error_details}", None, None | |
# --- Gradio Interface --- | |
with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π― Custom Themes Classification - MVP") | |
with gr.Tab("π Setup & Model"): | |
gr.Markdown("### Step 1: Load the Embedding Model (Optional)") | |
gr.Markdown("A default model (`google/embeddinggemma-300m`) will load automatically on first use. You can specify a different model here to use it in other tabs.") | |
with gr.Row(): | |
# This input is now used by the benchmark tab as well | |
model_input = gr.Textbox(label="HuggingFace Model Name", value="google/embeddinggemma-300m") | |
load_btn = gr.Button("Load Model", variant="primary") | |
load_status = gr.Textbox(label="Status", interactive=False) | |
load_btn.click(load_model_interface, inputs=[model_input], outputs=load_status) | |
gr.Markdown("### Step 2: Add Themes for a Client") | |
with gr.Row(): | |
client_input = gr.Textbox(label="Client ID", placeholder="e.g., client_1") | |
themes_input = gr.Textbox(label="Themes (one per line)", lines=5) | |
add_themes_btn = gr.Button("Add Themes", variant="secondary") | |
themes_status = gr.Textbox(label="Status", interactive=False) | |
add_themes_btn.click(add_themes_interface, inputs=[client_input, themes_input], outputs=themes_status) | |
with gr.Tab("π Single Text Classification"): | |
gr.Markdown("### Classify Individual Posts") | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox(label="Text to Classify", lines=3) | |
client_select = gr.Textbox(label="Client ID", placeholder="e.g., client_1") | |
confidence_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.1, label="Confidence Threshold") | |
classify_btn = gr.Button("Classify", variant="primary") | |
with gr.Column(): | |
classification_result = gr.Markdown(label="Results") | |
classify_btn.click(classify_interface, inputs=[text_input, client_select, confidence_slider], outputs=[classification_result, gr.Textbox(visible=False)]) | |
with gr.Tab("π CSV Benchmarking"): | |
gr.Markdown("### Benchmark on Your Dataset\nUpload a CSV with `text` and `real_tag` columns. The model from the 'Setup & Model' tab will be loaded and used for the benchmark.") | |
with gr.Row(): | |
with gr.Column(): | |
csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"]) | |
benchmark_client = gr.Textbox(label="Client ID for Benchmark", placeholder="e.g., benchmark_client") | |
benchmark_btn = gr.Button("Run Benchmark", variant="primary") | |
with gr.Column(): | |
benchmark_results = gr.Markdown(label="Benchmark Results") | |
with gr.Row(): | |
results_csv = gr.File(label="Download Detailed Results", interactive=False) | |
visualization = gr.HTML(label="Visualization") | |
# CORRECTED: The button now sends the model_input value to the benchmark function | |
benchmark_btn.click( | |
benchmark_interface, | |
inputs=[csv_upload, benchmark_client, model_input], | |
outputs=[benchmark_results, results_csv, visualization] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(share=True) |