import torch import torch.nn as nn from typing import Optional, Tuple, Union from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from configuration_pdeeppp import PDeepPPConfig logger = logging.get_logger(__name__) class SelfAttentionGlobalFeatures(nn.Module): def __init__(self, config): super().__init__() self.self_attention = nn.MultiheadAttention( embed_dim=config.input_size, num_heads=config.num_heads, batch_first=True ) self.fc1 = nn.Linear(config.input_size, config.hidden_size) self.fc2 = nn.Linear(config.hidden_size, config.output_size) self.layer_norm = nn.LayerNorm(config.input_size) self.dropout = nn.Dropout(config.dropout) def forward(self, x): attn_output, _ = self.self_attention(x, x, x) x = self.layer_norm(x + attn_output) x = self.fc1(x) x = self.dropout(x) x = self.fc2(x) return x class TransConv1d(nn.Module): def __init__(self, config): super().__init__() self.self_attention_global_features = SelfAttentionGlobalFeatures(config) self.transformer_encoder = nn.TransformerEncoderLayer( d_model=config.output_size, nhead=config.num_heads, dim_feedforward=config.hidden_size*2, dropout=config.dropout, batch_first=True ) self.transformer = nn.TransformerEncoder( self.transformer_encoder, num_layers=config.num_transformer_layers ) self.fc1 = nn.Linear(config.output_size, config.output_size) self.fc2 = nn.Linear(config.output_size, config.output_size) self.layer_norm = nn.LayerNorm(config.output_size) def forward(self, x): x = self.self_attention_global_features(x) residual = x x = self.transformer(x) x = self.fc1(x) residual = x x = self.fc2(x) x = self.layer_norm(x + residual) return x class PosCNN(nn.Module): def __init__(self, config, use_position_encoding=True): super().__init__() self.use_position_encoding = use_position_encoding self.conv1d = nn.Conv1d( in_channels=config.input_size, out_channels=64, kernel_size=3, padding=1 ) self.relu = nn.ReLU() self.global_pooling = nn.AdaptiveAvgPool1d(1) self.fc = nn.Linear(64, config.output_size) if self.use_position_encoding: self.position_encoding = nn.Parameter(torch.zeros(64, config.input_size)) def forward(self, x): x = x.permute(0, 2, 1) x = self.conv1d(x) x = self.relu(x) if self.use_position_encoding: seq_len = x.size(2) pos_encoding = self.position_encoding[:, :seq_len].unsqueeze(0) x = x + pos_encoding x = self.global_pooling(x) x = x.squeeze(-1) x = self.fc(x) return x class PDeepPPPreTrainedModel(PreTrainedModel): """ 抽象基类,包含所有PDeepPP模型所需的方法 """ config_class = PDeepPPConfig base_model_prefix = "PDeepPP" supports_gradient_checkpointing = True def _init_weights(self, module): """初始化权重""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class PDeepPPModel(PDeepPPPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.transformer = TransConv1d(config) self.cnn = PosCNN(config) self.cnn_layers = nn.Sequential( nn.Conv1d(config.output_size*2, 32, kernel_size=3, padding=1), nn.ReLU(), nn.AdaptiveMaxPool1d(1), nn.Dropout(config.dropout/2), nn.Conv1d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.AdaptiveMaxPool1d(1), nn.Dropout(config.dropout/2), nn.Flatten(), nn.Linear(64, 1) ) # 初始化权重 self.post_init() def forward( self, input_embeds=None, labels=None, return_dict=None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the classification loss. Returns: dict or tuple: 根据return_dict参数返回不同格式的结果 """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_output = self.transformer(input_embeds) cnn_output = self.cnn(input_embeds) cnn_output = cnn_output.unsqueeze(1).expand(-1, transformer_output.size(1), -1) combined = torch.cat([transformer_output, cnn_output], dim=2) combined = combined.permute(0, 2, 1) logits = self.cnn_layers(combined).squeeze(1) loss = None if labels is not None: loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels.float()) # 添加您自定义的损失函数 probs = torch.sigmoid(logits) ent = -(probs*torch.log(probs+1e-12) + (1-probs)*torch.log(1-probs+1e-12)).mean() cond_ent = -(probs*torch.log(probs+1e-12)).mean() reg_loss = self.config.lambda_ * ent - self.config.lambda_ * cond_ent loss = self.config.lambda_ * loss + (1 - self.config.lambda_) * reg_loss if return_dict: return { "loss": loss, "logits": logits, } else: return (loss, logits) if loss is not None else logits PDeepPPModel.register_for_auto_class("AutoModel")