|
import copy |
|
from typing import Dict, List, Optional |
|
|
|
import transformers |
|
from more_itertools import distribute |
|
from tqdm import tqdm |
|
|
|
from lm_eval.api.instance import Instance |
|
from lm_eval.api.registry import register_model |
|
from lm_eval.models.utils import ( |
|
Collator, |
|
handle_stop_sequences, |
|
replace_placeholders, |
|
undistribute, |
|
) |
|
from lm_eval.models.vllm_causallms import VLLM |
|
from lm_eval.utils import eval_logger |
|
|
|
|
|
try: |
|
import ray |
|
from vllm import LLM, SamplingParams |
|
from vllm.lora.request import LoRARequest |
|
from vllm.transformers_utils.tokenizer import get_tokenizer |
|
except ModuleNotFoundError: |
|
pass |
|
|
|
|
|
DEFAULT_IMAGE_PLACEHOLDER = "<image>" |
|
|
|
|
|
@register_model("vllm-vlm") |
|
class VLLM_VLM(VLLM): |
|
MULTIMODAL = True |
|
|
|
def __init__( |
|
self, |
|
pretrained: str, |
|
trust_remote_code: Optional[bool] = False, |
|
revision: Optional[str] = None, |
|
interleave: bool = True, |
|
|
|
max_images: int = 999, |
|
**kwargs, |
|
): |
|
if max_images != 999: |
|
kwargs["limit_mm_per_prompt"] = {"image": max_images} |
|
eval_logger.info(f"Setting limit_mm_per_prompt[image] to {max_images}") |
|
super().__init__( |
|
pretrained=pretrained, |
|
trust_remote_code=trust_remote_code, |
|
revision=revision, |
|
**kwargs, |
|
) |
|
self.interleave = interleave |
|
self.max_images = max_images |
|
self.processor = transformers.AutoProcessor.from_pretrained( |
|
pretrained, |
|
revision=revision, |
|
trust_remote_code=trust_remote_code, |
|
) |
|
self.chat_applied: bool = False |
|
|
|
def tok_batch_multimodal_encode( |
|
self, |
|
strings: List[str], |
|
images, |
|
left_truncate_len: int = None, |
|
truncation: bool = False, |
|
): |
|
images = [img[: self.max_images] for img in images] |
|
|
|
if self.chat_applied is False: |
|
strings = [ |
|
replace_placeholders( |
|
string, |
|
DEFAULT_IMAGE_PLACEHOLDER, |
|
DEFAULT_IMAGE_PLACEHOLDER, |
|
self.max_images, |
|
) |
|
for string in strings |
|
] |
|
|
|
outputs = [] |
|
for x, i in zip(strings, images): |
|
inputs = { |
|
"prompt": x, |
|
"multi_modal_data": {"image": i}, |
|
} |
|
outputs.append(inputs) |
|
return outputs |
|
|
|
def _model_generate( |
|
self, |
|
requests: List[List[dict]] = None, |
|
generate: bool = False, |
|
max_tokens: int = None, |
|
stop: Optional[List[str]] = None, |
|
**kwargs, |
|
): |
|
if generate: |
|
kwargs = self.modify_gen_kwargs(kwargs) |
|
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs) |
|
else: |
|
sampling_params = SamplingParams( |
|
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False |
|
) |
|
if self.data_parallel_size > 1: |
|
|
|
|
|
|
|
|
|
|
|
@ray.remote |
|
def run_inference_one_model( |
|
model_args: dict, sampling_params, requests: List[List[dict]] |
|
): |
|
llm = LLM(**model_args) |
|
return llm.generate(requests, sampling_params=sampling_params) |
|
|
|
|
|
|
|
requests = [list(x) for x in distribute(self.data_parallel_size, requests)] |
|
inputs = ((self.model_args, sampling_params, req) for req in requests) |
|
object_refs = [run_inference_one_model.remote(*x) for x in inputs] |
|
results = ray.get(object_refs) |
|
|
|
ray.shutdown() |
|
|
|
return undistribute(results) |
|
|
|
if self.lora_request is not None: |
|
outputs = self.model.generate( |
|
requests, |
|
sampling_params=sampling_params, |
|
use_tqdm=True if self.batch_size == "auto" else False, |
|
lora_request=self.lora_request, |
|
) |
|
else: |
|
outputs = self.model.generate( |
|
requests, |
|
sampling_params=sampling_params, |
|
use_tqdm=True if self.batch_size == "auto" else False, |
|
) |
|
return outputs |
|
|
|
def apply_chat_template( |
|
self, chat_history: List[Dict[str, str]], add_generation_prompt=True |
|
) -> str: |
|
self.chat_applied = True |
|
if not self.interleave: |
|
for content in chat_history: |
|
c = [] |
|
text = content["content"] |
|
|
|
|
|
image_count = min( |
|
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER) |
|
) |
|
text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "") |
|
|
|
|
|
for _ in range(image_count): |
|
c.append({"type": "image", "image": None}) |
|
|
|
|
|
c.append({"type": "text", "text": text}) |
|
|
|
content["content"] = c |
|
else: |
|
for content in chat_history: |
|
c = [] |
|
text = content["content"] |
|
expected_image_count = min( |
|
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER) |
|
) |
|
actual_image_count = 0 |
|
|
|
text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER) |
|
|
|
for i, part in enumerate(text_parts): |
|
|
|
if part: |
|
c.append({"type": "text", "text": part}) |
|
if ( |
|
(i < len(text_parts) - 1) and i < self.max_images |
|
): |
|
c.append({"type": "image"}) |
|
actual_image_count += 1 |
|
|
|
content["content"] = c |
|
|
|
if actual_image_count != expected_image_count: |
|
raise ValueError( |
|
f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}" |
|
) |
|
|
|
return self.processor.apply_chat_template( |
|
chat_history, |
|
add_generation_prompt=add_generation_prompt, |
|
continue_final_message=not add_generation_prompt, |
|
) |
|
|
|
def generate_until( |
|
self, requests: List[Instance], disable_tqdm: bool = False |
|
) -> List[str]: |
|
|
|
res = [] |
|
|
|
def _collate(x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
toks = self.tok_encode(x[0]) |
|
return -len(toks), x[0] |
|
|
|
pbar = tqdm( |
|
total=len(requests), |
|
disable=(disable_tqdm or (self.rank != 0)), |
|
desc="Running generate_until requests with text+image input", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
re_ords = Collator( |
|
[reg.args for reg in requests], |
|
_collate, |
|
group_by="gen_kwargs", |
|
group_fn=lambda x: x[1], |
|
) |
|
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) |
|
eos = self.tokenizer.decode(self.eot_token_id) |
|
for chunk in chunks: |
|
contexts, all_gen_kwargs, aux_arguments = zip(*chunk) |
|
|
|
visuals = [arg["visual"] for arg in aux_arguments] |
|
|
|
if not isinstance(contexts, list): |
|
contexts = list( |
|
contexts |
|
) |
|
|
|
|
|
|
|
|
|
gen_kwargs = all_gen_kwargs[0] |
|
|
|
if isinstance(gen_kwargs, dict): |
|
kwargs = copy.deepcopy(gen_kwargs) |
|
|
|
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) |
|
else: |
|
raise ValueError( |
|
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" |
|
) |
|
if "max_gen_toks" in kwargs.keys(): |
|
max_gen_toks = kwargs.pop("max_gen_toks") |
|
else: |
|
max_gen_toks = self.max_gen_toks |
|
|
|
max_ctx_len = self.max_length - max_gen_toks |
|
|
|
inputs = self.tok_batch_multimodal_encode( |
|
contexts, |
|
visuals, |
|
left_truncate_len=max_ctx_len, |
|
) |
|
|
|
cont = self._model_generate( |
|
inputs, stop=until, generate=True, max_tokens=max_gen_toks, **kwargs |
|
) |
|
|
|
for output, context in zip(cont, contexts): |
|
generated_text = output.outputs[0].text |
|
res.append(generated_text) |
|
self.cache_hook.add_partial( |
|
"generate_until", (context, gen_kwargs), generated_text |
|
) |
|
pbar.update(1) |
|
|
|
res = re_ords.get_original(res) |
|
|
|
pbar.close() |
|
return res |
|
|