import torch import torch.nn as nn import torch.nn.functional as F from transformers import ( PreTrainedModel, PretrainedConfig, AutoModel, ) from transformers.modeling_outputs import SequenceClassifierOutput import math from typing import Optional, Dict, Union, List class AttentionPooling(nn.Module): """Attention pooling layer""" def __init__(self, hidden_dim): super().__init__() self.attention_proj = nn.Linear(hidden_dim, hidden_dim) self.attention_vector = nn.Parameter(torch.randn(hidden_dim)) nn.init.xavier_uniform_(self.attention_proj.weight) nn.init.normal_(self.attention_vector, std=1 / math.sqrt(hidden_dim)) def forward(self, sequence_output, attention_mask): proj_sequence = torch.tanh(self.attention_proj(sequence_output)) scores = torch.matmul(proj_sequence, self.attention_vector) scores = scores.masked_fill(~attention_mask.bool(), -1e9) attn_weights = F.softmax(scores, dim=1) weighted_output = torch.bmm(attn_weights.unsqueeze(1), sequence_output).squeeze(1) return weighted_output class MultiScaleFeatures(nn.Module): """Multi-scale feature extraction""" def __init__(self, hidden_dim, scales=[1, 3, 5, 7], out_channels=128): super().__init__() self.convs = nn.ModuleList([nn.Conv1d(hidden_dim, out_channels, kernel_size=k, padding=k // 2) for k in scales]) self.activation = nn.GELU() def forward(self, sequence_output, attention_mask): x = sequence_output.transpose(1, 2) mask = attention_mask.unsqueeze(1).float() multi_scale_features = [] for conv in self.convs: feat = conv(x) feat = self.activation(feat) feat = feat * mask feat, _ = feat.max(dim=2) multi_scale_features.append(feat) return torch.cat(multi_scale_features, dim=1) class PromptInjectionConfig(PretrainedConfig): """Configuration class for PromptInjectionModel""" model_type = "prompt_injection_detector" def __init__( self, base_model_name="distilbert-base-uncased", hidden_size=768, dropout_rate=0.3, use_multi_sample_dropout=True, freeze_layers=2, multi_scale_kernels=[1, 3, 5, 7], multi_scale_channels=128, stats_dim=8, length_weight=0.1, temperature=1.0, **kwargs, ): super().__init__(**kwargs) self.base_model_name = base_model_name self.hidden_size = hidden_size self.dropout_rate = dropout_rate self.use_multi_sample_dropout = use_multi_sample_dropout self.freeze_layers = freeze_layers self.multi_scale_kernels = multi_scale_kernels self.multi_scale_channels = multi_scale_channels self.stats_dim = stats_dim self.length_weight = length_weight self.temperature = temperature class PromptInjectionModel(PreTrainedModel): """ Hugging Face compatible prompt injection detector Usage: from transformers import AutoTokenizer # Load from hub model = PromptInjectionModel.from_pretrained("your-username/prompt-injection-detector") tokenizer = AutoTokenizer.from_pretrained("your-username/prompt-injection-detector") # Inference inputs = tokenizer("Ignore all previous instructions", return_tensors="pt") outputs = model(**inputs) probability = outputs.injection_probability """ config_class = PromptInjectionConfig base_model_prefix = "encoder" # Changed from "base_model" to avoid recursion def __init__(self, config): super().__init__(config) # Base transformer model - renamed to avoid conflict with base_model property self.encoder = AutoModel.from_pretrained(config.base_model_name) # Get the actual hidden size from the loaded model if hasattr(self.encoder.config, "hidden_size"): self.hidden_size = self.encoder.config.hidden_size else: self.hidden_size = config.hidden_size # Freeze early layers if specified if config.freeze_layers > 0: self._freeze_base_layers(config.freeze_layers) # Attention pooling self.attention_pooling = AttentionPooling(self.hidden_size) # Multi-scale feature extraction self.multi_scale = MultiScaleFeatures( self.hidden_size, scales=config.multi_scale_kernels, out_channels=config.multi_scale_channels ) # Statistical features normalization self.stats_normalizer = nn.BatchNorm1d(config.stats_dim) # Calculate total features total_features = self.hidden_size + len(config.multi_scale_kernels) * config.multi_scale_channels + config.stats_dim # Main classifier self.classifier = nn.Sequential( nn.Linear(total_features, 512), nn.BatchNorm1d(512), nn.GELU(), nn.Dropout(config.dropout_rate), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.GELU(), nn.Dropout(config.dropout_rate * 0.8), nn.Linear(256, 64), nn.BatchNorm1d(64), nn.GELU(), nn.Dropout(config.dropout_rate * 0.6), nn.Linear(64, 1), ) # Length predictor self.length_predictor = nn.Sequential(nn.Linear(self.hidden_size, 128), nn.GELU(), nn.Linear(128, 1)) # Temperature for calibration self.temperature = nn.Parameter(torch.tensor(config.temperature)) # Initialize weights self.post_init() def _freeze_base_layers(self, num_layers: int): """Freeze layers in the base model - handles different architectures""" def get_layers(model) -> List[nn.Module]: """Get transformer layers from different model architectures""" # Check if this IS a base model (not wrapped) model_class_name = model.__class__.__name__ # DistilBERT if model_class_name == "DistilBertModel": return model.transformer.layer elif hasattr(model, "distilbert") and hasattr(model.distilbert, "transformer"): return model.distilbert.transformer.layer # BERT / RoBERTa elif model_class_name in ["BertModel", "RobertaModel"]: return model.encoder.layer elif hasattr(model, "bert") and hasattr(model.bert, "encoder"): return model.bert.encoder.layer elif hasattr(model, "roberta") and hasattr(model.roberta, "encoder"): return model.roberta.encoder.layer # ELECTRA elif model_class_name == "ElectraModel": return model.encoder.layer elif hasattr(model, "electra") and hasattr(model.electra, "encoder"): return model.electra.encoder.layer # GPT-2 elif model_class_name == "GPT2Model": return model.h elif hasattr(model, "transformer") and hasattr(model.transformer, "h"): return model.transformer.h # DeBERTa elif model_class_name in ["DebertaModel", "DebertaV2Model"]: return model.encoder.layer elif hasattr(model, "deberta") and hasattr(model.deberta, "encoder"): return model.deberta.encoder.layer # Generic fallback - check for common patterns elif hasattr(model, "transformer") and hasattr(model.transformer, "layer"): return model.transformer.layer elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"): return model.encoder.layer else: print(f"Warning: Could not find layers to freeze for model type {type(model).__name__}") print(f"Model attributes: {[attr for attr in dir(model) if not attr.startswith('_')][:10]}...") return [] layers = get_layers(self.encoder) layers_to_freeze = min(num_layers, len(layers)) for i in range(layers_to_freeze): for param in layers[i].parameters(): param.requires_grad = False print(f"Froze {layers_to_freeze} layers in {type(self.encoder).__name__}") def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, **kwargs, ) -> Union[SequenceClassifierOutput, Dict[str, torch.Tensor]]: """ Forward pass compatible with Hugging Face's expected interface Args: input_ids: Token IDs [batch_size, sequence_length] attention_mask: Attention mask [batch_size, sequence_length] labels: Binary labels (0=safe, 1=injection) [batch_size] return_dict: Whether to return a dict or tuple Returns: SequenceClassifierOutput or dict with: - loss: Classification loss (if labels provided) - logits: Raw logits before sigmoid [batch_size, 1] - injection_probability: Sigmoid probability [batch_size] - length_prediction: Predicted sequence length [batch_size] - hidden_states: Optional hidden states - attentions: Optional attention weights """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Filter out kwargs that DistilBERT doesn't accept # Common trainer kwargs to exclude excluded_keys = {"num_items_in_batch", "loss_kwargs"} encoder_kwargs = {k: v for k, v in kwargs.items() if k not in excluded_keys} # Get base model outputs base_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, output_attentions=output_attentions, return_dict=True, **encoder_kwargs, ) sequence_output = base_outputs.last_hidden_state # Extract features pooled_output = self.attention_pooling(sequence_output, attention_mask) multi_scale_features = self.multi_scale(sequence_output, attention_mask) statistical_features = self._extract_statistical_features(input_ids, attention_mask) # Combine features combined_features = torch.cat([pooled_output, multi_scale_features, statistical_features], dim=1) # Classification if self.config.use_multi_sample_dropout and self.training: logits_list = [] for _ in range(3): logits_list.append(self.classifier(combined_features)) logits = torch.stack(logits_list).mean(dim=0) else: logits = self.classifier(combined_features) # Apply temperature scaling logits = logits / self.temperature # Length prediction length_pred = self.length_predictor(pooled_output) # Calculate loss if labels provided loss = None if labels is not None: # Binary cross entropy loss bce_loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float()) # Length prediction auxiliary loss true_lengths = attention_mask.sum(dim=1).float() length_loss = F.mse_loss(length_pred.squeeze(-1), true_lengths) # Combined loss loss = bce_loss + self.config.length_weight * length_loss # Prepare outputs injection_probability = torch.sigmoid(logits.squeeze(-1)) if not return_dict: output = (logits,) + base_outputs[2:] return ((loss,) + output) if loss is not None else output # Create custom output that extends SequenceClassifierOutput output = SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=base_outputs.hidden_states, attentions=base_outputs.attentions, ) # Add custom attributes output.injection_probability = injection_probability output.length_prediction = length_pred.squeeze(-1) output.pooled_features = pooled_output output.statistical_features = statistical_features return output def _extract_statistical_features(self, input_ids, attention_mask): """Extract statistical features for injection detection""" batch_size = input_ids.size(0) device = input_ids.device stats = [] for i in range(batch_size): valid_mask = attention_mask[i].bool() valid_tokens = input_ids[i][valid_mask] seq_len = valid_mask.sum().float() # Avoid division by zero if seq_len == 0: stats.append(torch.zeros(8, device=device)) continue # Calculate features norm_length = seq_len / 512.0 unique_ratio = len(torch.unique(valid_tokens)) / seq_len special_tokens = (valid_tokens < 1000).float().mean() punct_density = ((valid_tokens > 999) & (valid_tokens < 2000)).float().mean() uppercase_proxy = (valid_tokens > 2000).float().mean() if len(valid_tokens) > 1: repetitions = (valid_tokens[1:] == valid_tokens[:-1]).float().mean() else: repetitions = torch.tensor(0.0, device=device) token_variance = valid_tokens.float().std() / 10000.0 if len(valid_tokens) > 0: _, counts = torch.unique(valid_tokens, return_counts=True) max_freq = counts.max().float() / seq_len else: max_freq = torch.tensor(0.0, device=device) stats.append( torch.tensor( [ norm_length, unique_ratio, special_tokens, punct_density, uppercase_proxy, repetitions, token_variance, max_freq, ], device=device, ) ) stats_tensor = torch.stack(stats) return self.stats_normalizer(stats_tensor)