File size: 7,430 Bytes
96b6673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import nltk
import numpy as np
import pandas as pd
import torch as ch
from numpy.typing import NDArray
from spacy.lang.en import English
from tqdm.auto import tqdm
from typing import Any, List, Optional, Tuple
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq



def split_text(text: str, split_by: str) -> Tuple[List[str], List[str], List[str]]:
    """Split response into parts and return the parts, start indices, and separators."""
    parts = []
    separators = []
    start_indices = []

    for line in text.splitlines():
        if split_by == "sentence":
            parts.extend(nltk.sent_tokenize(line))
        elif split_by == "word":
            tokenizer = English().tokenizer
            parts = [token.text for token in tokenizer(text)]
        else:
            raise ValueError(f"Cannot split response by '{split_by}'")

    cur_start = 0
    for part in parts:
        cur_end = text.find(part, cur_start)
        separator = text[cur_start:cur_end]
        separators.append(separator)
        start_indices.append(cur_end)
        cur_start = cur_end + len(part)

    return parts, separators, start_indices


def highlight_word_indices(words, indices, separators, color: bool):
    formatted_words = []

    # ANSI escape code for red color
    if color:
        RED = "\033[36m"  # ANSI escape code for light gray
        RESET = "\033[0m"  # Reset color to default
    else:
        RED = ""
        RESET = ""

    for word, idx in zip(words, indices):
        # Wrap index with red color
        formatted_words.append(f"{RED}[{idx}]{RESET}{word}")

    result = "".join(sep + word for sep, word in zip(separators, formatted_words))
    return result


def _create_mask(num_sources, alpha, seed):
    random = np.random.RandomState(seed)
    p = [1 - alpha, alpha]
    return random.choice([False, True], size=num_sources, p=p)


def _create_regression_dataset(
    num_masks, num_sources, get_prompt_ids, response_ids, alpha, base_seed=0
):
    masks = np.zeros((num_masks, num_sources), dtype=bool)
    data_dict = {
        "input_ids": [],
        "attention_mask": [],
        "labels": [],
    }
    for seed in range(num_masks):
        mask = _create_mask(num_sources, alpha, seed + base_seed)
        masks[seed] = mask
        prompt_ids = get_prompt_ids(mask=mask)
        input_ids = prompt_ids + response_ids
        data_dict["input_ids"].append(input_ids)
        data_dict["attention_mask"].append([1] * len(input_ids))
        data_dict["labels"].append([-100] * len(prompt_ids) + response_ids)
    return masks, Dataset.from_dict(data_dict)


def _compute_logit_probs(logits, labels):
    batch_size, seq_length = labels.shape
    # [num_tokens x vocab_size]
    reshaped_logits = logits.reshape(batch_size * seq_length, -1)
    reshaped_labels = labels.reshape(batch_size * seq_length)
    correct_logits = reshaped_logits.gather(-1, reshaped_labels[:, None])[:, 0]
    cloned_logits = reshaped_logits.clone()
    cloned_logits.scatter_(-1, reshaped_labels[:, None], -ch.inf)
    other_logits = cloned_logits.logsumexp(dim=-1)
    reshaped_outputs = correct_logits - other_logits
    return reshaped_outputs.reshape(batch_size, seq_length)


def _make_loader(dataset, tokenizer, batch_size):
    collate_fn = DataCollatorForSeq2Seq(tokenizer=tokenizer, padding="longest")
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
    )
    return loader


def _get_response_logit_probs(dataset, model, tokenizer, response_length, batch_size):
    if batch_size > 1:
        assert tokenizer.padding_side == "left", "Tokenizer must use left padding"
    loader = _make_loader(dataset, tokenizer, batch_size)
    logit_probs = ch.zeros((len(dataset), response_length), device=model.device)

    start_index = 0
    for batch in tqdm(loader):
        batch = {key: value.to(model.device) for key, value in batch.items()}
        with ch.no_grad(), ch.cuda.amp.autocast():
            output = model(**batch)
        logits = output.logits[:, -(response_length + 1) : -1]
        labels = batch["labels"][:, -response_length:]
        batch_size, _ = labels.shape
        cur_logit_probs = _compute_logit_probs(logits, labels)
        logit_probs[start_index : start_index + batch_size] = cur_logit_probs
        start_index += batch_size

    return logit_probs.cpu().numpy()


def get_masks_and_logit_probs(
    model,
    tokenizer,
    num_masks,
    num_sources,
    get_prompt_ids,
    response_ids,
    ablation_keep_prob,
    batch_size,
    base_seed=0,
):
    masks, dataset = _create_regression_dataset(
        num_masks,
        num_sources,
        get_prompt_ids,
        response_ids,
        ablation_keep_prob,
        base_seed=base_seed,
    )
    logit_probs = _get_response_logit_probs(
        dataset, model, tokenizer, len(response_ids), batch_size
    )
    return masks, logit_probs.astype(np.float32)


def aggregate_logit_probs(logit_probs, output_type="logit_prob"):
    """Compute sequence-level outputs from token-level logit-probabilities."""
    logit_probs = ch.tensor(logit_probs)
    log_probs = ch.nn.functional.logsigmoid(logit_probs).sum(dim=1)
    if output_type == "log_prob":
        return log_probs.numpy()
    elif output_type == "logit_prob":
        log_1mprobs = ch.log1p(-ch.exp(log_probs))
        return (log_probs - log_1mprobs).numpy()
    elif output_type == "total_token_logit_prob":
        return logit_probs.mean(dim=1).numpy()
    else:
        raise ValueError(f"Cannot aggregate log probs for output type '{output_type}'")


def _color_scale(val, max_val):
    start_color = (255, 255, 255)
    end_color = (80, 180, 80)
    if val == 0:
        return f"background-color: rgb{start_color}"
    elif val == max_val:
        return f"background-color: rgb{end_color}"
    else:
        fraction = val / max_val
        interpolated_color = tuple(
            start_color[i] + (end_color[i] - start_color[i]) * fraction
            for i in range(3)
        )
        return f"background-color: rgb{interpolated_color}"


def _apply_color_scale(df):
    # A score of np.log(10) means that the ablating this sources causes the
    # logit-probability to drop by np.log(10), which (roughly) corresponds to
    # a decrease in probability of 10x.
    max_val = max([df["Score"].max(), np.log(10)])
    return df.style.applymap(lambda val: _color_scale(val, max_val), subset=["Score"])


def get_attributions_df(
    attributions: NDArray[Any],
    context_partitioner,
    top_k: Optional[int] = None,
) -> Any:
    order = attributions.argsort()[::-1]
    selected_attributions = []
    selected_sources = []

    if top_k is not None:
        order = order[:top_k]

    for i in order:
        selected_attributions.append(attributions[i])
        selected_sources.append(context_partitioner.get_source(i))

    df = pd.DataFrame.from_dict(
        {"Score": selected_attributions, "Source": selected_sources}
    )
    df = _apply_color_scale(df).format(precision=3)
    return df


# The Llama 3 char_to_token is buggy (start and end chars for a given token
# are often the same), so we implement our own
def char_to_token(output_tokens, char_index):
    for i in range(len(output_tokens["input_ids"]) - 1):
        if char_index < output_tokens.token_to_chars(i + 1).start:
            return i
    return i + 1