import numpy as np import pandas as pd import torch as ch from numpy.typing import NDArray from typing import Dict, Any, Optional, List import logging from transformers import AutoTokenizer, AutoModelForCausalLM from .context_partitioner import BaseContextPartitioner, SimpleContextPartitioner from .solver import BaseSolver, LassoRegression from .utils import ( get_masks_and_logit_probs, aggregate_logit_probs, split_text, highlight_word_indices, get_attributions_df, char_to_token, ) DEFAULT_GENERATE_KWARGS = {"max_new_tokens": 512, "do_sample": False} DEFAULT_PROMPT_TEMPLATE = "Context: {context}\n\nQuery: {query}" class ContextCiter: def __init__( self, model: Any, tokenizer: Any, context: str, query: str, source_type: str = "sentence", generate_kwargs: Optional[Dict[str, Any]] = None, num_ablations: int = 64, ablation_keep_prob: float = 0.5, batch_size: int = 1, solver: Optional[BaseSolver] = None, prompt_template: str = DEFAULT_PROMPT_TEMPLATE, partitioner: Optional[BaseContextPartitioner] = None, ) -> None: """ Initializes a new instance of the ContextCiter class, which is designed to assist in generating contextualized responses using a given machine learning model and tokenizer, tailored to specific queries and contexts. Arguments: model (Any): The model to apply ContextCite to (a HuggingFace ModelForCausalLM). tokenizer (Any): The tokenizer associated with the provided model. context (str): The context provided to the model query (str): The query to pose to the model. source_type (str, optional): The type of source to partition the context into. Defaults to "sentence", can also be "word". generate_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments to pass to the model's generate method. num_ablations (int, optional): The number of ablations used to train the surrogate model. Defaults to 64. ablation_keep_prob (float, optional): The probability of keeping a source when ablating the context. Defaults to 0.5. batch_size (int, optional): The batch size used when performing inference using ablated contexts. Defaults to 1. solver (Optional[Solver], optional): The solver to use to compute the linear surrogate model. Lasso regression is used by default. prompt_template (str, optional): A template string used to create the prompt from the context and query. partitioner (Optional[BaseContextPartitioner], optional): A custom partitioner to split the context into sources. This will override "source_type" if specified. """ self.model = model self.tokenizer = tokenizer if partitioner is None: self.partitioner = SimpleContextPartitioner( context, source_type=source_type ) else: self.partitioner = partitioner if self.partitioner.context != context: raise ValueError("Partitioner context does not match provided context.") self.query = query self.generate_kwargs = generate_kwargs or DEFAULT_GENERATE_KWARGS self.num_ablations = num_ablations self.ablation_keep_prob = ablation_keep_prob self.batch_size = batch_size self.solver = solver or LassoRegression() self.prompt_template = prompt_template self._cache = {} self.logger = logging.getLogger("ContextCite") self.logger.setLevel(logging.DEBUG) # TODO: change to INFO later if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token @classmethod def from_pretrained( cls, pretrained_model_name_or_path, context: str, query: str, device: str = "cuda", model_kwargs: Dict[str, Any] = {}, tokenizer_kwargs: Dict[str, Any] = {}, **kwargs: Dict[str, Any], ) -> "ContextCiter": """ Load a ContextCiter instance from a pretrained model. Arguments: pretrained_model_name_or_path (str): The name or path of the pretrained model. This can be a local path or a model name on the HuggingFace model hub. context (str): The context provided to the model. The context and query will be used to construct a prompt for the model, using the prompt template. query (str): The query provided to the model. The context and query will be used to construct a prompt for the model, using the prompt template. device (str, optional): The device to use. Defaults to "cuda". model_kwargs (Dict[str, Any], optional): Additional keyword arguments to pass to the model's constructor. tokenizer_kwargs (Dict[str, Any], optional): Additional keyword arguments to pass to the tokenizer's constructor. **kwargs (Dict[str, Any], optional): Additional keyword arguments to pass to the ContextCiter constructor. Returns: ContextCiter: A ContextCiter instance initialized with the provided model, tokenizer, context, query, and other keyword arguments. """ model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, **model_kwargs ) model.to(device) tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, **tokenizer_kwargs ) tokenizer.padding_side = "left" return cls(model, tokenizer, context, query, **kwargs) def _get_prompt_ids( self, mask: Optional[NDArray] = None, return_prompt: bool = False, ): context = self.partitioner.get_context(mask) prompt = self.prompt_template.format(context=context, query=self.query) messages = [{"role": "user", "content": prompt}] chat_prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) chat_prompt_ids = self.tokenizer.encode(chat_prompt, add_special_tokens=False) if return_prompt: return chat_prompt_ids, chat_prompt else: return chat_prompt_ids @property def _response_start(self): prompt_ids = self._get_prompt_ids() return len(prompt_ids) @property def _output(self): if self._cache.get("output") is None: prompt_ids, prompt = self._get_prompt_ids(return_prompt=True) input_ids = ch.tensor([prompt_ids], device=self.model.device) output_ids = self.model.generate(input_ids, **self.generate_kwargs)[0] # We take the original prompt because sometimes encoding and decoding changes it raw_output = self.tokenizer.decode(output_ids) prompt_length = len(self.tokenizer.decode(prompt_ids)) self._cache["output"] = prompt + raw_output[prompt_length:] return self._cache["output"] @property def _output_tokens(self): return self.tokenizer(self._output, add_special_tokens=False) @property def _response_ids(self): return self._output_tokens["input_ids"][self._response_start :] @property def response(self): """ The response generated by the model (excluding the prompt). This property is cached. Returns: str: The response generated by the model. """ output_tokens = self._output_tokens char_response_start = output_tokens.token_to_chars(self._response_start).start return self._output[char_response_start:] @property def response_with_indices(self, split_by="word", color=True) -> [str, pd.DataFrame]: """ The response generated by the model, annotated with the starting index of each part. Arguments: split_by (str, optional): The method to split the response by. Can be "word" or "sentence". Defaults to "word". color (bool, optional): Whether to color the starting index of each part. Defaults to True. Returns: str: The response with the starting index of each part highlighted. """ start_indices = [] parts, separators, start_indices = split_text(self.response, split_by) separated_str = highlight_word_indices(parts, start_indices, separators, color) return separated_str @property def num_sources(self) -> int: """ The number of sources within the context. I.e., the number of sources that the context is partitioned into. Returns: int: The number of sources in the context. """ return self.partitioner.num_sources @property def sources(self) -> List[str]: """ The sources within the context. I.e., the context as a list where each element is a source. Returns: List[str]: The sources within the context. """ return self.partitioner.sources def _char_range_to_token_range(self, start_index, end_index): output_tokens = self._output_tokens response_start = self._response_start offset = output_tokens.token_to_chars(response_start).start ids_start_index = char_to_token(output_tokens, start_index + offset) ids_end_index = char_to_token(output_tokens, end_index + offset - 1) + 1 return ids_start_index - response_start, ids_end_index - response_start def _indices_to_token_indices(self, start_index=None, end_index=None): if start_index is None or end_index is None: start_index = 0 end_index = len(self.response) if not (0 <= start_index < end_index <= len(self.response)): raise ValueError( f"Invalid selection range ({start_index}, {end_index}). " f"Please select any range within (0, {len(self.response)})." ) return self._char_range_to_token_range(start_index, end_index) def _compute_masks_and_logit_probs(self) -> None: self._cache["reg_masks"], self._cache["reg_logit_probs"] = ( get_masks_and_logit_probs( self.model, self.tokenizer, self.num_ablations, self.num_sources, self._get_prompt_ids, self._response_ids, self.ablation_keep_prob, self.batch_size, ) ) @property def _masks(self): if self._cache.get("reg_masks") is None: self._compute_masks_and_logit_probs() return self._cache["reg_masks"] @property def _logit_probs(self): if self._cache.get("reg_logit_probs") is None: self._compute_masks_and_logit_probs() return self._cache["reg_logit_probs"] def _get_attributions_for_ids_range(self, ids_start_idx, ids_end_idx) -> tuple: outputs = aggregate_logit_probs(self._logit_probs[:, ids_start_idx:ids_end_idx]) num_output_tokens = ids_end_idx - ids_start_idx weight, bias = self.solver.fit(self._masks, outputs, num_output_tokens) return weight, bias def get_attributions( self, start_idx: Optional[int] = None, end_idx: Optional[int] = None, as_dataframe: bool = False, top_k: Optional[int] = None, verbose: bool = True, ): """ Get the attributions for (part of) the response. Arguments: start_idx (int, optional): Start index of the part to attribute to. If None, defaults to the start of the response. end_idx (int, optional): End index of the part to attribute to. If None, defaults to the end of the response. as_dataframe (bool, optional): If True, return the attributions as a stylized dataframe in sorted order. Otherwise, return them as a numpy array where the ith element corresponds to the score of the ith source within the context. Defaults to False. top_k (int, optional): Only used if as_dataframe is True. Number of top attributions to return. If None, all attributions are returned. Defaults to None. verbose (bool, optional): If True, print the selected part of the response. Defaults to True. Returns: NDArray | Any: If as_dataframe is False, return a numpy array where the ith element corresponds to the score of the ith source within the context. Otherwise, return a stylized dataframe in sorted order. """ if self.num_sources == 0: self.logger.warning("No sources to attribute to!") return np.array([]) if not as_dataframe and top_k is not None: self.logger.warning("top_k is ignored when not using dataframes.") ids_start_idx, ids_end_idx = self._indices_to_token_indices(start_idx, end_idx) selected_text = self.response[start_idx:end_idx] selected_tokens = self._response_ids[ids_start_idx:ids_end_idx] decoded_text = self.tokenizer.decode(selected_tokens) if selected_text.strip() not in decoded_text.strip(): self.logger.warning( "Decoded selected tokens do not match selected text.\n" "If the following look close enough, feel free to ignore:\n" "What you selected: %s\nWhat is being attributed: %s", selected_text.strip(), decoded_text.strip(), ) if verbose: print(f"Attributed: {decoded_text.strip()}") # _bias is the bias term in the l1 regression attributions, _bias = self._get_attributions_for_ids_range( ids_start_idx, ids_end_idx, ) if as_dataframe: return get_attributions_df(attributions, self.partitioner, top_k=top_k) else: return attributions