# models/parallel_bert_deberta.py import torch import torch.nn as nn from transformers import BertModel, DebertaModel from config import DROPOUT_RATE, BERT_MODEL_NAME, DEBERTA_MODEL_NAME # Import model names class Attention(nn.Module): """ Simple Attention layer to compute a context vector from a sequence of hidden states. It learns a single weight for each hidden state in the sequence, then uses softmax to normalize these weights and compute a weighted sum of the hidden states. """ def __init__(self, hidden_size): """ Initializes the Attention layer. Args: hidden_size (int): The dimensionality of the input hidden states. """ super(Attention, self).__init__() # A linear layer to project the hidden state to a single scalar (attention score) self.attn = nn.Linear(hidden_size, 1) def forward(self, encoder_output): """ Performs the forward pass of the attention mechanism. Args: encoder_output (torch.Tensor): Tensor of hidden states from an encoder. Shape: (batch_size, sequence_length, hidden_size) Returns: torch.Tensor: The context vector, a weighted sum of the hidden states. Shape: (batch_size, hidden_size) """ # Calculate raw attention scores # self.attn(encoder_output) -> (batch_size, sequence_length, 1) # .squeeze(-1) removes the last dimension, making it (batch_size, sequence_length) attn_weights = torch.softmax(self.attn(encoder_output).squeeze(-1), dim=1) # Compute the context vector as a weighted sum of encoder_output. # attn_weights.unsqueeze(-1) adds a dimension for broadcasting: (batch_size, sequence_length, 1) # This allows element-wise multiplication with encoder_output. # torch.sum(..., dim=1) sums along the sequence_length dimension. context_vector = torch.sum(attn_weights.unsqueeze(-1) * encoder_output, dim=1) return context_vector class ParallelMultiOutputModel(nn.Module): """ Hybrid model that leverages both BERT and DeBERTa in parallel. It extracts features from both models, applies an attention mechanism to their outputs, projects these attended features to a common dimension, concatenates them, and then uses this combined representation for multi-output classification. """ # Statically set tokenizer name to BERT's for this combined model # (assuming BERT's tokenizer is compatible or primary for combined input) tokenizer_name = BERT_MODEL_NAME def __init__(self, num_labels): """ Initializes the ParallelMultiOutputModel. Args: num_labels (list): A list where each element is the number of classes for a corresponding label column. """ super(ParallelMultiOutputModel, self).__init__() # Load pre-trained BERT and DeBERTa models self.bert = BertModel.from_pretrained(BERT_MODEL_NAME) self.deberta = DebertaModel.from_pretrained(DEBERTA_MODEL_NAME) # Initialize attention layers for each backbone model self.attn_bert = Attention(self.bert.config.hidden_size) self.attn_deberta = Attention(self.deberta.config.hidden_size) # Projection layers to reduce dimensionality of the context vectors # before concatenation. This helps manage the combined feature size. self.proj_bert = nn.Linear(self.bert.config.hidden_size, 256) self.proj_deberta = nn.Linear(self.deberta.config.hidden_size, 256) self.dropout = nn.Dropout(DROPOUT_RATE) # Dropout layer for regularization # Define classification heads. The input feature size is the sum of # the projected sizes from BERT and DeBERTa (256 + 256 = 512). self.classifiers = nn.ModuleList([ nn.Linear(512, n_classes) for n_classes in num_labels ]) def forward(self, input_ids, attention_mask): """ Performs the forward pass of the parallel model. Args: input_ids (torch.Tensor): Tensor of token IDs. attention_mask (torch.Tensor): Tensor indicating attention. Returns: list: A list of logit tensors, one for each classification head. """ # Get the last hidden states (sequence of hidden states for all tokens) # from both BERT and DeBERTa. These are typically used with attention. bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state deberta_output = self.deberta(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state # Apply attention to get a single context vector from each model's output context_bert = self.attn_bert(bert_output) context_deberta = self.attn_deberta(deberta_output) # Project the context vectors to their reduced dimensions reduced_bert = self.proj_bert(context_bert) reduced_deberta = self.proj_deberta(context_deberta) # Concatenate the reduced feature vectors from both models combined = torch.cat((reduced_bert, reduced_deberta), dim=1) combined = self.dropout(combined) # Apply dropout to the combined features # Pass the combined features through each classification head return [classifier(combined) for classifier in self.classifiers]