File size: 10,126 Bytes
d4c5a78
17d96ff
d4c5a78
 
 
 
17d96ff
 
d4c5a78
 
17d96ff
d4c5a78
 
17d96ff
 
 
d4c5a78
17d96ff
d4c5a78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17d96ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4c5a78
17d96ff
d4c5a78
17d96ff
 
 
 
 
 
 
 
 
 
 
 
d4c5a78
17d96ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4c5a78
 
 
 
 
17d96ff
d4c5a78
 
 
 
 
17d96ff
d4c5a78
 
 
17d96ff
 
 
 
d4c5a78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17d96ff
d4c5a78
 
17d96ff
d4c5a78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17d96ff
 
d4c5a78
17d96ff
 
d4c5a78
17d96ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
"""
Change the attention of Gemma3 to be bidirectional.
"""

import torch
import torch.nn as nn
from typing import Optional, List, Dict, Any
from functools import partial

from transformers import PretrainedConfig, PreTrainedModel
from transformers import Gemma3ForCausalLM, Gemma3TextConfig
from transformers.models.gemma3.modeling_gemma3 import (
    Gemma3Attention,
    Gemma3DecoderLayer, 
    Gemma3TextModel,

)

from transformers.modeling_outputs import TokenClassifierOutput
from transformers.utils import logging

logger = logging.get_logger(__name__)


class Gemma3PunctuationConfig(PretrainedConfig):
    """
    Configuration class for Gemma3 punctuation model.
    """
    model_type = "cadence_punctuation"
    
    def __init__(
        self,
        num_labels: int = 31,
        classifier_dropout_prob: float = 0.0,
        use_non_causal_attention: bool = True,
        **kwargs
    ):
        self.num_labels = num_labels
        self.classifier_dropout_prob = classifier_dropout_prob
        self.use_non_causal_attention = use_non_causal_attention
        super().__init__(**kwargs)


# ============ Token Classification Model Components ============

class NonCausalGemma3Attention(Gemma3Attention):
    """Gemma3Attention configured for non-causal token classification."""
    def __init__(self, config, layer_idx: int):
        super().__init__(config, layer_idx)
        self.is_causal = False
        self.sliding_window = None


class NonCausalGemma3DecoderLayer(Gemma3DecoderLayer):
    """Decoder layer with non-causal attention for token classification."""
    def __init__(self, config, layer_idx: int):
        super().__init__(config, layer_idx)
        self.self_attn = NonCausalGemma3Attention(config, layer_idx)


class Gemma3TokenClassificationModel(Gemma3TextModel):
    """Gemma3 base model configured for token classification."""
    _no_split_modules = ["NonCausalGemma3DecoderLayer"]

    def __init__(self, config):
        super().__init__(config)
        if getattr(config, 'use_non_causal_attention', True):
            # Replace layers with non-causal versions
            self.layers = nn.ModuleList(
                [
                    NonCausalGemma3DecoderLayer(config, layer_idx)
                    for layer_idx in range(config.num_hidden_layers)
                ]
            )

    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values = None,
        output_attentions: bool = False,
    ):
        """Override to create bidirectional attention mask (no causal masking)."""
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and 0.0 in attention_mask:
                return attention_mask
            return None

        past_seen_tokens = (
            past_key_values.get_seq_length() if past_key_values is not None else 0
        )
        using_static_cache = isinstance(past_key_values, type(None)) is False and hasattr(past_key_values, 'get_max_length')

        dtype, device = input_tensor.dtype, input_tensor.device
        min_dtype = torch.finfo(dtype).min
        sequence_length = input_tensor.shape[1]
        
        if using_static_cache:
            target_length = past_key_values.get_max_length()
        else:
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        if attention_mask is not None and attention_mask.dim() == 4:
            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
            if attention_mask.max() != 0:
                raise ValueError(
                    "Custom 4D attention mask should be passed in inverted form with max==0`"
                )
            causal_mask = attention_mask
        else:
            # KEY CHANGE: Start with zeros (attend to all) instead of min_dtype (mask all)
            causal_mask = torch.zeros(
                (sequence_length, target_length), dtype=dtype, device=device
            )
            # REMOVED: Causal masking lines that would make it lower triangular
            # if sequence_length != 1:
            #     causal_mask = torch.triu(causal_mask, diagonal=1)
            
            causal_mask *= torch.arange(
                target_length, device=device
            ) > cache_position.reshape(-1, 1)
            causal_mask = causal_mask[None, None, :, :].expand(
                input_tensor.shape[0], 1, -1, -1
            )
            
            if attention_mask is not None:
                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
                mask_length = attention_mask.shape[-1]
                padding_mask = (
                    causal_mask[:, :, :, :mask_length]
                    + attention_mask[:, None, None, :]
                )
                padding_mask = padding_mask == 0
                causal_mask[:, :, :, :mask_length] = causal_mask[
                    :, :, :, :mask_length
                ].masked_fill(padding_mask, min_dtype)

        # Handle SDPA-specific optimizations if needed
        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type == "cuda"
            and not output_attentions
        ):
            try:
                from transformers.modeling_attn_mask_utils import AttentionMaskConverter
                causal_mask = AttentionMaskConverter._unmask_unattended(
                    causal_mask, min_dtype
                )
            except ImportError:
                pass  # Fallback for older transformers versions

        return causal_mask


class Gemma3ForTokenClassification(Gemma3ForCausalLM):
    """
    Gemma3 model for token classification (punctuation prediction).
    Uses class-based architecture without monkey patching.
    """
    
    config_class = Gemma3PunctuationConfig
    
    def __init__(self, config):
        # Initialize with base Gemma3ForCausalLM structure
        super().__init__(config)
        self.num_labels = config.num_labels
        
        # Replace the base model with token classification version
        if getattr(config, 'use_non_causal_attention', True):
            self.model = Gemma3TokenClassificationModel(config)
        
        # Replace the lm_head with classification head
        classifier_dropout_prob = getattr(config, 'classifier_dropout_prob', 0.0)
        self.lm_head = nn.Sequential(
            nn.Dropout(classifier_dropout_prob),
            nn.Linear(config.hidden_size, config.num_labels)
        )
        
        # Update config for classification
        self.config.num_labels = config.num_labels
        
        # Initialize weights for the new head
        self.post_init()
    
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> TokenClassifierOutput:
        """Forward pass for token classification."""
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Get hidden states from the model
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        # Get the hidden states from the model output
        sequence_output = outputs[0]
        
        # Apply the classification head (which is now self.lm_head)
        logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# ============ Model Registration ============

from transformers import AutoConfig, AutoModel

# Register the punctuation config and model
AutoConfig.register("cadence_punctuation", Gemma3PunctuationConfig)
AutoModel.register(Gemma3PunctuationConfig, Gemma3ForTokenClassification)


# ============ Utility Functions ============


def create_token_classification_model(config: Gemma3PunctuationConfig):
    """Create a token classification model with non-causal attention."""
    return Gemma3ForTokenClassification(config)


def load_from_pretrained_with_config_detection(model_path: str, **kwargs):
    """
    Load model and auto-detect whether it's for token classification or bidirectional tasks
    based on the config.
    """
    from transformers import AutoConfig
    
    config = AutoConfig.from_pretrained(model_path)
    
    if hasattr(config, 'model_type') and config.model_type == "cadence_punctuation":
        # Token classification model
        return Gemma3ForTokenClassification.from_pretrained(model_path, config=config, **kwargs)