magic-bert-50m-roformer-classification / modeling_roformer_classification.py
mjbommar's picture
Upload magic-bert-50m-roformer-classification model files
9d1a15a verified
"""RoFormer model with projection head for classification.
This module provides a RoFormer-based model with a projection head for
contrastive learning, enabling both classification and embedding-based
similarity search for file type detection.
"""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RoFormerModel, RoFormerPreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
try:
from .configuration_roformer_classification import RoFormerClassificationConfig
except ImportError:
from configuration_roformer_classification import RoFormerClassificationConfig
class RoFormerForSequenceClassificationWithProjection(RoFormerPreTrainedModel):
"""RoFormer with projection head for file type classification.
This model extends RoFormer with a projection head that produces
L2-normalized embeddings suitable for both classification and
similarity search. The architecture is:
RoFormer -> CLS pooling -> Projection -> L2 Norm -> Classifier
The projection head enables contrastive learning and produces
embeddings for similarity-based file type matching.
"""
config_class = RoFormerClassificationConfig
def __init__(self, config: RoFormerClassificationConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.projection_dim = getattr(config, "projection_dim", 256)
self.roformer = RoFormerModel(config)
# Projection head for contrastive learning embeddings
self.projection = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.ReLU(),
nn.Linear(config.hidden_size, self.projection_dim),
)
# Classifier on pooled output (hidden_size, not projection_dim)
# This architecture uses hidden representation for classification
# while projection is for embedding similarity search
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutput]:
"""Forward pass for classification.
Args:
input_ids: Input token IDs [batch_size, seq_length]
attention_mask: Attention mask [batch_size, seq_length]
token_type_ids: Token type IDs (optional)
head_mask: Head mask for attention (optional)
inputs_embeds: Input embeddings (optional, alternative to input_ids)
labels: Labels for computing loss [batch_size]
output_attentions: Whether to return attention weights
output_hidden_states: Whether to return hidden states
return_dict: Whether to return a SequenceClassifierOutput
Returns:
SequenceClassifierOutput with loss, logits, and optional hidden states
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Pool using CLS token
sequence_output = outputs[0]
pooled_output = sequence_output[:, 0, :]
# Classify from pooled output directly
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def get_embeddings(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Get normalized projection embeddings for similarity search.
Args:
input_ids: Input token IDs [batch_size, seq_length]
attention_mask: Attention mask [batch_size, seq_length]
token_type_ids: Token type IDs (optional)
Returns:
L2-normalized embeddings [batch_size, projection_dim]
"""
outputs = self.roformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True,
)
pooled_output = outputs.last_hidden_state[:, 0, :]
projections = self.projection(pooled_output)
return F.normalize(projections, p=2, dim=1)