import copy from importlib.metadata import version from importlib.util import find_spec from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from more_itertools import distribute from packaging.version import parse as parse_version from tqdm import tqdm from lm_eval.api.instance import Instance from lm_eval.api.model import TemplateLM from lm_eval.api.registry import register_model from lm_eval.models.utils import ( Collator, configure_pad_token, handle_stop_sequences, undistribute, ) from lm_eval.utils import ( eval_logger, get_rolling_token_windows, make_disjoint_window, ) try: import ray from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer import get_tokenizer except ModuleNotFoundError: pass if TYPE_CHECKING: pass eval_logger = eval_logger @register_model("vllm") class VLLM(TemplateLM): _DEFAULT_MAX_LENGTH = 2048 def __init__( self, pretrained: str, dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto", revision: Optional[str] = None, trust_remote_code: Optional[bool] = False, tokenizer: Optional[str] = None, tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_revision: Optional[str] = None, add_bos_token: Optional[bool] = False, prefix_token_id: Optional[int] = None, tensor_parallel_size: int = 1, quantization: Optional[str] = None, max_gen_toks: int = 256, swap_space: int = 4, batch_size: Union[str, int] = 1, max_batch_size=None, max_length: int = None, max_model_len: int = None, seed: int = 1234, gpu_memory_utilization: float = 0.9, device: str = "cuda", data_parallel_size: int = 1, lora_local_path: str = None, **kwargs, ): super().__init__() if not find_spec("vllm"): raise ModuleNotFoundError( "attempted to use 'vllm' LM type, but package `vllm` is not installed. " "Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ) assert "cuda" in device or device is None, "vLLM only supports CUDA" assert max_length is None or max_model_len is None, ( "Either max_length or max_model_len may be provided, but not both" ) self._max_length = max_model_len if max_model_len is not None else max_length self.tensor_parallel_size = int(tensor_parallel_size) self.data_parallel_size = int(data_parallel_size) self.model_args = { "model": pretrained, "gpu_memory_utilization": float(gpu_memory_utilization), "revision": revision, "dtype": dtype, "tokenizer": tokenizer, "tokenizer_mode": tokenizer_mode, "tokenizer_revision": tokenizer_revision, "trust_remote_code": trust_remote_code, "tensor_parallel_size": int(tensor_parallel_size), "max_model_len": int(self._max_length) if self._max_length else None, "swap_space": int(swap_space), "quantization": quantization, "seed": int(seed), } self.model_args.update(kwargs) self.batch_size = ( "auto" if isinstance(batch_size, str) and "auto" in batch_size else int(batch_size) ) if self.data_parallel_size <= 1: self.model = LLM(**self.model_args) else: eval_logger.warning( "You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached." ) self.model_args["worker_use_ray"] = True self.batch_size = "auto" eval_logger.info("Manual batching is not compatible with data parallelism.") from transformers import AutoConfig self._config = AutoConfig.from_pretrained( pretrained, trust_remote_code=trust_remote_code, revision=revision ) self.tokenizer = get_tokenizer( tokenizer if tokenizer else pretrained, tokenizer_mode=tokenizer_mode, trust_remote_code=trust_remote_code, revision=tokenizer_revision, ) self.tokenizer = configure_pad_token(self.tokenizer) self.add_bos_token = add_bos_token if "gemma" in pretrained.lower(): self.add_bos_token = True eval_logger.info( "Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it." ) self.custom_prefix_token_id = prefix_token_id if prefix_token_id is not None: eval_logger.info( f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}" ) self._max_gen_toks = max_gen_toks if lora_local_path is not None: assert parse_version(version("vllm")) > parse_version("0.3.0"), ( "lora adapters only compatible with vllm > v0.3.0." ) self.lora_request = LoRARequest("finetuned", 1, lora_local_path) else: self.lora_request = None @property def eot_token_id(self): # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* return self.tokenizer.eos_token_id @property def prefix_token_id(self): # it is used as prefix for loglikelihood if self.custom_prefix_token_id is not None: return self.custom_prefix_token_id if self.tokenizer.bos_token_id is not None: return self.tokenizer.bos_token_id return self.tokenizer.eos_token_id @property def max_length(self): if self._max_length: # if max length manually set, return it return self._max_length if self.data_parallel_size <= 1: return self.model.llm_engine.model_config.max_model_len else: seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") for attr in seqlen_config_attrs: if hasattr(self._config, attr): return getattr(self._config, attr) if hasattr(self.tokenizer, "model_max_length"): if self.tokenizer.model_max_length == 1000000000000000019884624838656: return self._DEFAULT_MAX_LENGTH return self.tokenizer.model_max_length return self._DEFAULT_MAX_LENGTH @property def max_gen_toks(self): return self._max_gen_toks def apply_chat_template( self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True ) -> str: """ Method to apply a chat template to a list of chat history between user and model. """ chat_templated = self.tokenizer.apply_chat_template( chat_history, tokenize=False, add_generation_prompt=add_generation_prompt, continue_final_message=not add_generation_prompt, ) return chat_templated @property def tokenizer_name(self) -> str: return self.tokenizer.name_or_path.replace("/", "__") def tok_encode( self, string: Union[str, List[str]], left_truncate_len: int = None, add_special_tokens: bool = False, truncation: bool = False, ) -> Union[List[int], List[List[int]]]: if not add_special_tokens: add_special_tokens = False or self.add_bos_token encoding: Union[List[List[int]], List[int]] = self.tokenizer( string, add_special_tokens=add_special_tokens, truncation=truncation, return_attention_mask=False, ).input_ids # left-truncate the encoded context to be at most `left_truncate_len` tokens long if left_truncate_len: if not isinstance(string, str): encoding = [enc[-left_truncate_len:] for enc in encoding] else: encoding = encoding[-left_truncate_len:] return encoding def _model_generate( self, requests: List[List[int]] = None, generate: bool = False, max_tokens: int = None, stop: Optional[List[str]] = None, **kwargs, ): if generate: kwargs = self.modify_gen_kwargs(kwargs) sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs) else: sampling_params = SamplingParams( temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False ) if self.data_parallel_size > 1: # vLLM hangs if tensor_parallel > 1 and resources are set in ray.remote # also seems to only work with decorator and not with ray.remote() fn # see https://github.com/vllm-project/vllm/issues/973 # note: this has changed on 0.3.3, and it only works now if num_gpus are set. # but then tensor_parallel breaks @ray.remote def run_inference_one_model( model_args: dict, sampling_params, requests: List[List[int]], lora_request: LoRARequest, ): llm = LLM(**model_args) return llm.generate( prompt_token_ids=requests, sampling_params=sampling_params, lora_request=lora_request, ) # dispatch requests to all self.data_parallel_size workers, in interleaved fashion # interleaved important to balance context lengths across workers requests = [list(x) for x in distribute(self.data_parallel_size, requests)] inputs = ( (self.model_args, sampling_params, req, self.lora_request) for req in requests ) object_refs = [run_inference_one_model.remote(*x) for x in inputs] results = ray.get(object_refs) # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required. ray.shutdown() # flatten results return undistribute(results) outputs = self.model.generate( prompt_token_ids=requests, sampling_params=sampling_params, use_tqdm=True if self.batch_size == "auto" else False, lora_request=self.lora_request, ) return outputs def loglikelihood_rolling( self, requests: List[Instance], disable_tqdm: bool = False ) -> List[float]: adaptive_batch_size = None if self.batch_size == "auto": adaptive_batch_size = len(requests) # First, collect all windows from all requests all_windows = [] # List of (request_idx, window) tuples request_window_counts = [] # Track number of windows per request for req_idx, (string,) in enumerate( tqdm( [req.args for req in requests], disable=(disable_tqdm or (self.rank != 0)), ) ): rolling_token_windows: List[Tuple[List[int], List[int]]] = list( map( make_disjoint_window, get_rolling_token_windows( token_list=self.tok_encode(string), prefix_token=self.prefix_token_id, # max_seq_len - (1 for context) max_seq_len=self.max_length - 1, context_len=1, ), ) ) # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case windows = [(None,) + x for x in rolling_token_windows] # Store windows with their request index all_windows.extend((req_idx, window) for window in windows) request_window_counts.append(len(windows)) all_nlls = [] batch_size = adaptive_batch_size or int(self.batch_size) for i in range(0, len(all_windows), batch_size): batch = all_windows[i : i + batch_size] # Extract just the windows for processing, keeping track of request indices batch_indices, batch_windows = zip(*batch) batch_nlls = self._loglikelihood_tokens( requests=batch_windows, disable_tqdm=False, ) # Store results with their request indices all_nlls.extend(zip(batch_indices, batch_nlls)) # Reconstruct per-request loglikelihoods loglikelihoods = [] current_idx = 0 for window_count in request_window_counts: # Get all nlls for this request request_nlls = all_nlls[current_idx : current_idx + window_count] # Sum up the nlls for this request (discarding is_greedy) request_total = sum(nll[0] for _, nll in request_nlls) loglikelihoods.append(request_total) current_idx += window_count string = requests[len(loglikelihoods) - 1].args[0] self.cache_hook.add_partial( "loglikelihood_rolling", (string,), request_total ) return loglikelihoods def generate_until( self, requests: List[Instance], disable_tqdm: bool = False ) -> List[str]: res = [] # batch tokenize contexts context, all_gen_kwargs = zip(*(req.args for req in requests)) context_encoding: List[List[int]] = self.tok_encode( context, add_special_tokens=self.add_bos_token ) requests = [ ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs) ] def _collate_gen(_requests): # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end return -len(_requests[0][1]), _requests[0][0] # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs") chunks = re_ords.get_batched( n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None ) pbar = tqdm( total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running generate_until requests", ) # for each different set of kwargs, we execute all requests, by batch. eos = self.tokenizer.decode(self.eot_token_id) for chunk in chunks: context_and_encoding, all_gen_kwargs = zip(*chunk) context, context_encoding = zip(*context_and_encoding) # we assume all gen kwargs in the batch are the same # this is safe to assume because the `grouper` object ensures it. gen_kwargs = all_gen_kwargs[0] # unpack our keyword arguments. if isinstance(gen_kwargs, dict): kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 # add EOS token to stop sequences until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) else: raise ValueError( f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" ) if "max_gen_toks" in kwargs.keys(): max_gen_toks = kwargs.pop("max_gen_toks") else: max_gen_toks = self.max_gen_toks # set the max length in tokens of inputs ("context_enc") # max len for inputs = max length, minus room to generate the max new tokens max_ctx_len = self.max_length - max_gen_toks context_encoding = [x[-max_ctx_len:] for x in context_encoding] # perform batched generation cont = self._model_generate( requests=context_encoding, generate=True, max_tokens=max_gen_toks, stop=until, **kwargs, ) # cache generations for output, context in zip(cont, context): generated_text = output.outputs[0].text res.append(generated_text) self.cache_hook.add_partial( "generate_until", (context, gen_kwargs), generated_text ) pbar.update(1) pbar.close() # reorder all group of results back to original unsorted form return re_ords.get_original(res) def _loglikelihood_tokens( self, requests: List[Tuple[Tuple[str, str], List[int], List[int]]], disable_tqdm: bool = False, ) -> List[Tuple[float, bool]]: res = [] def _collate(x): toks = x[1] + x[2] return -len(toks), tuple(toks) # Reorder requests by length and batch re_ord = Collator(requests, sort_fn=_collate) chunks = re_ord.get_batched( n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None ) pbar = tqdm( total=len(requests), disable=disable_tqdm, desc="Running loglikelihood requests", ) for chunk in chunks: inputs = [] ctxlens = [] for cache_key, context_enc, continuation_enc in chunk: inp = (context_enc + continuation_enc)[-(self.max_length) :] ctxlen = len(context_enc) - max( 0, len(context_enc) + len(continuation_enc) - (self.max_length) ) inputs.append(inp) ctxlens.append(ctxlen) outputs = self._model_generate(requests=inputs, generate=False) for output, ctxlen, (cache_key, _, _), inp in zip( outputs, ctxlens, chunk, inputs ): answer = self._parse_logprobs( tokens=inp, outputs=output, ctxlen=ctxlen, ) res.append(answer) if cache_key is not None: # special case: loglikelihood_rolling produces a number of loglikelihood requests # all with cache key None. instead do add_partial on the per-example level # in the loglikelihood_rolling() function for those. self.cache_hook.add_partial("loglikelihood", cache_key, answer) pbar.update(1) pbar.close() return re_ord.get_original(res) @staticmethod def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]: """Process logprobs and tokens. :param tokens: list Input tokens (potentially left-truncated) :param outputs: RequestOutput Contains prompt_logprobs :param ctxlen: int Length of context (so we can slice them away and only keep the predictions) :return: continuation_logprobs: float Log probabilities of continuation tokens is_greedy: bool Whether argmax matches given continuation exactly """ # The first entry of prompt_logprobs is None because the model has no previous tokens to condition on. continuation_logprobs_dicts = outputs.prompt_logprobs def coerce_logprob_to_num(logprob): # vLLM changed the return type of logprobs from float # to a Logprob object storing the float value + extra data # (https://github.com/vllm-project/vllm/pull/3065). # If we are dealing with vllm's Logprob object, return # the logprob value stored as an attribute. Otherwise, # return the object itself (which should be a float # for older versions of vLLM). return getattr(logprob, "logprob", logprob) continuation_logprobs_dicts = [ { token: coerce_logprob_to_num(logprob) for token, logprob in logprob_dict.items() } if logprob_dict is not None else None for logprob_dict in continuation_logprobs_dicts ] # Calculate continuation_logprobs # assume ctxlen always >= 1 continuation_logprobs = sum( logprob_dict.get(token) for token, logprob_dict in zip( tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:] ) ) # Determine if is_greedy is_greedy = True for token, logprob_dict in zip( tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:] ): # Get the token with the maximum log probability from the logprob_dict if logprob_dict: # Ensure the logprob_dict is not None top_token = max(logprob_dict, key=logprob_dict.get) if top_token != token: is_greedy = False break return continuation_logprobs, is_greedy @staticmethod def modify_gen_kwargs(kwargs: dict) -> dict: # sampling_params do_sample = kwargs.pop("do_sample", None) if do_sample is False and "temperature" not in kwargs: eval_logger.debug( "Got `do_sample=False` and no temperature value, setting VLLM temperature to 0.0 ..." ) kwargs["temperature"] = 0.0 # hf defaults kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False) kwargs["spaces_between_special_tokens"] = kwargs.get( "spaces_between_special_tokens", False ) return kwargs