import copy from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F import transformers from tqdm import tqdm from transformers import BatchEncoding from lm_eval import utils from lm_eval.api.instance import Instance from lm_eval.api.registry import register_model from lm_eval.models.huggingface import HFLM from lm_eval.models.utils import ( Collator, flatten_image_list, handle_stop_sequences, pad_and_concat, replace_placeholders, stop_sequences_criteria, ) DEFAULT_IMAGE_PLACEHOLDER = "" eval_logger = utils.eval_logger @register_model("hf-multimodal") class HFMultimodalLM(HFLM): """ An abstracted Hugging Face model class for multimodal LMs like Llava and Idefics. """ AUTO_MODEL_CLASS = transformers.AutoModelForVision2Seq MULTIMODAL = True # flag to indicate, for now, that this model type can run multimodal requests def __init__( self, pretrained: Union[str, transformers.PreTrainedModel], image_token_id: Optional[int] = None, image_string: Optional[str] = None, interleave: bool = True, # TODO: handle whitespace in image placeholder (replacement) max_images: Optional[int] = 999, convert_img_format=False, **kwargs, ): # We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer # modify init behavior. super().__init__(pretrained, **kwargs) assert self.batch_size != "auto", ( "Batch size 'auto' is not yet supported for hf-multimodal models." ) self.chat_applied: bool = False # TODO: phi-3.5 "image placeholders" are , , ... in order. how to handle this case # HF AutoModelForVision2Seq models have an `image_token_id` value in their configs # denoting the token which indicates a location where an image will be substituted in. # This can take different string values across models, e.g. for Idefics2 and <|image_pad|> for Qwen2-VL self.interleave = interleave self.max_images = max_images self.rgb = convert_img_format # WARNING: improperly set image_token_id can lead to ignored image input or other (potentially silent) errors! if not image_string: self.image_token_id = ( int(image_token_id) if image_token_id else ( getattr(self.config, "image_token_id", None) or getattr(self.config, "image_token_index", None) ) ) assert self.image_token_id is not None, ( "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one." ) # get the string this token ID corresponds to self.image_token = self.tok_decode( [self.image_token_id], skip_special_tokens=False ) if image_token_id is not None: eval_logger.info( f"A non-default image_token_id with image_token_id={self.image_token_id} and string value '{self.image_token}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!" ) else: eval_logger.info( f"A non-default image_token string with string value image_string='{image_string}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!" ) self.image_token = image_string def _create_tokenizer( self, pretrained: Union[str, transformers.PreTrainedModel], tokenizer: Optional[ Union[ str, transformers.ProcessorMixin, ] ], revision: Optional[str] = "main", trust_remote_code: Optional[bool] = False, **kwargs, ) -> None: """ Helper method during initialization. For the multimodal variant, we initialize not just `self.tokenizer` but also `self.processor`. """ if tokenizer: if isinstance(tokenizer, str): return transformers.AutoProcessor.from_pretrained( tokenizer, revision=revision, trust_remote_code=trust_remote_code, # use_fast=use_fast_tokenizer, ) else: assert isinstance( tokenizer, transformers.ProcessorMixin ) # TODO: check this condition return tokenizer # Get tokenizer based on 'pretrained' if isinstance(pretrained, str): model_name = pretrained else: # get the HF hub name via accessor on model model_name = self.model.name_or_path self.processor = transformers.AutoProcessor.from_pretrained( model_name, revision=revision, trust_remote_code=trust_remote_code, # use_fast=use_fast_tokenizer, ) self.tokenizer = self.processor.tokenizer def tok_multimodal_encode( self, string, images, left_truncate_len=None, add_special_tokens=None ): """Helper function which encodes an image + string combo using AutoProcessor""" # We inherit special token kwarg setup from HFLM.tok_encode # special_tokens_kwargs = {} # by default for CausalLM - false or self.add_bos_token is set # if add_special_tokens is None: # special_tokens_kwargs = {"add_special_tokens": False or self.add_bos_token} # otherwise the method explicitly defines the value # else: # special_tokens_kwargs = {"add_special_tokens": add_special_tokens} # encode text+images # TODO: why does (Qwen2-VL) processor error when attempting to add special tokens to text? encoding = self.processor( text=string, images=images, return_tensors=None ) # , **special_tokens_kwargs) # remove (and store) our tokenized text text_encoding = encoding.pop("input_ids") encoding.pop("attention_mask") # left-truncate the encoded context to be at most `left_truncate_len` tokens long if left_truncate_len: text_encoding = text_encoding[-left_truncate_len:] return text_encoding, encoding # image_encoding is a dict def _encode_multimodal_pair(self, context, continuation, images): """Helper function to perform the role of TemplateLM._encode_pair Except allowing for image input to also be processed alongside `context`. This method is a bit messy due to the need to defer conversion of image and text token input into PyTorch tensors until the main inference loop. """ n_spaces = len(context) - len(context.rstrip()) if n_spaces > 0: continuation = context[-n_spaces:] + continuation context = context[:-n_spaces] # TODO: replace default placeholder with self.image_token, for contexts whole_enc, image_enc = self.tok_multimodal_encode( context + continuation, images ) context_enc, _ = self.tok_multimodal_encode(context, images) # tok_multimodal_encode returns List[List[int]] for tokenized text. Get rid of the batch dim # since we only are encoding a single string. # TODO: this is a bit hacky, it'd be nice to make this generally cleaner whole_enc, context_enc = whole_enc[0], context_enc[0] context_enc_len = len(context_enc) continuation_enc = whole_enc[context_enc_len:] return context_enc, continuation_enc, image_enc def apply_chat_template( self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True ) -> str: self.chat_applied = True if not self.interleave: for content in chat_history: c = [] text = content["content"] # Count and remove image placeholders image_count = min( self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER) ) text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "") # Add image entries for _ in range(image_count): c.append({"type": "image", "image": None}) # Add single text entry at the end c.append({"type": "text", "text": text}) content["content"] = c else: for content in chat_history: c = [] text = content["content"] expected_image_count = min( self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER) ) actual_image_count = 0 text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER) for i, part in enumerate(text_parts): # TODO: concatenate text parts (esp. if skipping images)? if part: # Add non-empty text parts c.append({"type": "text", "text": part}) if ( (i < len(text_parts) - 1) and i < self.max_images ): # Add image placeholder after each split except the last c.append({"type": "image"}) actual_image_count += 1 content["content"] = c if actual_image_count != expected_image_count: raise ValueError( f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}" ) return self.processor.apply_chat_template( chat_history, add_generation_prompt=add_generation_prompt, continue_final_message=not add_generation_prompt, ) def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: if hasattr(self.processor, "apply_chat_template"): _tokenizer = self.tokenizer self.tokenizer = self.processor selected_template = super().chat_template(chat_template) self.tokenizer = _tokenizer return selected_template else: return super().chat_template(chat_template) def tok_batch_multimodal_encode( self, strings: List[str], # note that input signature of this fn is different images: List[List], # TODO: images are pil.Image at the moment, update typehint padding_side: str = "left", left_truncate_len: int = None, truncation: bool = False, ) -> Union[ BatchEncoding, Dict[str, torch.Tensor] ]: # note that this return signature differs from HFLM tok_batch_encode. # NOTE: here, we replace tags with our model's corresponding image_token string value. if not self.chat_applied: # TODO: This still keeps the whitespace in the image placeholder, which is not ideal. strings = [ replace_placeholders( string, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images ) for string in strings ] # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. old_padding_side = self.tokenizer.padding_side self.tokenizer.padding_side = padding_side # add_special_tokens = {"add_special_tokens": False or self.add_bos_token} images = [img[: self.max_images] for img in images] if self.rgb: images = [[img.convert("RGB") for img in sublist] for sublist in images] # certain models like llava expect a single-level image list even for bs>1, multi-image. TODO: port this over to loglikelihoods if getattr(self.config, "model_type", "") == "llava": images = flatten_image_list(images) encoding = self.processor( images=images, text=strings, truncation=truncation, padding="longest", return_tensors="pt", # **add_special_tokens, # TODO: at least some Processors error out when passing this. How do we control whether text gets BOS added? ) encoding.to( # TODO: our other tokenization methods in HFLM don't typically move to device. this breaks convention self.device, self.model.dtype ) # TODO: This only casts the pixel values. Should they always be float16? if left_truncate_len: encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] encoding["attention_mask"] = encoding["attention_mask"][ :, -left_truncate_len: ] self.tokenizer.padding_side = old_padding_side return encoding def _model_multimodal_call(self, inps, imgs, attn_mask=None, labels=None): """ TODO: update docstring """ # note: imgs is a dict. with torch.no_grad(): return self.model(inps, **imgs).logits def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs): generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) do_sample = generation_kwargs.get("do_sample", None) # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies if generation_kwargs.get("temperature") == 0.0 and do_sample is None: generation_kwargs["do_sample"] = do_sample = False if do_sample is False and generation_kwargs.get("temperature") == 0.0: generation_kwargs.pop("temperature") stopping_criteria = stop_sequences_criteria( self.tokenizer, stop, inputs["input_ids"].shape[1], inputs["input_ids"].shape[0], ) return self.model.generate( **inputs, max_length=max_length, stopping_criteria=stopping_criteria, pad_token_id=self.tokenizer.pad_token_id, use_cache=True, **generation_kwargs, ) def _batch_images(self, image_encs): """ Helper function: batch together image encodings across examples in a batch. # TODO: for variable-sized images, this may break down. """ batched_imgs = {} for key in image_encs[0].keys(): batched_imgs[key] = torch.cat( [ torch.tensor( image_enc[key], device=self.device, dtype=self.model.dtype ) for image_enc in image_encs ], dim=0, ) return batched_imgs def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: raise NotImplementedError( "model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks ", "this is because we do not support measuring the loglikelihood a model assigns to an image.", ) def loglikelihood( self, requests: List[Instance], disable_tqdm: bool = False ) -> List[Tuple[float, bool]]: raise NotImplementedError( "'loglikelihood' requests for model type `hf-multimodal` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!" ) new_reqs = [] for context, continuation, aux_arguments in [req.args for req in requests]: if context == "": raise ValueError( "Must get non-empty context for multimodal requests! You might be trying to run 'loglikelihood_rolling', which is not supported in the multimodal case." ) else: visuals = aux_arguments["visual"] context_enc, continuation_enc, image_enc = self._encode_multimodal_pair( context, continuation, visuals ) # TODO: key to pick for caching images new_reqs.append( ( (context, continuation, visuals), context_enc, continuation_enc, image_enc, ) ) return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm) def _loglikelihood_tokens( self, requests: List[ Tuple[Tuple[None, str, str], List[int], List[int], List[int]] ], # TODO: update typehint to be correct disable_tqdm: bool = False, override_bs: int = None, ) -> List[Tuple[float, bool]]: res = [] # TODO: **improve multimodal collation.** We currently ignore image size when ordering docs. ideally we'd take them into account def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): """Defines the key for the sorted method""" # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = req[1] + req[2] return -len(toks), tuple(toks) def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): """Defines the key to group and lookup one-token continuations""" # Use with group_by="contexts" (optional)" # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. # speeds up some multiple-choice tasks proportionally to the number of choices. # groups requests by context+continuation[:-1] and infer on one request/group. return req[-1] + req[-3] + req[-2][:-1] re_ord = Collator( requests, sort_fn=_collate, group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM and self.logits_cache else None, group_fn=_lookup_one_token_cont, ) # automatic (variable) batch size detection for vectorization # pull longest context sample from request n_reordered_requests = len(re_ord) batch_size = ( self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0 ) batch_fn = ( self._batch_scheduler if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs else None ) chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) pbar = tqdm( total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running loglikelihood requests with text+image input", ) for chunk in chunks: imgs = [] inps = [] cont_toks_list = [] inplens = [] padding_len_inp = None # because vectorizing is annoying, we first convert each (context, continuation) pair to padded # tensors, then we pack them together into a batch, call the model, and then pick it all apart # again because vectorizing is annoying for _, context_enc, continuation_enc, image_enc in chunk: # sanity check assert len(image_enc) > 0 assert len(context_enc) > 0 assert len(continuation_enc) > 0 assert len(continuation_enc) <= self.max_length # how this all works (illustrated on a causal decoder-only setup): # CTX CONT # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] # model \ \ # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # when too long to fit in context, truncate from the left # TODO: assuming that we won't handle enc-dec Vision2Seq models. Is that a safe assumption? inp = torch.tensor( (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], dtype=torch.long, device=self.device, ) (inplen,) = inp.shape padding_len_inp = ( max(padding_len_inp, inplen) if padding_len_inp is not None else inplen ) inps.append(inp) # [1, inp_length] cont_toks_list.append(continuation_enc) inplens.append(inplen) imgs.append(image_enc) # create encoder attn mask and batched conts, if seq2seq call_kwargs = {} batched_inps = pad_and_concat( padding_len_inp, inps, padding_side="right" ) # [batch, padding_len_inp] # batch our examples' image inputs together batched_imgs = self._batch_images( imgs ) # TODO: fix/test for bs>1 case with differently-sized imgs! multi_logits = F.log_softmax( self._model_multimodal_call(batched_inps, batched_imgs, **call_kwargs), dim=-1, ) # [batch, padding_length (inp or cont), vocab] for ( request_str, ctx_tokens, _, image_encs, ), logits, inplen, cont_toks in zip( chunk, multi_logits, inplens, cont_toks_list ): # Slice to original seq length contlen = len(cont_toks) # take only logits in the continuation # (discard context toks if decoder-only ; discard right-padding) # also discards + checks for "virtual tokens" in the causal LM's input window # from prompt/prefix tuning tokens, if applicable ctx_len = ( inplen + (logits.shape[0] - padding_len_inp) if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM else None ) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = logits.unsqueeze(0) # [1, seq, vocab] # Check if per-token argmax is exactly equal to continuation greedy_tokens = logits.argmax(dim=-1) # check for one-token continuation cache hits. # noop in case group_by != "contexts" or no cache hit and returns the # original args. Otherwise, expands the logits batch dimension and yields each # batch along with matching continuation tokens and prompt strings. # logits -> [1, seq, vocab] for request_str, cont_toks, logits in re_ord.get_cache( req_str=request_str, cxt_toks=ctx_tokens, cont_toks=cont_toks, logits=logits, ): cont_toks = torch.tensor( cont_toks, dtype=torch.long, device=self.device ).unsqueeze(0) # [1, seq] max_equal = (greedy_tokens == cont_toks).all() # Obtain log-probs at the corresponding continuation token indices # last_token_slice = logits[:, -1, :].squeeze(0).tolist() logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( -1 ) # [1, seq] # Answer: (log prob, is-exact-match) answer = (float(logits.sum()), bool(max_equal)) res.append(answer) self.cache_hook.add_partial( "loglikelihood", request_str, answer ) # TODO: choose convention for adding images into the cache key pbar.update(1) pbar.close() return re_ord.get_original(res) def generate_until( self, requests: List[Instance], disable_tqdm: bool = False ) -> List[str]: # TODO: back out to HFLM.generate_until() for all requests without aux_arguments (text-only reqs) res = [] def _collate(x): # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = self.tok_encode(x[0]) return -len(toks), x[0] pbar = tqdm( total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running generate_until requests with text+image input", ) # TODO: port auto-batch sizing into this. # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. re_ords = Collator( [reg.args for reg in requests], _collate, group_by="gen_kwargs", group_fn=lambda x: x[1], ) chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) ### Up to here: was identical to non-multimodal HFLM generate_until ### eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) for chunk in chunks: contexts, all_gen_kwargs, aux_arguments = zip(*chunk) visuals = [arg["visual"] for arg in aux_arguments] if not isinstance(contexts, list): contexts = list( contexts ) # for Qwen2-VL, processor is unhappy accepting a tuple of strings instead of a list. # TODO: could we upstream this workaround to HF? ### this part onward: same as HFLM ### # we assume all gen kwargs in the batch are the same # this is safe to assume because the `grouper` object ensures it. gen_kwargs = all_gen_kwargs[0] # unpack our keyword arguments. if isinstance(gen_kwargs, dict): kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 # add EOS token to stop sequences until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) else: raise ValueError( f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" ) if "max_gen_toks" in kwargs.keys(): max_gen_toks = kwargs.pop("max_gen_toks") else: max_gen_toks = self.max_gen_toks ### end stuff that's entirely copied verbatim from HFLM ### max_ctx_len = self.max_length - max_gen_toks inputs = self.tok_batch_multimodal_encode( contexts, visuals, left_truncate_len=max_ctx_len, truncation=self.truncation, ) context_enc = inputs["input_ids"] if "max_length" not in kwargs: kwargs["max_length"] = context_enc.shape[1] + max_gen_toks cont = self._model_multimodal_generate(inputs, stop=until, **kwargs) del inputs torch.cuda.empty_cache() import gc gc.collect() ### essentially same as HFLM beyond this line! cont_toks_list = cont.tolist() for cont_toks, context in zip(cont_toks_list, contexts): # discard context + left-padding toks if using causal decoder-only VLM cont_toks = cont_toks[context_enc.shape[1] :] s = self.tok_decode(cont_toks) # use secondary stop seqs to cut off should-have-been-stopped content post-hoc for term in until: if len(term) > 0: # ignore '' separator, # for seq2seq case where self.tok_decode(self.eot_token_id) = '' s = s.split(term)[0] res.append(s) self.cache_hook.add_partial( "generate_until", (context, gen_kwargs), s ) # TODO: cache key for multimodal input should be what? pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) pbar.close() return res