import torch from torch import nn from transformers import PreTrainedModel class CustomClassifier(PreTrainedModel): config_class = CustomClassificationConfig def __init__(self, config): super().__init__(config) self.encoder = nn.Sequential( nn.Linear(config.input_dim, config.hidden_dim), nn.ReLU(), nn.Linear(config.hidden_dim, config.hidden_dim), nn.ReLU(), ) self.classifier = nn.Linear(config.hidden_dim, config.num_classes) def forward(self, input_ids=None, labels=None, **kwargs): # input_ids: shape (batch_size, input_dim) hidden = self.encoder(input_ids) logits = self.classifier(hidden) loss = None if labels is not None: loss_fn = nn.CrossEntropyLoss() loss = loss_fn(logits, labels) return {"loss": loss, "logits": logits}