OSUM-EChat / patches /modelling_qwen2_infer_gpu.py
xlgeng's picture
开始部署
841f290
import math
import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union
import transformers.models
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2RotaryEmbedding,
Qwen2ForCausalLM,
Qwen2MLP,
Qwen2RMSNorm,
apply_rotary_pos_emb,
repeat_kv,
_prepare_4d_causal_attention_mask_with_cache_position,
)
from transformers.utils import logging
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.cache_utils import Cache, StaticCache, SlidingWindowCache
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from .utils import InferTaskCode
logger = logging.get_logger(__name__)
_GPU_QWEN_TORCH_COMPILE = True
# ===================================================================
# =============================Attention=============================
# ===================================================================
class GPUQwen2Attention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
"""
def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.attention_dropout = config.attention_dropout
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = Qwen2RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
# Adapted from Qwen2Attention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# NOTE: RoPE return all embedding (to satisfy torch compile)
cos, sin = self.rotary_emb(value_states, seq_len=past_key_value.get_max_length())
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : past_key_value.get_max_length()]
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# ===================================================================
# =============================Layer=================================
# ===================================================================
class GPUQwen2DecoderLayer(nn.Module):
def __init__(self, config: Qwen2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
if config.sliding_window and config._attn_implementation != "flash_attention_2":
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
self.self_attn = GPUQwen2Attention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
# ===================================================================
# ========================Qwen2ForCausalLM===========================
# ===================================================================
class InferQwen2ForCausalLM(Qwen2ForCausalLM):
def __init__(self, config):
super().__init__(config)
self.compile_forward = torch.compile(self.simplify_forward, dynamic=False, fullgraph=True) \
if _GPU_QWEN_TORCH_COMPILE else self.simplify_forward
self.text_phase = True
'''
NOTE: 重写原Qwen2ForCausalLM forward函数,torchair直接编译原函数在返回CausalLMOutputWithPast时会出现编译错误
'''
def simplify_forward(self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
return outputs
def forward(self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
do_compile = True
) -> Union[Tuple, CausalLMOutputWithPast]:
if past_key_values is not None:
past_key_values.training = False
# print(self.text_phase)
if input_ids is not None:
if self.text_phase:
inputs_embeds = self.model.embed_tokens(input_ids)
else:
inputs_embeds = self.speech_token_emded(input_ids)
if torch.isin(input_ids, 151645).any():
self.text_phase = False
input_ids = None
if (inputs_embeds is not None and cache_position[0] == 0) or do_compile==False :
# prefill branch
outputs = self.simplify_forward(input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position)
else:
# decoding
outputs = self.compile_forward(input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position)
last_hidden_states = outputs.last_hidden_state
if self.text_phase:
logits = self.lm_head(last_hidden_states)
else:
logits = self.speech_head(last_hidden_states)
logits = logits.float()
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
"""
Mainly add static cache support
"""
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`,
# as otherwise the input `position_ids` would have various stride during the decoding.
# Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case,
# `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
# NOTE: 与上述的position_ids相同,same as position_ids, for torch.compile and cuda graph
input_ids = input_ids.clone(memory_format=torch.contiguous_format)
model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None and cache_position[0] == 0:
# prefill phase, inputs_embeds has shape (B,S,H)
batch_size, sequence_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
device = inputs_embeds.device
else:
# decdoing phase, input_ids has shape (B,S)
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
if inputs_embeds is not None and inputs_embeds.ndim == 2 or input_ids is not None and input_ids.size(-1) == 1:
# we only expand attention mask in docoding mode
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"do_compile": kwargs['do_compile'],
}
)
return model_inputs
# ===================================================================
print("========================= DO Qwen2 PATCH ===========================")
# ===================================================================
transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel._supports_static_cache = True # enable static cache
transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer = GPUQwen2DecoderLayer
transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM = InferQwen2ForCausalLM