|
import copy |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import transformers |
|
from tqdm import tqdm |
|
from transformers import BatchEncoding |
|
|
|
from lm_eval import utils |
|
from lm_eval.api.instance import Instance |
|
from lm_eval.api.registry import register_model |
|
from lm_eval.models.huggingface import HFLM |
|
from lm_eval.models.utils import ( |
|
Collator, |
|
flatten_image_list, |
|
handle_stop_sequences, |
|
pad_and_concat, |
|
replace_placeholders, |
|
stop_sequences_criteria, |
|
) |
|
|
|
|
|
DEFAULT_IMAGE_PLACEHOLDER = "<image>" |
|
|
|
|
|
eval_logger = utils.eval_logger |
|
|
|
|
|
@register_model("hf-multimodal") |
|
class HFMultimodalLM(HFLM): |
|
""" |
|
An abstracted Hugging Face model class for multimodal LMs like Llava and Idefics. |
|
""" |
|
|
|
AUTO_MODEL_CLASS = transformers.AutoModelForVision2Seq |
|
MULTIMODAL = True |
|
|
|
def __init__( |
|
self, |
|
pretrained: Union[str, transformers.PreTrainedModel], |
|
image_token_id: Optional[int] = None, |
|
image_string: Optional[str] = None, |
|
interleave: bool = True, |
|
|
|
max_images: Optional[int] = 999, |
|
convert_img_format=False, |
|
**kwargs, |
|
): |
|
|
|
|
|
super().__init__(pretrained, **kwargs) |
|
|
|
assert self.batch_size != "auto", ( |
|
"Batch size 'auto' is not yet supported for hf-multimodal models." |
|
) |
|
self.chat_applied: bool = False |
|
|
|
|
|
|
|
|
|
|
|
self.interleave = interleave |
|
self.max_images = max_images |
|
self.rgb = convert_img_format |
|
|
|
if not image_string: |
|
self.image_token_id = ( |
|
int(image_token_id) |
|
if image_token_id |
|
else ( |
|
getattr(self.config, "image_token_id", None) |
|
or getattr(self.config, "image_token_index", None) |
|
) |
|
) |
|
assert self.image_token_id is not None, ( |
|
"Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one." |
|
) |
|
|
|
self.image_token = self.tok_decode( |
|
[self.image_token_id], skip_special_tokens=False |
|
) |
|
if image_token_id is not None: |
|
eval_logger.info( |
|
f"A non-default image_token_id with image_token_id={self.image_token_id} and string value '{self.image_token}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!" |
|
) |
|
else: |
|
eval_logger.info( |
|
f"A non-default image_token string with string value image_string='{image_string}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!" |
|
) |
|
self.image_token = image_string |
|
|
|
def _create_tokenizer( |
|
self, |
|
pretrained: Union[str, transformers.PreTrainedModel], |
|
tokenizer: Optional[ |
|
Union[ |
|
str, |
|
transformers.ProcessorMixin, |
|
] |
|
], |
|
revision: Optional[str] = "main", |
|
trust_remote_code: Optional[bool] = False, |
|
**kwargs, |
|
) -> None: |
|
""" |
|
Helper method during initialization. |
|
|
|
For the multimodal variant, we initialize not just |
|
`self.tokenizer` but also `self.processor`. |
|
""" |
|
|
|
if tokenizer: |
|
if isinstance(tokenizer, str): |
|
return transformers.AutoProcessor.from_pretrained( |
|
tokenizer, |
|
revision=revision, |
|
trust_remote_code=trust_remote_code, |
|
|
|
) |
|
else: |
|
assert isinstance( |
|
tokenizer, transformers.ProcessorMixin |
|
) |
|
return tokenizer |
|
|
|
|
|
if isinstance(pretrained, str): |
|
model_name = pretrained |
|
else: |
|
|
|
model_name = self.model.name_or_path |
|
|
|
self.processor = transformers.AutoProcessor.from_pretrained( |
|
model_name, |
|
revision=revision, |
|
trust_remote_code=trust_remote_code, |
|
|
|
) |
|
|
|
self.tokenizer = self.processor.tokenizer |
|
|
|
def tok_multimodal_encode( |
|
self, string, images, left_truncate_len=None, add_special_tokens=None |
|
): |
|
"""Helper function which encodes an image + string combo using AutoProcessor""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoding = self.processor( |
|
text=string, images=images, return_tensors=None |
|
) |
|
|
|
|
|
text_encoding = encoding.pop("input_ids") |
|
encoding.pop("attention_mask") |
|
|
|
|
|
if left_truncate_len: |
|
text_encoding = text_encoding[-left_truncate_len:] |
|
|
|
return text_encoding, encoding |
|
|
|
def _encode_multimodal_pair(self, context, continuation, images): |
|
"""Helper function to perform the role of TemplateLM._encode_pair |
|
Except allowing for image input to also be processed alongside `context`. |
|
|
|
This method is a bit messy due to the need to defer conversion of image and text token input |
|
into PyTorch tensors until the main inference loop. |
|
""" |
|
|
|
n_spaces = len(context) - len(context.rstrip()) |
|
if n_spaces > 0: |
|
continuation = context[-n_spaces:] + continuation |
|
context = context[:-n_spaces] |
|
|
|
|
|
|
|
whole_enc, image_enc = self.tok_multimodal_encode( |
|
context + continuation, images |
|
) |
|
context_enc, _ = self.tok_multimodal_encode(context, images) |
|
|
|
|
|
|
|
|
|
whole_enc, context_enc = whole_enc[0], context_enc[0] |
|
|
|
context_enc_len = len(context_enc) |
|
continuation_enc = whole_enc[context_enc_len:] |
|
|
|
return context_enc, continuation_enc, image_enc |
|
|
|
def apply_chat_template( |
|
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True |
|
) -> str: |
|
self.chat_applied = True |
|
if not self.interleave: |
|
for content in chat_history: |
|
c = [] |
|
text = content["content"] |
|
|
|
|
|
image_count = min( |
|
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER) |
|
) |
|
text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "") |
|
|
|
|
|
for _ in range(image_count): |
|
c.append({"type": "image", "image": None}) |
|
|
|
|
|
c.append({"type": "text", "text": text}) |
|
|
|
content["content"] = c |
|
else: |
|
for content in chat_history: |
|
c = [] |
|
text = content["content"] |
|
expected_image_count = min( |
|
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER) |
|
) |
|
actual_image_count = 0 |
|
|
|
text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER) |
|
|
|
for i, part in enumerate(text_parts): |
|
|
|
if part: |
|
c.append({"type": "text", "text": part}) |
|
if ( |
|
(i < len(text_parts) - 1) and i < self.max_images |
|
): |
|
c.append({"type": "image"}) |
|
actual_image_count += 1 |
|
|
|
content["content"] = c |
|
|
|
if actual_image_count != expected_image_count: |
|
raise ValueError( |
|
f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}" |
|
) |
|
|
|
return self.processor.apply_chat_template( |
|
chat_history, |
|
add_generation_prompt=add_generation_prompt, |
|
continue_final_message=not add_generation_prompt, |
|
) |
|
|
|
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: |
|
if hasattr(self.processor, "apply_chat_template"): |
|
_tokenizer = self.tokenizer |
|
self.tokenizer = self.processor |
|
|
|
selected_template = super().chat_template(chat_template) |
|
|
|
self.tokenizer = _tokenizer |
|
return selected_template |
|
else: |
|
return super().chat_template(chat_template) |
|
|
|
def tok_batch_multimodal_encode( |
|
self, |
|
strings: List[str], |
|
images: List[List], |
|
padding_side: str = "left", |
|
left_truncate_len: int = None, |
|
truncation: bool = False, |
|
) -> Union[ |
|
BatchEncoding, Dict[str, torch.Tensor] |
|
]: |
|
|
|
if not self.chat_applied: |
|
|
|
strings = [ |
|
replace_placeholders( |
|
string, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images |
|
) |
|
for string in strings |
|
] |
|
|
|
|
|
old_padding_side = self.tokenizer.padding_side |
|
self.tokenizer.padding_side = padding_side |
|
|
|
|
|
|
|
images = [img[: self.max_images] for img in images] |
|
if self.rgb: |
|
images = [[img.convert("RGB") for img in sublist] for sublist in images] |
|
|
|
|
|
if getattr(self.config, "model_type", "") == "llava": |
|
images = flatten_image_list(images) |
|
|
|
encoding = self.processor( |
|
images=images, |
|
text=strings, |
|
truncation=truncation, |
|
padding="longest", |
|
return_tensors="pt", |
|
|
|
) |
|
|
|
encoding.to( |
|
self.device, self.model.dtype |
|
) |
|
if left_truncate_len: |
|
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] |
|
encoding["attention_mask"] = encoding["attention_mask"][ |
|
:, -left_truncate_len: |
|
] |
|
self.tokenizer.padding_side = old_padding_side |
|
|
|
return encoding |
|
|
|
def _model_multimodal_call(self, inps, imgs, attn_mask=None, labels=None): |
|
""" |
|
TODO: update docstring |
|
""" |
|
|
|
with torch.no_grad(): |
|
return self.model(inps, **imgs).logits |
|
|
|
def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs): |
|
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) |
|
do_sample = generation_kwargs.get("do_sample", None) |
|
|
|
|
|
if generation_kwargs.get("temperature") == 0.0 and do_sample is None: |
|
generation_kwargs["do_sample"] = do_sample = False |
|
|
|
if do_sample is False and generation_kwargs.get("temperature") == 0.0: |
|
generation_kwargs.pop("temperature") |
|
|
|
stopping_criteria = stop_sequences_criteria( |
|
self.tokenizer, |
|
stop, |
|
inputs["input_ids"].shape[1], |
|
inputs["input_ids"].shape[0], |
|
) |
|
return self.model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
stopping_criteria=stopping_criteria, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
use_cache=True, |
|
**generation_kwargs, |
|
) |
|
|
|
def _batch_images(self, image_encs): |
|
""" |
|
Helper function: batch together image encodings across examples in a batch. |
|
# TODO: for variable-sized images, this may break down. |
|
""" |
|
batched_imgs = {} |
|
for key in image_encs[0].keys(): |
|
batched_imgs[key] = torch.cat( |
|
[ |
|
torch.tensor( |
|
image_enc[key], device=self.device, dtype=self.model.dtype |
|
) |
|
for image_enc in image_encs |
|
], |
|
dim=0, |
|
) |
|
return batched_imgs |
|
|
|
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: |
|
raise NotImplementedError( |
|
"model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks ", |
|
"this is because we do not support measuring the loglikelihood a model assigns to an image.", |
|
) |
|
|
|
def loglikelihood( |
|
self, requests: List[Instance], disable_tqdm: bool = False |
|
) -> List[Tuple[float, bool]]: |
|
raise NotImplementedError( |
|
"'loglikelihood' requests for model type `hf-multimodal` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!" |
|
) |
|
|
|
new_reqs = [] |
|
for context, continuation, aux_arguments in [req.args for req in requests]: |
|
if context == "": |
|
raise ValueError( |
|
"Must get non-empty context for multimodal requests! You might be trying to run 'loglikelihood_rolling', which is not supported in the multimodal case." |
|
) |
|
else: |
|
visuals = aux_arguments["visual"] |
|
|
|
context_enc, continuation_enc, image_enc = self._encode_multimodal_pair( |
|
context, continuation, visuals |
|
) |
|
|
|
new_reqs.append( |
|
( |
|
(context, continuation, visuals), |
|
context_enc, |
|
continuation_enc, |
|
image_enc, |
|
) |
|
) |
|
|
|
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm) |
|
|
|
def _loglikelihood_tokens( |
|
self, |
|
requests: List[ |
|
Tuple[Tuple[None, str, str], List[int], List[int], List[int]] |
|
], |
|
disable_tqdm: bool = False, |
|
override_bs: int = None, |
|
) -> List[Tuple[float, bool]]: |
|
res = [] |
|
|
|
|
|
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): |
|
"""Defines the key for the sorted method""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
toks = req[1] + req[2] |
|
return -len(toks), tuple(toks) |
|
|
|
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): |
|
"""Defines the key to group and lookup one-token continuations""" |
|
|
|
|
|
|
|
|
|
return req[-1] + req[-3] + req[-2][:-1] |
|
|
|
re_ord = Collator( |
|
requests, |
|
sort_fn=_collate, |
|
group_by="contexts" |
|
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM |
|
and self.logits_cache |
|
else None, |
|
group_fn=_lookup_one_token_cont, |
|
) |
|
|
|
|
|
|
|
n_reordered_requests = len(re_ord) |
|
batch_size = ( |
|
self.batch_size |
|
if self.batch_size != "auto" |
|
else override_bs |
|
if override_bs is not None |
|
else 0 |
|
) |
|
batch_fn = ( |
|
self._batch_scheduler |
|
if self.batch_size == "auto" |
|
and n_reordered_requests > 0 |
|
and not override_bs |
|
else None |
|
) |
|
|
|
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) |
|
pbar = tqdm( |
|
total=len(requests), |
|
disable=(disable_tqdm or (self.rank != 0)), |
|
desc="Running loglikelihood requests with text+image input", |
|
) |
|
for chunk in chunks: |
|
imgs = [] |
|
inps = [] |
|
cont_toks_list = [] |
|
inplens = [] |
|
|
|
padding_len_inp = None |
|
|
|
|
|
|
|
|
|
for _, context_enc, continuation_enc, image_enc in chunk: |
|
|
|
assert len(image_enc) > 0 |
|
assert len(context_enc) > 0 |
|
assert len(continuation_enc) > 0 |
|
assert len(continuation_enc) <= self.max_length |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inp = torch.tensor( |
|
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], |
|
dtype=torch.long, |
|
device=self.device, |
|
) |
|
(inplen,) = inp.shape |
|
|
|
padding_len_inp = ( |
|
max(padding_len_inp, inplen) |
|
if padding_len_inp is not None |
|
else inplen |
|
) |
|
|
|
inps.append(inp) |
|
cont_toks_list.append(continuation_enc) |
|
inplens.append(inplen) |
|
|
|
imgs.append(image_enc) |
|
|
|
|
|
call_kwargs = {} |
|
batched_inps = pad_and_concat( |
|
padding_len_inp, inps, padding_side="right" |
|
) |
|
|
|
batched_imgs = self._batch_images( |
|
imgs |
|
) |
|
|
|
multi_logits = F.log_softmax( |
|
self._model_multimodal_call(batched_inps, batched_imgs, **call_kwargs), |
|
dim=-1, |
|
) |
|
|
|
for ( |
|
request_str, |
|
ctx_tokens, |
|
_, |
|
image_encs, |
|
), logits, inplen, cont_toks in zip( |
|
chunk, multi_logits, inplens, cont_toks_list |
|
): |
|
|
|
contlen = len(cont_toks) |
|
|
|
|
|
|
|
|
|
ctx_len = ( |
|
inplen + (logits.shape[0] - padding_len_inp) |
|
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM |
|
else None |
|
) |
|
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) |
|
logits = logits.unsqueeze(0) |
|
|
|
|
|
greedy_tokens = logits.argmax(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
for request_str, cont_toks, logits in re_ord.get_cache( |
|
req_str=request_str, |
|
cxt_toks=ctx_tokens, |
|
cont_toks=cont_toks, |
|
logits=logits, |
|
): |
|
cont_toks = torch.tensor( |
|
cont_toks, dtype=torch.long, device=self.device |
|
).unsqueeze(0) |
|
max_equal = (greedy_tokens == cont_toks).all() |
|
|
|
|
|
|
|
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( |
|
-1 |
|
) |
|
|
|
|
|
answer = (float(logits.sum()), bool(max_equal)) |
|
|
|
res.append(answer) |
|
|
|
self.cache_hook.add_partial( |
|
"loglikelihood", request_str, answer |
|
) |
|
pbar.update(1) |
|
|
|
pbar.close() |
|
|
|
return re_ord.get_original(res) |
|
|
|
def generate_until( |
|
self, requests: List[Instance], disable_tqdm: bool = False |
|
) -> List[str]: |
|
|
|
res = [] |
|
|
|
def _collate(x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
toks = self.tok_encode(x[0]) |
|
return -len(toks), x[0] |
|
|
|
pbar = tqdm( |
|
total=len(requests), |
|
disable=(disable_tqdm or (self.rank != 0)), |
|
desc="Running generate_until requests with text+image input", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
re_ords = Collator( |
|
[reg.args for reg in requests], |
|
_collate, |
|
group_by="gen_kwargs", |
|
group_fn=lambda x: x[1], |
|
) |
|
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) |
|
|
|
|
|
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) |
|
for chunk in chunks: |
|
contexts, all_gen_kwargs, aux_arguments = zip(*chunk) |
|
|
|
visuals = [arg["visual"] for arg in aux_arguments] |
|
|
|
if not isinstance(contexts, list): |
|
contexts = list( |
|
contexts |
|
) |
|
|
|
|
|
|
|
|
|
|
|
gen_kwargs = all_gen_kwargs[0] |
|
|
|
if isinstance(gen_kwargs, dict): |
|
kwargs = copy.deepcopy(gen_kwargs) |
|
|
|
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) |
|
else: |
|
raise ValueError( |
|
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" |
|
) |
|
if "max_gen_toks" in kwargs.keys(): |
|
max_gen_toks = kwargs.pop("max_gen_toks") |
|
else: |
|
max_gen_toks = self.max_gen_toks |
|
|
|
|
|
|
|
max_ctx_len = self.max_length - max_gen_toks |
|
|
|
inputs = self.tok_batch_multimodal_encode( |
|
contexts, |
|
visuals, |
|
left_truncate_len=max_ctx_len, |
|
truncation=self.truncation, |
|
) |
|
|
|
context_enc = inputs["input_ids"] |
|
|
|
if "max_length" not in kwargs: |
|
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks |
|
|
|
cont = self._model_multimodal_generate(inputs, stop=until, **kwargs) |
|
|
|
del inputs |
|
torch.cuda.empty_cache() |
|
import gc |
|
|
|
gc.collect() |
|
|
|
|
|
|
|
cont_toks_list = cont.tolist() |
|
for cont_toks, context in zip(cont_toks_list, contexts): |
|
|
|
cont_toks = cont_toks[context_enc.shape[1] :] |
|
|
|
s = self.tok_decode(cont_toks) |
|
|
|
|
|
for term in until: |
|
if len(term) > 0: |
|
|
|
|
|
s = s.split(term)[0] |
|
|
|
res.append(s) |
|
self.cache_hook.add_partial( |
|
"generate_until", (context, gen_kwargs), s |
|
) |
|
pbar.update(1) |
|
|
|
res = re_ords.get_original(res) |
|
|
|
pbar.close() |
|
return res |
|
|