|
from torch import nn |
|
import torch |
|
|
|
|
|
class ClassifierHead(nn.Module): |
|
"""Basically a fancy MLP: 3-layer classifier head with GELU, LayerNorm, and Skip Connections.""" |
|
def __init__(self, hidden_size, num_labels, dropout_prob): |
|
super().__init__() |
|
|
|
self.dense1 = nn.Linear(hidden_size, hidden_size) |
|
self.norm1 = nn.LayerNorm(hidden_size) |
|
self.activation = nn.GELU() |
|
self.dropout1 = nn.Dropout(dropout_prob) |
|
|
|
|
|
self.dense2 = nn.Linear(hidden_size, hidden_size) |
|
self.norm2 = nn.LayerNorm(hidden_size) |
|
self.dropout2 = nn.Dropout(dropout_prob) |
|
|
|
|
|
self.out_proj = nn.Linear(hidden_size, num_labels) |
|
|
|
def forward(self, features): |
|
|
|
identity1 = features |
|
x = self.norm1(features) |
|
x = self.dense1(x) |
|
x = self.activation(x) |
|
x = self.dropout1(x) |
|
x = x + identity1 |
|
|
|
|
|
identity2 = x |
|
x = self.norm2(x) |
|
x = self.dense2(x) |
|
x = self.activation(x) |
|
x = self.dropout2(x) |
|
x = x + identity2 |
|
|
|
|
|
logits = self.out_proj(x) |
|
return logits |
|
|
|
|
|
class ConcatClassifierHead(nn.Module): |
|
""" |
|
An enhanced classifier head designed for concatenated CLS + Mean Pooling input. |
|
Includes an initial projection layer before the standard enhanced block. |
|
""" |
|
def __init__(self, input_size, hidden_size, num_labels, dropout_prob): |
|
super().__init__() |
|
|
|
self.initial_projection = nn.Linear(input_size, hidden_size) |
|
self.initial_norm = nn.LayerNorm(hidden_size) |
|
self.initial_activation = nn.GELU() |
|
self.initial_dropout = nn.Dropout(dropout_prob) |
|
|
|
|
|
self.dense1 = nn.Linear(hidden_size, hidden_size) |
|
self.norm1 = nn.LayerNorm(hidden_size) |
|
self.activation = nn.GELU() |
|
self.dropout1 = nn.Dropout(dropout_prob) |
|
|
|
|
|
self.dense2 = nn.Linear(hidden_size, hidden_size) |
|
self.norm2 = nn.LayerNorm(hidden_size) |
|
self.dropout2 = nn.Dropout(dropout_prob) |
|
|
|
|
|
self.out_proj = nn.Linear(hidden_size, num_labels) |
|
|
|
def forward(self, features): |
|
|
|
x = self.initial_projection(features) |
|
x = self.initial_norm(x) |
|
x = self.initial_activation(x) |
|
x = self.initial_dropout(x) |
|
|
|
|
|
|
|
identity1 = x |
|
x_res = self.norm1(x) |
|
x_res = self.dense1(x_res) |
|
x_res = self.activation(x_res) |
|
x_res = self.dropout1(x_res) |
|
x = x + x_res |
|
|
|
|
|
identity2 = x |
|
x_res = self.norm2(x) |
|
x_res = self.dense2(x_res) |
|
x_res = self.activation(x_res) |
|
x_res = self.dropout2(x_res) |
|
x = x + x_res |
|
|
|
|
|
logits = self.out_proj(x) |
|
return logits |
|
|
|
|
|
|
|
class ExpansionClassifierHead(nn.Module): |
|
""" |
|
A classifier head using FFN-style expansion (input -> 4*hidden -> hidden -> labels). |
|
Takes concatenated CLS + Mean Pooled features as input. |
|
""" |
|
def __init__(self, input_size, hidden_size, num_labels, dropout_prob): |
|
super().__init__() |
|
intermediate_size = hidden_size * 4 |
|
|
|
|
|
self.norm1 = nn.LayerNorm(input_size) |
|
self.dense1 = nn.Linear(input_size, intermediate_size) |
|
self.activation = nn.GELU() |
|
self.dropout1 = nn.Dropout(dropout_prob) |
|
|
|
|
|
self.norm2 = nn.LayerNorm(intermediate_size) |
|
self.dense2 = nn.Linear(intermediate_size, hidden_size) |
|
|
|
self.dropout2 = nn.Dropout(dropout_prob) |
|
|
|
|
|
self.out_proj = nn.Linear(hidden_size, num_labels) |
|
|
|
def forward(self, features): |
|
|
|
x = self.norm1(features) |
|
x = self.dense1(x) |
|
x = self.activation(x) |
|
x = self.dropout1(x) |
|
|
|
|
|
x = self.norm2(x) |
|
x = self.dense2(x) |
|
x = self.activation(x) |
|
x = self.dropout2(x) |
|
|
|
|
|
logits = self.out_proj(x) |
|
return logits |
|
|