File size: 10,990 Bytes
9d5b280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import copy
from typing import Dict, List, Optional

import transformers
from more_itertools import distribute
from tqdm import tqdm

from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model
from lm_eval.models.utils import (
    Collator,
    handle_stop_sequences,
    replace_placeholders,
    undistribute,
)
from lm_eval.models.vllm_causallms import VLLM
from lm_eval.utils import eval_logger


try:
    import ray
    from vllm import LLM, SamplingParams
    from vllm.lora.request import LoRARequest  # noqa: F401
    from vllm.transformers_utils.tokenizer import get_tokenizer  # noqa: F401
except ModuleNotFoundError:
    pass


DEFAULT_IMAGE_PLACEHOLDER = "<image>"


@register_model("vllm-vlm")
class VLLM_VLM(VLLM):
    MULTIMODAL = True

    def __init__(
        self,
        pretrained: str,
        trust_remote_code: Optional[bool] = False,
        revision: Optional[str] = None,
        interleave: bool = True,
        # TODO<baber>: handle max_images and limit_mm_per_prompt better
        max_images: int = 999,
        **kwargs,
    ):
        if max_images != 999:
            kwargs["limit_mm_per_prompt"] = {"image": max_images}
            eval_logger.info(f"Setting limit_mm_per_prompt[image] to {max_images}")
        super().__init__(
            pretrained=pretrained,
            trust_remote_code=trust_remote_code,
            revision=revision,
            **kwargs,
        )
        self.interleave = interleave
        self.max_images = max_images
        self.processor = transformers.AutoProcessor.from_pretrained(
            pretrained,
            revision=revision,
            trust_remote_code=trust_remote_code,
        )
        self.chat_applied: bool = False

    def tok_batch_multimodal_encode(
        self,
        strings: List[str],  # note that input signature of this fn is different
        images,  # TODO: typehint on this
        left_truncate_len: int = None,
        truncation: bool = False,
    ):
        images = [img[: self.max_images] for img in images]
        # TODO<baber>: is the default placeholder always <image>?
        if self.chat_applied is False:
            strings = [
                replace_placeholders(
                    string,
                    DEFAULT_IMAGE_PLACEHOLDER,
                    DEFAULT_IMAGE_PLACEHOLDER,
                    self.max_images,
                )
                for string in strings
            ]

        outputs = []
        for x, i in zip(strings, images):
            inputs = {
                "prompt": x,
                "multi_modal_data": {"image": i},
            }
            outputs.append(inputs)
        return outputs

    def _model_generate(
        self,
        requests: List[List[dict]] = None,
        generate: bool = False,
        max_tokens: int = None,
        stop: Optional[List[str]] = None,
        **kwargs,
    ):
        if generate:
            kwargs = self.modify_gen_kwargs(kwargs)
            sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
        else:
            sampling_params = SamplingParams(
                temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
            )
        if self.data_parallel_size > 1:
            # vLLM hangs if tensor_parallel > 1 and resources are set in ray.remote
            # also seems to only work with decorator and not with ray.remote() fn
            # see https://github.com/vllm-project/vllm/issues/973
            # note: this has changed on 0.3.3, and it only works now if num_gpus are set.
            # but then tensor_parallel breaks
            @ray.remote
            def run_inference_one_model(
                model_args: dict, sampling_params, requests: List[List[dict]]
            ):
                llm = LLM(**model_args)
                return llm.generate(requests, sampling_params=sampling_params)

            # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
            # interleaved important to balance context lengths across workers
            requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
            inputs = ((self.model_args, sampling_params, req) for req in requests)
            object_refs = [run_inference_one_model.remote(*x) for x in inputs]
            results = ray.get(object_refs)
            # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
            ray.shutdown()
            # flatten results
            return undistribute(results)

        if self.lora_request is not None:
            outputs = self.model.generate(
                requests,
                sampling_params=sampling_params,
                use_tqdm=True if self.batch_size == "auto" else False,
                lora_request=self.lora_request,
            )
        else:
            outputs = self.model.generate(
                requests,
                sampling_params=sampling_params,
                use_tqdm=True if self.batch_size == "auto" else False,
            )
        return outputs

    def apply_chat_template(
        self, chat_history: List[Dict[str, str]], add_generation_prompt=True
    ) -> str:
        self.chat_applied = True
        if not self.interleave:
            for content in chat_history:
                c = []
                text = content["content"]

                # Count and remove image placeholders
                image_count = min(
                    self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
                )
                text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "")

                # Add image entries
                for _ in range(image_count):
                    c.append({"type": "image", "image": None})

                # Add single text entry at the end
                c.append({"type": "text", "text": text})

                content["content"] = c
        else:
            for content in chat_history:
                c = []
                text = content["content"]
                expected_image_count = min(
                    self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
                )
                actual_image_count = 0

                text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER)

                for i, part in enumerate(text_parts):
                    # TODO: concatenate text parts (esp. if skipping images)?
                    if part:  # Add non-empty text parts
                        c.append({"type": "text", "text": part})
                    if (
                        (i < len(text_parts) - 1) and i < self.max_images
                    ):  # Add image placeholder after each split except the last
                        c.append({"type": "image"})
                        actual_image_count += 1

                content["content"] = c

                if actual_image_count != expected_image_count:
                    raise ValueError(
                        f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}"
                    )

        return self.processor.apply_chat_template(
            chat_history,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=not add_generation_prompt,
        )

    def generate_until(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[str]:
        # TODO: support text-only reqs
        res = []

        def _collate(x):
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
            toks = self.tok_encode(x[0])
            return -len(toks), x[0]

        pbar = tqdm(
            total=len(requests),
            disable=(disable_tqdm or (self.rank != 0)),
            desc="Running generate_until requests with text+image input",
        )
        # TODO: port auto-batch sizing into this.

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
        re_ords = Collator(
            [reg.args for reg in requests],
            _collate,
            group_by="gen_kwargs",
            group_fn=lambda x: x[1],
        )
        chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
        eos = self.tokenizer.decode(self.eot_token_id)
        for chunk in chunks:
            contexts, all_gen_kwargs, aux_arguments = zip(*chunk)

            visuals = [arg["visual"] for arg in aux_arguments]

            if not isinstance(contexts, list):
                contexts = list(
                    contexts
                )  # for Qwen2-VL, processor is unhappy accepting a tuple of strings instead of a list.
                # TODO: could we upstream this workaround to HF?

            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                # add EOS token to stop sequences
                until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
            else:
                raise ValueError(
                    f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
                )
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            max_ctx_len = self.max_length - max_gen_toks

            inputs = self.tok_batch_multimodal_encode(
                contexts,
                visuals,
                left_truncate_len=max_ctx_len,
            )

            cont = self._model_generate(
                inputs, stop=until, generate=True, max_tokens=max_gen_toks, **kwargs
            )

            for output, context in zip(cont, contexts):
                generated_text = output.outputs[0].text
                res.append(generated_text)
                self.cache_hook.add_partial(
                    "generate_until", (context, gen_kwargs), generated_text
                )
                pbar.update(1)
        # reorder this group of results back to original unsorted form
        res = re_ords.get_original(res)

        pbar.close()
        return res