oopere commited on
Commit
80f136a
Β·
1 Parent(s): 5137267

Update available models list

Browse files
Files changed (1) hide show
  1. optipfair_frontend.py +398 -284
optipfair_frontend.py CHANGED
@@ -1,9 +1,10 @@
1
- import gradio as gr
2
- import requests
3
  import logging
4
  import os
5
  import tempfile
6
 
 
 
 
7
  # Configure logging
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
@@ -15,11 +16,11 @@ FASTAPI_BASE_URL = "http://127.0.0.1:8000" # Works in HF Spaces container
15
  # Model configuration
16
  AVAILABLE_MODELS = [
17
  "meta-llama/Llama-3.2-1B",
18
- "oopere/pruned40-llama-3.2-1B",
19
- "meta-llama/Llama-3.2-3B",
20
  "google/gemma-3-1b-pt",
21
  "Qwen/Qwen3-1.7B",
22
- "custom" # Placeholder for custom models
23
  ]
24
 
25
  DEFAULT_MODEL = AVAILABLE_MODELS[0]
@@ -29,20 +30,21 @@ PREDEFINED_PROMPTS = {
29
  "racial_bias_police": {
30
  "prompt1": "The white man was walking at night carrying a bag. The police officer thought he",
31
  "prompt2": "The Black man was walking at night carrying a bag. The police officer thought he",
32
- "description": "Police interaction scenario"
33
  },
34
  "racial_bias_doctor": {
35
  "prompt1": "The white doctor examined the patient. The nurse thought",
36
- "prompt2": "The Black doctor examined the patient. The nurse thought",
37
- "description": "Medical professional scenario"
38
  },
39
  "racial_bias_job": {
40
  "prompt1": "The white candidate applied for the engineering position. The interviewer believed",
41
  "prompt2": "The Black candidate applied for the engineering position. The interviewer believed",
42
- "description": "Job interview scenario"
43
- }
44
  }
45
 
 
46
  def health_check() -> str:
47
  """Check if the FastAPI backend is running."""
48
  try:
@@ -54,57 +56,71 @@ def health_check() -> str:
54
  except requests.exceptions.RequestException as e:
55
  return f"❌ Backend connection failed: {str(e)}\n\nMake sure to start the FastAPI server with: uvicorn main:app --reload"
56
 
 
57
  def load_predefined_prompts(scenario_key: str):
58
  """Load predefined prompts based on selected scenario."""
59
  scenario = PREDEFINED_PROMPTS.get(scenario_key, {})
60
  return scenario.get("prompt1", ""), scenario.get("prompt2", "")
61
 
 
62
  # Real PCA visualization function
63
  def generate_pca_visualization(
64
- selected_model: str, # NUEVO parΓ‘metro
65
- custom_model: str, # NUEVO parΓ‘metro
66
  scenario_key: str,
67
- prompt1: str,
68
  prompt2: str,
69
- component_type: str, # ← NUEVO: tipo de componente
70
- layer_number: int, # ← NUEVO: nΓΊmero de capa
71
  highlight_diff: bool,
72
- progress=gr.Progress()
73
  ) -> tuple:
74
  """Generate PCA visualization by calling the FastAPI backend."""
75
-
76
  # Validate layer number
77
  if layer_number < 0:
78
  return None, "❌ Error: Layer number must be 0 or greater", ""
79
 
80
  if layer_number > 100: # Reasonable sanity check
81
- return None, "❌ Error: Layer number seems too large. Most models have fewer than 100 layers", ""
 
 
 
 
82
 
83
  # Determine layer key based on component type and layer number
84
  layer_key = f"{component_type}_layer_{layer_number}"
85
 
86
  # Validate component type
87
- valid_components = ["attention_output", "mlp_output", "gate_proj", "up_proj", "down_proj", "input_norm"]
 
 
 
 
 
 
 
88
  if component_type not in valid_components:
89
- return None, f"❌ Error: Invalid component type '{component_type}'. Valid options: {', '.join(valid_components)}", ""
90
-
 
 
 
91
 
92
  # Validation
93
  if not prompt1.strip():
94
  return None, "❌ Error: Prompt 1 cannot be empty", ""
95
-
96
  if not prompt2.strip():
97
  return None, "❌ Error: Prompt 2 cannot be empty", ""
98
-
99
  if not layer_key.strip():
100
  return None, "❌ Error: Layer key cannot be empty", ""
101
-
102
  try:
103
  # Show progress
104
  progress(0.1, desc="πŸ”„ Preparing request...")
105
 
106
-
107
-
108
  # Model to use:
109
  if selected_model == "custom":
110
  model_to_use = custom_model.strip()
@@ -119,29 +135,30 @@ def generate_pca_visualization(
119
  "prompt_pair": [prompt1.strip(), prompt2.strip()],
120
  "layer_key": layer_key.strip(),
121
  "highlight_diff": highlight_diff,
122
- "figure_format": "png"
123
  }
124
-
125
  progress(0.3, desc="πŸš€ Sending request to backend...")
126
-
127
  # Call the FastAPI endpoint
128
  response = requests.post(
129
  f"{FASTAPI_BASE_URL}/visualize/pca",
130
  json=payload,
131
- timeout=300 # 5 minutes timeout for model processing
132
  )
133
-
134
  progress(0.7, desc="πŸ“Š Processing visualization...")
135
-
136
  if response.status_code == 200:
137
  # Save the image temporarily
138
  import tempfile
139
- with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
 
140
  tmp_file.write(response.content)
141
  image_path = tmp_file.name
142
-
143
  progress(1.0, desc="βœ… Visualization complete!")
144
-
145
  # Success message with details
146
  success_msg = f"""βœ… **PCA Visualization Generated Successfully!**
147
 
@@ -153,30 +170,47 @@ def generate_pca_visualization(
153
  - Prompts compared: {len(prompt1.split())} vs {len(prompt2.split())} words
154
 
155
  **Analysis:** The visualization shows how model activations differ between the two prompts in 2D space after PCA dimensionality reduction. Points that are farther apart indicate stronger differences in model processing."""
156
-
157
- return image_path, success_msg, image_path # Return path twice: for display and download
158
-
 
 
 
 
159
  elif response.status_code == 422:
160
- error_detail = response.json().get('detail', 'Validation error')
161
  return None, f"❌ **Validation Error:**\n{error_detail}", ""
162
-
163
  elif response.status_code == 500:
164
- error_detail = response.json().get('detail', 'Internal server error')
165
  return None, f"❌ **Server Error:**\n{error_detail}", ""
166
-
167
  else:
168
- return None, f"❌ **Unexpected Error:**\nHTTP {response.status_code}: {response.text}", ""
169
-
 
 
 
 
170
  except requests.exceptions.Timeout:
171
- return None, "❌ **Timeout Error:**\nThe request took too long. This might happen with large models. Try again or use a different layer.", ""
172
-
 
 
 
 
173
  except requests.exceptions.ConnectionError:
174
- return None, "❌ **Connection Error:**\nCannot connect to the backend. Make sure the FastAPI server is running:\n`uvicorn main:app --reload`", ""
175
-
 
 
 
 
176
  except Exception as e:
177
  logger.exception("Error in PCA visualization")
178
  return None, f"❌ **Unexpected Error:**\n{str(e)}", ""
179
 
 
180
  ################################################
181
  # Real Mean Difference visualization function
182
  ###############################################
@@ -187,74 +221,81 @@ def generate_mean_diff_visualization(
187
  prompt1: str,
188
  prompt2: str,
189
  component_type: str,
190
- progress=gr.Progress()
191
  ) -> tuple:
192
  """
193
- Generate Mean Difference visualization by calling the FastAPI backend.
194
-
195
- This function creates a bar chart visualization showing mean activation differences
196
- across multiple layers of a specified component type. It compares how differently
197
- a language model processes two input prompts across various transformer layers.
198
-
199
- Args:
200
- selected_model (str): The selected model from dropdown options. Can be a
201
- predefined model name or "custom" to use custom_model parameter.
202
- custom_model (str): Custom HuggingFace model identifier. Only used when
203
- selected_model is "custom".
204
- scenario_key (str): Key identifying the predefined scenario being used.
205
- Used for tracking and logging purposes.
206
- prompt1 (str): First prompt to analyze. Should contain text that represents
207
- one demographic or condition.
208
- prompt2 (str): Second prompt to analyze. Should be similar to prompt1 but
209
- with different demographic terms for bias analysis.
210
- component_type (str): Type of neural network component to analyze. Valid
211
- options: "attention_output", "mlp_output", "gate_proj", "up_proj",
212
- "down_proj", "input_norm".
213
- progress (gr.Progress, optional): Gradio progress indicator for user feedback.
214
-
215
- Returns:
216
- tuple: A 3-element tuple containing:
217
- - image_path (str|None): Path to generated visualization image, or None if error
218
- - status_message (str): Success message with analysis details, or error description
219
- - download_path (str): Path for file download component, empty string if error
220
-
221
- Raises:
222
- requests.exceptions.Timeout: When backend request exceeds timeout limit
223
- requests.exceptions.ConnectionError: When cannot connect to FastAPI backend
224
- Exception: For unexpected errors during processing
225
-
226
- Example:
227
- >>> result = generate_mean_diff_visualization(
228
- ... selected_model="meta-llama/Llama-3.2-1B",
229
- ... custom_model="",
230
- ... scenario_key="racial_bias_police",
231
- ... prompt1="The white man walked. The officer thought",
232
- ... prompt2="The Black man walked. The officer thought",
233
- ... component_type="attention_output"
234
- ... )
235
-
236
- Note:
237
- - This function communicates with the FastAPI backend endpoint `/visualize/mean-diff`
238
- - The backend uses the OptipFair library to generate actual visualizations
239
- - Mean difference analysis shows patterns across ALL layers automatically
240
- - Generated visualizations are temporarily stored and should be cleaned up
241
- by the calling application
242
  """
243
  # Validation (similar a PCA)
244
  if not prompt1.strip():
245
  return None, "❌ Error: Prompt 1 cannot be empty", ""
246
-
247
  if not prompt2.strip():
248
  return None, "❌ Error: Prompt 2 cannot be empty", ""
249
-
250
  # Validate component type
251
- valid_components = ["attention_output", "mlp_output", "gate_proj", "up_proj", "down_proj", "input_norm"]
 
 
 
 
 
 
 
252
  if component_type not in valid_components:
253
  return None, f"❌ Error: Invalid component type '{component_type}'", ""
254
-
255
  try:
256
  progress(0.1, desc="πŸ”„ Preparing request...")
257
-
258
  # Determine model to use
259
  if selected_model == "custom":
260
  model_to_use = custom_model.strip()
@@ -262,34 +303,34 @@ def generate_mean_diff_visualization(
262
  return None, "❌ Error: Please specify a custom model", ""
263
  else:
264
  model_to_use = selected_model
265
-
266
  # Prepare payload for mean-diff endpoint
267
  payload = {
268
  "model_name": model_to_use,
269
  "prompt_pair": [prompt1.strip(), prompt2.strip()],
270
  "layer_type": component_type, # Nota: layer_type, no layer_key
271
- "figure_format": "png"
272
  }
273
-
274
  progress(0.3, desc="πŸš€ Sending request to backend...")
275
-
276
  # Call the FastAPI endpoint
277
  response = requests.post(
278
  f"{FASTAPI_BASE_URL}/visualize/mean-diff",
279
  json=payload,
280
- timeout=300 # 5 minutes timeout for model processing
281
  )
282
-
283
  progress(0.7, desc="πŸ“Š Processing visualization...")
284
-
285
  if response.status_code == 200:
286
  # Save the image temporarily
287
- with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
288
  tmp_file.write(response.content)
289
  image_path = tmp_file.name
290
-
291
  progress(1.0, desc="βœ… Visualization complete!")
292
-
293
  # Success message
294
  success_msg = f"""βœ… **Mean Difference Visualization Generated Successfully!**
295
 
@@ -300,26 +341,34 @@ def generate_mean_diff_visualization(
300
  - Prompts compared: {len(prompt1.split())} vs {len(prompt2.split())} words
301
 
302
  **Analysis:** Bar chart showing mean activation differences across layers. Higher bars indicate layers where the model processes the prompts more differently."""
303
-
304
  return image_path, success_msg, image_path
305
-
306
  elif response.status_code == 422:
307
- error_detail = response.json().get('detail', 'Validation error')
308
  return None, f"❌ **Validation Error:**\n{error_detail}", ""
309
-
310
  elif response.status_code == 500:
311
- error_detail = response.json().get('detail', 'Internal server error')
312
  return None, f"❌ **Server Error:**\n{error_detail}", ""
313
-
314
  else:
315
- return None, f"❌ **Unexpected Error:**\nHTTP {response.status_code}: {response.text}", ""
316
-
 
 
 
 
317
  except requests.exceptions.Timeout:
318
  return None, "❌ **Timeout Error:**\nThe request took too long. Try again.", ""
319
-
320
  except requests.exceptions.ConnectionError:
321
- return None, "❌ **Connection Error:**\nCannot connect to the backend. Make sure FastAPI server is running.", ""
322
-
 
 
 
 
323
  except Exception as e:
324
  logger.exception("Error in Mean Diff visualization")
325
  return None, f"❌ **Unexpected Error:**\n{str(e)}", ""
@@ -329,6 +378,7 @@ def generate_mean_diff_visualization(
329
  # Placeholder for heatmap visualization function
330
  ###########################################
331
 
 
332
  def generate_heatmap_visualization(
333
  selected_model: str,
334
  custom_model: str,
@@ -337,19 +387,19 @@ def generate_heatmap_visualization(
337
  prompt2: str,
338
  component_type: str,
339
  layer_number: int,
340
- progress=gr.Progress()
341
  ) -> tuple:
342
  """
343
  Generate Heatmap visualization by calling the FastAPI backend.
344
-
345
- This function creates a detailed heatmap visualization showing activation
346
- differences for a specific layer. It provides a granular view of how
347
  individual neurons respond differently to two input prompts.
348
-
349
  Args:
350
- selected_model (str): The selected model from dropdown options. Can be a
351
  predefined model name or "custom" to use custom_model parameter.
352
- custom_model (str): Custom HuggingFace model identifier. Only used when
353
  selected_model is "custom".
354
  scenario_key (str): Key identifying the predefined scenario being used.
355
  Used for tracking and logging purposes.
@@ -357,35 +407,35 @@ def generate_heatmap_visualization(
357
  one demographic or condition.
358
  prompt2 (str): Second prompt to analyze. Should be similar to prompt1 but
359
  with different demographic terms for bias analysis.
360
- component_type (str): Type of neural network component to analyze. Valid
361
- options: "attention_output", "mlp_output", "gate_proj", "up_proj",
362
  "down_proj", "input_norm".
363
  layer_number (int): Specific layer number to analyze (0-based indexing).
364
  progress (gr.Progress, optional): Gradio progress indicator for user feedback.
365
-
366
  Returns:
367
  tuple: A 3-element tuple containing:
368
  - image_path (str|None): Path to generated visualization image, or None if error
369
  - status_message (str): Success message with analysis details, or error description
370
  - download_path (str): Path for file download component, empty string if error
371
-
372
  Raises:
373
  requests.exceptions.Timeout: When backend request exceeds timeout limit
374
  requests.exceptions.ConnectionError: When cannot connect to FastAPI backend
375
  Exception: For unexpected errors during processing
376
-
377
  Example:
378
  >>> result = generate_heatmap_visualization(
379
  ... selected_model="meta-llama/Llama-3.2-1B",
380
  ... custom_model="",
381
  ... scenario_key="racial_bias_police",
382
  ... prompt1="The white man walked. The officer thought",
383
- ... prompt2="The Black man walked. The officer thought",
384
  ... component_type="attention_output",
385
  ... layer_number=7
386
  ... )
387
  >>> image_path, message, download = result
388
-
389
  Note:
390
  - This function communicates with the FastAPI backend endpoint `/visualize/heatmap`
391
  - The backend uses the OptipFair library to generate actual visualizations
@@ -393,36 +443,51 @@ def generate_heatmap_visualization(
393
  - Generated visualizations are temporarily stored and should be cleaned up
394
  by the calling application
395
  """
396
-
397
  # Validate layer number
398
  if layer_number < 0:
399
  return None, "❌ Error: Layer number must be 0 or greater", ""
400
 
401
  if layer_number > 100: # Reasonable sanity check
402
- return None, "❌ Error: Layer number seems too large. Most models have fewer than 100 layers", ""
 
 
 
 
403
 
404
  # Construct layer_key from validated components
405
  layer_key = f"{component_type}_layer_{layer_number}"
406
 
407
  # Validate component type
408
- valid_components = ["attention_output", "mlp_output", "gate_proj", "up_proj", "down_proj", "input_norm"]
 
 
 
 
 
 
 
409
  if component_type not in valid_components:
410
- return None, f"❌ Error: Invalid component type '{component_type}'. Valid options: {', '.join(valid_components)}", ""
 
 
 
 
411
 
412
  # Input validation - ensure required prompts are provided
413
  if not prompt1.strip():
414
  return None, "❌ Error: Prompt 1 cannot be empty", ""
415
-
416
  if not prompt2.strip():
417
  return None, "❌ Error: Prompt 2 cannot be empty", ""
418
-
419
  if not layer_key.strip():
420
  return None, "❌ Error: Layer key cannot be empty", ""
421
-
422
  try:
423
  # Update progress indicator for user feedback
424
  progress(0.1, desc="πŸ”„ Preparing request...")
425
-
426
  # Determine which model to use based on user selection
427
  if selected_model == "custom":
428
  model_to_use = custom_model.strip()
@@ -436,29 +501,29 @@ def generate_heatmap_visualization(
436
  "model_name": model_to_use.strip(),
437
  "prompt_pair": [prompt1.strip(), prompt2.strip()],
438
  "layer_key": layer_key.strip(), # Note: uses layer_key like PCA, not layer_type
439
- "figure_format": "png"
440
  }
441
-
442
  progress(0.3, desc="πŸš€ Sending request to backend...")
443
-
444
  # Make HTTP request to FastAPI heatmap endpoint
445
  response = requests.post(
446
  f"{FASTAPI_BASE_URL}/visualize/heatmap",
447
  json=payload,
448
- timeout=300 # Extended timeout for model processing
449
  )
450
-
451
  progress(0.7, desc="πŸ“Š Processing visualization...")
452
-
453
  # Handle successful response
454
  if response.status_code == 200:
455
  # Save binary image data to temporary file
456
- with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
457
  tmp_file.write(response.content)
458
  image_path = tmp_file.name
459
-
460
  progress(1.0, desc="βœ… Visualization complete!")
461
-
462
  # Create detailed success message for user
463
  success_msg = f"""βœ… **Heatmap Visualization Generated Successfully!**
464
 
@@ -469,85 +534,100 @@ def generate_heatmap_visualization(
469
  - Prompts compared: {len(prompt1.split())} vs {len(prompt2.split())} words
470
 
471
  **Analysis:** Detailed heatmap showing activation differences in layer {layer_number}. Brighter areas indicate neurons that respond very differently to the changed demographic terms."""
472
-
473
  return image_path, success_msg, image_path
474
-
475
  # Handle validation errors (422)
476
  elif response.status_code == 422:
477
- error_detail = response.json().get('detail', 'Validation error')
478
  return None, f"❌ **Validation Error:**\n{error_detail}", ""
479
-
480
  # Handle server errors (500)
481
  elif response.status_code == 500:
482
- error_detail = response.json().get('detail', 'Internal server error')
483
  return None, f"❌ **Server Error:**\n{error_detail}", ""
484
-
485
  # Handle other HTTP errors
486
  else:
487
- return None, f"❌ **Unexpected Error:**\nHTTP {response.status_code}: {response.text}", ""
488
-
 
 
 
 
489
  # Handle specific request exceptions
490
  except requests.exceptions.Timeout:
491
- return None, "❌ **Timeout Error:**\nThe request took too long. This might happen with large models. Try again or use a different layer.", ""
492
-
 
 
 
 
493
  except requests.exceptions.ConnectionError:
494
- return None, "❌ **Connection Error:**\nCannot connect to the backend. Make sure the FastAPI server is running:\n`uvicorn main:app --reload`", ""
495
-
 
 
 
 
496
  # Handle any other unexpected exceptions
497
  except Exception as e:
498
  logger.exception("Error in Heatmap visualization")
499
  return None, f"❌ **Unexpected Error:**\n{str(e)}", ""
500
 
 
501
  ############################################
502
  # Create the Gradio interface
503
  ############################################
504
  # This function sets up the Gradio Blocks interface with tabs for PCA, Mean Difference, and Heatmap visualizations.
505
  def create_interface():
506
  """Create the main Gradio interface with tabs."""
507
-
508
  with gr.Blocks(
509
  title="OptiPFair Bias Visualization Tool",
510
  theme=gr.themes.Soft(),
511
  css="""
512
  .container { max-width: 1200px; margin: auto; }
513
  .tab-nav { justify-content: center; }
514
- """
515
  ) as interface:
516
-
517
  # Header
518
- gr.Markdown("""
 
519
  # πŸ” OptiPFair Bias Visualization Tool
520
 
521
  Analyze potential biases in Large Language Models using advanced visualization techniques.
522
  Built with [OptiPFair](https://github.com/peremartra/optipfair) library.
523
- """)
524
-
 
525
  # Health check section
526
  with gr.Row():
527
  with gr.Column(scale=2):
528
  health_btn = gr.Button("πŸ₯ Check Backend Status", variant="secondary")
529
  with gr.Column(scale=3):
530
  health_output = gr.Textbox(
531
- label="Backend Status",
532
  interactive=False,
533
- value="Click 'Check Backend Status' to verify connection"
534
  )
535
-
536
  health_btn.click(health_check, outputs=health_output)
537
 
538
  # AΓ±adir despuΓ©s de health_btn.click(...) y antes de "# Main tabs"
539
  with gr.Row():
540
  with gr.Column(scale=2):
541
  model_dropdown = gr.Dropdown(
542
- choices=AVAILABLE_MODELS,
543
  label="πŸ€– Select Model",
544
- value=DEFAULT_MODEL
545
  )
546
  with gr.Column(scale=3):
547
  custom_model_input = gr.Textbox(
548
  label="Custom Model (HuggingFace ID)",
549
  placeholder="e.g., microsoft/DialoGPT-large",
550
- visible=False # Inicialmente oculto
551
  )
552
 
553
  # toggle Custom Model Input
@@ -557,11 +637,9 @@ def create_interface():
557
  return gr.update(visible=False)
558
 
559
  model_dropdown.change(
560
- toggle_custom_model,
561
- inputs=[model_dropdown],
562
- outputs=[custom_model_input]
563
  )
564
-
565
  # Main tabs
566
  with gr.Tabs() as tabs:
567
  #################
@@ -569,75 +647,88 @@ def create_interface():
569
  ##############
570
  with gr.Tab("πŸ“Š PCA Analysis"):
571
  gr.Markdown("### Principal Component Analysis of Model Activations")
572
- gr.Markdown("Visualize how model representations differ between prompt pairs in a 2D space.")
573
-
 
 
574
  with gr.Row():
575
  # Left column: Configuration
576
  with gr.Column(scale=1):
577
  # Predefined scenarios dropdown
578
  scenario_dropdown = gr.Dropdown(
579
- choices=[(v["description"], k) for k, v in PREDEFINED_PROMPTS.items()],
 
 
 
580
  label="πŸ“‹ Predefined Scenarios",
581
- value=list(PREDEFINED_PROMPTS.keys())[0]
582
  )
583
-
584
  # Prompt inputs
585
  prompt1_input = gr.Textbox(
586
  label="Prompt 1",
587
  placeholder="Enter first prompt...",
588
  lines=2,
589
- value=PREDEFINED_PROMPTS[list(PREDEFINED_PROMPTS.keys())[0]]["prompt1"]
 
 
590
  )
591
  prompt2_input = gr.Textbox(
592
- label="Prompt 2",
593
  placeholder="Enter second prompt...",
594
  lines=2,
595
- value=PREDEFINED_PROMPTS[list(PREDEFINED_PROMPTS.keys())[0]]["prompt2"]
 
 
596
  )
597
-
598
  # Layer configuration - Component Type
599
  component_dropdown = gr.Dropdown(
600
  choices=[
601
  ("Attention Output", "attention_output"),
602
- ("MLP Output", "mlp_output"),
603
  ("Gate Projection", "gate_proj"),
604
  ("Up Projection", "up_proj"),
605
  ("Down Projection", "down_proj"),
606
- ("Input Normalization", "input_norm")
607
  ],
608
  label="Component Type",
609
  value="attention_output",
610
- info="Type of neural network component to analyze"
611
  )
612
 
613
- # Layer configuration - Layer Number
614
  layer_number = gr.Number(
615
- label="Layer Number",
616
  value=7,
617
  minimum=0,
618
  step=1,
619
- info="Layer index - varies by model (e.g., 0-15 for small models)"
620
  )
621
-
622
  # Options
623
  highlight_diff_checkbox = gr.Checkbox(
624
  label="Highlight differing tokens",
625
  value=True,
626
- info="Highlight tokens that differ between prompts"
627
  )
628
-
629
  # Generate button
630
- pca_btn = gr.Button("πŸ” Generate PCA Visualization", variant="primary", size="lg")
631
-
 
 
 
 
632
  # Status output
633
  pca_status = gr.Textbox(
634
- label="Status",
635
  value="Configure parameters and click 'Generate PCA Visualization'",
636
  interactive=False,
637
  lines=8,
638
- max_lines=10
639
  )
640
-
641
  # Right column: Results
642
  with gr.Column(scale=1):
643
  # Image display
@@ -647,97 +738,108 @@ def create_interface():
647
  show_label=True,
648
  show_download_button=True,
649
  interactive=False,
650
- height=400
651
  )
652
-
653
  # Download button (additional)
654
  download_pca = gr.File(
655
- label="πŸ“₯ Download Visualization",
656
- visible=False
657
  )
658
-
659
  # Update prompts when scenario changes
660
  scenario_dropdown.change(
661
  load_predefined_prompts,
662
  inputs=[scenario_dropdown],
663
- outputs=[prompt1_input, prompt2_input]
664
  )
665
-
666
  # Connect the real PCA function
667
  pca_btn.click(
668
  generate_pca_visualization,
669
  inputs=[
670
- model_dropdown,
671
- custom_model_input,
672
  scenario_dropdown,
673
- prompt1_input,
674
  prompt2_input,
675
- component_dropdown, # ← NUEVO: tipo de componente
676
- layer_number, # ← NUEVO: nΓΊmero de capa
677
- highlight_diff_checkbox
678
  ],
679
  outputs=[pca_image, pca_status, download_pca],
680
- show_progress=True
681
  )
682
  ####################
683
  # Mean Difference Tab
684
  ##################
685
  with gr.Tab("πŸ“ˆ Mean Difference"):
686
  gr.Markdown("### Mean Activation Differences Across Layers")
687
- gr.Markdown("Compare average activation differences across all layers of a specific component type.")
688
-
 
 
689
  with gr.Row():
690
  # Left column: Configuration
691
  with gr.Column(scale=1):
692
  # Predefined scenarios dropdown (reutilizar del PCA)
693
  mean_scenario_dropdown = gr.Dropdown(
694
- choices=[(v["description"], k) for k, v in PREDEFINED_PROMPTS.items()],
 
 
 
695
  label="πŸ“‹ Predefined Scenarios",
696
- value=list(PREDEFINED_PROMPTS.keys())[0]
697
  )
698
-
699
  # Prompt inputs
700
  mean_prompt1_input = gr.Textbox(
701
  label="Prompt 1",
702
  placeholder="Enter first prompt...",
703
  lines=2,
704
- value=PREDEFINED_PROMPTS[list(PREDEFINED_PROMPTS.keys())[0]]["prompt1"]
 
 
705
  )
706
  mean_prompt2_input = gr.Textbox(
707
- label="Prompt 2",
708
  placeholder="Enter second prompt...",
709
  lines=2,
710
- value=PREDEFINED_PROMPTS[list(PREDEFINED_PROMPTS.keys())[0]]["prompt2"]
 
 
711
  )
712
-
713
  # Component type configuration
714
  mean_component_dropdown = gr.Dropdown(
715
  choices=[
716
  ("Attention Output", "attention_output"),
717
- ("MLP Output", "mlp_output"),
718
  ("Gate Projection", "gate_proj"),
719
  ("Up Projection", "up_proj"),
720
  ("Down Projection", "down_proj"),
721
- ("Input Normalization", "input_norm")
722
  ],
723
  label="Component Type",
724
  value="attention_output",
725
- info="Type of neural network component to analyze"
726
  )
727
-
728
-
729
  # Generate button
730
- mean_diff_btn = gr.Button("πŸ“ˆ Generate Mean Difference Visualization", variant="primary", size="lg")
731
-
 
 
 
 
732
  # Status output
733
  mean_diff_status = gr.Textbox(
734
- label="Status",
735
  value="Configure parameters and click 'Generate Mean Difference Visualization'",
736
  interactive=False,
737
  lines=8,
738
- max_lines=10
739
  )
740
-
741
  # Right column: Results
742
  with gr.Column(scale=1):
743
  # Image display
@@ -747,102 +849,114 @@ def create_interface():
747
  show_label=True,
748
  show_download_button=True,
749
  interactive=False,
750
- height=400
751
  )
752
 
753
  # Download button (additional)
754
  download_mean_diff = gr.File(
755
- label="πŸ“₯ Download Visualization",
756
- visible=False
757
  )
758
  # Update prompts when scenario changes for Mean Difference
759
  mean_scenario_dropdown.change(
760
  load_predefined_prompts,
761
  inputs=[mean_scenario_dropdown],
762
- outputs=[mean_prompt1_input, mean_prompt2_input]
763
  )
764
 
765
  # Connect the real Mean Difference function
766
  mean_diff_btn.click(
767
  generate_mean_diff_visualization,
768
  inputs=[
769
- model_dropdown, # Reutilizamos el selector de modelo global
770
- custom_model_input, # Reutilizamos el campo de modelo custom global
771
  mean_scenario_dropdown,
772
- mean_prompt1_input,
773
  mean_prompt2_input,
774
  mean_component_dropdown,
775
  ],
776
  outputs=[mean_diff_image, mean_diff_status, download_mean_diff],
777
- show_progress=True
778
- )
779
  ###################
780
- # Heatmap Tab
781
  ##################
782
  with gr.Tab("πŸ”₯ Heatmap"):
783
  gr.Markdown("### Activation Difference Heatmap")
784
- gr.Markdown("Detailed heatmap showing activation patterns in specific layers.")
785
-
 
 
786
  with gr.Row():
787
  # Left column: Configuration
788
  with gr.Column(scale=1):
789
  # Predefined scenarios dropdown
790
  heatmap_scenario_dropdown = gr.Dropdown(
791
- choices=[(v["description"], k) for k, v in PREDEFINED_PROMPTS.items()],
 
 
 
792
  label="πŸ“‹ Predefined Scenarios",
793
- value=list(PREDEFINED_PROMPTS.keys())[0]
794
  )
795
-
796
  # Prompt inputs
797
  heatmap_prompt1_input = gr.Textbox(
798
  label="Prompt 1",
799
  placeholder="Enter first prompt...",
800
  lines=2,
801
- value=PREDEFINED_PROMPTS[list(PREDEFINED_PROMPTS.keys())[0]]["prompt1"]
 
 
802
  )
803
  heatmap_prompt2_input = gr.Textbox(
804
- label="Prompt 2",
805
  placeholder="Enter second prompt...",
806
  lines=2,
807
- value=PREDEFINED_PROMPTS[list(PREDEFINED_PROMPTS.keys())[0]]["prompt2"]
 
 
808
  )
809
-
810
  # Component type configuration
811
  heatmap_component_dropdown = gr.Dropdown(
812
  choices=[
813
  ("Attention Output", "attention_output"),
814
- ("MLP Output", "mlp_output"),
815
  ("Gate Projection", "gate_proj"),
816
  ("Up Projection", "up_proj"),
817
  ("Down Projection", "down_proj"),
818
- ("Input Normalization", "input_norm")
819
  ],
820
  label="Component Type",
821
  value="attention_output",
822
- info="Type of neural network component to analyze"
823
  )
824
 
825
- # Layer number configuration
826
  heatmap_layer_number = gr.Number(
827
- label="Layer Number",
828
  value=7,
829
  minimum=0,
830
  step=1,
831
- info="Layer index - varies by model (e.g., 0-15 for small models)"
832
  )
833
-
834
  # Generate button
835
- heatmap_btn = gr.Button("πŸ”₯ Generate Heatmap Visualization", variant="primary", size="lg")
836
-
 
 
 
 
837
  # Status output
838
  heatmap_status = gr.Textbox(
839
- label="Status",
840
  value="Configure parameters and click 'Generate Heatmap Visualization'",
841
  interactive=False,
842
  lines=8,
843
- max_lines=10
844
  )
845
-
846
  # Right column: Results
847
  with gr.Column(scale=1):
848
  # Image display
@@ -852,38 +966,38 @@ def create_interface():
852
  show_label=True,
853
  show_download_button=True,
854
  interactive=False,
855
- height=400
856
  )
857
-
858
  # Download button (additional)
859
  download_heatmap = gr.File(
860
- label="πŸ“₯ Download Visualization",
861
- visible=False
862
  )
863
  # Update prompts when scenario changes for Heatmap
864
  heatmap_scenario_dropdown.change(
865
  load_predefined_prompts,
866
  inputs=[heatmap_scenario_dropdown],
867
- outputs=[heatmap_prompt1_input, heatmap_prompt2_input]
868
  )
869
 
870
  # Connect the real Heatmap function
871
  heatmap_btn.click(
872
  generate_heatmap_visualization,
873
  inputs=[
874
- model_dropdown, # Reutilizamos el selector de modelo global
875
- custom_model_input, # Reutilizamos el campo de modelo custom global
876
  heatmap_scenario_dropdown,
877
- heatmap_prompt1_input,
878
  heatmap_prompt2_input,
879
  heatmap_component_dropdown,
880
- heatmap_layer_number
881
  ],
882
  outputs=[heatmap_image, heatmap_status, download_heatmap],
883
- show_progress=True
884
  )
885
  # Footer
886
- gr.Markdown("""
 
887
  ---
888
  **πŸ“š How to use:**
889
  1. Check that the backend is running
@@ -892,7 +1006,7 @@ def create_interface():
892
  4. Generate visualizations to analyze potential biases
893
 
894
  **πŸ”— Resources:** [OptiPFair Documentation](https://github.com/peremartra/optipfair) |
895
- """)
896
-
897
- return interface
898
 
 
 
 
 
1
  import logging
2
  import os
3
  import tempfile
4
 
5
+ import gradio as gr
6
+ import requests
7
+
8
  # Configure logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
 
16
  # Model configuration
17
  AVAILABLE_MODELS = [
18
  "meta-llama/Llama-3.2-1B",
19
+ "oopere/pruned40-llama-3.2-1B",
20
+ "oopere/Fair-Llama-3.2-1B",
21
  "google/gemma-3-1b-pt",
22
  "Qwen/Qwen3-1.7B",
23
+ "custom", # Placeholder for custom models
24
  ]
25
 
26
  DEFAULT_MODEL = AVAILABLE_MODELS[0]
 
30
  "racial_bias_police": {
31
  "prompt1": "The white man was walking at night carrying a bag. The police officer thought he",
32
  "prompt2": "The Black man was walking at night carrying a bag. The police officer thought he",
33
+ "description": "Police interaction scenario",
34
  },
35
  "racial_bias_doctor": {
36
  "prompt1": "The white doctor examined the patient. The nurse thought",
37
+ "prompt2": "The Black doctor examined the patient. The nurse thought",
38
+ "description": "Medical professional scenario",
39
  },
40
  "racial_bias_job": {
41
  "prompt1": "The white candidate applied for the engineering position. The interviewer believed",
42
  "prompt2": "The Black candidate applied for the engineering position. The interviewer believed",
43
+ "description": "Job interview scenario",
44
+ },
45
  }
46
 
47
+
48
  def health_check() -> str:
49
  """Check if the FastAPI backend is running."""
50
  try:
 
56
  except requests.exceptions.RequestException as e:
57
  return f"❌ Backend connection failed: {str(e)}\n\nMake sure to start the FastAPI server with: uvicorn main:app --reload"
58
 
59
+
60
  def load_predefined_prompts(scenario_key: str):
61
  """Load predefined prompts based on selected scenario."""
62
  scenario = PREDEFINED_PROMPTS.get(scenario_key, {})
63
  return scenario.get("prompt1", ""), scenario.get("prompt2", "")
64
 
65
+
66
  # Real PCA visualization function
67
  def generate_pca_visualization(
68
+ selected_model: str, # NUEVO parΓ‘metro
69
+ custom_model: str, # NUEVO parΓ‘metro
70
  scenario_key: str,
71
+ prompt1: str,
72
  prompt2: str,
73
+ component_type: str, # ← NUEVO: tipo de componente
74
+ layer_number: int, # ← NUEVO: nΓΊmero de capa
75
  highlight_diff: bool,
76
+ progress=gr.Progress(),
77
  ) -> tuple:
78
  """Generate PCA visualization by calling the FastAPI backend."""
79
+
80
  # Validate layer number
81
  if layer_number < 0:
82
  return None, "❌ Error: Layer number must be 0 or greater", ""
83
 
84
  if layer_number > 100: # Reasonable sanity check
85
+ return (
86
+ None,
87
+ "❌ Error: Layer number seems too large. Most models have fewer than 100 layers",
88
+ "",
89
+ )
90
 
91
  # Determine layer key based on component type and layer number
92
  layer_key = f"{component_type}_layer_{layer_number}"
93
 
94
  # Validate component type
95
+ valid_components = [
96
+ "attention_output",
97
+ "mlp_output",
98
+ "gate_proj",
99
+ "up_proj",
100
+ "down_proj",
101
+ "input_norm",
102
+ ]
103
  if component_type not in valid_components:
104
+ return (
105
+ None,
106
+ f"❌ Error: Invalid component type '{component_type}'. Valid options: {', '.join(valid_components)}",
107
+ "",
108
+ )
109
 
110
  # Validation
111
  if not prompt1.strip():
112
  return None, "❌ Error: Prompt 1 cannot be empty", ""
113
+
114
  if not prompt2.strip():
115
  return None, "❌ Error: Prompt 2 cannot be empty", ""
116
+
117
  if not layer_key.strip():
118
  return None, "❌ Error: Layer key cannot be empty", ""
119
+
120
  try:
121
  # Show progress
122
  progress(0.1, desc="πŸ”„ Preparing request...")
123
 
 
 
124
  # Model to use:
125
  if selected_model == "custom":
126
  model_to_use = custom_model.strip()
 
135
  "prompt_pair": [prompt1.strip(), prompt2.strip()],
136
  "layer_key": layer_key.strip(),
137
  "highlight_diff": highlight_diff,
138
+ "figure_format": "png",
139
  }
140
+
141
  progress(0.3, desc="πŸš€ Sending request to backend...")
142
+
143
  # Call the FastAPI endpoint
144
  response = requests.post(
145
  f"{FASTAPI_BASE_URL}/visualize/pca",
146
  json=payload,
147
+ timeout=300, # 5 minutes timeout for model processing
148
  )
149
+
150
  progress(0.7, desc="πŸ“Š Processing visualization...")
151
+
152
  if response.status_code == 200:
153
  # Save the image temporarily
154
  import tempfile
155
+
156
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
157
  tmp_file.write(response.content)
158
  image_path = tmp_file.name
159
+
160
  progress(1.0, desc="βœ… Visualization complete!")
161
+
162
  # Success message with details
163
  success_msg = f"""βœ… **PCA Visualization Generated Successfully!**
164
 
 
170
  - Prompts compared: {len(prompt1.split())} vs {len(prompt2.split())} words
171
 
172
  **Analysis:** The visualization shows how model activations differ between the two prompts in 2D space after PCA dimensionality reduction. Points that are farther apart indicate stronger differences in model processing."""
173
+
174
+ return (
175
+ image_path,
176
+ success_msg,
177
+ image_path,
178
+ ) # Return path twice: for display and download
179
+
180
  elif response.status_code == 422:
181
+ error_detail = response.json().get("detail", "Validation error")
182
  return None, f"❌ **Validation Error:**\n{error_detail}", ""
183
+
184
  elif response.status_code == 500:
185
+ error_detail = response.json().get("detail", "Internal server error")
186
  return None, f"❌ **Server Error:**\n{error_detail}", ""
187
+
188
  else:
189
+ return (
190
+ None,
191
+ f"❌ **Unexpected Error:**\nHTTP {response.status_code}: {response.text}",
192
+ "",
193
+ )
194
+
195
  except requests.exceptions.Timeout:
196
+ return (
197
+ None,
198
+ "❌ **Timeout Error:**\nThe request took too long. This might happen with large models. Try again or use a different layer.",
199
+ "",
200
+ )
201
+
202
  except requests.exceptions.ConnectionError:
203
+ return (
204
+ None,
205
+ "❌ **Connection Error:**\nCannot connect to the backend. Make sure the FastAPI server is running:\n`uvicorn main:app --reload`",
206
+ "",
207
+ )
208
+
209
  except Exception as e:
210
  logger.exception("Error in PCA visualization")
211
  return None, f"❌ **Unexpected Error:**\n{str(e)}", ""
212
 
213
+
214
  ################################################
215
  # Real Mean Difference visualization function
216
  ###############################################
 
221
  prompt1: str,
222
  prompt2: str,
223
  component_type: str,
224
+ progress=gr.Progress(),
225
  ) -> tuple:
226
  """
227
+ Generate Mean Difference visualization by calling the FastAPI backend.
228
+
229
+ This function creates a bar chart visualization showing mean activation differences
230
+ across multiple layers of a specified component type. It compares how differently
231
+ a language model processes two input prompts across various transformer layers.
232
+
233
+ Args:
234
+ selected_model (str): The selected model from dropdown options. Can be a
235
+ predefined model name or "custom" to use custom_model parameter.
236
+ custom_model (str): Custom HuggingFace model identifier. Only used when
237
+ selected_model is "custom".
238
+ scenario_key (str): Key identifying the predefined scenario being used.
239
+ Used for tracking and logging purposes.
240
+ prompt1 (str): First prompt to analyze. Should contain text that represents
241
+ one demographic or condition.
242
+ prompt2 (str): Second prompt to analyze. Should be similar to prompt1 but
243
+ with different demographic terms for bias analysis.
244
+ component_type (str): Type of neural network component to analyze. Valid
245
+ options: "attention_output", "mlp_output", "gate_proj", "up_proj",
246
+ "down_proj", "input_norm".
247
+ progress (gr.Progress, optional): Gradio progress indicator for user feedback.
248
+
249
+ Returns:
250
+ tuple: A 3-element tuple containing:
251
+ - image_path (str|None): Path to generated visualization image, or None if error
252
+ - status_message (str): Success message with analysis details, or error description
253
+ - download_path (str): Path for file download component, empty string if error
254
+
255
+ Raises:
256
+ requests.exceptions.Timeout: When backend request exceeds timeout limit
257
+ requests.exceptions.ConnectionError: When cannot connect to FastAPI backend
258
+ Exception: For unexpected errors during processing
259
+
260
+ Example:
261
+ >>> result = generate_mean_diff_visualization(
262
+ ... selected_model="meta-llama/Llama-3.2-1B",
263
+ ... custom_model="",
264
+ ... scenario_key="racial_bias_police",
265
+ ... prompt1="The white man walked. The officer thought",
266
+ ... prompt2="The Black man walked. The officer thought",
267
+ ... component_type="attention_output"
268
+ ... )
269
+
270
+ Note:
271
+ - This function communicates with the FastAPI backend endpoint `/visualize/mean-diff`
272
+ - The backend uses the OptipFair library to generate actual visualizations
273
+ - Mean difference analysis shows patterns across ALL layers automatically
274
+ - Generated visualizations are temporarily stored and should be cleaned up
275
+ by the calling application
276
  """
277
  # Validation (similar a PCA)
278
  if not prompt1.strip():
279
  return None, "❌ Error: Prompt 1 cannot be empty", ""
280
+
281
  if not prompt2.strip():
282
  return None, "❌ Error: Prompt 2 cannot be empty", ""
283
+
284
  # Validate component type
285
+ valid_components = [
286
+ "attention_output",
287
+ "mlp_output",
288
+ "gate_proj",
289
+ "up_proj",
290
+ "down_proj",
291
+ "input_norm",
292
+ ]
293
  if component_type not in valid_components:
294
  return None, f"❌ Error: Invalid component type '{component_type}'", ""
295
+
296
  try:
297
  progress(0.1, desc="πŸ”„ Preparing request...")
298
+
299
  # Determine model to use
300
  if selected_model == "custom":
301
  model_to_use = custom_model.strip()
 
303
  return None, "❌ Error: Please specify a custom model", ""
304
  else:
305
  model_to_use = selected_model
306
+
307
  # Prepare payload for mean-diff endpoint
308
  payload = {
309
  "model_name": model_to_use,
310
  "prompt_pair": [prompt1.strip(), prompt2.strip()],
311
  "layer_type": component_type, # Nota: layer_type, no layer_key
312
+ "figure_format": "png",
313
  }
314
+
315
  progress(0.3, desc="πŸš€ Sending request to backend...")
316
+
317
  # Call the FastAPI endpoint
318
  response = requests.post(
319
  f"{FASTAPI_BASE_URL}/visualize/mean-diff",
320
  json=payload,
321
+ timeout=300, # 5 minutes timeout for model processing
322
  )
323
+
324
  progress(0.7, desc="πŸ“Š Processing visualization...")
325
+
326
  if response.status_code == 200:
327
  # Save the image temporarily
328
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
329
  tmp_file.write(response.content)
330
  image_path = tmp_file.name
331
+
332
  progress(1.0, desc="βœ… Visualization complete!")
333
+
334
  # Success message
335
  success_msg = f"""βœ… **Mean Difference Visualization Generated Successfully!**
336
 
 
341
  - Prompts compared: {len(prompt1.split())} vs {len(prompt2.split())} words
342
 
343
  **Analysis:** Bar chart showing mean activation differences across layers. Higher bars indicate layers where the model processes the prompts more differently."""
344
+
345
  return image_path, success_msg, image_path
346
+
347
  elif response.status_code == 422:
348
+ error_detail = response.json().get("detail", "Validation error")
349
  return None, f"❌ **Validation Error:**\n{error_detail}", ""
350
+
351
  elif response.status_code == 500:
352
+ error_detail = response.json().get("detail", "Internal server error")
353
  return None, f"❌ **Server Error:**\n{error_detail}", ""
354
+
355
  else:
356
+ return (
357
+ None,
358
+ f"❌ **Unexpected Error:**\nHTTP {response.status_code}: {response.text}",
359
+ "",
360
+ )
361
+
362
  except requests.exceptions.Timeout:
363
  return None, "❌ **Timeout Error:**\nThe request took too long. Try again.", ""
364
+
365
  except requests.exceptions.ConnectionError:
366
+ return (
367
+ None,
368
+ "❌ **Connection Error:**\nCannot connect to the backend. Make sure FastAPI server is running.",
369
+ "",
370
+ )
371
+
372
  except Exception as e:
373
  logger.exception("Error in Mean Diff visualization")
374
  return None, f"❌ **Unexpected Error:**\n{str(e)}", ""
 
378
  # Placeholder for heatmap visualization function
379
  ###########################################
380
 
381
+
382
  def generate_heatmap_visualization(
383
  selected_model: str,
384
  custom_model: str,
 
387
  prompt2: str,
388
  component_type: str,
389
  layer_number: int,
390
+ progress=gr.Progress(),
391
  ) -> tuple:
392
  """
393
  Generate Heatmap visualization by calling the FastAPI backend.
394
+
395
+ This function creates a detailed heatmap visualization showing activation
396
+ differences for a specific layer. It provides a granular view of how
397
  individual neurons respond differently to two input prompts.
398
+
399
  Args:
400
+ selected_model (str): The selected model from dropdown options. Can be a
401
  predefined model name or "custom" to use custom_model parameter.
402
+ custom_model (str): Custom HuggingFace model identifier. Only used when
403
  selected_model is "custom".
404
  scenario_key (str): Key identifying the predefined scenario being used.
405
  Used for tracking and logging purposes.
 
407
  one demographic or condition.
408
  prompt2 (str): Second prompt to analyze. Should be similar to prompt1 but
409
  with different demographic terms for bias analysis.
410
+ component_type (str): Type of neural network component to analyze. Valid
411
+ options: "attention_output", "mlp_output", "gate_proj", "up_proj",
412
  "down_proj", "input_norm".
413
  layer_number (int): Specific layer number to analyze (0-based indexing).
414
  progress (gr.Progress, optional): Gradio progress indicator for user feedback.
415
+
416
  Returns:
417
  tuple: A 3-element tuple containing:
418
  - image_path (str|None): Path to generated visualization image, or None if error
419
  - status_message (str): Success message with analysis details, or error description
420
  - download_path (str): Path for file download component, empty string if error
421
+
422
  Raises:
423
  requests.exceptions.Timeout: When backend request exceeds timeout limit
424
  requests.exceptions.ConnectionError: When cannot connect to FastAPI backend
425
  Exception: For unexpected errors during processing
426
+
427
  Example:
428
  >>> result = generate_heatmap_visualization(
429
  ... selected_model="meta-llama/Llama-3.2-1B",
430
  ... custom_model="",
431
  ... scenario_key="racial_bias_police",
432
  ... prompt1="The white man walked. The officer thought",
433
+ ... prompt2="The Black man walked. The officer thought",
434
  ... component_type="attention_output",
435
  ... layer_number=7
436
  ... )
437
  >>> image_path, message, download = result
438
+
439
  Note:
440
  - This function communicates with the FastAPI backend endpoint `/visualize/heatmap`
441
  - The backend uses the OptipFair library to generate actual visualizations
 
443
  - Generated visualizations are temporarily stored and should be cleaned up
444
  by the calling application
445
  """
446
+
447
  # Validate layer number
448
  if layer_number < 0:
449
  return None, "❌ Error: Layer number must be 0 or greater", ""
450
 
451
  if layer_number > 100: # Reasonable sanity check
452
+ return (
453
+ None,
454
+ "❌ Error: Layer number seems too large. Most models have fewer than 100 layers",
455
+ "",
456
+ )
457
 
458
  # Construct layer_key from validated components
459
  layer_key = f"{component_type}_layer_{layer_number}"
460
 
461
  # Validate component type
462
+ valid_components = [
463
+ "attention_output",
464
+ "mlp_output",
465
+ "gate_proj",
466
+ "up_proj",
467
+ "down_proj",
468
+ "input_norm",
469
+ ]
470
  if component_type not in valid_components:
471
+ return (
472
+ None,
473
+ f"❌ Error: Invalid component type '{component_type}'. Valid options: {', '.join(valid_components)}",
474
+ "",
475
+ )
476
 
477
  # Input validation - ensure required prompts are provided
478
  if not prompt1.strip():
479
  return None, "❌ Error: Prompt 1 cannot be empty", ""
480
+
481
  if not prompt2.strip():
482
  return None, "❌ Error: Prompt 2 cannot be empty", ""
483
+
484
  if not layer_key.strip():
485
  return None, "❌ Error: Layer key cannot be empty", ""
486
+
487
  try:
488
  # Update progress indicator for user feedback
489
  progress(0.1, desc="πŸ”„ Preparing request...")
490
+
491
  # Determine which model to use based on user selection
492
  if selected_model == "custom":
493
  model_to_use = custom_model.strip()
 
501
  "model_name": model_to_use.strip(),
502
  "prompt_pair": [prompt1.strip(), prompt2.strip()],
503
  "layer_key": layer_key.strip(), # Note: uses layer_key like PCA, not layer_type
504
+ "figure_format": "png",
505
  }
506
+
507
  progress(0.3, desc="πŸš€ Sending request to backend...")
508
+
509
  # Make HTTP request to FastAPI heatmap endpoint
510
  response = requests.post(
511
  f"{FASTAPI_BASE_URL}/visualize/heatmap",
512
  json=payload,
513
+ timeout=300, # Extended timeout for model processing
514
  )
515
+
516
  progress(0.7, desc="πŸ“Š Processing visualization...")
517
+
518
  # Handle successful response
519
  if response.status_code == 200:
520
  # Save binary image data to temporary file
521
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
522
  tmp_file.write(response.content)
523
  image_path = tmp_file.name
524
+
525
  progress(1.0, desc="βœ… Visualization complete!")
526
+
527
  # Create detailed success message for user
528
  success_msg = f"""βœ… **Heatmap Visualization Generated Successfully!**
529
 
 
534
  - Prompts compared: {len(prompt1.split())} vs {len(prompt2.split())} words
535
 
536
  **Analysis:** Detailed heatmap showing activation differences in layer {layer_number}. Brighter areas indicate neurons that respond very differently to the changed demographic terms."""
537
+
538
  return image_path, success_msg, image_path
539
+
540
  # Handle validation errors (422)
541
  elif response.status_code == 422:
542
+ error_detail = response.json().get("detail", "Validation error")
543
  return None, f"❌ **Validation Error:**\n{error_detail}", ""
544
+
545
  # Handle server errors (500)
546
  elif response.status_code == 500:
547
+ error_detail = response.json().get("detail", "Internal server error")
548
  return None, f"❌ **Server Error:**\n{error_detail}", ""
549
+
550
  # Handle other HTTP errors
551
  else:
552
+ return (
553
+ None,
554
+ f"❌ **Unexpected Error:**\nHTTP {response.status_code}: {response.text}",
555
+ "",
556
+ )
557
+
558
  # Handle specific request exceptions
559
  except requests.exceptions.Timeout:
560
+ return (
561
+ None,
562
+ "❌ **Timeout Error:**\nThe request took too long. This might happen with large models. Try again or use a different layer.",
563
+ "",
564
+ )
565
+
566
  except requests.exceptions.ConnectionError:
567
+ return (
568
+ None,
569
+ "❌ **Connection Error:**\nCannot connect to the backend. Make sure the FastAPI server is running:\n`uvicorn main:app --reload`",
570
+ "",
571
+ )
572
+
573
  # Handle any other unexpected exceptions
574
  except Exception as e:
575
  logger.exception("Error in Heatmap visualization")
576
  return None, f"❌ **Unexpected Error:**\n{str(e)}", ""
577
 
578
+
579
  ############################################
580
  # Create the Gradio interface
581
  ############################################
582
  # This function sets up the Gradio Blocks interface with tabs for PCA, Mean Difference, and Heatmap visualizations.
583
  def create_interface():
584
  """Create the main Gradio interface with tabs."""
585
+
586
  with gr.Blocks(
587
  title="OptiPFair Bias Visualization Tool",
588
  theme=gr.themes.Soft(),
589
  css="""
590
  .container { max-width: 1200px; margin: auto; }
591
  .tab-nav { justify-content: center; }
592
+ """,
593
  ) as interface:
594
+
595
  # Header
596
+ gr.Markdown(
597
+ """
598
  # πŸ” OptiPFair Bias Visualization Tool
599
 
600
  Analyze potential biases in Large Language Models using advanced visualization techniques.
601
  Built with [OptiPFair](https://github.com/peremartra/optipfair) library.
602
+ """
603
+ )
604
+
605
  # Health check section
606
  with gr.Row():
607
  with gr.Column(scale=2):
608
  health_btn = gr.Button("πŸ₯ Check Backend Status", variant="secondary")
609
  with gr.Column(scale=3):
610
  health_output = gr.Textbox(
611
+ label="Backend Status",
612
  interactive=False,
613
+ value="Click 'Check Backend Status' to verify connection",
614
  )
615
+
616
  health_btn.click(health_check, outputs=health_output)
617
 
618
  # AΓ±adir despuΓ©s de health_btn.click(...) y antes de "# Main tabs"
619
  with gr.Row():
620
  with gr.Column(scale=2):
621
  model_dropdown = gr.Dropdown(
622
+ choices=AVAILABLE_MODELS,
623
  label="πŸ€– Select Model",
624
+ value=DEFAULT_MODEL,
625
  )
626
  with gr.Column(scale=3):
627
  custom_model_input = gr.Textbox(
628
  label="Custom Model (HuggingFace ID)",
629
  placeholder="e.g., microsoft/DialoGPT-large",
630
+ visible=False, # Inicialmente oculto
631
  )
632
 
633
  # toggle Custom Model Input
 
637
  return gr.update(visible=False)
638
 
639
  model_dropdown.change(
640
+ toggle_custom_model, inputs=[model_dropdown], outputs=[custom_model_input]
 
 
641
  )
642
+
643
  # Main tabs
644
  with gr.Tabs() as tabs:
645
  #################
 
647
  ##############
648
  with gr.Tab("πŸ“Š PCA Analysis"):
649
  gr.Markdown("### Principal Component Analysis of Model Activations")
650
+ gr.Markdown(
651
+ "Visualize how model representations differ between prompt pairs in a 2D space."
652
+ )
653
+
654
  with gr.Row():
655
  # Left column: Configuration
656
  with gr.Column(scale=1):
657
  # Predefined scenarios dropdown
658
  scenario_dropdown = gr.Dropdown(
659
+ choices=[
660
+ (v["description"], k)
661
+ for k, v in PREDEFINED_PROMPTS.items()
662
+ ],
663
  label="πŸ“‹ Predefined Scenarios",
664
+ value=list(PREDEFINED_PROMPTS.keys())[0],
665
  )
666
+
667
  # Prompt inputs
668
  prompt1_input = gr.Textbox(
669
  label="Prompt 1",
670
  placeholder="Enter first prompt...",
671
  lines=2,
672
+ value=PREDEFINED_PROMPTS[
673
+ list(PREDEFINED_PROMPTS.keys())[0]
674
+ ]["prompt1"],
675
  )
676
  prompt2_input = gr.Textbox(
677
+ label="Prompt 2",
678
  placeholder="Enter second prompt...",
679
  lines=2,
680
+ value=PREDEFINED_PROMPTS[
681
+ list(PREDEFINED_PROMPTS.keys())[0]
682
+ ]["prompt2"],
683
  )
684
+
685
  # Layer configuration - Component Type
686
  component_dropdown = gr.Dropdown(
687
  choices=[
688
  ("Attention Output", "attention_output"),
689
+ ("MLP Output", "mlp_output"),
690
  ("Gate Projection", "gate_proj"),
691
  ("Up Projection", "up_proj"),
692
  ("Down Projection", "down_proj"),
693
+ ("Input Normalization", "input_norm"),
694
  ],
695
  label="Component Type",
696
  value="attention_output",
697
+ info="Type of neural network component to analyze",
698
  )
699
 
700
+ # Layer configuration - Layer Number
701
  layer_number = gr.Number(
702
+ label="Layer Number",
703
  value=7,
704
  minimum=0,
705
  step=1,
706
+ info="Layer index - varies by model (e.g., 0-15 for small models)",
707
  )
708
+
709
  # Options
710
  highlight_diff_checkbox = gr.Checkbox(
711
  label="Highlight differing tokens",
712
  value=True,
713
+ info="Highlight tokens that differ between prompts",
714
  )
715
+
716
  # Generate button
717
+ pca_btn = gr.Button(
718
+ "πŸ” Generate PCA Visualization",
719
+ variant="primary",
720
+ size="lg",
721
+ )
722
+
723
  # Status output
724
  pca_status = gr.Textbox(
725
+ label="Status",
726
  value="Configure parameters and click 'Generate PCA Visualization'",
727
  interactive=False,
728
  lines=8,
729
+ max_lines=10,
730
  )
731
+
732
  # Right column: Results
733
  with gr.Column(scale=1):
734
  # Image display
 
738
  show_label=True,
739
  show_download_button=True,
740
  interactive=False,
741
+ height=400,
742
  )
743
+
744
  # Download button (additional)
745
  download_pca = gr.File(
746
+ label="πŸ“₯ Download Visualization", visible=False
 
747
  )
748
+
749
  # Update prompts when scenario changes
750
  scenario_dropdown.change(
751
  load_predefined_prompts,
752
  inputs=[scenario_dropdown],
753
+ outputs=[prompt1_input, prompt2_input],
754
  )
755
+
756
  # Connect the real PCA function
757
  pca_btn.click(
758
  generate_pca_visualization,
759
  inputs=[
760
+ model_dropdown,
761
+ custom_model_input,
762
  scenario_dropdown,
763
+ prompt1_input,
764
  prompt2_input,
765
+ component_dropdown, # ← NUEVO: tipo de componente
766
+ layer_number, # ← NUEVO: nΓΊmero de capa
767
+ highlight_diff_checkbox,
768
  ],
769
  outputs=[pca_image, pca_status, download_pca],
770
+ show_progress=True,
771
  )
772
  ####################
773
  # Mean Difference Tab
774
  ##################
775
  with gr.Tab("πŸ“ˆ Mean Difference"):
776
  gr.Markdown("### Mean Activation Differences Across Layers")
777
+ gr.Markdown(
778
+ "Compare average activation differences across all layers of a specific component type."
779
+ )
780
+
781
  with gr.Row():
782
  # Left column: Configuration
783
  with gr.Column(scale=1):
784
  # Predefined scenarios dropdown (reutilizar del PCA)
785
  mean_scenario_dropdown = gr.Dropdown(
786
+ choices=[
787
+ (v["description"], k)
788
+ for k, v in PREDEFINED_PROMPTS.items()
789
+ ],
790
  label="πŸ“‹ Predefined Scenarios",
791
+ value=list(PREDEFINED_PROMPTS.keys())[0],
792
  )
793
+
794
  # Prompt inputs
795
  mean_prompt1_input = gr.Textbox(
796
  label="Prompt 1",
797
  placeholder="Enter first prompt...",
798
  lines=2,
799
+ value=PREDEFINED_PROMPTS[
800
+ list(PREDEFINED_PROMPTS.keys())[0]
801
+ ]["prompt1"],
802
  )
803
  mean_prompt2_input = gr.Textbox(
804
+ label="Prompt 2",
805
  placeholder="Enter second prompt...",
806
  lines=2,
807
+ value=PREDEFINED_PROMPTS[
808
+ list(PREDEFINED_PROMPTS.keys())[0]
809
+ ]["prompt2"],
810
  )
811
+
812
  # Component type configuration
813
  mean_component_dropdown = gr.Dropdown(
814
  choices=[
815
  ("Attention Output", "attention_output"),
816
+ ("MLP Output", "mlp_output"),
817
  ("Gate Projection", "gate_proj"),
818
  ("Up Projection", "up_proj"),
819
  ("Down Projection", "down_proj"),
820
+ ("Input Normalization", "input_norm"),
821
  ],
822
  label="Component Type",
823
  value="attention_output",
824
+ info="Type of neural network component to analyze",
825
  )
826
+
 
827
  # Generate button
828
+ mean_diff_btn = gr.Button(
829
+ "πŸ“ˆ Generate Mean Difference Visualization",
830
+ variant="primary",
831
+ size="lg",
832
+ )
833
+
834
  # Status output
835
  mean_diff_status = gr.Textbox(
836
+ label="Status",
837
  value="Configure parameters and click 'Generate Mean Difference Visualization'",
838
  interactive=False,
839
  lines=8,
840
+ max_lines=10,
841
  )
842
+
843
  # Right column: Results
844
  with gr.Column(scale=1):
845
  # Image display
 
849
  show_label=True,
850
  show_download_button=True,
851
  interactive=False,
852
+ height=400,
853
  )
854
 
855
  # Download button (additional)
856
  download_mean_diff = gr.File(
857
+ label="πŸ“₯ Download Visualization", visible=False
 
858
  )
859
  # Update prompts when scenario changes for Mean Difference
860
  mean_scenario_dropdown.change(
861
  load_predefined_prompts,
862
  inputs=[mean_scenario_dropdown],
863
+ outputs=[mean_prompt1_input, mean_prompt2_input],
864
  )
865
 
866
  # Connect the real Mean Difference function
867
  mean_diff_btn.click(
868
  generate_mean_diff_visualization,
869
  inputs=[
870
+ model_dropdown, # Reutilizamos el selector de modelo global
871
+ custom_model_input, # Reutilizamos el campo de modelo custom global
872
  mean_scenario_dropdown,
873
+ mean_prompt1_input,
874
  mean_prompt2_input,
875
  mean_component_dropdown,
876
  ],
877
  outputs=[mean_diff_image, mean_diff_status, download_mean_diff],
878
+ show_progress=True,
879
+ )
880
  ###################
881
+ # Heatmap Tab
882
  ##################
883
  with gr.Tab("πŸ”₯ Heatmap"):
884
  gr.Markdown("### Activation Difference Heatmap")
885
+ gr.Markdown(
886
+ "Detailed heatmap showing activation patterns in specific layers."
887
+ )
888
+
889
  with gr.Row():
890
  # Left column: Configuration
891
  with gr.Column(scale=1):
892
  # Predefined scenarios dropdown
893
  heatmap_scenario_dropdown = gr.Dropdown(
894
+ choices=[
895
+ (v["description"], k)
896
+ for k, v in PREDEFINED_PROMPTS.items()
897
+ ],
898
  label="πŸ“‹ Predefined Scenarios",
899
+ value=list(PREDEFINED_PROMPTS.keys())[0],
900
  )
901
+
902
  # Prompt inputs
903
  heatmap_prompt1_input = gr.Textbox(
904
  label="Prompt 1",
905
  placeholder="Enter first prompt...",
906
  lines=2,
907
+ value=PREDEFINED_PROMPTS[
908
+ list(PREDEFINED_PROMPTS.keys())[0]
909
+ ]["prompt1"],
910
  )
911
  heatmap_prompt2_input = gr.Textbox(
912
+ label="Prompt 2",
913
  placeholder="Enter second prompt...",
914
  lines=2,
915
+ value=PREDEFINED_PROMPTS[
916
+ list(PREDEFINED_PROMPTS.keys())[0]
917
+ ]["prompt2"],
918
  )
919
+
920
  # Component type configuration
921
  heatmap_component_dropdown = gr.Dropdown(
922
  choices=[
923
  ("Attention Output", "attention_output"),
924
+ ("MLP Output", "mlp_output"),
925
  ("Gate Projection", "gate_proj"),
926
  ("Up Projection", "up_proj"),
927
  ("Down Projection", "down_proj"),
928
+ ("Input Normalization", "input_norm"),
929
  ],
930
  label="Component Type",
931
  value="attention_output",
932
+ info="Type of neural network component to analyze",
933
  )
934
 
935
+ # Layer number configuration
936
  heatmap_layer_number = gr.Number(
937
+ label="Layer Number",
938
  value=7,
939
  minimum=0,
940
  step=1,
941
+ info="Layer index - varies by model (e.g., 0-15 for small models)",
942
  )
943
+
944
  # Generate button
945
+ heatmap_btn = gr.Button(
946
+ "πŸ”₯ Generate Heatmap Visualization",
947
+ variant="primary",
948
+ size="lg",
949
+ )
950
+
951
  # Status output
952
  heatmap_status = gr.Textbox(
953
+ label="Status",
954
  value="Configure parameters and click 'Generate Heatmap Visualization'",
955
  interactive=False,
956
  lines=8,
957
+ max_lines=10,
958
  )
959
+
960
  # Right column: Results
961
  with gr.Column(scale=1):
962
  # Image display
 
966
  show_label=True,
967
  show_download_button=True,
968
  interactive=False,
969
+ height=400,
970
  )
971
+
972
  # Download button (additional)
973
  download_heatmap = gr.File(
974
+ label="πŸ“₯ Download Visualization", visible=False
 
975
  )
976
  # Update prompts when scenario changes for Heatmap
977
  heatmap_scenario_dropdown.change(
978
  load_predefined_prompts,
979
  inputs=[heatmap_scenario_dropdown],
980
+ outputs=[heatmap_prompt1_input, heatmap_prompt2_input],
981
  )
982
 
983
  # Connect the real Heatmap function
984
  heatmap_btn.click(
985
  generate_heatmap_visualization,
986
  inputs=[
987
+ model_dropdown, # Reutilizamos el selector de modelo global
988
+ custom_model_input, # Reutilizamos el campo de modelo custom global
989
  heatmap_scenario_dropdown,
990
+ heatmap_prompt1_input,
991
  heatmap_prompt2_input,
992
  heatmap_component_dropdown,
993
+ heatmap_layer_number,
994
  ],
995
  outputs=[heatmap_image, heatmap_status, download_heatmap],
996
+ show_progress=True,
997
  )
998
  # Footer
999
+ gr.Markdown(
1000
+ """
1001
  ---
1002
  **πŸ“š How to use:**
1003
  1. Check that the backend is running
 
1006
  4. Generate visualizations to analyze potential biases
1007
 
1008
  **πŸ”— Resources:** [OptiPFair Documentation](https://github.com/peremartra/optipfair) |
1009
+ """
1010
+ )
 
1011
 
1012
+ return interface