Spaces:
Runtime error
Runtime error
""" | |
Implementation of Chi-Squared Hypothesis (CSH) test for comparing neural network models. | |
This module provides functions to test whether two models have similar activation patterns | |
across different layers using Chi-Squared statistical tests. | |
""" | |
import torch | |
from collections import defaultdict | |
import scipy | |
import numpy as np | |
from scipy.stats import chi2 | |
from scipy.optimize import linear_sum_assignment as LAP | |
from tracing.utils.utils import cossim | |
from tracing.utils.evaluate import evaluate | |
def statistic(base_model, ft_model, dataloader): | |
""" | |
Compute Chi-Squared Hypothesis test statistic between two models. | |
Args: | |
base_model: Base model to compare | |
ft_model: Fine-tuned or target model to compare against the base model | |
dataloader: DataLoader providing input data for activation collection | |
Returns: | |
tuple: (p_value, p_values_per_layer) from the CSH test | |
""" | |
return csh_sp_dataloader(base_model, ft_model, dataloader) | |
def hook(m, inp, op, feats, name): | |
""" | |
Forward hook to capture output activations from model layers. | |
Args: | |
m: Module being hooked | |
inp: Input to the module | |
op: Output from the module | |
feats: Dictionary to store activations | |
name: Key to store the activations under | |
""" | |
feats[name].append(op.detach().cpu()) | |
def hook_in(m, inp, op, feats, name): | |
""" | |
Forward hook to capture input activations to model layers. | |
Args: | |
m: Module being hooked | |
inp: Input to the module (tuple) | |
op: Output from the module | |
feats: Dictionary to store activations | |
name: Key to store the activations under | |
""" | |
feats[name].append(inp[0].detach().cpu()) | |
def csh_sp_dataloader_block(base_model, ft_model, dataloader, i): | |
""" | |
Apply CSH test to a specific block in the model. | |
Args: | |
base_model: Base model to compare | |
ft_model: Fine-tuned or target model to compare against the base model | |
dataloader: DataLoader providing input data for activation collection | |
i: Block index to analyze | |
Returns: | |
float: p-value indicating the statistical similarity between models at block i | |
""" | |
feats = defaultdict(list) | |
base_hook = lambda *args: hook(*args, feats, "base") | |
base_model.model.layers[i].mlp.down_proj.register_forward_hook(base_hook) | |
ft_hook = lambda *args: hook(*args, feats, "ft") | |
ft_model.model.layers[i].mlp.down_proj.register_forward_hook(ft_hook) | |
evaluate(base_model, dataloader) | |
evaluate(ft_model, dataloader) | |
base_mat = torch.vstack(feats["base"]) | |
ft_mat = torch.vstack(feats["ft"]) | |
base_mat = base_mat.view(-1, base_mat.shape[-1]).T | |
ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T | |
matched = torch.argmax(cossim(base_mat, ft_mat), axis=-1) | |
orig = torch.arange(len(matched)) | |
cor, pvalue = scipy.stats.spearmanr(matched.tolist(), orig.tolist()) | |
return pvalue | |
def csh_sp_dataloader(base_model, ft_model, dataloader, n_blocks=32): | |
""" | |
Apply CSH test across all model blocks using activations from a dataloader. | |
Performs Chi-Squared Hypothesis test by: | |
1. Collecting activations from both models using the same input data | |
2. Computing optimal matching between neurons in corresponding layers | |
3. Calculating Spearman correlation between matched indices and original indices | |
4. Computing combined p-value using Fisher's method | |
Args: | |
base_model: Base model to compare | |
ft_model: Fine-tuned or target model to compare against the base model | |
dataloader: DataLoader providing input data for activation collection | |
n_blocks: Number of transformer blocks to analyze (default: 32) | |
Returns: | |
tuple: (combined_p_value, list_of_p_values_per_layer) | |
""" | |
chi_squared = 0 | |
feats = defaultdict(list) | |
base_hooks = {} | |
ft_hooks = {} | |
for i in range(n_blocks): | |
layer = str(i) | |
base_hooks[layer] = lambda m, inp, op, layer=layer, feats=feats: hook( | |
m, inp, op, feats, "base_" + layer | |
) | |
base_model.model.layers[i].mlp.up_proj.register_forward_hook(base_hooks[layer]) | |
ft_hooks[layer] = lambda m, inp, op, layer=layer, feats=feats: hook( | |
m, inp, op, feats, "ft_" + layer | |
) | |
ft_model.model.layers[i].mlp.up_proj.register_forward_hook(ft_hooks[layer]) | |
evaluate(base_model, dataloader) | |
evaluate(ft_model, dataloader) | |
p_values = [] | |
count = 0 | |
for i in range(n_blocks): | |
base_mat = torch.vstack(feats["base_" + str(i)]) | |
ft_mat = torch.vstack(feats["ft_" + str(i)]) | |
base_mat = torch.reshape( | |
base_mat, (base_mat.shape[0] * base_mat.shape[1], base_mat.shape[2]) | |
) | |
ft_mat = torch.reshape(ft_mat, (ft_mat.shape[0] * ft_mat.shape[1], ft_mat.shape[2])) | |
base_mat = base_mat.T | |
ft_mat = ft_mat.T | |
matched = LAP( | |
cossim(base_mat.type(torch.float64), ft_mat.type(torch.float64)), maximize=True | |
) | |
matched = matched[1] | |
orig = torch.arange(len(matched)) | |
cor, temp = scipy.stats.spearmanr(matched.tolist(), orig.tolist()) | |
if not np.isnan(temp): | |
chi_squared -= 2 * np.log(temp) | |
count += 1 | |
print(i, temp) | |
p_values.append(temp) | |
p_value = chi2.sf(chi_squared, df=2 * count) | |
return p_value, p_values | |