|
import abc |
|
import hashlib |
|
import json |
|
import logging |
|
import os |
|
from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union |
|
|
|
import transformers |
|
from sqlitedict import SqliteDict |
|
from tqdm import tqdm |
|
|
|
from lm_eval import utils |
|
|
|
|
|
eval_logger = logging.getLogger("lm-eval") |
|
|
|
T = TypeVar("T", bound="LM") |
|
|
|
|
|
class LM(abc.ABC): |
|
def __init__(self) -> None: |
|
"""Defines the interface that should be implemented by all LM subclasses. |
|
LMs are assumed to take text (strings) as input and yield strings as output |
|
(inputs/outputs should be tokenization-agnostic.) |
|
|
|
""" |
|
|
|
self._rank = 0 |
|
self._world_size = 1 |
|
self.cache_hook = CacheHook(None) |
|
|
|
@abc.abstractmethod |
|
def loglikelihood(self, requests) -> List[Tuple[float, bool]]: |
|
"""Compute log-likelihood of generating a continuation from a context. |
|
Downstream tasks should attempt to use loglikelihood instead of other |
|
LM calls whenever possible. |
|
|
|
:param requests: list[Instance] |
|
A list of Instance objects, with property `args` which returns a tuple (context, continuation). |
|
`context: str` |
|
Context string. Implementations of LM must be able to handle an |
|
empty context string. |
|
`continuation: str` |
|
The continuation over which log likelihood will be calculated. If |
|
there is a word boundary, the space should be in the continuation. |
|
For example, context="hello" continuation=" world" is correct. |
|
|
|
:return: list[tuple[float, bool]] |
|
A list of pairs (logprob, isgreedy) |
|
`logprob: float` |
|
The log probability of `continuation`. |
|
`isgreedy`: |
|
Whether `continuation` would be generated by greedy sampling from `context`. |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def loglikelihood_rolling(self, requests) -> List[float]: |
|
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation |
|
- We will use the full max context length of the model. |
|
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to |
|
the max context length. |
|
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations |
|
which may simply concatenate multiple documents together. |
|
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into |
|
multiple chunks, the last input will still a full-sized context. |
|
Example: |
|
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] |
|
Prefix: BOS/EOS |
|
Max context length: 4 |
|
Resulting input/prediction pairs: |
|
|
|
INPUT: BOS 0 1 2 |
|
PRED: 0 1 2 3 |
|
|
|
INPUT: 3 4 5 6 |
|
PRED: 4 5 6 7 |
|
|
|
INPUT: 5 6 7 8 |
|
PRED: 8 9 |
|
|
|
Observe that: |
|
1. Each token is predicted exactly once |
|
2. For the last pair, we provide the full context, but only score the last two tokens |
|
|
|
:param requests: list[Instance] |
|
A list of Instance objects with property `args` which returns a tuple (context,). |
|
string: str |
|
String for which we are computing overall loglikelihood |
|
:return: list[tuple[float]] |
|
A list of tuples (logprob,) |
|
logprob: float |
|
The log probability of `context` conditioned on the BOS/EOS token. |
|
Can also be overridden for custom cases by `prefix_token_id`. |
|
""" |
|
pass |
|
|
|
|
|
@abc.abstractmethod |
|
def generate_until(self, requests) -> List[str]: |
|
"""Generate greedily until a stopping sequence |
|
|
|
:param requests: list[Instance] |
|
A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs). |
|
context: str |
|
Context string |
|
gen_kwargs: dict |
|
A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc. |
|
:return: list[str] |
|
A list of model generated continuations. |
|
continuation: str |
|
The generated continuation. |
|
""" |
|
pass |
|
|
|
def apply_chat_template( |
|
self, chat_history: List[Dict[str, str]], add_generation_prompt=True |
|
) -> str: |
|
""" |
|
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM. |
|
|
|
:param chat_history: list[dict[str, str]] |
|
A list of dictionaries with keys 'role' and 'content'. |
|
Values are strings representing the role name and the content of the message, respectively. |
|
:param add_generation_prompt: bool |
|
Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message. |
|
:return: str |
|
A string representing the chat history in a format that can be used as input to the LM. |
|
""" |
|
raise NotImplementedError( |
|
"To use this model with chat templates, please implement the 'apply_chat_template' method for your model type." |
|
) |
|
|
|
@classmethod |
|
def create_from_arg_string( |
|
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None |
|
) -> T: |
|
""" |
|
Creates an instance of the LM class using the given argument string and additional config. |
|
|
|
Parameters: |
|
- arg_string: A string containing arguments in the format key1=value1,key2=value2. |
|
- additional_config: Optional dictionary containing additional configuration parameters. |
|
|
|
Returns: |
|
- Instance of the LM class. |
|
""" |
|
additional_config = {} if additional_config is None else additional_config |
|
args = utils.simple_parse_args_string(arg_string) |
|
args2 = {k: v for k, v in additional_config.items() if v is not None} |
|
return cls(**args, **args2) |
|
|
|
@classmethod |
|
def create_from_arg_obj( |
|
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None |
|
) -> T: |
|
""" |
|
Creates an instance of the LM class using the given arg_obj |
|
|
|
Parameters: |
|
- arg_obj: A dict containing arguments in the format key1=value1,key2=value2. |
|
- additional_config: Optional dictionary containing additional configuration parameters. |
|
|
|
Returns: |
|
- Instance of the LM class. |
|
""" |
|
|
|
additional_config = {} if additional_config is None else additional_config |
|
additional_config = { |
|
k: v for k, v in additional_config.items() if v is not None |
|
} |
|
|
|
return cls(**arg_dict, **additional_config) |
|
|
|
@property |
|
def rank(self): |
|
|
|
|
|
|
|
return self._rank |
|
|
|
@property |
|
def world_size(self): |
|
|
|
|
|
|
|
return self._world_size |
|
|
|
@property |
|
def tokenizer_name(self) -> str: |
|
"""Must be defined for LM subclasses which implement Chat Templating. |
|
Should return the name of the tokenizer or chat template used. |
|
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used. |
|
""" |
|
raise NotImplementedError( |
|
"To use this model with chat templates, please implement the 'tokenizer_name' property." |
|
) |
|
|
|
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: |
|
"""Returns the chat template structure for user/assistant messages if a template is provided. |
|
This method is intended to be overridden in a subclass to define a specific chat template format. |
|
For models that do not support chat templates, this method returns None by default. |
|
""" |
|
|
|
return "" |
|
|
|
def set_cache_hook(self, cache_hook) -> None: |
|
self.cache_hook = cache_hook |
|
|
|
|
|
|
|
def hash_args(attr, args): |
|
dat = json.dumps([attr] + list(args)) |
|
return hashlib.sha256(dat.encode("utf-8")).hexdigest() |
|
|
|
|
|
class CacheHook: |
|
def __init__(self, cachinglm) -> None: |
|
if cachinglm is None: |
|
self.dbdict = None |
|
return |
|
|
|
self.dbdict = cachinglm.dbdict |
|
|
|
def add_partial(self, attr, req, res) -> None: |
|
if self.dbdict is None: |
|
return |
|
hsh = hash_args(attr, req) |
|
self.dbdict[hsh] = res |
|
|
|
|
|
class CachingLM: |
|
def __init__(self, lm, cache_db) -> None: |
|
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not. |
|
|
|
:param lm: LM |
|
Underlying LM |
|
:param cache_db: str |
|
Path to cache db |
|
""" |
|
self.lm = lm |
|
self.cache_db = cache_db |
|
if os.path.dirname(cache_db): |
|
os.makedirs(os.path.dirname(cache_db), exist_ok=True) |
|
self.dbdict = SqliteDict(cache_db, autocommit=True) |
|
|
|
|
|
lm.set_cache_hook(self.get_cache_hook()) |
|
|
|
def __getattr__(self, attr: str): |
|
lm_attr = getattr(self.lm, attr) |
|
if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]: |
|
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM") |
|
return lm_attr |
|
|
|
def fn(requests): |
|
res = [] |
|
remaining_reqs = [] |
|
warned = False |
|
|
|
eval_logger.info( |
|
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..." |
|
) |
|
for req in tqdm(requests, desc="Checking cached requests"): |
|
hsh = hash_args(attr, req.args) |
|
if attr == "generate_until" and req.args[1].get("do_sample", False): |
|
|
|
|
|
if not warned: |
|
eval_logger.warning( |
|
f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests." |
|
) |
|
warned = True |
|
res.append(None) |
|
remaining_reqs.append(req) |
|
elif hsh in self.dbdict: |
|
ob = self.dbdict[hsh] |
|
|
|
assert ob is not None |
|
|
|
res.append(ob) |
|
else: |
|
res.append(None) |
|
remaining_reqs.append(req) |
|
eval_logger.info( |
|
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}" |
|
) |
|
if remaining_reqs: |
|
|
|
rem_res = getattr(self.lm, attr)(remaining_reqs) |
|
else: |
|
rem_res = [] |
|
|
|
|
|
resptr = 0 |
|
for req, r in zip(remaining_reqs, rem_res): |
|
while res[resptr] is not None: |
|
resptr += 1 |
|
|
|
res[resptr] = r |
|
|
|
|
|
hsh = hash_args(attr, req.args) |
|
self.dbdict[hsh] = r |
|
self.dbdict.commit() |
|
|
|
return res |
|
|
|
return fn |
|
|
|
def get_cache_hook(self): |
|
return CacheHook(self) |
|
|
|
|
|
class TemplateLM(LM): |
|
""" |
|
A class acting as intermediary between the LM base class |
|
and boilerplate often included in other LM subclasses. |
|
""" |
|
|
|
tokenizer = None |
|
|
|
@property |
|
@abc.abstractmethod |
|
def eot_token_id(self): |
|
pass |
|
|
|
@property |
|
def prefix_token_id(self): |
|
|
|
return self.eot_token_id |
|
|
|
@abc.abstractmethod |
|
def tok_encode(self, string: str, **kwargs) -> List[int]: |
|
""" |
|
Tokenize a string using the model's tokenizer and return a list of token IDs. |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: |
|
pass |
|
|
|
def _encode_pair( |
|
self, context: str, continuation: str |
|
) -> Tuple[List[int], List[int]]: |
|
n_spaces = len(context) - len(context.rstrip()) |
|
if n_spaces > 0: |
|
continuation = context[-n_spaces:] + continuation |
|
context = context[:-n_spaces] |
|
|
|
model_class = getattr(self, "AUTO_MODEL_CLASS", None) |
|
|
|
if model_class == transformers.AutoModelForSeq2SeqLM: |
|
context_enc = self.tok_encode(context) |
|
continuation_enc = self.tok_encode(continuation, add_special_tokens=False) |
|
else: |
|
whole_enc = self.tok_encode(context + continuation) |
|
context_enc = self.tok_encode(context) |
|
|
|
context_enc_len = len(context_enc) |
|
continuation_enc = whole_enc[context_enc_len:] |
|
|
|
return context_enc, continuation_enc |
|
|
|
def loglikelihood( |
|
self, requests, disable_tqdm: bool = False |
|
) -> List[Tuple[float, bool]]: |
|
new_reqs = [] |
|
for context, continuation in [req.args for req in requests]: |
|
if context == "": |
|
|
|
context_enc, continuation_enc = ( |
|
[self.prefix_token_id], |
|
self.tok_encode(continuation), |
|
) |
|
else: |
|
context_enc, continuation_enc = self._encode_pair(context, continuation) |
|
|
|
new_reqs.append(((context, continuation), context_enc, continuation_enc)) |
|
|
|
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm) |
|
|
|
@abc.abstractmethod |
|
def loglikelihood_rolling( |
|
self, requests, disable_tqdm: bool = False |
|
) -> List[float]: |
|
pass |
|
|
|
@abc.abstractmethod |
|
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: |
|
pass |
|
|
|
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: |
|
""" |
|
Set and get the appropriate chat template for the model. |
|
This method sets the tokenizer's chat_template and returns the template string for reproducibility. |
|
|
|
The template selection logic is adapted from the Transformers library's `apply_chat_template` |
|
method in the Tokenizer class. The original implementation can be found at: |
|
https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687 |
|
|
|
This method ensures that the right template is chosen based on the following: |
|
0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string. |
|
1. If the model's tokenizer has multiple templates: |
|
a. Use the specified template if it exists in the dictionary. |
|
b. Use the default template from the list if no specific template is provided. |
|
c. Raise an error if no default template exists and no specific template is provided. |
|
2. If the model's tokenizer has a single template or no template: |
|
a. Use the tokenizer's chat template if available. |
|
b. Fall back to the default chat template if no tokenizer chat template exists. |
|
|
|
Args: |
|
chat_template (Union[bool, str]): Specifies the chat template to use. |
|
- If False or None, no template is applied. |
|
- If True, the default or only available template is used. |
|
- If a string, the template with the matching name is used. |
|
|
|
Returns: |
|
Optional[str]: The selected chat template, or None if no template is applied. |
|
""" |
|
if self.tokenizer is None: |
|
return "" |
|
|
|
if chat_template is False or chat_template is None: |
|
eval_logger.warning( |
|
"model.chat_template was called with the chat_template set to False or None. " |
|
"Therefore no chat template will be applied. Make sure this is an intended behavior." |
|
) |
|
return None |
|
|
|
|
|
if isinstance(chat_template, bool): |
|
chat_template = None |
|
using_default_template = False |
|
|
|
|
|
try: |
|
template = ( |
|
self.tokenizer.chat_template or self.tokenizer.default_chat_template |
|
) |
|
except AttributeError: |
|
return None |
|
|
|
if isinstance(template, dict): |
|
using_default_dict = self.tokenizer.chat_template is None |
|
|
|
if chat_template is not None: |
|
if chat_template in template: |
|
selected_template = template[chat_template] |
|
if using_default_dict: |
|
using_default_template = True |
|
else: |
|
raise ValueError( |
|
f"The specified chat template '{chat_template}' is not available. " |
|
f"Available template names are {sorted(template.keys())}." |
|
) |
|
else: |
|
|
|
if "default" in template: |
|
selected_template = template["default"] |
|
using_default_template = True |
|
else: |
|
raise ValueError( |
|
"This model has multiple chat templates with no default specified! Please either pass a chat " |
|
"template or the name of the template you wish to use to the `chat_template` argument. Available " |
|
f"template names are {sorted(template.keys())}." |
|
) |
|
|
|
|
|
else: |
|
|
|
if isinstance(chat_template, str): |
|
eval_logger.warning( |
|
"Chat template name provided, but the tokenizer's chat template is not a dictionary. " |
|
"Using the tokenizer's chat template or the default template instead." |
|
) |
|
if self.tokenizer.chat_template is not None: |
|
selected_template = self.tokenizer.chat_template |
|
else: |
|
selected_template = self.tokenizer.default_chat_template |
|
using_default_template = True |
|
|
|
if using_default_template: |
|
eval_logger.warning( |
|
"No chat template is set for this tokenizer, falling back to a default class-level template. This is " |
|
"very error-prone, because models are often trained with templates different from the class default! " |
|
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " |
|
"point any code depending on them will stop working. We recommend setting a valid chat template before " |
|
"then to ensure that this model continues working without issues." |
|
) |
|
|
|
return selected_template |
|
|