|
|
"""RoFormer model with projection head for classification. |
|
|
|
|
|
This module provides a RoFormer-based model with a projection head for |
|
|
contrastive learning, enabling both classification and embedding-based |
|
|
similarity search for file type detection. |
|
|
""" |
|
|
|
|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import RoFormerModel, RoFormerPreTrainedModel |
|
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
|
try: |
|
|
from .configuration_roformer_classification import RoFormerClassificationConfig |
|
|
except ImportError: |
|
|
from configuration_roformer_classification import RoFormerClassificationConfig |
|
|
|
|
|
|
|
|
class RoFormerForSequenceClassificationWithProjection(RoFormerPreTrainedModel): |
|
|
"""RoFormer with projection head for file type classification. |
|
|
|
|
|
This model extends RoFormer with a projection head that produces |
|
|
L2-normalized embeddings suitable for both classification and |
|
|
similarity search. The architecture is: |
|
|
|
|
|
RoFormer -> CLS pooling -> Projection -> L2 Norm -> Classifier |
|
|
|
|
|
The projection head enables contrastive learning and produces |
|
|
embeddings for similarity-based file type matching. |
|
|
""" |
|
|
|
|
|
config_class = RoFormerClassificationConfig |
|
|
|
|
|
def __init__(self, config: RoFormerClassificationConfig): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
self.projection_dim = getattr(config, "projection_dim", 256) |
|
|
|
|
|
self.roformer = RoFormerModel(config) |
|
|
|
|
|
|
|
|
self.projection = nn.Sequential( |
|
|
nn.Linear(config.hidden_size, config.hidden_size), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.hidden_size, self.projection_dim), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutput]: |
|
|
"""Forward pass for classification. |
|
|
|
|
|
Args: |
|
|
input_ids: Input token IDs [batch_size, seq_length] |
|
|
attention_mask: Attention mask [batch_size, seq_length] |
|
|
token_type_ids: Token type IDs (optional) |
|
|
head_mask: Head mask for attention (optional) |
|
|
inputs_embeds: Input embeddings (optional, alternative to input_ids) |
|
|
labels: Labels for computing loss [batch_size] |
|
|
output_attentions: Whether to return attention weights |
|
|
output_hidden_states: Whether to return hidden states |
|
|
return_dict: Whether to return a SequenceClassifierOutput |
|
|
|
|
|
Returns: |
|
|
SequenceClassifierOutput with loss, logits, and optional hidden states |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.roformer( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
|
|
|
sequence_output = outputs[0] |
|
|
pooled_output = sequence_output[:, 0, :] |
|
|
|
|
|
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits, labels) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[2:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return SequenceClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
def get_embeddings( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
"""Get normalized projection embeddings for similarity search. |
|
|
|
|
|
Args: |
|
|
input_ids: Input token IDs [batch_size, seq_length] |
|
|
attention_mask: Attention mask [batch_size, seq_length] |
|
|
token_type_ids: Token type IDs (optional) |
|
|
|
|
|
Returns: |
|
|
L2-normalized embeddings [batch_size, projection_dim] |
|
|
""" |
|
|
outputs = self.roformer( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
pooled_output = outputs.last_hidden_state[:, 0, :] |
|
|
projections = self.projection(pooled_output) |
|
|
return F.normalize(projections, p=2, dim=1) |
|
|
|