File size: 3,525 Bytes
09d7140
 
 
 
 
 
 
 
 
 
7e18a82
09d7140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from transformers.pipelines.text_generation import Chat
from transformers import TextGenerationPipeline
from typing import Dict


class MyTextGenerationPipeline(TextGenerationPipeline):
    """
    This subclass overrides the preprocess method to add pad_to_multiple_of=8 to tokenizer_kwargs.
    Fix for: "RuntimeError: p.attn_bias_ptr is not correctly aligned"
    https://github.com/google-deepmind/gemma/issues/169
    NOTE: we also need padding="longest", which is set during class instantiation
    """

    def preprocess(
        self,
        prompt_text,
        prefix="",
        handle_long_generation=None,
        add_special_tokens=None,
        truncation=None,
        padding=None,
        max_length=None,
        continue_final_message=None,
        **generate_kwargs,
    ):
        # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
        tokenizer_kwargs = {
            "add_special_tokens": add_special_tokens,
            "truncation": truncation,
            "padding": padding,
            "max_length": max_length,
            "pad_to_multiple_of": 8,
        }
        tokenizer_kwargs = {
            key: value for key, value in tokenizer_kwargs.items() if value is not None
        }

        if isinstance(prompt_text, Chat):
            tokenizer_kwargs.pop(
                "add_special_tokens", None
            )  # ignore add_special_tokens on chats
            # If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
            # because very few models support multiple separate, consecutive assistant messages
            if continue_final_message is None:
                continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
            inputs = self.tokenizer.apply_chat_template(
                prompt_text.messages,
                add_generation_prompt=not continue_final_message,
                continue_final_message=continue_final_message,
                return_dict=True,
                return_tensors=self.framework,
                **tokenizer_kwargs,
            )
        else:
            inputs = self.tokenizer(
                prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs
            )

        inputs["prompt_text"] = prompt_text

        if handle_long_generation == "hole":
            cur_len = inputs["input_ids"].shape[-1]
            if "max_new_tokens" in generate_kwargs:
                new_tokens = generate_kwargs["max_new_tokens"]
            else:
                new_tokens = (
                    generate_kwargs.get("max_length", self.generation_config.max_length)
                    - cur_len
                )
                if new_tokens < 0:
                    raise ValueError("We cannot infer how many new tokens are expected")
            if cur_len + new_tokens > self.tokenizer.model_max_length:
                keep_length = self.tokenizer.model_max_length - new_tokens
                if keep_length <= 0:
                    raise ValueError(
                        "We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
                        " models max length"
                    )

                inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
                if "attention_mask" in inputs:
                    inputs["attention_mask"] = inputs["attention_mask"][
                        :, -keep_length:
                    ]

        return inputs