Spaces:
Running
Running
File size: 5,772 Bytes
cb7223a 5137267 cb7223a 5137267 cb7223a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# utils/visualize_pca.py
import os
import tempfile
import logging
from functools import lru_cache
from typing import Tuple, Optional, Union, List
import torch
from optipfair.bias import visualize_pca, visualize_mean_differences, visualize_heatmap
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib
matplotlib.use('Agg') # Use 'Agg' backend for non-GUI environments
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@lru_cache(maxsize=None)
def load_model_tokenizer(model_name: str):
"""
Loads the model and tokenizer on the CPU once and caches the result.
"""
logger.info(f"Loading model and tokenizer for '{model_name}'")
# Get HF token from environment for gated models
hf_token = os.getenv("HF_TOKEN")
# Device selection: MPS (Apple Silicon) > CUDA > CPU
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
logger.info(f"Using device: {device}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
token=hf_token # ← AÑADIR ESTA LÍNEA
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=hf_token # ← AÑADIR ESTA LÍNEA
)
model = model.to(device)
logger.info(f"Model loaded on device: {next(model.parameters()).device}")
return model, tokenizer
def run_visualize_pca(
model_name: str,
prompt_pair: Tuple[str, str],
layer_key: str,
highlight_diff: bool = True,
output_dir: Optional[str] = None,
figure_format: str = "png",
pair_index: int = 0,
) -> str:
if output_dir is None:
output_dir = tempfile.mkdtemp(prefix="optipfair_pca_")
os.makedirs(output_dir, exist_ok=True)
model, tokenizer = load_model_tokenizer(model_name)
visualize_pca(
model=model,
tokenizer=tokenizer,
prompt_pair=prompt_pair,
layer_key=layer_key,
highlight_diff=highlight_diff,
output_dir=output_dir,
figure_format=figure_format,
pair_index=pair_index
)
layer_parts = layer_key.split("_")
layer_type = "_".join(layer_parts[:-1])
layer_num = layer_parts[-1]
filename = build_visualization_filename(
vis_type="pca",
layer_type=layer_type,
layer_num=layer_num,
pair_index=pair_index,
figure_format=figure_format
)
filepath = os.path.join(output_dir, filename)
if not os.path.isfile(filepath):
raise FileNotFoundError(f"Expected image file not found: {filepath}")
logger.info(f"PCA image saved at {filepath}")
return filepath
def run_visualize_mean_diff(
model_name: str,
prompt_pair: Tuple[str, str],
layer_type: str, # Changed from layer_key to layer_type
figure_format: str = "png",
output_dir: Optional[str] = None,
pair_index: int = 0,
) -> str:
if output_dir is None:
output_dir = tempfile.mkdtemp(prefix="optipfair_mean_diff_")
os.makedirs(output_dir, exist_ok=True)
model, tokenizer = load_model_tokenizer(model_name)
visualize_mean_differences(
model=model,
tokenizer=tokenizer,
prompt_pair=prompt_pair,
layer_type=layer_type,
layers="all", # By default, show all layers
output_dir=output_dir,
figure_format=figure_format,
pair_index=pair_index
)
filename = build_visualization_filename(
vis_type="mean_diff",
layer_type=layer_type,
pair_index=pair_index,
figure_format=figure_format
)
filepath = os.path.join(output_dir, filename)
if not os.path.isfile(filepath):
raise FileNotFoundError(f"Expected image file not found: {filepath}")
logger.info(f"Mean-diff image saved at {filepath}")
return filepath
def run_visualize_heatmap(
model_name: str,
prompt_pair: Tuple[str, str],
layer_key: str,
figure_format: str = "png",
output_dir: Optional[str] = None,
pair_index: int = 0,
) -> str:
if output_dir is None:
output_dir = tempfile.mkdtemp(prefix="optipfair_heatmap_")
os.makedirs(output_dir, exist_ok=True)
model, tokenizer = load_model_tokenizer(model_name)
visualize_heatmap(
model=model,
tokenizer=tokenizer,
prompt_pair=prompt_pair,
layer_key=layer_key,
output_dir=output_dir,
figure_format=figure_format,
pair_index=pair_index
)
parts = layer_key.split("_")
layer_type = "_".join(parts[:-1])
layer_num = parts[-1]
filename = build_visualization_filename(
vis_type="heatmap",
layer_type=layer_type,
layer_num=layer_num,
pair_index=pair_index,
figure_format=figure_format
)
filepath = os.path.join(output_dir, filename)
if not os.path.isfile(filepath):
raise FileNotFoundError(f"Expected image file not found: {filepath}")
logger.info(f"Heatmap image saved at {filepath}")
return filepath
def build_visualization_filename(
vis_type: str,
layer_type: str,
layer_num: str = None,
layers: Union[str, List[int]] = None,
pair_index: int = 0,
figure_format: str = "png"
) -> str:
"""
Builds the filename for any visualization.
"""
if vis_type == "mean_diff":
# The visualize_mean_differences function does not include the layer number in the filename
return f"mean_diff_{layer_type}_pair{pair_index}.{figure_format}"
elif vis_type in ("pca", "heatmap"):
return f"{vis_type}_{layer_type}_{layer_num}_pair{pair_index}.{figure_format}"
else:
raise ValueError(f"Unknown visualization type: {vis_type}")
|