Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,133 Bytes
0c5a75b 18a59a4 0c5a75b 18a59a4 40a7521 b60d459 7b53477 6b60724 6cbb6d1 0c5a75b 18a59a4 30f9702 18a59a4 407491f 30f9702 18a59a4 30f9702 13c7fa5 30f9702 92204cf 18a59a4 b60d459 cd9bb6e 18a59a4 30f9702 b60d459 13c7fa5 18a59a4 92204cf 30f9702 92204cf 30f9702 92204cf 30f9702 18a59a4 92204cf 7b53477 18a59a4 0c5a75b 18a59a4 92204cf 0c5a75b 18a59a4 92204cf 7b53477 92204cf 18a59a4 92204cf 18a59a4 7b53477 92204cf 18a59a4 26f50fc 18a59a4 26f50fc 30f9702 febe156 30f9702 febe156 30f9702 febe156 0c5a75b 2c7390f 92204cf 18a59a4 0c5a75b 92204cf 40a7521 30f9702 0c5a75b 92204cf 18a59a4 92204cf 40a7521 18a59a4 30f9702 18a59a4 2c7390f 7b53477 5a8e848 7b53477 92204cf 0c5a75b 18a59a4 92204cf 0c5a75b 18a59a4 0c5a75b b60d459 18a59a4 0c5a75b b60d459 18a59a4 7b53477 0c5a75b 18a59a4 b60d459 18a59a4 7b53477 18a59a4 92204cf 18a59a4 0c5a75b 26f50fc 92204cf 26f50fc 1bb89fc fdc6f42 26f50fc 0c5a75b 1bb89fc 26f50fc 0c5a75b 7b53477 1bb89fc 0c5a75b 2c7390f 40a7521 7b53477 0c5a75b 92204cf 407491f 18a59a4 26f50fc 407491f 18a59a4 0c5a75b 18a59a4 7b53477 0c5a75b 7b53477 18a59a4 7b53477 18a59a4 7b53477 18a59a4 0c5a75b 26f50fc 18a59a4 7b53477 18a59a4 26f50fc 0c5a75b 18a59a4 0c5a75b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
import gradio as gr
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer, util
import numpy as np
from typing import Dict, List, Tuple, Optional
import io
import plotly.express as px
import plotly.graph_objects as go
from collections import defaultdict
import json
import traceback
import spaces # Import the spaces library
import tempfile
from dotenv import load_dotenv
import os
token_hf = os.getenv('HF_TOKEN')
load_dotenv()
class MultiClientThemeClassifier:
def __init__(self):
self.model = None
self.client_themes = {}
self.model_loaded = False
self.default_model = 'google/embeddinggemma-300m'
self.current_model_name = self.default_model
def load_model(self, model_name: str):
"""Load the embedding model onto the GPU, remembering the choice."""
try:
# Prevent reloading the same model
if self.model_loaded and self.current_model_name == model_name:
return f"β
Model '{model_name}' is already loaded."
self.model = None
self.client_themes = {}
self.model_loaded = False
print(f"Loading model: {model_name} onto CUDA device")
self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True,token=token_hf)
self.model_loaded = True
self.current_model_name = model_name
return f"β
Model '{model_name}' loaded successfully onto GPU!"
except Exception as e:
self.model_loaded = False
error_details = traceback.format_exc()
return f"β Error loading model '{model_name}': {str(e)}\n\nDetails:\n{error_details}"
def _ensure_model_is_loaded(self) -> Optional[str]:
"""Internal helper to load the correct model if it's not already loaded."""
if not self.model_loaded:
print(f"Model not loaded. Automatically loading last selected model: {self.current_model_name}...")
status = self.load_model(self.current_model_name)
if "Error" in status:
return status
return None
def add_client_themes(self, client_id: str, themes: List[str]):
"""Add themes for a specific client"""
error_status = self._ensure_model_is_loaded()
if error_status: return error_status
try:
self.client_themes[client_id] = {}
for theme in themes:
prototype = self.model.encode(theme, convert_to_tensor=True)
self.client_themes[client_id][theme] = prototype
return f"β
Added {len(themes)} themes for client '{client_id}'"
except Exception as e:
return f"β Error adding themes: {str(e)}"
def classify_text(self, text: str, client_id: str, confidence_threshold: float = 0.3) -> Tuple[str, float, Dict[str, float]]:
"""Classify a single text for a specific client"""
error_status = self._ensure_model_is_loaded()
if error_status: return f"Error: {error_status}", 0.0, {}
if client_id not in self.client_themes:
return "Client not found", 0.0, {}
try:
text_embedding = self.model.encode(text, convert_to_tensor=True)
similarities = {theme: util.cos_sim(text_embedding, prototype).item()
for theme, prototype in self.client_themes[client_id].items()}
if not similarities: return "No themes for client", 0.0, {}
best_theme = max(similarities, key=similarities.get)
best_score = similarities[best_theme]
if best_score < confidence_threshold:
return "UNKNOWN_THEME", best_score, similarities
return best_theme, best_score, similarities
except Exception as e:
return f"Error: {str(e)}", 0.0, {}
# CORRECTED: The benchmark function now takes the model_name as an argument
def benchmark_csv(self, csv_filepath: str, client_id: str, model_name: str) -> Tuple[str, Optional[str], Optional[str]]:
"""Benchmark a specific model on a CSV file."""
# Step 1: Explicitly load the model requested by the user for this benchmark run.
load_status = self.load_model(model_name)
# We allow the function to proceed if the model is "already loaded", but stop for any other error.
if "β" in load_status:
return f"β Model '{model_name}' could not be loaded for benchmarking.\n\nError: {load_status}", None, None
# Step 2: Proceed with the benchmark logic as before.
encodings_to_try = ['utf-8-sig', 'utf-8', 'cp1256', 'latin1']
df = None
for encoding in encodings_to_try:
try:
df = pd.read_csv(csv_filepath, encoding=encoding)
print(f"Successfully read CSV with encoding: {encoding}")
break
except (UnicodeDecodeError, pd.errors.ParserError):
continue
if df is None:
return "β Could not decode the CSV. Please save it as 'UTF-8' and try again.", None, None
try:
if 'text' not in df.columns or 'real_tag' not in df.columns:
return f"β CSV must have 'text' and 'real_tag' columns! Found: {df.columns.to_list()}", None, None
df.dropna(subset=['text', 'real_tag'], inplace=True)
df['text'] = df['text'].astype(str)
df['real_tag'] = df['real_tag'].astype(str)
unique_themes = df['real_tag'].unique().tolist()
self.add_client_themes(client_id, unique_themes)
texts = df['text'].str.slice(0, 500).tolist()
results = [self.classify_text(text, client_id) for text in texts]
df['predicted_tag'] = [res[0] for res in results]
df['confidence'] = [res[1] for res in results]
correct = (df['real_tag'] == df['predicted_tag']).sum()
total = len(df)
accuracy = correct / total if total > 0 else 0
results_summary = f"π **Benchmarking Results for `{self.current_model_name}`**\n\n**Accuracy: {accuracy:.2%}** ({correct}/{total})"
fig = px.bar(df['real_tag'].value_counts(), title="Theme Distribution", labels={'index': 'Theme', 'value': 'Count'})
visualization_html = fig.to_html()
temp_file_path = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8-sig').name
df.to_csv(temp_file_path, index=False)
return results_summary, temp_file_path, visualization_html
except Exception as e:
error_details = traceback.format_exc()
return f"β Error during benchmarking: {str(e)}\n\n{error_details}", None, None
# Initialize the classifier
classifier = MultiClientThemeClassifier()
@spaces.GPU
def load_model_interface(model_name: str):
return classifier.load_model(model_name.strip())
@spaces.GPU
def add_themes_interface(client_id: str, themes_text: str):
if not themes_text.strip(): return "β Please enter themes!"
themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()]
return classifier.add_client_themes(client_id, themes)
@spaces.GPU
def classify_interface(text: str, client_id: str, confidence_threshold: float):
if not text.strip(): return "Please enter text to classify!", ""
pred_theme, confidence, similarities = classifier.classify_text(text, client_id, confidence_threshold)
sim_display = "**Similarity Scores:**\n" + "\n".join([f"- {theme}: {sim:.3f}" for theme, sim in sorted(similarities.items(), key=lambda x: x[1], reverse=True)])
result = f"π― **Predicted Theme:** {pred_theme}\nπ₯ **Confidence:** {confidence:.3f}\n\n{sim_display}"
return result, ""
# CORRECTED: The interface now accepts model_name
@spaces.GPU(duration=300)
def benchmark_interface(csv_file_obj, client_id: str, model_name: str):
if csv_file_obj is None:
return "Please upload a CSV file!", None, None
if not model_name.strip():
return "Please enter a model name for the benchmark!", None, None
try:
csv_filepath = csv_file_obj.name
# Pass the model name from the UI down to the classifier method
return classifier.benchmark_csv(csv_filepath, client_id, model_name.strip())
except Exception as e:
error_details = traceback.format_exc()
return f"β Error processing CSV file object: {str(e)}\n\nDetails:\n{error_details}", None, None
# --- Gradio Interface ---
with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo:
gr.Markdown("# π― Custom Themes Classification - MVP")
with gr.Tab("π Setup & Model"):
gr.Markdown("### Step 1: Load the Embedding Model (Optional)")
gr.Markdown("A default model (`google/embeddinggemma-300m`) will load automatically on first use. You can specify a different model here to use it in other tabs.")
with gr.Row():
# This input is now used by the benchmark tab as well
model_input = gr.Textbox(label="HuggingFace Model Name", value="google/embeddinggemma-300m")
load_btn = gr.Button("Load Model", variant="primary")
load_status = gr.Textbox(label="Status", interactive=False)
load_btn.click(load_model_interface, inputs=[model_input], outputs=load_status)
gr.Markdown("### Step 2: Add Themes for a Client")
with gr.Row():
client_input = gr.Textbox(label="Client ID", placeholder="e.g., client_1")
themes_input = gr.Textbox(label="Themes (one per line)", lines=5)
add_themes_btn = gr.Button("Add Themes", variant="secondary")
themes_status = gr.Textbox(label="Status", interactive=False)
add_themes_btn.click(add_themes_interface, inputs=[client_input, themes_input], outputs=themes_status)
with gr.Tab("π Single Text Classification"):
gr.Markdown("### Classify Individual Posts")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Text to Classify", lines=3)
client_select = gr.Textbox(label="Client ID", placeholder="e.g., client_1")
confidence_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.1, label="Confidence Threshold")
classify_btn = gr.Button("Classify", variant="primary")
with gr.Column():
classification_result = gr.Markdown(label="Results")
classify_btn.click(classify_interface, inputs=[text_input, client_select, confidence_slider], outputs=[classification_result, gr.Textbox(visible=False)])
with gr.Tab("π CSV Benchmarking"):
gr.Markdown("### Benchmark on Your Dataset\nUpload a CSV with `text` and `real_tag` columns. The model from the 'Setup & Model' tab will be loaded and used for the benchmark.")
with gr.Row():
with gr.Column():
csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
benchmark_client = gr.Textbox(label="Client ID for Benchmark", placeholder="e.g., benchmark_client")
benchmark_btn = gr.Button("Run Benchmark", variant="primary")
with gr.Column():
benchmark_results = gr.Markdown(label="Benchmark Results")
with gr.Row():
results_csv = gr.File(label="Download Detailed Results", interactive=False)
visualization = gr.HTML(label="Visualization")
# CORRECTED: The button now sends the model_input value to the benchmark function
benchmark_btn.click(
benchmark_interface,
inputs=[csv_upload, benchmark_client, model_input],
outputs=[benchmark_results, results_csv, visualization]
)
# Launch the app
if __name__ == "__main__":
demo.launch(share=True) |