Spaces:
Runtime error
Runtime error
File size: 7,055 Bytes
de071e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
"""
Implementation of Jensen-Shannon Divergence (JSD) for comparing language model outputs.
This module provides functions to compute the Jensen-Shannon Divergence between
probability distributions output by two language models, measuring their similarity
in output space rather than parameter space.
"""
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from tracing.utils.evaluate import (
prepare_hf_dataset,
prepare_hf_dataloader,
)
def statistic(base_model, ft_model, dataloader, device="cuda"):
"""
Compute Jensen-Shannon Divergence between outputs of two language models.
Args:
base_model: Base model to compare
ft_model: Fine-tuned or target model to compare against the base model
dataloader: DataLoader providing input data for model evaluation
device: Device to run the computation on (default: "cuda")
Returns:
float: Sum of Jensen-Shannon Divergence values across all batches
"""
return compute_jsd(base_model, ft_model, dataloader, device)
def statistic_stable(base_model, ft_model, dataloader, device="cuda"):
"""
Compute numerically stable Jensen-Shannon Divergence between outputs of two models.
This version handles potential numerical issues better than the standard version.
Args:
base_model: Base model to compare
ft_model: Fine-tuned or target model to compare against the base model
dataloader: DataLoader providing input data for model evaluation
device: Device to run the computation on (default: "cuda")
Returns:
float: Sum of Jensen-Shannon Divergence values across all batches
"""
return compute_jsd_stable(base_model, ft_model, dataloader, device)
def compute_jsd(base_model, ft_model, dataloader, device="cuda"):
"""
Compute Jensen-Shannon Divergence between two models using softmax outputs.
Processes each batch in the dataloader and computes JSD between the models'
probability distributions over vocabulary tokens. Handles potential vocabulary
size differences by truncating to a common size (32000 tokens).
Args:
base_model: Base model to compare
ft_model: Fine-tuned or target model to compare against the base model
dataloader: DataLoader providing input data for model evaluation
device: Device to run the computation on (default: "cuda")
Returns:
float: Sum of Jensen-Shannon Divergence values across all batches
"""
jsds = []
base_model.to(device)
ft_model.to(device)
with torch.no_grad():
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs_base = base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
outputs_ft = ft_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
logits_base = outputs_base.logits.squeeze()
logits_ft = outputs_ft.logits.squeeze()
softmax_base = torch.softmax(logits_base, dim=-1)
softmax_ft = torch.softmax(logits_ft, dim=-1)
# Truncate the softmax outputs to the first 32000 dimensions
softmax_base = softmax_base[:, :32000]
softmax_ft = softmax_ft[:, :32000]
m = 0.5 * (softmax_base + softmax_ft)
jsd = 0.5 * (F.kl_div(m.log(), softmax_base) + F.kl_div(m.log(), softmax_ft))
jsds.append(jsd.item())
base_model.to("cpu")
ft_model.to("cpu")
return sum(jsds)
def compute_jsd_stable(base_model, ft_model, dataloader, device="cuda"):
"""
Compute numerically stable Jensen-Shannon Divergence between two models.
A more robust implementation that:
1. Handles vocabulary size mismatches by truncating to the minimum size
2. Uses log-space calculations to avoid numerical underflow
3. Computes JSD directly from log probabilities for better stability
Args:
base_model: Base model to compare
ft_model: Fine-tuned or target model to compare against the base model
dataloader: DataLoader providing input data for model evaluation
device: Device to run the computation on (default: "cuda")
Returns:
float: Sum of Jensen-Shannon Divergence values across all batches
"""
jsds = []
base_model.to(device)
ft_model.to(device)
with torch.no_grad():
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs_base = base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
outputs_ft = ft_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
logits_base = outputs_base.logits.squeeze()
logits_ft = outputs_ft.logits.squeeze()
# Determine the minimum vocabulary size between the two models
min_vocab_size = min(logits_base.size(-1), logits_ft.size(-1))
# Truncate the logits to the minimum vocabulary size
logits_base = logits_base[..., :min_vocab_size]
logits_ft = logits_ft[..., :min_vocab_size]
log_probs_base = F.log_softmax(logits_base, dim=-1)
log_probs_ft = F.log_softmax(logits_ft, dim=-1)
m = 0.5 * (log_probs_base.exp() + log_probs_ft.exp())
log_m = m.log()
kl_div_base_m = (log_probs_base - log_m).sum(dim=-1)
kl_div_ft_m = (log_probs_ft - log_m).sum(dim=-1)
jsd = 0.5 * (kl_div_base_m + kl_div_ft_m).mean()
jsds.append(jsd.item())
base_model.to("cpu")
ft_model.to("cpu")
return sum(jsds)
if __name__ == "__main__":
base_model_name = "LLM360/Amber" # 'openlm-research/open_llama_7b' # 'lmsys/vicuna-7b-v1.5'
ft_model_name = "LLM360/AmberChat" # 'openlm-research/open_llama_7b_v2' # 'LLM360/Amber' # "lmsys/vicuna-7b-v1.1"
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)
ft_model = AutoModelForCausalLM.from_pretrained(ft_model_name, torch_dtype=torch.bfloat16)
base_tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=False)
# dataset = load_generated_datasets(base_model_name, ft_model_name, 512, base_tokenizer, ["text"])
# dataloader = prepare_hf_dataloader(dataset, 1)
dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", 512, base_tokenizer)
dataloader = prepare_hf_dataloader(dataset, 1)
print(statistic(base_model, ft_model, dataloader))
|