Spaces:
Running
Running
# utils/visualize_pca.py | |
import os | |
import tempfile | |
import logging | |
from functools import lru_cache | |
from typing import Tuple, Optional, Union, List | |
import torch | |
from optipfair.bias import visualize_pca, visualize_mean_differences, visualize_heatmap | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import matplotlib | |
matplotlib.use('Agg') # Use 'Agg' backend for non-GUI environments | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
def load_model_tokenizer(model_name: str): | |
""" | |
Loads the model and tokenizer on the CPU once and caches the result. | |
""" | |
logger.info(f"Loading model and tokenizer for '{model_name}'") | |
# Get HF token from environment for gated models | |
hf_token = os.getenv("HF_TOKEN") | |
# Device selection: MPS (Apple Silicon) > CUDA > CPU | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif torch.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
logger.info(f"Using device: {device}") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
token=hf_token # ← AÑADIR ESTA LÍNEA | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
token=hf_token # ← AÑADIR ESTA LÍNEA | |
) | |
model = model.to(device) | |
logger.info(f"Model loaded on device: {next(model.parameters()).device}") | |
return model, tokenizer | |
def run_visualize_pca( | |
model_name: str, | |
prompt_pair: Tuple[str, str], | |
layer_key: str, | |
highlight_diff: bool = True, | |
output_dir: Optional[str] = None, | |
figure_format: str = "png", | |
pair_index: int = 0, | |
) -> str: | |
if output_dir is None: | |
output_dir = tempfile.mkdtemp(prefix="optipfair_pca_") | |
os.makedirs(output_dir, exist_ok=True) | |
model, tokenizer = load_model_tokenizer(model_name) | |
visualize_pca( | |
model=model, | |
tokenizer=tokenizer, | |
prompt_pair=prompt_pair, | |
layer_key=layer_key, | |
highlight_diff=highlight_diff, | |
output_dir=output_dir, | |
figure_format=figure_format, | |
pair_index=pair_index | |
) | |
layer_parts = layer_key.split("_") | |
layer_type = "_".join(layer_parts[:-1]) | |
layer_num = layer_parts[-1] | |
filename = build_visualization_filename( | |
vis_type="pca", | |
layer_type=layer_type, | |
layer_num=layer_num, | |
pair_index=pair_index, | |
figure_format=figure_format | |
) | |
filepath = os.path.join(output_dir, filename) | |
if not os.path.isfile(filepath): | |
raise FileNotFoundError(f"Expected image file not found: {filepath}") | |
logger.info(f"PCA image saved at {filepath}") | |
return filepath | |
def run_visualize_mean_diff( | |
model_name: str, | |
prompt_pair: Tuple[str, str], | |
layer_type: str, # Changed from layer_key to layer_type | |
figure_format: str = "png", | |
output_dir: Optional[str] = None, | |
pair_index: int = 0, | |
) -> str: | |
if output_dir is None: | |
output_dir = tempfile.mkdtemp(prefix="optipfair_mean_diff_") | |
os.makedirs(output_dir, exist_ok=True) | |
model, tokenizer = load_model_tokenizer(model_name) | |
visualize_mean_differences( | |
model=model, | |
tokenizer=tokenizer, | |
prompt_pair=prompt_pair, | |
layer_type=layer_type, | |
layers="all", # By default, show all layers | |
output_dir=output_dir, | |
figure_format=figure_format, | |
pair_index=pair_index | |
) | |
filename = build_visualization_filename( | |
vis_type="mean_diff", | |
layer_type=layer_type, | |
pair_index=pair_index, | |
figure_format=figure_format | |
) | |
filepath = os.path.join(output_dir, filename) | |
if not os.path.isfile(filepath): | |
raise FileNotFoundError(f"Expected image file not found: {filepath}") | |
logger.info(f"Mean-diff image saved at {filepath}") | |
return filepath | |
def run_visualize_heatmap( | |
model_name: str, | |
prompt_pair: Tuple[str, str], | |
layer_key: str, | |
figure_format: str = "png", | |
output_dir: Optional[str] = None, | |
pair_index: int = 0, | |
) -> str: | |
if output_dir is None: | |
output_dir = tempfile.mkdtemp(prefix="optipfair_heatmap_") | |
os.makedirs(output_dir, exist_ok=True) | |
model, tokenizer = load_model_tokenizer(model_name) | |
visualize_heatmap( | |
model=model, | |
tokenizer=tokenizer, | |
prompt_pair=prompt_pair, | |
layer_key=layer_key, | |
output_dir=output_dir, | |
figure_format=figure_format, | |
pair_index=pair_index | |
) | |
parts = layer_key.split("_") | |
layer_type = "_".join(parts[:-1]) | |
layer_num = parts[-1] | |
filename = build_visualization_filename( | |
vis_type="heatmap", | |
layer_type=layer_type, | |
layer_num=layer_num, | |
pair_index=pair_index, | |
figure_format=figure_format | |
) | |
filepath = os.path.join(output_dir, filename) | |
if not os.path.isfile(filepath): | |
raise FileNotFoundError(f"Expected image file not found: {filepath}") | |
logger.info(f"Heatmap image saved at {filepath}") | |
return filepath | |
def build_visualization_filename( | |
vis_type: str, | |
layer_type: str, | |
layer_num: str = None, | |
layers: Union[str, List[int]] = None, | |
pair_index: int = 0, | |
figure_format: str = "png" | |
) -> str: | |
""" | |
Builds the filename for any visualization. | |
""" | |
if vis_type == "mean_diff": | |
# The visualize_mean_differences function does not include the layer number in the filename | |
return f"mean_diff_{layer_type}_pair{pair_index}.{figure_format}" | |
elif vis_type in ("pca", "heatmap"): | |
return f"{vis_type}_{layer_type}_{layer_num}_pair{pair_index}.{figure_format}" | |
else: | |
raise ValueError(f"Unknown visualization type: {vis_type}") | |