|
""" |
|
Change the attention of Gemma3 to be bidirectional. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from typing import Optional, List, Dict, Any |
|
from functools import partial |
|
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
from transformers import Gemma3ForCausalLM, Gemma3TextConfig |
|
from transformers.models.gemma3.modeling_gemma3 import ( |
|
Gemma3Attention, |
|
Gemma3DecoderLayer, |
|
Gemma3TextModel, |
|
|
|
) |
|
|
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Gemma3PunctuationConfig(PretrainedConfig): |
|
""" |
|
Configuration class for Gemma3 punctuation model. |
|
""" |
|
model_type = "cadence_punctuation" |
|
|
|
def __init__( |
|
self, |
|
num_labels: int = 31, |
|
classifier_dropout_prob: float = 0.0, |
|
use_non_causal_attention: bool = True, |
|
**kwargs |
|
): |
|
self.num_labels = num_labels |
|
self.classifier_dropout_prob = classifier_dropout_prob |
|
self.use_non_causal_attention = use_non_causal_attention |
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
|
class NonCausalGemma3Attention(Gemma3Attention): |
|
"""Gemma3Attention configured for non-causal token classification.""" |
|
def __init__(self, config, layer_idx: int): |
|
super().__init__(config, layer_idx) |
|
self.is_causal = False |
|
self.sliding_window = None |
|
|
|
|
|
class NonCausalGemma3DecoderLayer(Gemma3DecoderLayer): |
|
"""Decoder layer with non-causal attention for token classification.""" |
|
def __init__(self, config, layer_idx: int): |
|
super().__init__(config, layer_idx) |
|
self.self_attn = NonCausalGemma3Attention(config, layer_idx) |
|
|
|
|
|
class Gemma3TokenClassificationModel(Gemma3TextModel): |
|
"""Gemma3 base model configured for token classification.""" |
|
_no_split_modules = ["NonCausalGemma3DecoderLayer"] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
if getattr(config, 'use_non_causal_attention', True): |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
NonCausalGemma3DecoderLayer(config, layer_idx) |
|
for layer_idx in range(config.num_hidden_layers) |
|
] |
|
) |
|
|
|
def _update_causal_mask( |
|
self, |
|
attention_mask: torch.Tensor, |
|
input_tensor: torch.Tensor, |
|
cache_position: torch.Tensor, |
|
past_key_values = None, |
|
output_attentions: bool = False, |
|
): |
|
"""Override to create bidirectional attention mask (no causal masking).""" |
|
if self.config._attn_implementation == "flash_attention_2": |
|
if attention_mask is not None and 0.0 in attention_mask: |
|
return attention_mask |
|
return None |
|
|
|
past_seen_tokens = ( |
|
past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
) |
|
using_static_cache = isinstance(past_key_values, type(None)) is False and hasattr(past_key_values, 'get_max_length') |
|
|
|
dtype, device = input_tensor.dtype, input_tensor.device |
|
min_dtype = torch.finfo(dtype).min |
|
sequence_length = input_tensor.shape[1] |
|
|
|
if using_static_cache: |
|
target_length = past_key_values.get_max_length() |
|
else: |
|
target_length = ( |
|
attention_mask.shape[-1] |
|
if isinstance(attention_mask, torch.Tensor) |
|
else past_seen_tokens + sequence_length + 1 |
|
) |
|
|
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
if attention_mask.max() != 0: |
|
raise ValueError( |
|
"Custom 4D attention mask should be passed in inverted form with max==0`" |
|
) |
|
causal_mask = attention_mask |
|
else: |
|
|
|
causal_mask = torch.zeros( |
|
(sequence_length, target_length), dtype=dtype, device=device |
|
) |
|
|
|
|
|
|
|
|
|
causal_mask *= torch.arange( |
|
target_length, device=device |
|
) > cache_position.reshape(-1, 1) |
|
causal_mask = causal_mask[None, None, :, :].expand( |
|
input_tensor.shape[0], 1, -1, -1 |
|
) |
|
|
|
if attention_mask is not None: |
|
causal_mask = causal_mask.clone() |
|
mask_length = attention_mask.shape[-1] |
|
padding_mask = ( |
|
causal_mask[:, :, :, :mask_length] |
|
+ attention_mask[:, None, None, :] |
|
) |
|
padding_mask = padding_mask == 0 |
|
causal_mask[:, :, :, :mask_length] = causal_mask[ |
|
:, :, :, :mask_length |
|
].masked_fill(padding_mask, min_dtype) |
|
|
|
|
|
if ( |
|
self.config._attn_implementation == "sdpa" |
|
and attention_mask is not None |
|
and attention_mask.device.type == "cuda" |
|
and not output_attentions |
|
): |
|
try: |
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
|
causal_mask = AttentionMaskConverter._unmask_unattended( |
|
causal_mask, min_dtype |
|
) |
|
except ImportError: |
|
pass |
|
|
|
return causal_mask |
|
|
|
|
|
class Gemma3ForTokenClassification(Gemma3ForCausalLM): |
|
""" |
|
Gemma3 model for token classification (punctuation prediction). |
|
Uses class-based architecture without monkey patching. |
|
""" |
|
|
|
config_class = Gemma3PunctuationConfig |
|
|
|
def __init__(self, config): |
|
|
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
|
|
if getattr(config, 'use_non_causal_attention', True): |
|
self.model = Gemma3TokenClassificationModel(config) |
|
|
|
|
|
classifier_dropout_prob = getattr(config, 'classifier_dropout_prob', 0.0) |
|
self.lm_head = nn.Sequential( |
|
nn.Dropout(classifier_dropout_prob), |
|
nn.Linear(config.hidden_size, config.num_labels) |
|
) |
|
|
|
|
|
self.config.num_labels = config.num_labels |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
) -> TokenClassifierOutput: |
|
"""Forward pass for token classification.""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
|
|
|
logits = self.lm_head(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
|
|
|
|
from transformers import AutoConfig, AutoModel |
|
|
|
|
|
AutoConfig.register("cadence_punctuation", Gemma3PunctuationConfig) |
|
AutoModel.register(Gemma3PunctuationConfig, Gemma3ForTokenClassification) |
|
|
|
|
|
|
|
|
|
|
|
def create_token_classification_model(config: Gemma3PunctuationConfig): |
|
"""Create a token classification model with non-causal attention.""" |
|
return Gemma3ForTokenClassification(config) |
|
|
|
|
|
def load_from_pretrained_with_config_detection(model_path: str, **kwargs): |
|
""" |
|
Load model and auto-detect whether it's for token classification or bidirectional tasks |
|
based on the config. |
|
""" |
|
from transformers import AutoConfig |
|
|
|
config = AutoConfig.from_pretrained(model_path) |
|
|
|
if hasattr(config, 'model_type') and config.model_type == "cadence_punctuation": |
|
|
|
return Gemma3ForTokenClassification.from_pretrained(model_path, config=config, **kwargs) |
|
|