|
import copy |
|
from importlib.metadata import version |
|
from importlib.util import find_spec |
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union |
|
|
|
from more_itertools import distribute |
|
from packaging.version import parse as parse_version |
|
from tqdm import tqdm |
|
|
|
from lm_eval.api.instance import Instance |
|
from lm_eval.api.model import TemplateLM |
|
from lm_eval.api.registry import register_model |
|
from lm_eval.models.utils import ( |
|
Collator, |
|
configure_pad_token, |
|
handle_stop_sequences, |
|
undistribute, |
|
) |
|
from lm_eval.utils import ( |
|
eval_logger, |
|
get_rolling_token_windows, |
|
make_disjoint_window, |
|
) |
|
|
|
|
|
try: |
|
import ray |
|
from vllm import LLM, SamplingParams |
|
from vllm.lora.request import LoRARequest |
|
from vllm.transformers_utils.tokenizer import get_tokenizer |
|
except ModuleNotFoundError: |
|
pass |
|
|
|
if TYPE_CHECKING: |
|
pass |
|
|
|
eval_logger = eval_logger |
|
|
|
|
|
@register_model("vllm") |
|
class VLLM(TemplateLM): |
|
_DEFAULT_MAX_LENGTH = 2048 |
|
|
|
def __init__( |
|
self, |
|
pretrained: str, |
|
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto", |
|
revision: Optional[str] = None, |
|
trust_remote_code: Optional[bool] = False, |
|
tokenizer: Optional[str] = None, |
|
tokenizer_mode: Literal["auto", "slow"] = "auto", |
|
tokenizer_revision: Optional[str] = None, |
|
add_bos_token: Optional[bool] = False, |
|
prefix_token_id: Optional[int] = None, |
|
tensor_parallel_size: int = 1, |
|
quantization: Optional[str] = None, |
|
max_gen_toks: int = 256, |
|
swap_space: int = 4, |
|
batch_size: Union[str, int] = 1, |
|
max_batch_size=None, |
|
max_length: int = None, |
|
max_model_len: int = None, |
|
seed: int = 1234, |
|
gpu_memory_utilization: float = 0.9, |
|
device: str = "cuda", |
|
data_parallel_size: int = 1, |
|
lora_local_path: str = None, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
if not find_spec("vllm"): |
|
raise ModuleNotFoundError( |
|
"attempted to use 'vllm' LM type, but package `vllm` is not installed. " |
|
"Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" |
|
) |
|
|
|
assert "cuda" in device or device is None, "vLLM only supports CUDA" |
|
assert max_length is None or max_model_len is None, ( |
|
"Either max_length or max_model_len may be provided, but not both" |
|
) |
|
|
|
self._max_length = max_model_len if max_model_len is not None else max_length |
|
self.tensor_parallel_size = int(tensor_parallel_size) |
|
self.data_parallel_size = int(data_parallel_size) |
|
self.model_args = { |
|
"model": pretrained, |
|
"gpu_memory_utilization": float(gpu_memory_utilization), |
|
"revision": revision, |
|
"dtype": dtype, |
|
"tokenizer": tokenizer, |
|
"tokenizer_mode": tokenizer_mode, |
|
"tokenizer_revision": tokenizer_revision, |
|
"trust_remote_code": trust_remote_code, |
|
"tensor_parallel_size": int(tensor_parallel_size), |
|
"max_model_len": int(self._max_length) if self._max_length else None, |
|
"swap_space": int(swap_space), |
|
"quantization": quantization, |
|
"seed": int(seed), |
|
} |
|
self.model_args.update(kwargs) |
|
self.batch_size = ( |
|
"auto" |
|
if isinstance(batch_size, str) and "auto" in batch_size |
|
else int(batch_size) |
|
) |
|
if self.data_parallel_size <= 1: |
|
self.model = LLM(**self.model_args) |
|
else: |
|
eval_logger.warning( |
|
"You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached." |
|
) |
|
self.model_args["worker_use_ray"] = True |
|
self.batch_size = "auto" |
|
eval_logger.info("Manual batching is not compatible with data parallelism.") |
|
|
|
from transformers import AutoConfig |
|
|
|
self._config = AutoConfig.from_pretrained( |
|
pretrained, trust_remote_code=trust_remote_code, revision=revision |
|
) |
|
self.tokenizer = get_tokenizer( |
|
tokenizer if tokenizer else pretrained, |
|
tokenizer_mode=tokenizer_mode, |
|
trust_remote_code=trust_remote_code, |
|
revision=tokenizer_revision, |
|
) |
|
self.tokenizer = configure_pad_token(self.tokenizer) |
|
self.add_bos_token = add_bos_token |
|
if "gemma" in pretrained.lower(): |
|
self.add_bos_token = True |
|
eval_logger.info( |
|
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it." |
|
) |
|
|
|
self.custom_prefix_token_id = prefix_token_id |
|
if prefix_token_id is not None: |
|
eval_logger.info( |
|
f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}" |
|
) |
|
|
|
self._max_gen_toks = max_gen_toks |
|
|
|
if lora_local_path is not None: |
|
assert parse_version(version("vllm")) > parse_version("0.3.0"), ( |
|
"lora adapters only compatible with vllm > v0.3.0." |
|
) |
|
self.lora_request = LoRARequest("finetuned", 1, lora_local_path) |
|
else: |
|
self.lora_request = None |
|
|
|
@property |
|
def eot_token_id(self): |
|
|
|
return self.tokenizer.eos_token_id |
|
|
|
@property |
|
def prefix_token_id(self): |
|
|
|
if self.custom_prefix_token_id is not None: |
|
return self.custom_prefix_token_id |
|
if self.tokenizer.bos_token_id is not None: |
|
return self.tokenizer.bos_token_id |
|
return self.tokenizer.eos_token_id |
|
|
|
@property |
|
def max_length(self): |
|
if self._max_length: |
|
return self._max_length |
|
if self.data_parallel_size <= 1: |
|
return self.model.llm_engine.model_config.max_model_len |
|
else: |
|
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") |
|
for attr in seqlen_config_attrs: |
|
if hasattr(self._config, attr): |
|
return getattr(self._config, attr) |
|
if hasattr(self.tokenizer, "model_max_length"): |
|
if self.tokenizer.model_max_length == 1000000000000000019884624838656: |
|
return self._DEFAULT_MAX_LENGTH |
|
return self.tokenizer.model_max_length |
|
return self._DEFAULT_MAX_LENGTH |
|
|
|
@property |
|
def max_gen_toks(self): |
|
return self._max_gen_toks |
|
|
|
def apply_chat_template( |
|
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True |
|
) -> str: |
|
""" |
|
Method to apply a chat template to a list of chat history between user and model. |
|
""" |
|
chat_templated = self.tokenizer.apply_chat_template( |
|
chat_history, |
|
tokenize=False, |
|
add_generation_prompt=add_generation_prompt, |
|
continue_final_message=not add_generation_prompt, |
|
) |
|
|
|
return chat_templated |
|
|
|
@property |
|
def tokenizer_name(self) -> str: |
|
return self.tokenizer.name_or_path.replace("/", "__") |
|
|
|
def tok_encode( |
|
self, |
|
string: Union[str, List[str]], |
|
left_truncate_len: int = None, |
|
add_special_tokens: bool = False, |
|
truncation: bool = False, |
|
) -> Union[List[int], List[List[int]]]: |
|
if not add_special_tokens: |
|
add_special_tokens = False or self.add_bos_token |
|
encoding: Union[List[List[int]], List[int]] = self.tokenizer( |
|
string, |
|
add_special_tokens=add_special_tokens, |
|
truncation=truncation, |
|
return_attention_mask=False, |
|
).input_ids |
|
|
|
|
|
if left_truncate_len: |
|
if not isinstance(string, str): |
|
encoding = [enc[-left_truncate_len:] for enc in encoding] |
|
else: |
|
encoding = encoding[-left_truncate_len:] |
|
|
|
return encoding |
|
|
|
def _model_generate( |
|
self, |
|
requests: List[List[int]] = None, |
|
generate: bool = False, |
|
max_tokens: int = None, |
|
stop: Optional[List[str]] = None, |
|
**kwargs, |
|
): |
|
if generate: |
|
kwargs = self.modify_gen_kwargs(kwargs) |
|
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs) |
|
else: |
|
sampling_params = SamplingParams( |
|
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False |
|
) |
|
if self.data_parallel_size > 1: |
|
|
|
|
|
|
|
|
|
|
|
@ray.remote |
|
def run_inference_one_model( |
|
model_args: dict, |
|
sampling_params, |
|
requests: List[List[int]], |
|
lora_request: LoRARequest, |
|
): |
|
llm = LLM(**model_args) |
|
return llm.generate( |
|
prompt_token_ids=requests, |
|
sampling_params=sampling_params, |
|
lora_request=lora_request, |
|
) |
|
|
|
|
|
|
|
requests = [list(x) for x in distribute(self.data_parallel_size, requests)] |
|
inputs = ( |
|
(self.model_args, sampling_params, req, self.lora_request) |
|
for req in requests |
|
) |
|
object_refs = [run_inference_one_model.remote(*x) for x in inputs] |
|
results = ray.get(object_refs) |
|
|
|
ray.shutdown() |
|
|
|
return undistribute(results) |
|
|
|
outputs = self.model.generate( |
|
prompt_token_ids=requests, |
|
sampling_params=sampling_params, |
|
use_tqdm=True if self.batch_size == "auto" else False, |
|
lora_request=self.lora_request, |
|
) |
|
return outputs |
|
|
|
def loglikelihood_rolling( |
|
self, requests: List[Instance], disable_tqdm: bool = False |
|
) -> List[float]: |
|
adaptive_batch_size = None |
|
if self.batch_size == "auto": |
|
adaptive_batch_size = len(requests) |
|
|
|
|
|
all_windows = [] |
|
request_window_counts = [] |
|
|
|
for req_idx, (string,) in enumerate( |
|
tqdm( |
|
[req.args for req in requests], |
|
disable=(disable_tqdm or (self.rank != 0)), |
|
) |
|
): |
|
rolling_token_windows: List[Tuple[List[int], List[int]]] = list( |
|
map( |
|
make_disjoint_window, |
|
get_rolling_token_windows( |
|
token_list=self.tok_encode(string), |
|
prefix_token=self.prefix_token_id, |
|
|
|
max_seq_len=self.max_length - 1, |
|
context_len=1, |
|
), |
|
) |
|
) |
|
|
|
|
|
windows = [(None,) + x for x in rolling_token_windows] |
|
|
|
|
|
all_windows.extend((req_idx, window) for window in windows) |
|
request_window_counts.append(len(windows)) |
|
|
|
all_nlls = [] |
|
batch_size = adaptive_batch_size or int(self.batch_size) |
|
for i in range(0, len(all_windows), batch_size): |
|
batch = all_windows[i : i + batch_size] |
|
|
|
batch_indices, batch_windows = zip(*batch) |
|
|
|
batch_nlls = self._loglikelihood_tokens( |
|
requests=batch_windows, |
|
disable_tqdm=False, |
|
) |
|
|
|
all_nlls.extend(zip(batch_indices, batch_nlls)) |
|
|
|
|
|
loglikelihoods = [] |
|
current_idx = 0 |
|
for window_count in request_window_counts: |
|
|
|
request_nlls = all_nlls[current_idx : current_idx + window_count] |
|
|
|
request_total = sum(nll[0] for _, nll in request_nlls) |
|
loglikelihoods.append(request_total) |
|
current_idx += window_count |
|
|
|
string = requests[len(loglikelihoods) - 1].args[0] |
|
self.cache_hook.add_partial( |
|
"loglikelihood_rolling", (string,), request_total |
|
) |
|
|
|
return loglikelihoods |
|
|
|
def generate_until( |
|
self, requests: List[Instance], disable_tqdm: bool = False |
|
) -> List[str]: |
|
res = [] |
|
|
|
|
|
context, all_gen_kwargs = zip(*(req.args for req in requests)) |
|
context_encoding: List[List[int]] = self.tok_encode( |
|
context, add_special_tokens=self.add_bos_token |
|
) |
|
requests = [ |
|
((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs) |
|
] |
|
|
|
def _collate_gen(_requests): |
|
|
|
|
|
|
|
|
|
|
|
|
|
return -len(_requests[0][1]), _requests[0][0] |
|
|
|
|
|
|
|
|
|
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs") |
|
chunks = re_ords.get_batched( |
|
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None |
|
) |
|
|
|
pbar = tqdm( |
|
total=len(requests), |
|
disable=(disable_tqdm or (self.rank != 0)), |
|
desc="Running generate_until requests", |
|
) |
|
|
|
eos = self.tokenizer.decode(self.eot_token_id) |
|
for chunk in chunks: |
|
context_and_encoding, all_gen_kwargs = zip(*chunk) |
|
context, context_encoding = zip(*context_and_encoding) |
|
|
|
|
|
gen_kwargs = all_gen_kwargs[0] |
|
|
|
if isinstance(gen_kwargs, dict): |
|
kwargs = copy.deepcopy(gen_kwargs) |
|
|
|
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) |
|
else: |
|
raise ValueError( |
|
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" |
|
) |
|
if "max_gen_toks" in kwargs.keys(): |
|
max_gen_toks = kwargs.pop("max_gen_toks") |
|
else: |
|
max_gen_toks = self.max_gen_toks |
|
|
|
|
|
|
|
max_ctx_len = self.max_length - max_gen_toks |
|
context_encoding = [x[-max_ctx_len:] for x in context_encoding] |
|
|
|
|
|
cont = self._model_generate( |
|
requests=context_encoding, |
|
generate=True, |
|
max_tokens=max_gen_toks, |
|
stop=until, |
|
**kwargs, |
|
) |
|
|
|
|
|
for output, context in zip(cont, context): |
|
generated_text = output.outputs[0].text |
|
res.append(generated_text) |
|
self.cache_hook.add_partial( |
|
"generate_until", (context, gen_kwargs), generated_text |
|
) |
|
pbar.update(1) |
|
|
|
pbar.close() |
|
|
|
return re_ords.get_original(res) |
|
|
|
def _loglikelihood_tokens( |
|
self, |
|
requests: List[Tuple[Tuple[str, str], List[int], List[int]]], |
|
disable_tqdm: bool = False, |
|
) -> List[Tuple[float, bool]]: |
|
res = [] |
|
|
|
def _collate(x): |
|
toks = x[1] + x[2] |
|
return -len(toks), tuple(toks) |
|
|
|
|
|
re_ord = Collator(requests, sort_fn=_collate) |
|
chunks = re_ord.get_batched( |
|
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None |
|
) |
|
|
|
pbar = tqdm( |
|
total=len(requests), |
|
disable=disable_tqdm, |
|
desc="Running loglikelihood requests", |
|
) |
|
for chunk in chunks: |
|
inputs = [] |
|
ctxlens = [] |
|
for cache_key, context_enc, continuation_enc in chunk: |
|
inp = (context_enc + continuation_enc)[-(self.max_length) :] |
|
ctxlen = len(context_enc) - max( |
|
0, len(context_enc) + len(continuation_enc) - (self.max_length) |
|
) |
|
|
|
inputs.append(inp) |
|
ctxlens.append(ctxlen) |
|
|
|
outputs = self._model_generate(requests=inputs, generate=False) |
|
|
|
for output, ctxlen, (cache_key, _, _), inp in zip( |
|
outputs, ctxlens, chunk, inputs |
|
): |
|
answer = self._parse_logprobs( |
|
tokens=inp, |
|
outputs=output, |
|
ctxlen=ctxlen, |
|
) |
|
|
|
res.append(answer) |
|
|
|
if cache_key is not None: |
|
|
|
|
|
|
|
self.cache_hook.add_partial("loglikelihood", cache_key, answer) |
|
pbar.update(1) |
|
pbar.close() |
|
return re_ord.get_original(res) |
|
|
|
@staticmethod |
|
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]: |
|
"""Process logprobs and tokens. |
|
|
|
:param tokens: list |
|
Input tokens (potentially left-truncated) |
|
:param outputs: RequestOutput |
|
Contains prompt_logprobs |
|
:param ctxlen: int |
|
Length of context (so we can slice them away and only keep the predictions) |
|
:return: |
|
continuation_logprobs: float |
|
Log probabilities of continuation tokens |
|
is_greedy: bool |
|
Whether argmax matches given continuation exactly |
|
""" |
|
|
|
|
|
continuation_logprobs_dicts = outputs.prompt_logprobs |
|
|
|
def coerce_logprob_to_num(logprob): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return getattr(logprob, "logprob", logprob) |
|
|
|
continuation_logprobs_dicts = [ |
|
{ |
|
token: coerce_logprob_to_num(logprob) |
|
for token, logprob in logprob_dict.items() |
|
} |
|
if logprob_dict is not None |
|
else None |
|
for logprob_dict in continuation_logprobs_dicts |
|
] |
|
|
|
|
|
|
|
continuation_logprobs = sum( |
|
logprob_dict.get(token) |
|
for token, logprob_dict in zip( |
|
tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:] |
|
) |
|
) |
|
|
|
|
|
is_greedy = True |
|
for token, logprob_dict in zip( |
|
tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:] |
|
): |
|
|
|
if logprob_dict: |
|
top_token = max(logprob_dict, key=logprob_dict.get) |
|
if top_token != token: |
|
is_greedy = False |
|
break |
|
|
|
return continuation_logprobs, is_greedy |
|
|
|
@staticmethod |
|
def modify_gen_kwargs(kwargs: dict) -> dict: |
|
|
|
do_sample = kwargs.pop("do_sample", None) |
|
if do_sample is False and "temperature" not in kwargs: |
|
eval_logger.debug( |
|
"Got `do_sample=False` and no temperature value, setting VLLM temperature to 0.0 ..." |
|
) |
|
kwargs["temperature"] = 0.0 |
|
|
|
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False) |
|
kwargs["spaces_between_special_tokens"] = kwargs.get( |
|
"spaces_between_special_tokens", False |
|
) |
|
return kwargs |
|
|