|
import abc |
|
import asyncio |
|
import copy |
|
import itertools |
|
import json |
|
from functools import cached_property |
|
from typing import ( |
|
Any, |
|
Awaitable, |
|
Callable, |
|
Dict, |
|
Iterable, |
|
List, |
|
Literal, |
|
NamedTuple, |
|
Optional, |
|
Tuple, |
|
Union, |
|
) |
|
|
|
|
|
try: |
|
import requests |
|
from aiohttp import ClientSession, ClientTimeout, TCPConnector |
|
from tenacity import RetryError, retry, stop_after_attempt, wait_exponential |
|
from tqdm import tqdm |
|
from tqdm.asyncio import tqdm_asyncio |
|
except ModuleNotFoundError: |
|
pass |
|
|
|
|
|
from importlib.util import find_spec |
|
|
|
from lm_eval import utils |
|
from lm_eval.api.instance import Instance |
|
from lm_eval.api.model import TemplateLM |
|
from lm_eval.models.utils import Collator, chunks, configure_pad_token |
|
|
|
|
|
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]] |
|
|
|
|
|
|
|
class JsonChatStr(NamedTuple): |
|
prompt: str |
|
|
|
def encode(self, encoding): |
|
return self.prompt.encode(encoding) |
|
|
|
|
|
eval_logger = utils.eval_logger |
|
|
|
|
|
class TemplateAPI(TemplateLM): |
|
def __init__( |
|
self, |
|
model: str = None, |
|
pretrained: str = None, |
|
base_url: str = None, |
|
tokenizer: Optional[str] = None, |
|
|
|
|
|
|
|
tokenizer_backend: Optional[ |
|
Literal["tiktoken", "huggingface", "None", "none"] |
|
] = "huggingface", |
|
truncate: bool = False, |
|
|
|
num_concurrent: int = 1, |
|
max_retries: int = 3, |
|
max_gen_toks: int = 256, |
|
batch_size: Union[str, int] = 1, |
|
seed: int = 1234, |
|
max_length: Optional[int] = 2048, |
|
add_bos_token: bool = False, |
|
custom_prefix_token_id: int = None, |
|
|
|
tokenized_requests: bool = True, |
|
trust_remote_code: bool = False, |
|
revision: Optional[str] = "main", |
|
use_fast_tokenizer: bool = True, |
|
verify_certificate: bool = True, |
|
eos_string: str = None, |
|
|
|
timeout: int = 300, |
|
**kwargs, |
|
) -> None: |
|
super().__init__() |
|
missing_packages = [ |
|
pkg |
|
for pkg in ["aiohttp", "tqdm", "tenacity", "requests"] |
|
if find_spec(pkg) is None |
|
] |
|
if missing_packages: |
|
raise ModuleNotFoundError( |
|
f"Attempted to use an API model, but the required packages {missing_packages} are not installed. " |
|
'Please install these via `pip install lm-eval[api]` or `pip install -e ."[api]"`' |
|
) |
|
self.model = model or pretrained |
|
self.base_url = base_url |
|
self.tokenizer = tokenizer |
|
if not isinstance(batch_size, int) and "auto" in batch_size: |
|
eval_logger.warning( |
|
"Automatic batch size is not supported for API models. Defaulting to batch size 1." |
|
) |
|
elif int(batch_size) > 1: |
|
eval_logger.warning( |
|
"Batch size > 1 detected. Ensure your API supports batched requests with varying total sequence lengths." |
|
) |
|
self._batch_size = int(batch_size) if batch_size != "auto" else 1 |
|
self._truncate = truncate |
|
self._max_gen_toks = int(max_gen_toks) |
|
self._seed = int(seed) |
|
|
|
eval_logger.info(f"Using max length {max_length} - 1") |
|
self.max_length = max_length - 1 |
|
if int(num_concurrent) <= 1: |
|
eval_logger.info( |
|
"Concurrent requests are disabled. To enable concurrent requests, set `num_concurrent` > 1." |
|
) |
|
self._concurrent = int(num_concurrent) |
|
self.tokenizer_backend = ( |
|
None if tokenizer_backend in ("None", "none") else tokenizer_backend |
|
) |
|
self.add_bos_token = add_bos_token |
|
self.custom_prefix_token_id = custom_prefix_token_id |
|
self.tokenized_requests = tokenized_requests |
|
self.max_retries = int(max_retries) |
|
self.verify_certificate = verify_certificate |
|
self._eos_string = eos_string |
|
self.timeout = int(timeout) |
|
|
|
eval_logger.info(f"Using tokenizer {self.tokenizer_backend}") |
|
if self.tokenizer_backend is None: |
|
self.tokenizer = None |
|
self.tokenized_requests = False |
|
else: |
|
if self.tokenizer is None: |
|
if self.tokenizer_backend == "huggingface": |
|
import transformers |
|
|
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
self.tokenizer if self.tokenizer else self.model, |
|
trust_remote_code=trust_remote_code, |
|
revision=revision, |
|
use_fast=use_fast_tokenizer, |
|
) |
|
|
|
self.tokenizer = configure_pad_token(self.tokenizer) |
|
elif self.tokenizer_backend == "tiktoken": |
|
try: |
|
import tiktoken |
|
|
|
self.tokenizer = tiktoken.encoding_for_model(self.model) |
|
except ModuleNotFoundError as e: |
|
raise ModuleNotFoundError( |
|
"Attempted to use 'openai' LM type, but the package `tiktoken` is not installed. " |
|
"Please install it via `pip install lm-eval[api]` or `pip install -e .[api]`." |
|
) from e |
|
if "openai" not in self.base_url: |
|
eval_logger.warning( |
|
f"Passed `base_url={self.base_url}` but using (OpenAI) Tiktoken tokenizer backend. " |
|
"Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken." |
|
) |
|
else: |
|
import transformers |
|
|
|
assert isinstance(tokenizer, str), "tokenizer must be a string" |
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
tokenizer, |
|
trust_remote_code=trust_remote_code, |
|
revision=revision, |
|
use_fast=use_fast_tokenizer, |
|
) |
|
|
|
@abc.abstractmethod |
|
def _create_payload( |
|
self, |
|
messages: Union[List[List[int]], List[dict], List[str], str], |
|
*, |
|
generate: bool = True, |
|
gen_kwargs: Optional[dict] = None, |
|
seed: int = 1234, |
|
eos: str = None, |
|
**kwargs, |
|
) -> dict: |
|
"""This method is responsible for creating the json payload that will be sent to the API.""" |
|
raise NotImplementedError |
|
|
|
def create_message( |
|
self, |
|
messages: Union[List[List[int]], List[str], List[JsonChatStr]], |
|
generate=False, |
|
) -> Union[List[List[int]], List[dict], List[str], str]: |
|
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests""" |
|
if isinstance(messages[0], JsonChatStr): |
|
|
|
assert self._batch_size == 1, ( |
|
"non-tokenized chat requests are only supported with batch_size=1" |
|
) |
|
|
|
return json.loads(messages[0].prompt) |
|
|
|
if not self.tokenized_requests: |
|
|
|
if isinstance(messages[0][0], int): |
|
|
|
|
|
messages = self.decode_batch(messages) |
|
if self._batch_size <= 1: |
|
|
|
return messages[0] |
|
else: |
|
|
|
return messages |
|
|
|
|
|
return messages |
|
|
|
@staticmethod |
|
@abc.abstractmethod |
|
def parse_logprobs( |
|
outputs: Union[Any, List[Any]], |
|
tokens: List[List[int]] = None, |
|
ctxlen: List[int] = None, |
|
**kwargs, |
|
) -> List[Tuple[float, bool]]: |
|
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples""" |
|
raise NotImplementedError |
|
|
|
@staticmethod |
|
@abc.abstractmethod |
|
def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]: |
|
"""Method used to parse the generations from the (batched) API response. This method should return a list of str""" |
|
raise NotImplementedError |
|
|
|
@cached_property |
|
def api_key(self) -> str: |
|
"""Override this property to return the API key for the API request.""" |
|
return "" |
|
|
|
@cached_property |
|
def header(self) -> dict: |
|
"""Override this property to return the headers for the API request.""" |
|
return {"Authorization": f"Bearer {self.api_key}"} |
|
|
|
@property |
|
def tokenizer_name(self) -> str: |
|
"""Must be defined for LM subclasses which implement Chat Templating. |
|
Should return the name of the tokenizer or chat template used. |
|
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used. |
|
""" |
|
return "" |
|
|
|
def apply_chat_template( |
|
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True |
|
) -> Union[str, JsonChatStr]: |
|
"""Applies a chat template to a list of chat history between user and model.""" |
|
if self.tokenizer_backend == "huggingface" and self.tokenized_requests: |
|
return self.tokenizer.apply_chat_template( |
|
chat_history, |
|
tokenize=False, |
|
add_generation_prompt=add_generation_prompt, |
|
continue_final_message=not add_generation_prompt, |
|
) |
|
else: |
|
|
|
return JsonChatStr(json.dumps(chat_history)) |
|
|
|
@cached_property |
|
def eot_token_id(self) -> Optional[int]: |
|
if self.tokenizer is None: |
|
return None |
|
else: |
|
if self.tokenizer_backend == "huggingface": |
|
return self.tokenizer.eos_token_id |
|
elif self.tokenizer_backend == "tiktoken": |
|
return self.tokenizer.eot_token |
|
|
|
@cached_property |
|
def eos_string(self) -> Optional[str]: |
|
if self._eos_string: |
|
return self._eos_string |
|
elif self.tokenizer is not None: |
|
if self.tokenizer_backend == "huggingface": |
|
return self.tokenizer.eos_token |
|
elif self.tokenizer_backend == "tiktoken": |
|
return self.tokenizer.decode([self.tokenizer.eot_token]) |
|
else: |
|
eval_logger.warning( |
|
"Cannot determine EOS string to pass to stop sequence. Manually set by passing `eos_string` to model_args." |
|
) |
|
return None |
|
|
|
@cached_property |
|
def prefix_token_id(self) -> Optional[int]: |
|
if self.tokenizer is None: |
|
return None |
|
else: |
|
if self.custom_prefix_token_id is not None: |
|
return self.custom_prefix_token_id |
|
if self.tokenizer_backend == "huggingface": |
|
if self.tokenizer.bos_token_id is not None: |
|
return self.tokenizer.bos_token_id |
|
return self.tokenizer.eos_token_id |
|
else: |
|
return self.tokenizer.eot_token |
|
|
|
def tok_encode( |
|
self, |
|
string: str, |
|
left_truncate_len: int = None, |
|
add_special_tokens: bool = False, |
|
truncation: bool = False, |
|
**kwargs, |
|
) -> Union[List[List[int]], List[int], List[str]]: |
|
if self.tokenizer_backend is None: |
|
return [string] |
|
elif self.tokenizer_backend == "huggingface": |
|
|
|
if not add_special_tokens: |
|
add_special_tokens = False or self.add_bos_token |
|
encoding: Union[List[List[int]], List[int]] = self.tokenizer( |
|
string, |
|
add_special_tokens=add_special_tokens, |
|
truncation=truncation, |
|
return_attention_mask=False, |
|
).input_ids |
|
|
|
|
|
if left_truncate_len: |
|
if not isinstance(string, str): |
|
encoding = [enc[-left_truncate_len:] for enc in encoding] |
|
else: |
|
encoding = encoding[-left_truncate_len:] |
|
|
|
return encoding |
|
|
|
else: |
|
try: |
|
encoding = self.tokenizer.encode(string) |
|
except Exception: |
|
encoding = self.tokenizer.encode_batch(string) |
|
return encoding |
|
|
|
def decode_batch(self, tokens: List[List[int]]) -> List[str]: |
|
if self.tokenizer_backend == "huggingface": |
|
return self.tokenizer.batch_decode(tokens) |
|
elif self.tokenizer_backend == "tiktoken": |
|
return self.tokenizer.decode_batch(tokens) |
|
|
|
def model_call( |
|
self, |
|
messages: Union[List[List[int]], List[str], List[JsonChatStr]], |
|
*, |
|
generate: bool = True, |
|
gen_kwargs: Optional[Dict] = None, |
|
**kwargs, |
|
) -> Optional[dict]: |
|
|
|
gen_kwargs = copy.deepcopy(gen_kwargs) |
|
try: |
|
response = requests.post( |
|
self.base_url, |
|
json=self._create_payload( |
|
self.create_message(messages), |
|
generate=generate, |
|
gen_kwargs=gen_kwargs, |
|
seed=self._seed, |
|
eos=self.eos_string, |
|
**kwargs, |
|
), |
|
headers=self.header, |
|
verify=self.verify_certificate, |
|
) |
|
if not response.ok: |
|
eval_logger.warning( |
|
f"API request failed with error message: {response.text}. Retrying..." |
|
) |
|
response.raise_for_status() |
|
return response.json() |
|
except RetryError: |
|
eval_logger.error( |
|
"API request failed after multiple retries. Please check the API status." |
|
) |
|
return None |
|
|
|
async def amodel_call( |
|
self, |
|
session: ClientSession, |
|
messages: Union[List[List[int]], List[str], List[JsonChatStr]], |
|
*, |
|
generate: bool = True, |
|
cache_keys: list = None, |
|
ctxlens: Optional[List[int]] = None, |
|
gen_kwargs: Optional[Dict] = None, |
|
**kwargs, |
|
) -> Union[List[str], List[Tuple[float, bool]], None]: |
|
|
|
gen_kwargs = copy.deepcopy(gen_kwargs) |
|
payload = self._create_payload( |
|
self.create_message(messages), |
|
generate=generate, |
|
gen_kwargs=gen_kwargs, |
|
seed=self._seed, |
|
**kwargs, |
|
) |
|
cache_method = "generate_until" if generate else "loglikelihood" |
|
try: |
|
async with session.post( |
|
self.base_url, |
|
json=payload, |
|
headers=self.header, |
|
) as response: |
|
if not response.ok: |
|
error_text = await response.text() |
|
eval_logger.warning( |
|
f"API request failed with error message: {error_text}. Retrying..." |
|
) |
|
|
|
response.raise_for_status() |
|
outputs = await response.json() |
|
answers = ( |
|
self.parse_generations( |
|
outputs=outputs, |
|
) |
|
if generate |
|
else self.parse_logprobs( |
|
outputs=outputs, |
|
tokens=messages, |
|
ctxlens=ctxlens, |
|
) |
|
) |
|
if cache_keys: |
|
for res, cache in zip(answers, cache_keys): |
|
self.cache_hook.add_partial(cache_method, cache, res) |
|
return answers |
|
|
|
except RetryError: |
|
eval_logger.error( |
|
"API request failed after multiple retries. Please check the API status." |
|
) |
|
return None |
|
|
|
def batch_loglikelihood_requests( |
|
self, chunks: Iterable[List[LogLikelihoodInputs]] |
|
) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]: |
|
inputs = [] |
|
ctxlens = [] |
|
cache_keys = [] |
|
for chunk in chunks: |
|
for cache_key, context_enc, continuation_enc in chunk: |
|
|
|
inp = (context_enc + continuation_enc)[-self.max_length :] |
|
if len(inp) < len(context_enc + continuation_enc): |
|
eval_logger.warning( |
|
f"Context length ({len(context_enc)}) + continuation length ({len(continuation_enc)}) > max_length ({self.max_length}). Left truncating context." |
|
) |
|
ctxlen = len(context_enc) - max( |
|
0, len(context_enc) + len(continuation_enc) - self.max_length |
|
) |
|
|
|
inputs.append(inp) |
|
ctxlens.append(ctxlen) |
|
cache_keys.append(cache_key) |
|
return inputs, ctxlens, cache_keys |
|
|
|
async def get_batched_requests( |
|
self, |
|
requests: list, |
|
cache_keys: list, |
|
*, |
|
generate: bool = True, |
|
ctxlens: List[int] = None, |
|
**kwargs, |
|
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]: |
|
ctxlens = ctxlens if ctxlens else [None] * len(requests) |
|
conn = TCPConnector(limit=self._concurrent) |
|
async with ClientSession( |
|
connector=conn, timeout=ClientTimeout(total=self.timeout) |
|
) as session: |
|
retry_: Callable[..., Awaitable[Any]] = retry( |
|
stop=stop_after_attempt(self.max_retries), |
|
wait=wait_exponential(multiplier=0.5, min=1, max=10), |
|
reraise=True, |
|
)(self.amodel_call) |
|
|
|
tasks = [ |
|
asyncio.create_task( |
|
retry_( |
|
session=session, |
|
messages=message, |
|
cache_keys=cache_key, |
|
generate=generate, |
|
ctxlens=ctxlen, |
|
**kwargs, |
|
) |
|
) |
|
for message, cache_key, ctxlen in zip( |
|
chunks(requests, n=self._batch_size), |
|
chunks(cache_keys, n=self._batch_size), |
|
chunks(ctxlens, n=self._batch_size), |
|
) |
|
] |
|
|
|
return await tqdm_asyncio.gather(*tasks, desc="Requesting API") |
|
|
|
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: |
|
assert self.tokenizer is not None, ( |
|
"Tokenizer is required for loglikelihood tasks to compute context lengths." |
|
) |
|
res = [] |
|
|
|
def _collate(req: LogLikelihoodInputs): |
|
"""Defines the key for the sorted method""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toks = req[1] + req[2] |
|
return -len(toks), tuple(toks) |
|
|
|
re_ord = Collator( |
|
requests, |
|
sort_fn=_collate, |
|
group_by=None, |
|
) |
|
|
|
chunked = re_ord.get_batched(n=self._batch_size if self._concurrent <= 1 else 0) |
|
if self._concurrent <= 1: |
|
pbar = tqdm(desc="Requesting API", total=len(requests)) |
|
for chunk in chunked: |
|
inputs, ctxlens, cache_keys = self.batch_loglikelihood_requests([chunk]) |
|
|
|
outputs = retry( |
|
stop=stop_after_attempt(self.max_retries), |
|
wait=wait_exponential(multiplier=0.5, min=1, max=10), |
|
reraise=True, |
|
)(self.model_call)(messages=inputs, generate=False) |
|
if isinstance(outputs, dict): |
|
outputs = [outputs] |
|
for answer_, cache_key in zip( |
|
self.parse_logprobs( |
|
outputs=outputs, tokens=inputs, ctxlens=ctxlens |
|
), |
|
cache_keys, |
|
): |
|
if answer_ is not None: |
|
res.append(answer_) |
|
|
|
if cache_key is not None: |
|
self.cache_hook.add_partial( |
|
"loglikelihood", cache_key, answer_ |
|
) |
|
pbar.update(1) |
|
else: |
|
inputs, ctxlens, cache_keys = self.batch_loglikelihood_requests(chunked) |
|
res = itertools.chain.from_iterable( |
|
asyncio.run( |
|
self.get_batched_requests( |
|
inputs, cache_keys, generate=False, ctxlens=ctxlens |
|
) |
|
) |
|
) |
|
|
|
return re_ord.get_original(res) |
|
|
|
def generate_until( |
|
self, requests: List[Instance], disable_tqdm: bool = False |
|
) -> List[str]: |
|
res = [] |
|
|
|
def _collate_gen(_requests): |
|
|
|
return -len(_requests[0]) |
|
|
|
|
|
requests, all_gen_kwargs = zip(*(req.args for req in requests)) |
|
if self.tokenized_requests: |
|
encodings_list = self.tok_encode( |
|
requests, add_special_tokens=self.add_bos_token |
|
) |
|
else: |
|
encodings_list = [None] * len(requests) |
|
requests = [ |
|
(a, b, c) for a, b, c in zip(requests, all_gen_kwargs, encodings_list) |
|
] |
|
|
|
re_ord = Collator( |
|
requests, |
|
sort_fn=_collate_gen, |
|
group_by="gen_kwargs", |
|
) |
|
chunked = re_ord.get_batched( |
|
n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None |
|
) |
|
if self._concurrent <= 1: |
|
pbar = tqdm(desc="Requesting API", total=len(requests)) |
|
for chunk in chunked: |
|
contexts, all_gen_kwargs, encodings_list = zip(*chunk) |
|
if self.tokenized_requests: |
|
max_gen_toks = all_gen_kwargs[0].get( |
|
"max_gen_toks", self._max_gen_toks |
|
) |
|
max_context_len = self.max_length - max_gen_toks |
|
|
|
encodings_list = [x[-max_context_len:] for x in encodings_list] |
|
|
|
if any( |
|
len(x) + max_gen_toks > self.max_length for x in encodings_list |
|
): |
|
eval_logger.warning( |
|
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks: ({max_gen_toks}). They were left truncated." |
|
) |
|
else: |
|
eval_logger.info( |
|
"Tokenized requests are disabled. Context + generation length is not checked." |
|
) |
|
req = encodings_list if self.tokenized_requests else contexts |
|
outputs = retry( |
|
stop=stop_after_attempt(self.max_retries), |
|
wait=wait_exponential(multiplier=0.5, min=1, max=10), |
|
reraise=True, |
|
)(self.model_call)( |
|
messages=req, |
|
generate=True, |
|
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]), |
|
) |
|
for generated_text, context in zip( |
|
self.parse_generations( |
|
outputs=outputs, |
|
contexts=contexts, |
|
), |
|
contexts, |
|
): |
|
if generated_text is not None: |
|
res.append(generated_text) |
|
|
|
|
|
if context is not None: |
|
self.cache_hook.add_partial( |
|
"generate_until", |
|
(context, all_gen_kwargs[0]), |
|
generated_text, |
|
) |
|
pbar.update(1) |
|
else: |
|
for chunk in chunked: |
|
contexts, all_gen_kwargs, encodings_list = zip(*chunk) |
|
if self.tokenized_requests: |
|
max_gen_toks = all_gen_kwargs[0].get( |
|
"max_gen_toks", self._max_gen_toks |
|
) |
|
max_context_len = self.max_length - max_gen_toks |
|
|
|
encodings_list = [x[-max_context_len:] for x in encodings_list] |
|
|
|
if any( |
|
len(x) + max_gen_toks > self.max_length for x in encodings_list |
|
): |
|
eval_logger.warning( |
|
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks ({max_gen_toks}). They were left truncated." |
|
) |
|
else: |
|
eval_logger.info( |
|
"Tokenized requests are disabled. Context + generation length is not checked." |
|
) |
|
req = encodings_list if self.tokenized_requests else contexts |
|
results = itertools.chain.from_iterable( |
|
asyncio.run( |
|
self.get_batched_requests( |
|
req, |
|
cache_keys=[(ctx, all_gen_kwargs[0]) for ctx in contexts], |
|
generate=True, |
|
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]), |
|
) |
|
) |
|
) |
|
res.extend(results) |
|
|
|
return re_ord.get_original(res) |
|
|
|
def loglikelihood_rolling( |
|
self, requests: List[Instance], disable_tqdm: bool = False |
|
) -> List[float]: |
|
loglikelihoods = [] |
|
|
|
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm): |
|
rolling_token_windows = list( |
|
map( |
|
utils.make_disjoint_window, |
|
utils.get_rolling_token_windows( |
|
token_list=self.tok_encode(string), |
|
prefix_token=self.prefix_token_id, |
|
|
|
max_seq_len=self.max_length - 1, |
|
context_len=1, |
|
), |
|
) |
|
) |
|
|
|
|
|
rolling_token_windows = [(None,) + x for x in rolling_token_windows] |
|
|
|
string_nll = self._loglikelihood_tokens( |
|
rolling_token_windows, |
|
disable_tqdm=True, |
|
) |
|
|
|
|
|
string_nll = [x[0] for x in string_nll] |
|
|
|
string_nll = sum(string_nll) |
|
loglikelihoods.append(string_nll) |
|
|
|
|
|
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll) |
|
return loglikelihoods |
|
|