File size: 2,384 Bytes
51e9bcf
 
8d965e0
51e9bcf
 
1752355
51e9bcf
1752355
 
8d965e0
51e9bcf
1752355
 
 
8d965e0
f730f44
 
8d965e0
1752355
 
 
51e9bcf
1752355
51e9bcf
1752355
51e9bcf
1752355
51e9bcf
 
1752355
51e9bcf
1752355
 
 
 
 
 
 
51e9bcf
8d965e0
51e9bcf
8d965e0
 
 
 
 
 
 
1752355
 
8d965e0
51e9bcf
1752355
 
8d965e0
f730f44
 
8d965e0
f730f44
51e9bcf
8d965e0
 
 
 
f730f44
1752355
 
8d965e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
import gradio as gr

# Config
MODEL_NAME = "yiyanghkust/finbert-tone"
FACTOR_MODEL_PATH = "finbert_factors.pth"
FRAMING_MODEL_PATH = "finbert_framing.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Tokenizer
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

# Label maps
label_map_factors = {0: 'Internal Factor', 1: 'External Factor', 2: 'No Factor'}
label_map_framing = {0: 'Internal Framing', 1: 'External Framing', 2: 'No Framing'}

# Unified model class (used for both)
class SingleTaskClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained(MODEL_NAME)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 3)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self.dropout(outputs.pooler_output)
        return self.classifier(pooled)

# Load models
factor_model = SingleTaskClassifier().to(device)
framing_model = SingleTaskClassifier().to(device)
factor_model.load_state_dict(torch.load(FACTOR_MODEL_PATH, map_location=device))
framing_model.load_state_dict(torch.load(FRAMING_MODEL_PATH, map_location=device))
factor_model.eval()
framing_model.eval()

# Prediction function
def predict(text):
    encoding = tokenizer(
        text,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=128
    )
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    with torch.no_grad():
        logits_factors = factor_model(input_ids, attention_mask)
        logits_framing = framing_model(input_ids, attention_mask)

    pred_factors = torch.argmax(logits_factors, dim=1).item()
    pred_framing = torch.argmax(logits_framing, dim=1).item()

    return label_map_factors[pred_factors], label_map_framing[pred_framing]

# Gradio interface
gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=3, placeholder="Enter a sentence to analyze..."),
    outputs=["text", "text"],
    title="FinBERT Dual Classifier",
    description="This demo independently predicts both Factors and Framing using two fine-tuned FinBERT models."
).launch(share=True)