R-help-chat / pipeline.py
jedick
Enable FlashAttention
7e18a82
from transformers.pipelines.text_generation import Chat
from transformers import TextGenerationPipeline
from typing import Dict
class MyTextGenerationPipeline(TextGenerationPipeline):
"""
This subclass overrides the preprocess method to add pad_to_multiple_of=8 to tokenizer_kwargs.
Fix for: "RuntimeError: p.attn_bias_ptr is not correctly aligned"
https://github.com/google-deepmind/gemma/issues/169
NOTE: we also need padding="longest", which is set during class instantiation
"""
def preprocess(
self,
prompt_text,
prefix="",
handle_long_generation=None,
add_special_tokens=None,
truncation=None,
padding=None,
max_length=None,
continue_final_message=None,
**generate_kwargs,
):
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs = {
"add_special_tokens": add_special_tokens,
"truncation": truncation,
"padding": padding,
"max_length": max_length,
"pad_to_multiple_of": 8,
}
tokenizer_kwargs = {
key: value for key, value in tokenizer_kwargs.items() if value is not None
}
if isinstance(prompt_text, Chat):
tokenizer_kwargs.pop(
"add_special_tokens", None
) # ignore add_special_tokens on chats
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
# because very few models support multiple separate, consecutive assistant messages
if continue_final_message is None:
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
inputs = self.tokenizer.apply_chat_template(
prompt_text.messages,
add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message,
return_dict=True,
return_tensors=self.framework,
**tokenizer_kwargs,
)
else:
inputs = self.tokenizer(
prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs
)
inputs["prompt_text"] = prompt_text
if handle_long_generation == "hole":
cur_len = inputs["input_ids"].shape[-1]
if "max_new_tokens" in generate_kwargs:
new_tokens = generate_kwargs["max_new_tokens"]
else:
new_tokens = (
generate_kwargs.get("max_length", self.generation_config.max_length)
- cur_len
)
if new_tokens < 0:
raise ValueError("We cannot infer how many new tokens are expected")
if cur_len + new_tokens > self.tokenizer.model_max_length:
keep_length = self.tokenizer.model_max_length - new_tokens
if keep_length <= 0:
raise ValueError(
"We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
" models max length"
)
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
if "attention_mask" in inputs:
inputs["attention_mask"] = inputs["attention_mask"][
:, -keep_length:
]
return inputs