import torch import torch.nn as nn from transformers import BertModel, BertTokenizer import gradio as gr # Config MODEL_NAME = "yiyanghkust/finbert-tone" FACTOR_MODEL_PATH = "finbert_factors.pth" FRAMING_MODEL_PATH = "finbert_framing.pth" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Tokenizer tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) # Label maps label_map_factors = {0: 'Internal Factor', 1: 'External Factor', 2: 'No Factor'} label_map_framing = {0: 'Internal Framing', 1: 'External Framing', 2: 'No Framing'} # Unified model class (used for both) class SingleTaskClassifier(nn.Module): def __init__(self): super().__init__() self.bert = BertModel.from_pretrained(MODEL_NAME) self.dropout = nn.Dropout(0.1) self.classifier = nn.Linear(self.bert.config.hidden_size, 3) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled = self.dropout(outputs.pooler_output) return self.classifier(pooled) # Load models factor_model = SingleTaskClassifier().to(device) framing_model = SingleTaskClassifier().to(device) factor_model.load_state_dict(torch.load(FACTOR_MODEL_PATH, map_location=device)) framing_model.load_state_dict(torch.load(FRAMING_MODEL_PATH, map_location=device)) factor_model.eval() framing_model.eval() # Prediction function def predict(text): encoding = tokenizer( text, return_tensors="pt", padding="max_length", truncation=True, max_length=128 ) input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) with torch.no_grad(): logits_factors = factor_model(input_ids, attention_mask) logits_framing = framing_model(input_ids, attention_mask) pred_factors = torch.argmax(logits_factors, dim=1).item() pred_framing = torch.argmax(logits_framing, dim=1).item() return label_map_factors[pred_factors], label_map_framing[pred_framing] # Gradio interface gr.Interface( fn=predict, inputs=gr.Textbox(lines=3, placeholder="Enter a sentence to analyze..."), outputs=["text", "text"], title="FinBERT Dual Classifier", description="This demo independently predicts both Factors and Framing using two fine-tuned FinBERT models." ).launch(share=True)