import torch from torch import nn from transformers import PreTrainedModel from transformers import PretrainedConfig class CustomClassificationConfig(PretrainedConfig): model_type = "custom_classifier" def __init__(self, input_dim=32, hidden_dim=64, num_classes=2, **kwargs): super().__init__(**kwargs) self.input_dim = input_dim self.hidden_dim = hidden_dim self.num_classes = num_classes class CustomClassifier(PreTrainedModel): config_class = CustomClassificationConfig def __init__(self, config): super().__init__(config) self.encoder = nn.Sequential( nn.Linear(config.input_dim, config.hidden_dim), nn.ReLU(), nn.Linear(config.hidden_dim, config.hidden_dim), nn.ReLU(), ) self.classifier = nn.Linear(config.hidden_dim, config.num_classes) def forward(self, input_ids=None, labels=None, **kwargs): # input_ids: shape (batch_size, input_dim) hidden = self.encoder(input_ids) logits = self.classifier(hidden) loss = None if labels is not None: loss_fn = nn.CrossEntropyLoss() loss = loss_fn(logits, labels) return {"loss": loss, "logits": logits}