from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from torch.nested._internal.nested_tensor import nested_from_padded from transformers import ( LlamaConfig, LlamaModel, LlamaPreTrainedModel, PreTrainedTokenizer, ) from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, LlamaMLP, LlamaRMSNorm, LlamaRotaryEmbedding, rotate_half, ) from transformers.processing_utils import Unpack class ModifiedLlamaAttention(LlamaAttention): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.is_causal = False def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get( "output_attentions", False ): warnings.warn( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) attn_output, attn_weights = sdpa_attention_forward( self, query_states, key_states, value_states, attention_mask, dropout=0.0, scaling=self.scaling, is_causal=False, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights def sdpa_attention_forward( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, dropout: float = 0.0, scaling: Optional[float] = None, is_causal: Optional[bool] = None, **kwargs: Any, ) -> Tuple[torch.Tensor, None]: if hasattr(module, "num_key_value_groups"): if key.is_nested: key = repeat_jagged_kv(key, module.num_key_value_groups) value = repeat_jagged_kv(value, module.num_key_value_groups) else: key = repeat_dense_kv(key, module.num_key_value_groups) value = repeat_dense_kv(value, module.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None and causal_mask.ndim == 4: causal_mask = causal_mask[:, :, :, : key.shape[-2]] # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions # Reference: https://github.com/pytorch/pytorch/issues/112577. query = query.contiguous() key = key.contiguous() value = value.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` if is_causal is None: is_causal = query.shape[2] > 1 and causal_mask is None # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # We convert it to a bool for the SDPA kernel that only accepts bools. if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): is_causal = is_causal.item() attn_output = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=causal_mask, dropout_p=dropout, scale=scaling, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None def repeat_jagged_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape expand_shape = (batch, num_key_value_heads, -1, n_rep, head_dim) if n_rep == 1: return hidden_states hidden_states = ( hidden_states.unsqueeze(3) .expand(expand_shape) .transpose(1, 2) .flatten(2, 3) .transpose(1, 2) ) return hidden_states def repeat_dense_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def apply_rotary_pos_emb( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor]: """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) if q.is_nested and k.is_nested: if q.layout != torch.jagged: raise NotImplementedError(f"Unsupported layout: {q.layout}") if k.layout != torch.jagged: raise NotImplementedError(f"Unsupported layout: {k.layout}") return _jagged_tensor_forward(q, k, cos, sin) else: return _padded_tensor_forward(q, k, cos, sin) def _jagged_tensor_forward( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: q_dense = q.to_padded_tensor(0.0) k_dense = k.to_padded_tensor(0.0) q_dense_embed = (q_dense * cos) + (rotate_half(q_dense) * sin) k_dense_embed = (k_dense * cos) + (rotate_half(k_dense) * sin) q_jagged_embed = convert_dense_to_jagged(q, q_dense_embed) k_jagged_embed = convert_dense_to_jagged(k, k_dense_embed) return q_jagged_embed, k_jagged_embed def _padded_tensor_forward( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def convert_dense_to_jagged(nested_q: torch.Tensor, q: torch.Tensor) -> torch.Tensor: padded_max_S = nested_q._get_max_seqlen() total_L = nested_q._values.shape[nested_q._ragged_idx - 1] if padded_max_S is None: # use upper bound on max seqlen if it's not present padded_max_S = total_L # convert dense tensor -> jagged q = q.expand( [ x if i != nested_q._ragged_idx else padded_max_S for i, x in enumerate(q.shape) ] ) nested_result = nested_from_padded( q, offsets=nested_q._offsets, ragged_idx=nested_q._ragged_idx, sum_S=total_L, min_seqlen=nested_q._get_min_seqlen(), max_seqlen=padded_max_S, ) return nested_result class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int) -> None: nn.Module.__init__(self) self.hidden_size: int = config.hidden_size self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) class LlamaBiModel(LlamaModel): def __init__(self, config: LlamaConfig) -> None: LlamaPreTrainedModel.__init__(self, config) self.padding_idx: int = config.pad_token_id self.vocab_size: int = config.vocab_size self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = nn.ModuleList( [ ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_seen_tokens=None, output_attentions=False, ): """ Updates the causal mask for attention computations. """ if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None if attention_mask is None or attention_mask.dim() == 4: return attention_mask return AttentionMaskConverter._expand_mask( mask=attention_mask, dtype=input_tensor.dtype, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = False return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training and use_cache: warnings.warn( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.", DeprecationWarning, stacklevel=2, ) use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False if ( use_cache and not isinstance(past_key_values, Cache) and not self.training ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) warnings.warn( "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)", DeprecationWarning, stacklevel=2, ) if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) if inputs_embeds.is_nested: seq_len = inputs_embeds._get_max_seqlen() else: seq_len = inputs_embeds.shape[1] cache_position = torch.arange( past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) if not inputs_embeds.is_nested: causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, ) else: causal_mask = None hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class DramaModel(LlamaBiModel): """ DramaModel is a modified version of the LlamaModel that supports bi-directional attention and provides query and document encoding functionalities. """ def __init__(self, config: LlamaConfig): """ Initializes the DramaModel by disabling causal masking in self-attention layers. """ super().__init__(config) for layer in self.layers: layer.self_attn.is_causal = False # query prefix self.query_prefix = "Query: " self.max_seq_len = 8192 self.hidden_size = config.hidden_size def _average_pool( self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: """ Computes the average pooled representation of the last hidden states. """ last_hidden = last_hidden_states.masked_fill( ~attention_mask[..., None].bool(), 0.0 ) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] def _tokenize( self, tokenizer: PreTrainedTokenizer, texts: list[str], max_seq_len: int = None, use_nested: bool = False, ): """ Tokenizes input text sequences with optional sequence length restriction. """ if max_seq_len is None: max_seq_len = self.max_seq_len if use_nested: tokenized = tokenizer( texts, truncation=True, max_length=max_seq_len, return_length=True, ) tokenized.input_ids = torch.nested.nested_tensor( tokenized.input_ids, layout=torch.jagged ).to(self.device) tokenized.attention_mask = None else: tokenized = tokenizer( texts, padding=True, truncation=True, max_length=max_seq_len, return_tensors="pt", ).to(self.device) tokenizer_ouput = {} tokenizer_ouput["input_ids"] = tokenized.input_ids tokenizer_ouput["attention_mask"] = tokenized.attention_mask return tokenizer_ouput def encode(self, input_ids, attention_mask, dim, *args, **kwargs): """ Pass through the model and compute normalized embeddings. Args: input_ids (torch.Tensor): Input token IDs. attention_mask (torch.Tensor): Attention mask tensor. dim (int): Dimensionality for output embeddings. Returns: torch.Tensor: Normalized output embeddings. """ outputs = self.forward( input_ids, attention_mask, *args, **kwargs ).last_hidden_state if not outputs.is_nested: if dim is not None: outputs = outputs[:, :, :dim] embeddings = self._average_pool(outputs, attention_mask) else: if dim is not None: outputs, _ = outputs.split_with_sizes( split_sizes=[dim, outputs.shape[-1] - dim], dim=-1 ) embeddings = outputs.sum(dim=-2) # normalize embeddings embeddings = F.normalize(embeddings, p=2, dim=1) return embeddings def encode_queries( self, tokenizer: PreTrainedTokenizer, queries: list[str], max_seq_len: int = None, dim: int = None, use_nested: bool = False, ): """ Encodes a list of queries into embeddings. Args: tokenizer (PreTrainedTokenizer): Tokenizer for text processing. queries (list[str]): List of query texts. max_seq_len (int, optional): Maximum sequence length. dim (int, optional): Dimensionality for output embeddings. Returns: torch.Tensor: Encoded query embeddings in shape (num_queries, dim). """ if not queries: raise ValueError("queries must not be empty.") if not isinstance(queries, list) or not all( isinstance(q, str) for q in queries ): raise ValueError("queries must be a list of strings.") if tokenizer is None: raise ValueError("tokenizer must not be None.") if dim is not None and (dim < 1 or dim > self.hidden_size): raise ValueError(f"dim must be in range [1, {self.hidden_size}].") queries = [self.query_prefix + query for query in queries] tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len, use_nested) embeddings = self.encode(**tokenized_queries, dim=dim) return embeddings def encode_documents( self, tokenizer: PreTrainedTokenizer, documents: list[str], max_seq_len: int = None, dim: int = None, use_nested: bool = False, ): """ Encodes a list of documents into embeddings. Args: tokenizer (PreTrainedTokenizer): Tokenizer for text processing. documents (list[str]): List of document texts. max_seq_len (int, optional): Maximum sequence length. dim (int, optional): Dimensionality for output embeddings. Returns: torch.Tensor: Encoded document embeddings in shape (num_documents, dim). """ if not documents: raise ValueError("documents must not be empty.") if not isinstance(documents, list) or not all( isinstance(d, str) for d in documents ): raise ValueError("documents must be a list of strings.") if tokenizer is None: raise ValueError("tokenizer must not be None.") if dim is not None and (dim < 1 or dim > self.hidden_size): raise ValueError(f"dim must be in range [1, {self.hidden_size}].") tokenized_documents = self._tokenize( tokenizer, documents, max_seq_len, use_nested ) embeddings = self.encode(**tokenized_documents, dim=dim) return embeddings