Spaces:
Runtime error
Runtime error
""" | |
Implementation of Jensen-Shannon Divergence (JSD) for comparing language model outputs. | |
This module provides functions to compute the Jensen-Shannon Divergence between | |
probability distributions output by two language models, measuring their similarity | |
in output space rather than parameter space. | |
""" | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from tracing.utils.evaluate import ( | |
prepare_hf_dataset, | |
prepare_hf_dataloader, | |
) | |
def statistic(base_model, ft_model, dataloader, device="cuda"): | |
""" | |
Compute Jensen-Shannon Divergence between outputs of two language models. | |
Args: | |
base_model: Base model to compare | |
ft_model: Fine-tuned or target model to compare against the base model | |
dataloader: DataLoader providing input data for model evaluation | |
device: Device to run the computation on (default: "cuda") | |
Returns: | |
float: Sum of Jensen-Shannon Divergence values across all batches | |
""" | |
return compute_jsd(base_model, ft_model, dataloader, device) | |
def statistic_stable(base_model, ft_model, dataloader, device="cuda"): | |
""" | |
Compute numerically stable Jensen-Shannon Divergence between outputs of two models. | |
This version handles potential numerical issues better than the standard version. | |
Args: | |
base_model: Base model to compare | |
ft_model: Fine-tuned or target model to compare against the base model | |
dataloader: DataLoader providing input data for model evaluation | |
device: Device to run the computation on (default: "cuda") | |
Returns: | |
float: Sum of Jensen-Shannon Divergence values across all batches | |
""" | |
return compute_jsd_stable(base_model, ft_model, dataloader, device) | |
def compute_jsd(base_model, ft_model, dataloader, device="cuda"): | |
""" | |
Compute Jensen-Shannon Divergence between two models using softmax outputs. | |
Processes each batch in the dataloader and computes JSD between the models' | |
probability distributions over vocabulary tokens. Handles potential vocabulary | |
size differences by truncating to a common size (32000 tokens). | |
Args: | |
base_model: Base model to compare | |
ft_model: Fine-tuned or target model to compare against the base model | |
dataloader: DataLoader providing input data for model evaluation | |
device: Device to run the computation on (default: "cuda") | |
Returns: | |
float: Sum of Jensen-Shannon Divergence values across all batches | |
""" | |
jsds = [] | |
base_model.to(device) | |
ft_model.to(device) | |
with torch.no_grad(): | |
for batch in dataloader: | |
input_ids = batch["input_ids"].to(device) | |
attention_mask = batch["attention_mask"].to(device) | |
labels = batch["labels"].to(device) | |
outputs_base = base_model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
labels=labels, | |
) | |
outputs_ft = ft_model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
labels=labels, | |
) | |
logits_base = outputs_base.logits.squeeze() | |
logits_ft = outputs_ft.logits.squeeze() | |
softmax_base = torch.softmax(logits_base, dim=-1) | |
softmax_ft = torch.softmax(logits_ft, dim=-1) | |
# Truncate the softmax outputs to the first 32000 dimensions | |
softmax_base = softmax_base[:, :32000] | |
softmax_ft = softmax_ft[:, :32000] | |
m = 0.5 * (softmax_base + softmax_ft) | |
jsd = 0.5 * (F.kl_div(m.log(), softmax_base) + F.kl_div(m.log(), softmax_ft)) | |
jsds.append(jsd.item()) | |
base_model.to("cpu") | |
ft_model.to("cpu") | |
return sum(jsds) | |
def compute_jsd_stable(base_model, ft_model, dataloader, device="cuda"): | |
""" | |
Compute numerically stable Jensen-Shannon Divergence between two models. | |
A more robust implementation that: | |
1. Handles vocabulary size mismatches by truncating to the minimum size | |
2. Uses log-space calculations to avoid numerical underflow | |
3. Computes JSD directly from log probabilities for better stability | |
Args: | |
base_model: Base model to compare | |
ft_model: Fine-tuned or target model to compare against the base model | |
dataloader: DataLoader providing input data for model evaluation | |
device: Device to run the computation on (default: "cuda") | |
Returns: | |
float: Sum of Jensen-Shannon Divergence values across all batches | |
""" | |
jsds = [] | |
base_model.to(device) | |
ft_model.to(device) | |
with torch.no_grad(): | |
for batch in dataloader: | |
input_ids = batch["input_ids"].to(device) | |
attention_mask = batch["attention_mask"].to(device) | |
labels = batch["labels"].to(device) | |
outputs_base = base_model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
labels=labels, | |
) | |
outputs_ft = ft_model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
labels=labels, | |
) | |
logits_base = outputs_base.logits.squeeze() | |
logits_ft = outputs_ft.logits.squeeze() | |
# Determine the minimum vocabulary size between the two models | |
min_vocab_size = min(logits_base.size(-1), logits_ft.size(-1)) | |
# Truncate the logits to the minimum vocabulary size | |
logits_base = logits_base[..., :min_vocab_size] | |
logits_ft = logits_ft[..., :min_vocab_size] | |
log_probs_base = F.log_softmax(logits_base, dim=-1) | |
log_probs_ft = F.log_softmax(logits_ft, dim=-1) | |
m = 0.5 * (log_probs_base.exp() + log_probs_ft.exp()) | |
log_m = m.log() | |
kl_div_base_m = (log_probs_base - log_m).sum(dim=-1) | |
kl_div_ft_m = (log_probs_ft - log_m).sum(dim=-1) | |
jsd = 0.5 * (kl_div_base_m + kl_div_ft_m).mean() | |
jsds.append(jsd.item()) | |
base_model.to("cpu") | |
ft_model.to("cpu") | |
return sum(jsds) | |
if __name__ == "__main__": | |
base_model_name = "LLM360/Amber" # 'openlm-research/open_llama_7b' # 'lmsys/vicuna-7b-v1.5' | |
ft_model_name = "LLM360/AmberChat" # 'openlm-research/open_llama_7b_v2' # 'LLM360/Amber' # "lmsys/vicuna-7b-v1.1" | |
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16) | |
ft_model = AutoModelForCausalLM.from_pretrained(ft_model_name, torch_dtype=torch.bfloat16) | |
base_tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=False) | |
# dataset = load_generated_datasets(base_model_name, ft_model_name, 512, base_tokenizer, ["text"]) | |
# dataloader = prepare_hf_dataloader(dataset, 1) | |
dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", 512, base_tokenizer) | |
dataloader = prepare_hf_dataloader(dataset, 1) | |
print(statistic(base_model, ft_model, dataloader)) | |