Spaces:
Sleeping
Sleeping
File size: 2,384 Bytes
51e9bcf 8d965e0 51e9bcf 1752355 51e9bcf 1752355 8d965e0 51e9bcf 1752355 8d965e0 f730f44 8d965e0 1752355 51e9bcf 1752355 51e9bcf 1752355 51e9bcf 1752355 51e9bcf 1752355 51e9bcf 1752355 51e9bcf 8d965e0 51e9bcf 8d965e0 1752355 8d965e0 51e9bcf 1752355 8d965e0 f730f44 8d965e0 f730f44 51e9bcf 8d965e0 f730f44 1752355 8d965e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
import gradio as gr
# Config
MODEL_NAME = "yiyanghkust/finbert-tone"
FACTOR_MODEL_PATH = "finbert_factors.pth"
FRAMING_MODEL_PATH = "finbert_framing.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Tokenizer
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
# Label maps
label_map_factors = {0: 'Internal Factor', 1: 'External Factor', 2: 'No Factor'}
label_map_framing = {0: 'Internal Framing', 1: 'External Framing', 2: 'No Framing'}
# Unified model class (used for both)
class SingleTaskClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = BertModel.from_pretrained(MODEL_NAME)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(self.bert.config.hidden_size, 3)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled = self.dropout(outputs.pooler_output)
return self.classifier(pooled)
# Load models
factor_model = SingleTaskClassifier().to(device)
framing_model = SingleTaskClassifier().to(device)
factor_model.load_state_dict(torch.load(FACTOR_MODEL_PATH, map_location=device))
framing_model.load_state_dict(torch.load(FRAMING_MODEL_PATH, map_location=device))
factor_model.eval()
framing_model.eval()
# Prediction function
def predict(text):
encoding = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=128
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
with torch.no_grad():
logits_factors = factor_model(input_ids, attention_mask)
logits_framing = framing_model(input_ids, attention_mask)
pred_factors = torch.argmax(logits_factors, dim=1).item()
pred_framing = torch.argmax(logits_framing, dim=1).item()
return label_map_factors[pred_factors], label_map_framing[pred_framing]
# Gradio interface
gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, placeholder="Enter a sentence to analyze..."),
outputs=["text", "text"],
title="FinBERT Dual Classifier",
description="This demo independently predicts both Factors and Framing using two fine-tuned FinBERT models."
).launch(share=True) |