File size: 12,133 Bytes
0c5a75b
 
 
 
 
18a59a4
0c5a75b
 
18a59a4
 
 
40a7521
b60d459
7b53477
6b60724
6cbb6d1
 
 
0c5a75b
 
 
18a59a4
30f9702
18a59a4
407491f
30f9702
18a59a4
30f9702
 
13c7fa5
30f9702
 
92204cf
 
 
 
 
18a59a4
b60d459
cd9bb6e
18a59a4
30f9702
b60d459
13c7fa5
18a59a4
 
 
 
92204cf
30f9702
92204cf
30f9702
 
92204cf
 
 
 
30f9702
18a59a4
92204cf
7b53477
18a59a4
0c5a75b
 
18a59a4
92204cf
0c5a75b
 
 
 
18a59a4
 
 
92204cf
7b53477
92204cf
18a59a4
 
 
 
 
92204cf
 
18a59a4
7b53477
92204cf
18a59a4
 
 
 
 
 
 
 
 
 
26f50fc
 
 
 
 
 
 
 
18a59a4
26f50fc
30f9702
febe156
 
 
 
 
30f9702
febe156
 
 
 
30f9702
febe156
0c5a75b
 
2c7390f
 
92204cf
18a59a4
 
 
0c5a75b
92204cf
40a7521
30f9702
 
0c5a75b
92204cf
 
18a59a4
92204cf
 
40a7521
18a59a4
30f9702
18a59a4
2c7390f
7b53477
 
5a8e848
7b53477
 
 
92204cf
0c5a75b
18a59a4
92204cf
0c5a75b
18a59a4
 
0c5a75b
b60d459
18a59a4
 
0c5a75b
b60d459
18a59a4
7b53477
0c5a75b
18a59a4
 
b60d459
18a59a4
7b53477
18a59a4
 
 
92204cf
 
18a59a4
 
0c5a75b
26f50fc
92204cf
26f50fc
1bb89fc
fdc6f42
26f50fc
 
0c5a75b
1bb89fc
26f50fc
 
0c5a75b
7b53477
1bb89fc
0c5a75b
2c7390f
40a7521
7b53477
0c5a75b
 
92204cf
407491f
18a59a4
26f50fc
407491f
18a59a4
0c5a75b
18a59a4
 
 
 
 
7b53477
0c5a75b
 
7b53477
18a59a4
 
 
 
 
7b53477
 
 
18a59a4
 
 
7b53477
18a59a4
0c5a75b
26f50fc
18a59a4
 
7b53477
 
18a59a4
 
 
 
 
 
26f50fc
 
 
 
 
 
 
0c5a75b
18a59a4
0c5a75b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import gradio as gr
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer, util
import numpy as np
from typing import Dict, List, Tuple, Optional
import io
import plotly.express as px
import plotly.graph_objects as go
from collections import defaultdict
import json
import traceback
import spaces # Import the spaces library
import tempfile
from dotenv import load_dotenv
import os
token_hf = os.getenv('HF_TOKEN')
load_dotenv()

class MultiClientThemeClassifier:
    def __init__(self):
        self.model = None
        self.client_themes = {}
        self.model_loaded = False
        self.default_model = 'google/embeddinggemma-300m'
        self.current_model_name = self.default_model
        
    def load_model(self, model_name: str):
        """Load the embedding model onto the GPU, remembering the choice."""
        try:
            # Prevent reloading the same model
            if self.model_loaded and self.current_model_name == model_name:
                return f"βœ… Model '{model_name}' is already loaded."

            self.model = None
            self.client_themes = {}
            self.model_loaded = False
            
            print(f"Loading model: {model_name} onto CUDA device")
            self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True,token=token_hf)
            self.model_loaded = True
            self.current_model_name = model_name
            return f"βœ… Model '{model_name}' loaded successfully onto GPU!"
        except Exception as e:
            self.model_loaded = False
            error_details = traceback.format_exc()
            return f"❌ Error loading model '{model_name}': {str(e)}\n\nDetails:\n{error_details}"
    
    def _ensure_model_is_loaded(self) -> Optional[str]:
        """Internal helper to load the correct model if it's not already loaded."""
        if not self.model_loaded:
            print(f"Model not loaded. Automatically loading last selected model: {self.current_model_name}...")
            status = self.load_model(self.current_model_name)
            if "Error" in status:
                return status
        return None

    def add_client_themes(self, client_id: str, themes: List[str]):
        """Add themes for a specific client"""
        error_status = self._ensure_model_is_loaded()
        if error_status: return error_status
        
        try:
            self.client_themes[client_id] = {}
            for theme in themes:
                prototype = self.model.encode(theme, convert_to_tensor=True)
                self.client_themes[client_id][theme] = prototype
            return f"βœ… Added {len(themes)} themes for client '{client_id}'"
        except Exception as e:
            return f"❌ Error adding themes: {str(e)}"
    
    def classify_text(self, text: str, client_id: str, confidence_threshold: float = 0.3) -> Tuple[str, float, Dict[str, float]]:
        """Classify a single text for a specific client"""
        error_status = self._ensure_model_is_loaded()
        if error_status: return f"Error: {error_status}", 0.0, {}

        if client_id not in self.client_themes:
            return "Client not found", 0.0, {}
        
        try:
            text_embedding = self.model.encode(text, convert_to_tensor=True)
            similarities = {theme: util.cos_sim(text_embedding, prototype).item() 
                            for theme, prototype in self.client_themes[client_id].items()}
            
            if not similarities: return "No themes for client", 0.0, {}

            best_theme = max(similarities, key=similarities.get)
            best_score = similarities[best_theme]
            
            if best_score < confidence_threshold:
                return "UNKNOWN_THEME", best_score, similarities
            
            return best_theme, best_score, similarities
        except Exception as e:
            return f"Error: {str(e)}", 0.0, {}
    
    # CORRECTED: The benchmark function now takes the model_name as an argument
    def benchmark_csv(self, csv_filepath: str, client_id: str, model_name: str) -> Tuple[str, Optional[str], Optional[str]]:
        """Benchmark a specific model on a CSV file."""
        # Step 1: Explicitly load the model requested by the user for this benchmark run.
        load_status = self.load_model(model_name)
        # We allow the function to proceed if the model is "already loaded", but stop for any other error.
        if "❌" in load_status:
             return f"❌ Model '{model_name}' could not be loaded for benchmarking.\n\nError: {load_status}", None, None
        
        # Step 2: Proceed with the benchmark logic as before.
        encodings_to_try = ['utf-8-sig', 'utf-8', 'cp1256', 'latin1']
        df = None
        for encoding in encodings_to_try:
            try:
                df = pd.read_csv(csv_filepath, encoding=encoding)
                print(f"Successfully read CSV with encoding: {encoding}")
                break
            except (UnicodeDecodeError, pd.errors.ParserError):
                continue
        
        if df is None:
            return "❌ Could not decode the CSV. Please save it as 'UTF-8' and try again.", None, None

        try:
            if 'text' not in df.columns or 'real_tag' not in df.columns:
                return f"❌ CSV must have 'text' and 'real_tag' columns! Found: {df.columns.to_list()}", None, None

            df.dropna(subset=['text', 'real_tag'], inplace=True)
            df['text'] = df['text'].astype(str)
            df['real_tag'] = df['real_tag'].astype(str)
            
            unique_themes = df['real_tag'].unique().tolist()
            self.add_client_themes(client_id, unique_themes)
            
            texts = df['text'].str.slice(0, 500).tolist()
            results = [self.classify_text(text, client_id) for text in texts]
            
            df['predicted_tag'] = [res[0] for res in results]
            df['confidence'] = [res[1] for res in results]
            
            correct = (df['real_tag'] == df['predicted_tag']).sum()
            total = len(df)
            accuracy = correct / total if total > 0 else 0
            
            results_summary = f"πŸ“Š **Benchmarking Results for `{self.current_model_name}`**\n\n**Accuracy: {accuracy:.2%}** ({correct}/{total})"
            
            fig = px.bar(df['real_tag'].value_counts(), title="Theme Distribution", labels={'index': 'Theme', 'value': 'Count'})
            visualization_html = fig.to_html()

            temp_file_path = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8-sig').name
            df.to_csv(temp_file_path, index=False)
            
            return results_summary, temp_file_path, visualization_html

        except Exception as e:
            error_details = traceback.format_exc()
            return f"❌ Error during benchmarking: {str(e)}\n\n{error_details}", None, None

# Initialize the classifier
classifier = MultiClientThemeClassifier()

@spaces.GPU
def load_model_interface(model_name: str):
    return classifier.load_model(model_name.strip())

@spaces.GPU
def add_themes_interface(client_id: str, themes_text: str):
    if not themes_text.strip(): return "❌ Please enter themes!"
    themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()]
    return classifier.add_client_themes(client_id, themes)

@spaces.GPU
def classify_interface(text: str, client_id: str, confidence_threshold: float):
    if not text.strip(): return "Please enter text to classify!", ""
    
    pred_theme, confidence, similarities = classifier.classify_text(text, client_id, confidence_threshold)
    
    sim_display = "**Similarity Scores:**\n" + "\n".join([f"- {theme}: {sim:.3f}" for theme, sim in sorted(similarities.items(), key=lambda x: x[1], reverse=True)])
    result = f"🎯 **Predicted Theme:** {pred_theme}\nπŸ”₯ **Confidence:** {confidence:.3f}\n\n{sim_display}"
    
    return result, ""

# CORRECTED: The interface now accepts model_name
@spaces.GPU(duration=300)
def benchmark_interface(csv_file_obj, client_id: str, model_name: str):
    if csv_file_obj is None:
        return "Please upload a CSV file!", None, None
    if not model_name.strip():
        return "Please enter a model name for the benchmark!", None, None
    try:
        csv_filepath = csv_file_obj.name
        # Pass the model name from the UI down to the classifier method
        return classifier.benchmark_csv(csv_filepath, client_id, model_name.strip())
    except Exception as e:
        error_details = traceback.format_exc()
        return f"❌ Error processing CSV file object: {str(e)}\n\nDetails:\n{error_details}", None, None

# --- Gradio Interface ---
with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🎯 Custom Themes Classification - MVP")
    
    with gr.Tab("πŸš€ Setup & Model"):
        gr.Markdown("### Step 1: Load the Embedding Model (Optional)")
        gr.Markdown("A default model (`google/embeddinggemma-300m`) will load automatically on first use. You can specify a different model here to use it in other tabs.")
        with gr.Row():
            # This input is now used by the benchmark tab as well
            model_input = gr.Textbox(label="HuggingFace Model Name", value="google/embeddinggemma-300m")
            load_btn = gr.Button("Load Model", variant="primary")
        load_status = gr.Textbox(label="Status", interactive=False)
        load_btn.click(load_model_interface, inputs=[model_input], outputs=load_status)
        
        gr.Markdown("### Step 2: Add Themes for a Client")
        with gr.Row():
            client_input = gr.Textbox(label="Client ID", placeholder="e.g., client_1")
            themes_input = gr.Textbox(label="Themes (one per line)", lines=5)
        add_themes_btn = gr.Button("Add Themes", variant="secondary")
        themes_status = gr.Textbox(label="Status", interactive=False)
        add_themes_btn.click(add_themes_interface, inputs=[client_input, themes_input], outputs=themes_status)
    
    with gr.Tab("πŸ” Single Text Classification"):
        gr.Markdown("### Classify Individual Posts")
        with gr.Row():
            with gr.Column():
                text_input = gr.Textbox(label="Text to Classify", lines=3)
                client_select = gr.Textbox(label="Client ID", placeholder="e.g., client_1")
                confidence_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.1, label="Confidence Threshold")
                classify_btn = gr.Button("Classify", variant="primary")
            with gr.Column():
                classification_result = gr.Markdown(label="Results")
        classify_btn.click(classify_interface, inputs=[text_input, client_select, confidence_slider], outputs=[classification_result, gr.Textbox(visible=False)])
    
    with gr.Tab("πŸ“Š CSV Benchmarking"):
        gr.Markdown("### Benchmark on Your Dataset\nUpload a CSV with `text` and `real_tag` columns. The model from the 'Setup & Model' tab will be loaded and used for the benchmark.")
        with gr.Row():
            with gr.Column():
                csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
                benchmark_client = gr.Textbox(label="Client ID for Benchmark", placeholder="e.g., benchmark_client")
                benchmark_btn = gr.Button("Run Benchmark", variant="primary")
            with gr.Column():
                benchmark_results = gr.Markdown(label="Benchmark Results")
        with gr.Row():
            results_csv = gr.File(label="Download Detailed Results", interactive=False)
            visualization = gr.HTML(label="Visualization")
        
        # CORRECTED: The button now sends the model_input value to the benchmark function
        benchmark_btn.click(
            benchmark_interface, 
            inputs=[csv_upload, benchmark_client, model_input], 
            outputs=[benchmark_results, results_csv, visualization]
        )

# Launch the app
if __name__ == "__main__":
    demo.launch(share=True)