# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This is a fully self-contained version of the model script. # It includes the MDMGenerationMixin and all necessary utilities for public release. import logging import warnings import copy from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributions as dists from torch import nn from torch.nn import functional as F from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from transformers.generation.configuration_utils import GenerationConfig from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput, ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.processing_utils import Unpack from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings, ) logger = logging.getLogger(__name__) # ============================================================================== # Start of Generation Utilities (Integrated directly into this file) # ============================================================================== def top_p_logits(logits, top_p=None): sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) return logits def top_k_logits(logits, top_k=None): if top_k is None or top_k == 0: return logits top_k = min(top_k, logits.size(-1)) indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) return logits def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False): if temperature > 0: logits = logits / temperature if top_p is not None and top_p < 1: logits = top_p_logits(logits, top_p) if top_k is not None: logits = top_k_logits(logits, top_k) probs = torch.softmax(logits.float(), dim=-1) if temperature > 0: x0 = dists.Categorical(probs=probs).sample() else: _, x0 = probs.max(dim=-1) confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) if margin_confidence: sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) top1_probs = sorted_probs[..., 0] top2_probs = sorted_probs[..., 1] confidence = top1_probs - top2_probs elif neg_entropy: log_probs = torch.log(probs.clamp(min=1e-10)) confidence = (probs * log_probs).sum(dim=-1) return confidence, x0 @dataclass class MDMModelOutput(ModelOutput): sequences: torch.LongTensor = None history: Optional[Tuple[torch.FloatTensor]] = None class MDMGenerationConfig(GenerationConfig): def __init__(self, **kwargs): super().__init__(**kwargs) self.temperature: float = kwargs.pop("temperature", 0.0) self.top_p: Optional[float] = kwargs.pop("top_p", None) self.top_k: Optional[int] = kwargs.pop("top_k", None) self.eps: float = kwargs.pop("eps", 1e-3) self.steps: int = kwargs.pop("steps", 512) self.alg: str = kwargs.pop("alg", 'entropy') self.alg_temp: Optional[float] = kwargs.pop("alg_temp", 0.0) self.output_history: bool = kwargs.pop("output_history", False) self.mask_token_id = kwargs.pop("mask_token_id", None) class MDMGenerationMixin: """ Mixin class for Masked Diffusion Model generation. """ @staticmethod def _expand_inputs_for_generation( expand_size: int = 1, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None ) -> Tuple[torch.LongTensor, Dict[str, Any]]: if expand_size == 1: return input_ids, attention_mask if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) if attention_mask is not None: attention_mask = attention_mask.repeat_interleave(expand_size, dim=0) return input_ids, attention_mask def _prepare_generation_config( self, generation_config: Optional[GenerationConfig], **kwargs ) -> MDMGenerationConfig: if generation_config is None: generation_config = self.generation_config if not isinstance(generation_config, MDMGenerationConfig): generation_config = MDMGenerationConfig.from_dict(generation_config.to_dict()) generation_config.update(**kwargs) return generation_config @torch.no_grad() def diffusion_generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[MDMGenerationConfig] = None, **kwargs, ) -> Union[MDMModelOutput, torch.LongTensor]: generation_config = self._prepare_generation_config(generation_config, **kwargs) input_ids = inputs attention_mask = kwargs.get("attention_mask", None) if input_ids is None: raise ValueError("`inputs` must be provided for diffusion generation.") if generation_config.max_new_tokens is not None: generation_config.max_length = input_ids.shape[-1] + generation_config.max_new_tokens input_ids, attention_mask = self._expand_inputs_for_generation( expand_size=generation_config.num_return_sequences, input_ids=input_ids, attention_mask=attention_mask ) return self._sample( input_ids, attention_mask=attention_mask, generation_config=generation_config ) def _sample( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor], generation_config: MDMGenerationConfig ) -> Union[MDMModelOutput, torch.LongTensor]: max_length = generation_config.max_length mask_token_id = generation_config.mask_token_id if mask_token_id is None: raise ValueError("`mask_token_id` must be set in the generation config.") steps = generation_config.steps eps = generation_config.eps alg = generation_config.alg alg_temp = generation_config.alg_temp temperature = generation_config.temperature top_p = generation_config.top_p top_k = generation_config.top_k histories = [] if generation_config.output_history else None x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) gen_attention_mask = (x != self.config.pad_token_id).long() if self.config.pad_token_id is not None else None timesteps = torch.linspace(1, eps, steps + 1, device=x.device) for i in range(steps): mask_index = (x == mask_token_id) if not mask_index.any(): break outputs = self(input_ids=x, attention_mask=gen_attention_mask, is_causal=False) logits = outputs.logits logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) mask_logits = logits[mask_index] t = timesteps[i] s = timesteps[i + 1] confidence_alg_map = {'maskgit_plus': False, 'topk_margin': True, 'entropy': True} is_margin_conf = confidence_alg_map.get(alg, False) is_neg_entropy = alg == 'entropy' confidence, x0 = sample_tokens(mask_logits, temperature, top_p, top_k, margin_confidence=is_margin_conf, neg_entropy=is_neg_entropy) num_masked = mask_index.sum(dim=-1, keepdim=True) gamma = 1 - s / t num_to_unmask = (num_masked * gamma).long() full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=confidence.dtype) full_confidence[mask_index] = confidence if (alg_temp is not None and alg_temp > 0): unmask_probs = F.softmax(full_confidence / alg_temp, dim=-1) unmask_indices = torch.multinomial(unmask_probs, num_samples=num_to_unmask.max(), replacement=False) else: _, unmask_indices = torch.topk(full_confidence, k=num_to_unmask.max(), dim=-1) rows = torch.arange(x.size(0), device=x.device).unsqueeze(1) unmask_selection_mask = torch.zeros_like(x, dtype=torch.bool) unmask_selection_mask[rows, unmask_indices] = True unmask_selection_mask = unmask_selection_mask & (torch.cumsum(unmask_selection_mask.long(), dim=-1) <= num_to_unmask) x_unmasked_proposals = torch.full_like(x, fill_value=mask_token_id) x_unmasked_proposals[mask_index] = x0 x[unmask_selection_mask] = x_unmasked_proposals[unmask_selection_mask] if histories is not None: histories.append(x.clone()) if generation_config.return_dict_in_generate: return MDMModelOutput(sequences=x, history=histories) else: return x # ============================================================================== # End of Generation Utilities # ============================================================================== _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf" _CONFIG_FOR_DOC = "Qwen2Config" class Qwen2MLP(nn.Module): # ... (class unchanged) def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj def rotate_half(x): # ... (function unchanged) x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # ... (function unchanged) cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: # ... (function unchanged) batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class Qwen2Attention(nn.Module): # ... (class unchanged) def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, is_causal: bool = True, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() hidden_shape = (bsz, q_len, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) full_q_len = query_states.size(2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get(self.config._attn_implementation, None) if attention_interface is None: raise ValueError(f"Attention implementation {self.config._attn_implementation} not found.") if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once("Using SDPA with `output_attentions=True` requires eager attention.") attention_interface = ALL_ATTENTION_FUNCTIONS["eager"] attn_output, attn_weights = attention_interface( query_states, key_states, value_states, attention_mask=attention_mask, dropout=self.attention_dropout if self.training else 0.0, is_causal=is_causal, **kwargs, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class Qwen2RMSNorm(nn.Module): # ... (class unchanged) def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class Qwen2DecoderLayer(nn.Module): # ... (class unchanged) def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, is_causal: bool = True, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, cache_position=cache_position, position_embeddings=position_embeddings, is_causal=is_causal, **kwargs, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class Qwen2RotaryEmbedding(nn.Module): # ... (class unchanged) def __init__(self, config: Qwen2Config, device=None): super().__init__() if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class Qwen2PreTrainedModel(PreTrainedModel): # ... (class unchanged) config_class = Qwen2Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class Qwen2Model(Qwen2PreTrainedModel): # ... (class unchanged) def __init__(self, config: Qwen2Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, is_causal: bool = True, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) past_key_values_length = 0 if use_cache: if past_key_values is None: past_key_values = DynamicCache() past_key_values_length = past_key_values.get_seq_length() if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, is_causal) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, is_causal=is_causal, **flash_attn_kwargs, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def _update_causal_mask(self, attention_mask, input_tensor, cache_position, is_causal): if not is_causal: return attention_mask seq_len = input_tensor.shape[1] if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None dtype = input_tensor.dtype device = input_tensor.device causal_mask = torch.triu(torch.full((seq_len, seq_len), torch.finfo(dtype).min, device=device), 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() causal_mask = causal_mask + attention_mask[:, None, None, :] return causal_mask class Qwen2ForCausalLM(Qwen2PreTrainedModel, MDMGenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = Qwen2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, is_causal: bool = True, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, is_causal=is_causal, **kwargs, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) ModelClass = Qwen2ForCausalLM __all__ = ["Qwen2ForCausalLM", "Qwen2Model", "Qwen2PreTrainedModel", "MDMGenerationMixin"]