File size: 14,940 Bytes
96b6673 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 |
import numpy as np
import pandas as pd
import torch as ch
from numpy.typing import NDArray
from typing import Dict, Any, Optional, List
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
from .context_partitioner import BaseContextPartitioner, SimpleContextPartitioner
from .solver import BaseSolver, LassoRegression
from .utils import (
get_masks_and_logit_probs,
aggregate_logit_probs,
split_text,
highlight_word_indices,
get_attributions_df,
char_to_token,
)
DEFAULT_GENERATE_KWARGS = {"max_new_tokens": 512, "do_sample": False}
DEFAULT_PROMPT_TEMPLATE = "Context: {context}\n\nQuery: {query}"
class ContextCiter:
def __init__(
self,
model: Any,
tokenizer: Any,
context: str,
query: str,
source_type: str = "sentence",
generate_kwargs: Optional[Dict[str, Any]] = None,
num_ablations: int = 64,
ablation_keep_prob: float = 0.5,
batch_size: int = 1,
solver: Optional[BaseSolver] = None,
prompt_template: str = DEFAULT_PROMPT_TEMPLATE,
partitioner: Optional[BaseContextPartitioner] = None,
) -> None:
"""
Initializes a new instance of the ContextCiter class, which is designed
to assist in generating contextualized responses using a given machine
learning model and tokenizer, tailored to specific queries and contexts.
Arguments:
model (Any):
The model to apply ContextCite to (a HuggingFace
ModelForCausalLM).
tokenizer (Any):
The tokenizer associated with the provided model.
context (str):
The context provided to the model
query (str):
The query to pose to the model.
source_type (str, optional):
The type of source to partition the context into. Defaults to
"sentence", can also be "word".
generate_kwargs (Optional[Dict[str, Any]], optional):
Additional keyword arguments to pass to the model's generate
method.
num_ablations (int, optional):
The number of ablations used to train the surrogate model.
Defaults to 64.
ablation_keep_prob (float, optional):
The probability of keeping a source when ablating the context.
Defaults to 0.5.
batch_size (int, optional):
The batch size used when performing inference using ablated
contexts. Defaults to 1.
solver (Optional[Solver], optional):
The solver to use to compute the linear surrogate model. Lasso
regression is used by default.
prompt_template (str, optional):
A template string used to create the prompt from the context
and query.
partitioner (Optional[BaseContextPartitioner], optional):
A custom partitioner to split the context into sources. This
will override "source_type" if specified.
"""
self.model = model
self.tokenizer = tokenizer
if partitioner is None:
self.partitioner = SimpleContextPartitioner(
context, source_type=source_type
)
else:
self.partitioner = partitioner
if self.partitioner.context != context:
raise ValueError("Partitioner context does not match provided context.")
self.query = query
self.generate_kwargs = generate_kwargs or DEFAULT_GENERATE_KWARGS
self.num_ablations = num_ablations
self.ablation_keep_prob = ablation_keep_prob
self.batch_size = batch_size
self.solver = solver or LassoRegression()
self.prompt_template = prompt_template
self._cache = {}
self.logger = logging.getLogger("ContextCite")
self.logger.setLevel(logging.DEBUG) # TODO: change to INFO later
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
context: str,
query: str,
device: str = "cuda",
model_kwargs: Dict[str, Any] = {},
tokenizer_kwargs: Dict[str, Any] = {},
**kwargs: Dict[str, Any],
) -> "ContextCiter":
"""
Load a ContextCiter instance from a pretrained model.
Arguments:
pretrained_model_name_or_path (str):
The name or path of the pretrained model. This can be a local
path or a model name on the HuggingFace model hub.
context (str):
The context provided to the model. The context and query will be
used to construct a prompt for the model, using the prompt template.
query (str):
The query provided to the model. The context and query will be
used to construct a prompt for the model, using the prompt template.
device (str, optional):
The device to use. Defaults to "cuda".
model_kwargs (Dict[str, Any], optional):
Additional keyword arguments to pass to the model's constructor.
tokenizer_kwargs (Dict[str, Any], optional):
Additional keyword arguments to pass to the tokenizer's constructor.
**kwargs (Dict[str, Any], optional):
Additional keyword arguments to pass to the ContextCiter constructor.
Returns:
ContextCiter:
A ContextCiter instance initialized with the provided model,
tokenizer, context, query, and other keyword arguments.
"""
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, **model_kwargs
)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, **tokenizer_kwargs
)
tokenizer.padding_side = "left"
return cls(model, tokenizer, context, query, **kwargs)
def _get_prompt_ids(
self,
mask: Optional[NDArray] = None,
return_prompt: bool = False,
):
context = self.partitioner.get_context(mask)
prompt = self.prompt_template.format(context=context, query=self.query)
messages = [{"role": "user", "content": prompt}]
chat_prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
chat_prompt_ids = self.tokenizer.encode(chat_prompt, add_special_tokens=False)
if return_prompt:
return chat_prompt_ids, chat_prompt
else:
return chat_prompt_ids
@property
def _response_start(self):
prompt_ids = self._get_prompt_ids()
return len(prompt_ids)
@property
def _output(self):
if self._cache.get("output") is None:
prompt_ids, prompt = self._get_prompt_ids(return_prompt=True)
input_ids = ch.tensor([prompt_ids], device=self.model.device)
output_ids = self.model.generate(input_ids, **self.generate_kwargs)[0]
# We take the original prompt because sometimes encoding and decoding changes it
raw_output = self.tokenizer.decode(output_ids)
prompt_length = len(self.tokenizer.decode(prompt_ids))
self._cache["output"] = prompt + raw_output[prompt_length:]
return self._cache["output"]
@property
def _output_tokens(self):
return self.tokenizer(self._output, add_special_tokens=False)
@property
def _response_ids(self):
return self._output_tokens["input_ids"][self._response_start :]
@property
def response(self):
"""
The response generated by the model (excluding the prompt). This
property is cached.
Returns:
str:
The response generated by the model.
"""
output_tokens = self._output_tokens
char_response_start = output_tokens.token_to_chars(self._response_start).start
return self._output[char_response_start:]
@property
def response_with_indices(self, split_by="word", color=True) -> [str, pd.DataFrame]:
"""
The response generated by the model, annotated with the starting index
of each part.
Arguments:
split_by (str, optional):
The method to split the response by. Can be "word" or "sentence".
Defaults to "word".
color (bool, optional):
Whether to color the starting index of each part. Defaults to True.
Returns:
str:
The response with the starting index of each part highlighted.
"""
start_indices = []
parts, separators, start_indices = split_text(self.response, split_by)
separated_str = highlight_word_indices(parts, start_indices, separators, color)
return separated_str
@property
def num_sources(self) -> int:
"""
The number of sources within the context. I.e., the number of sources
that the context is partitioned into.
Returns:
int:
The number of sources in the context.
"""
return self.partitioner.num_sources
@property
def sources(self) -> List[str]:
"""
The sources within the context. I.e., the context as a list
where each element is a source.
Returns:
List[str]:
The sources within the context.
"""
return self.partitioner.sources
def _char_range_to_token_range(self, start_index, end_index):
output_tokens = self._output_tokens
response_start = self._response_start
offset = output_tokens.token_to_chars(response_start).start
ids_start_index = char_to_token(output_tokens, start_index + offset)
ids_end_index = char_to_token(output_tokens, end_index + offset - 1) + 1
return ids_start_index - response_start, ids_end_index - response_start
def _indices_to_token_indices(self, start_index=None, end_index=None):
if start_index is None or end_index is None:
start_index = 0
end_index = len(self.response)
if not (0 <= start_index < end_index <= len(self.response)):
raise ValueError(
f"Invalid selection range ({start_index}, {end_index}). "
f"Please select any range within (0, {len(self.response)})."
)
return self._char_range_to_token_range(start_index, end_index)
def _compute_masks_and_logit_probs(self) -> None:
self._cache["reg_masks"], self._cache["reg_logit_probs"] = (
get_masks_and_logit_probs(
self.model,
self.tokenizer,
self.num_ablations,
self.num_sources,
self._get_prompt_ids,
self._response_ids,
self.ablation_keep_prob,
self.batch_size,
)
)
@property
def _masks(self):
if self._cache.get("reg_masks") is None:
self._compute_masks_and_logit_probs()
return self._cache["reg_masks"]
@property
def _logit_probs(self):
if self._cache.get("reg_logit_probs") is None:
self._compute_masks_and_logit_probs()
return self._cache["reg_logit_probs"]
def _get_attributions_for_ids_range(self, ids_start_idx, ids_end_idx) -> tuple:
outputs = aggregate_logit_probs(self._logit_probs[:, ids_start_idx:ids_end_idx])
num_output_tokens = ids_end_idx - ids_start_idx
weight, bias = self.solver.fit(self._masks, outputs, num_output_tokens)
return weight, bias
def get_attributions(
self,
start_idx: Optional[int] = None,
end_idx: Optional[int] = None,
as_dataframe: bool = False,
top_k: Optional[int] = None,
verbose: bool = True,
):
"""
Get the attributions for (part of) the response.
Arguments:
start_idx (int, optional):
Start index of the part to attribute to. If None, defaults to
the start of the response.
end_idx (int, optional):
End index of the part to attribute to. If None, defaults to the
end of the response.
as_dataframe (bool, optional):
If True, return the attributions as a stylized dataframe in
sorted order. Otherwise, return them as a numpy array where
the ith element corresponds to the score of the ith source
within the context. Defaults to False.
top_k (int, optional):
Only used if as_dataframe is True. Number of top attributions to
return. If None, all attributions are returned. Defaults to None.
verbose (bool, optional):
If True, print the selected part of the response. Defaults to
True.
Returns:
NDArray | Any:
If as_dataframe is False, return a numpy array where the ith element
corresponds to the score of the ith source within the context.
Otherwise, return a stylized dataframe in sorted order.
"""
if self.num_sources == 0:
self.logger.warning("No sources to attribute to!")
return np.array([])
if not as_dataframe and top_k is not None:
self.logger.warning("top_k is ignored when not using dataframes.")
ids_start_idx, ids_end_idx = self._indices_to_token_indices(start_idx, end_idx)
selected_text = self.response[start_idx:end_idx]
selected_tokens = self._response_ids[ids_start_idx:ids_end_idx]
decoded_text = self.tokenizer.decode(selected_tokens)
if selected_text.strip() not in decoded_text.strip():
self.logger.warning(
"Decoded selected tokens do not match selected text.\n"
"If the following look close enough, feel free to ignore:\n"
"What you selected: %s\nWhat is being attributed: %s",
selected_text.strip(),
decoded_text.strip(),
)
if verbose:
print(f"Attributed: {decoded_text.strip()}")
# _bias is the bias term in the l1 regression
attributions, _bias = self._get_attributions_for_ids_range(
ids_start_idx,
ids_end_idx,
)
if as_dataframe:
return get_attributions_df(attributions, self.partitioner, top_k=top_k)
else:
return attributions
|