File size: 10,990 Bytes
9d5b280 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 |
import copy
from typing import Dict, List, Optional
import transformers
from more_itertools import distribute
from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model
from lm_eval.models.utils import (
Collator,
handle_stop_sequences,
replace_placeholders,
undistribute,
)
from lm_eval.models.vllm_causallms import VLLM
from lm_eval.utils import eval_logger
try:
import ray
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest # noqa: F401
from vllm.transformers_utils.tokenizer import get_tokenizer # noqa: F401
except ModuleNotFoundError:
pass
DEFAULT_IMAGE_PLACEHOLDER = "<image>"
@register_model("vllm-vlm")
class VLLM_VLM(VLLM):
MULTIMODAL = True
def __init__(
self,
pretrained: str,
trust_remote_code: Optional[bool] = False,
revision: Optional[str] = None,
interleave: bool = True,
# TODO<baber>: handle max_images and limit_mm_per_prompt better
max_images: int = 999,
**kwargs,
):
if max_images != 999:
kwargs["limit_mm_per_prompt"] = {"image": max_images}
eval_logger.info(f"Setting limit_mm_per_prompt[image] to {max_images}")
super().__init__(
pretrained=pretrained,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
self.interleave = interleave
self.max_images = max_images
self.processor = transformers.AutoProcessor.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
)
self.chat_applied: bool = False
def tok_batch_multimodal_encode(
self,
strings: List[str], # note that input signature of this fn is different
images, # TODO: typehint on this
left_truncate_len: int = None,
truncation: bool = False,
):
images = [img[: self.max_images] for img in images]
# TODO<baber>: is the default placeholder always <image>?
if self.chat_applied is False:
strings = [
replace_placeholders(
string,
DEFAULT_IMAGE_PLACEHOLDER,
DEFAULT_IMAGE_PLACEHOLDER,
self.max_images,
)
for string in strings
]
outputs = []
for x, i in zip(strings, images):
inputs = {
"prompt": x,
"multi_modal_data": {"image": i},
}
outputs.append(inputs)
return outputs
def _model_generate(
self,
requests: List[List[dict]] = None,
generate: bool = False,
max_tokens: int = None,
stop: Optional[List[str]] = None,
**kwargs,
):
if generate:
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
)
if self.data_parallel_size > 1:
# vLLM hangs if tensor_parallel > 1 and resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn
# see https://github.com/vllm-project/vllm/issues/973
# note: this has changed on 0.3.3, and it only works now if num_gpus are set.
# but then tensor_parallel breaks
@ray.remote
def run_inference_one_model(
model_args: dict, sampling_params, requests: List[List[dict]]
):
llm = LLM(**model_args)
return llm.generate(requests, sampling_params=sampling_params)
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
inputs = ((self.model_args, sampling_params, req) for req in requests)
object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs)
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
ray.shutdown()
# flatten results
return undistribute(results)
if self.lora_request is not None:
outputs = self.model.generate(
requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request,
)
else:
outputs = self.model.generate(
requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
)
return outputs
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
) -> str:
self.chat_applied = True
if not self.interleave:
for content in chat_history:
c = []
text = content["content"]
# Count and remove image placeholders
image_count = min(
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
)
text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "")
# Add image entries
for _ in range(image_count):
c.append({"type": "image", "image": None})
# Add single text entry at the end
c.append({"type": "text", "text": text})
content["content"] = c
else:
for content in chat_history:
c = []
text = content["content"]
expected_image_count = min(
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
)
actual_image_count = 0
text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER)
for i, part in enumerate(text_parts):
# TODO: concatenate text parts (esp. if skipping images)?
if part: # Add non-empty text parts
c.append({"type": "text", "text": part})
if (
(i < len(text_parts) - 1) and i < self.max_images
): # Add image placeholder after each split except the last
c.append({"type": "image"})
actual_image_count += 1
content["content"] = c
if actual_image_count != expected_image_count:
raise ValueError(
f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}"
)
return self.processor.apply_chat_template(
chat_history,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
# TODO: support text-only reqs
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return -len(toks), x[0]
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests with text+image input",
)
# TODO: port auto-batch sizing into this.
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator(
[reg.args for reg in requests],
_collate,
group_by="gen_kwargs",
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
eos = self.tokenizer.decode(self.eot_token_id)
for chunk in chunks:
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
visuals = [arg["visual"] for arg in aux_arguments]
if not isinstance(contexts, list):
contexts = list(
contexts
) # for Qwen2-VL, processor is unhappy accepting a tuple of strings instead of a list.
# TODO: could we upstream this workaround to HF?
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
max_ctx_len = self.max_length - max_gen_toks
inputs = self.tok_batch_multimodal_encode(
contexts,
visuals,
left_truncate_len=max_ctx_len,
)
cont = self._model_generate(
inputs, stop=until, generate=True, max_tokens=max_gen_toks, **kwargs
)
for output, context in zip(cont, contexts):
generated_text = output.outputs[0].text
res.append(generated_text)
self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
pbar.close()
return res
|