File size: 7,055 Bytes
de071e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""
Implementation of Jensen-Shannon Divergence (JSD) for comparing language model outputs.

This module provides functions to compute the Jensen-Shannon Divergence between
probability distributions output by two language models, measuring their similarity
in output space rather than parameter space.
"""

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

from tracing.utils.evaluate import (
    prepare_hf_dataset,
    prepare_hf_dataloader,
)


def statistic(base_model, ft_model, dataloader, device="cuda"):
    """
    Compute Jensen-Shannon Divergence between outputs of two language models.

    Args:
        base_model: Base model to compare
        ft_model: Fine-tuned or target model to compare against the base model
        dataloader: DataLoader providing input data for model evaluation
        device: Device to run the computation on (default: "cuda")

    Returns:
        float: Sum of Jensen-Shannon Divergence values across all batches
    """
    return compute_jsd(base_model, ft_model, dataloader, device)


def statistic_stable(base_model, ft_model, dataloader, device="cuda"):
    """
    Compute numerically stable Jensen-Shannon Divergence between outputs of two models.

    This version handles potential numerical issues better than the standard version.

    Args:
        base_model: Base model to compare
        ft_model: Fine-tuned or target model to compare against the base model
        dataloader: DataLoader providing input data for model evaluation
        device: Device to run the computation on (default: "cuda")

    Returns:
        float: Sum of Jensen-Shannon Divergence values across all batches
    """
    return compute_jsd_stable(base_model, ft_model, dataloader, device)


def compute_jsd(base_model, ft_model, dataloader, device="cuda"):
    """
    Compute Jensen-Shannon Divergence between two models using softmax outputs.

    Processes each batch in the dataloader and computes JSD between the models'
    probability distributions over vocabulary tokens. Handles potential vocabulary
    size differences by truncating to a common size (32000 tokens).

    Args:
        base_model: Base model to compare
        ft_model: Fine-tuned or target model to compare against the base model
        dataloader: DataLoader providing input data for model evaluation
        device: Device to run the computation on (default: "cuda")

    Returns:
        float: Sum of Jensen-Shannon Divergence values across all batches
    """
    jsds = []

    base_model.to(device)
    ft_model.to(device)

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs_base = base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            outputs_ft = ft_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )

            logits_base = outputs_base.logits.squeeze()
            logits_ft = outputs_ft.logits.squeeze()

            softmax_base = torch.softmax(logits_base, dim=-1)
            softmax_ft = torch.softmax(logits_ft, dim=-1)

            # Truncate the softmax outputs to the first 32000 dimensions
            softmax_base = softmax_base[:, :32000]
            softmax_ft = softmax_ft[:, :32000]

            m = 0.5 * (softmax_base + softmax_ft)
            jsd = 0.5 * (F.kl_div(m.log(), softmax_base) + F.kl_div(m.log(), softmax_ft))

            jsds.append(jsd.item())

    base_model.to("cpu")
    ft_model.to("cpu")
    return sum(jsds)


def compute_jsd_stable(base_model, ft_model, dataloader, device="cuda"):
    """
    Compute numerically stable Jensen-Shannon Divergence between two models.

    A more robust implementation that:
    1. Handles vocabulary size mismatches by truncating to the minimum size
    2. Uses log-space calculations to avoid numerical underflow
    3. Computes JSD directly from log probabilities for better stability

    Args:
        base_model: Base model to compare
        ft_model: Fine-tuned or target model to compare against the base model
        dataloader: DataLoader providing input data for model evaluation
        device: Device to run the computation on (default: "cuda")

    Returns:
        float: Sum of Jensen-Shannon Divergence values across all batches
    """
    jsds = []

    base_model.to(device)
    ft_model.to(device)

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs_base = base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            outputs_ft = ft_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )

            logits_base = outputs_base.logits.squeeze()
            logits_ft = outputs_ft.logits.squeeze()

            # Determine the minimum vocabulary size between the two models
            min_vocab_size = min(logits_base.size(-1), logits_ft.size(-1))

            # Truncate the logits to the minimum vocabulary size
            logits_base = logits_base[..., :min_vocab_size]
            logits_ft = logits_ft[..., :min_vocab_size]

            log_probs_base = F.log_softmax(logits_base, dim=-1)
            log_probs_ft = F.log_softmax(logits_ft, dim=-1)

            m = 0.5 * (log_probs_base.exp() + log_probs_ft.exp())
            log_m = m.log()

            kl_div_base_m = (log_probs_base - log_m).sum(dim=-1)
            kl_div_ft_m = (log_probs_ft - log_m).sum(dim=-1)

            jsd = 0.5 * (kl_div_base_m + kl_div_ft_m).mean()
            jsds.append(jsd.item())

    base_model.to("cpu")
    ft_model.to("cpu")

    return sum(jsds)


if __name__ == "__main__":

    base_model_name = "LLM360/Amber"  # 'openlm-research/open_llama_7b' # 'lmsys/vicuna-7b-v1.5'
    ft_model_name = "LLM360/AmberChat"  # 'openlm-research/open_llama_7b_v2' # 'LLM360/Amber' # "lmsys/vicuna-7b-v1.1"

    base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)
    ft_model = AutoModelForCausalLM.from_pretrained(ft_model_name, torch_dtype=torch.bfloat16)
    base_tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=False)

    # dataset = load_generated_datasets(base_model_name, ft_model_name, 512, base_tokenizer, ["text"])
    # dataloader = prepare_hf_dataloader(dataset, 1)

    dataset = prepare_hf_dataset("dlwh/wikitext_103_detokenized", 512, base_tokenizer)
    dataloader = prepare_hf_dataloader(dataset, 1)

    print(statistic(base_model, ft_model, dataloader))