Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import BertModel | |
| from typing import Tuple, Dict | |
| class AttentionLayer(nn.Module): | |
| def __init__(self, hidden_size: int): | |
| super().__init__() | |
| self.attention = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.Tanh(), | |
| nn.Linear(hidden_size, 1) | |
| ) | |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| attention_weights = torch.softmax(self.attention(x), dim=1) | |
| attended = torch.sum(attention_weights * x, dim=1) | |
| return attended, attention_weights | |
| class HybridFakeNewsDetector(nn.Module): | |
| def __init__(self, | |
| bert_model_name: str = "bert-base-uncased", | |
| lstm_hidden_size: int = 256, | |
| lstm_num_layers: int = 2, | |
| dropout_rate: float = 0.3, | |
| num_classes: int = 2): | |
| super().__init__() | |
| # BERT encoder | |
| self.bert = BertModel.from_pretrained(bert_model_name) | |
| bert_hidden_size = self.bert.config.hidden_size | |
| # BiLSTM layer | |
| self.lstm = nn.LSTM( | |
| input_size=bert_hidden_size, | |
| hidden_size=lstm_hidden_size, | |
| num_layers=lstm_num_layers, | |
| batch_first=True, | |
| bidirectional=True | |
| ) | |
| # Attention layer | |
| self.attention = AttentionLayer(lstm_hidden_size * 2) | |
| # Classification head | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(lstm_hidden_size * 2, lstm_hidden_size), | |
| nn.ReLU(), | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(lstm_hidden_size, num_classes) | |
| ) | |
| def forward(self, input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]: | |
| # Get BERT embeddings | |
| bert_outputs = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| bert_embeddings = bert_outputs.last_hidden_state | |
| # Process through BiLSTM | |
| lstm_output, _ = self.lstm(bert_embeddings) | |
| # Apply attention | |
| attended, attention_weights = self.attention(lstm_output) | |
| # Classification | |
| logits = self.classifier(attended) | |
| return { | |
| 'logits': logits, | |
| 'attention_weights': attention_weights | |
| } | |
| def predict(self, input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor) -> torch.Tensor: | |
| """Get model predictions.""" | |
| outputs = self.forward(input_ids, attention_mask) | |
| return torch.softmax(outputs['logits'], dim=1) | |
| def get_attention_weights(self, input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor) -> torch.Tensor: | |
| """Get attention weights for interpretability.""" | |
| outputs = self.forward(input_ids, attention_mask) | |
| return outputs['attention_weights'] |