Spaces:
Sleeping
Sleeping
Upload 18 files
Browse files- app_gradio.py +108 -0
- models/lstm.py +20 -0
- models/rnn.py +19 -0
- models/transformer.py +20 -0
- plots/class_distribution.png +0 -0
- plots/lstm_confusion_matrices.png +0 -0
- plots/lstm_loss_curve.png +0 -0
- plots/rnn_confusion_matrices.png +0 -0
- plots/rnn_loss_curve.png +0 -0
- plots/transformer_confusion_matrices.png +0 -0
- plots/transformer_loss_curve.png +0 -0
- pretrained_models/best_lstm.pt +3 -0
- pretrained_models/best_rnn.pt +3 -0
- pretrained_models/best_transformer.pt +3 -0
- pretrained_models/vocab.pkl +3 -0
- requirements.txt +10 -0
- train.py +256 -0
- utility.py +94 -0
app_gradio.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
import pickle
|
6 |
+
|
7 |
+
from models.rnn import RNNClassifier
|
8 |
+
from models.lstm import LSTMClassifier
|
9 |
+
from models.transformer import TransformerClassifier
|
10 |
+
from utility import simple_tokenizer
|
11 |
+
|
12 |
+
# =========================
|
13 |
+
# Load models and vocab
|
14 |
+
# =========================
|
15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
model_name = "prajjwal1/bert-tiny"
|
17 |
+
|
18 |
+
def load_vocab():
|
19 |
+
with open("pretrained_models/vocab.pkl", "rb") as f:
|
20 |
+
return pickle.load(f)
|
21 |
+
|
22 |
+
def load_models(vocab_size, output_dim=6, padding_idx=0):
|
23 |
+
rnn_model = RNNClassifier(vocab_size, 128, 128, output_dim, padding_idx)
|
24 |
+
rnn_model.load_state_dict(torch.load("pretrained_models/best_rnn.pt"))
|
25 |
+
rnn_model = rnn_model.to(device)
|
26 |
+
rnn_model.eval()
|
27 |
+
|
28 |
+
lstm_model = LSTMClassifier(vocab_size, 128, 128, output_dim, padding_idx)
|
29 |
+
lstm_model.load_state_dict(torch.load("pretrained_models/best_lstm.pt"))
|
30 |
+
lstm_model = lstm_model.to(device)
|
31 |
+
lstm_model.eval()
|
32 |
+
|
33 |
+
transformer_model = TransformerClassifier(model_name, output_dim)
|
34 |
+
transformer_model.load_state_dict(torch.load("pretrained_models/best_transformer.pt", map_location=device))
|
35 |
+
transformer_model = transformer_model.to(device)
|
36 |
+
transformer_model.eval()
|
37 |
+
|
38 |
+
return rnn_model, lstm_model, transformer_model
|
39 |
+
|
40 |
+
|
41 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
+
vocab = load_vocab()
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
44 |
+
rnn_model, lstm_model, transformer_model = load_models(len(vocab))
|
45 |
+
|
46 |
+
emotions = ["anger", "fear", "joy", "love", "sadness", "surprise"]
|
47 |
+
|
48 |
+
def predict(model, text, model_type, vocab, tokenizer=None, max_length=32):
|
49 |
+
if model_type in ["rnn", "lstm"]:
|
50 |
+
# Match collate_fn_rnn but with no random truncation
|
51 |
+
tokens = simple_tokenizer(text)
|
52 |
+
ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
|
53 |
+
|
54 |
+
if len(ids) < max_length:
|
55 |
+
ids += [vocab["<PAD>"]] * (max_length - len(ids))
|
56 |
+
else:
|
57 |
+
ids = ids[:max_length]
|
58 |
+
|
59 |
+
input_ids = torch.tensor([ids], dtype=torch.long).to(device)
|
60 |
+
outputs = model(input_ids)
|
61 |
+
|
62 |
+
else:
|
63 |
+
# Match collate_fn_transformer but with no partial_prob
|
64 |
+
encoding = tokenizer(
|
65 |
+
text,
|
66 |
+
padding="max_length",
|
67 |
+
truncation=True,
|
68 |
+
max_length=128,
|
69 |
+
return_tensors="pt"
|
70 |
+
)
|
71 |
+
input_ids = encoding["input_ids"].to(device)
|
72 |
+
attention_mask = encoding["attention_mask"].to(device)
|
73 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
74 |
+
|
75 |
+
probs = F.softmax(outputs, dim=-1)
|
76 |
+
return probs.squeeze().detach().cpu().numpy()
|
77 |
+
|
78 |
+
# =========================
|
79 |
+
# Gradio App
|
80 |
+
# =========================
|
81 |
+
|
82 |
+
def emotion_typeahead(text):
|
83 |
+
if len(text.strip()) <= 2:
|
84 |
+
return {}, {}, {}
|
85 |
+
|
86 |
+
rnn_probs = predict(rnn_model, text.strip(), "rnn", vocab)
|
87 |
+
lstm_probs = predict(lstm_model, text.strip(), "lstm", vocab)
|
88 |
+
transformer_probs = predict(transformer_model, text.strip(), "transformer", vocab, tokenizer)
|
89 |
+
|
90 |
+
rnn_dict = {emo: float(prob) for emo, prob in zip(emotions, rnn_probs)}
|
91 |
+
lstm_dict = {emo: float(prob) for emo, prob in zip(emotions, lstm_probs)}
|
92 |
+
transformer_dict = {emo: float(prob) for emo, prob in zip(emotions, transformer_probs)}
|
93 |
+
|
94 |
+
return rnn_dict, lstm_dict, transformer_dict
|
95 |
+
|
96 |
+
with gr.Blocks() as demo:
|
97 |
+
gr.Markdown("## 🎯 Emotion Typeahead Predictor (RNN, LSTM, Transformer)")
|
98 |
+
|
99 |
+
text_input = gr.Textbox(label="Type your sentence here...")
|
100 |
+
|
101 |
+
with gr.Row():
|
102 |
+
rnn_output = gr.Label(label="🧠 RNN Prediction")
|
103 |
+
lstm_output = gr.Label(label="🧠 LSTM Prediction")
|
104 |
+
transformer_output = gr.Label(label="🧠 Transformer Prediction")
|
105 |
+
|
106 |
+
text_input.change(emotion_typeahead, inputs=text_input, outputs=[rnn_output, lstm_output, transformer_output])
|
107 |
+
|
108 |
+
demo.launch()
|
models/lstm.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class LSTMClassifier(nn.Module):
|
5 |
+
def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, padding_idx):
|
6 |
+
super(LSTMClassifier, self).__init__()
|
7 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
|
8 |
+
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=1, dropout=0.3, batch_first=True, bidirectional=True)
|
9 |
+
self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)
|
10 |
+
self.relu = nn.ReLU()
|
11 |
+
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
embedded = self.embedding(x)
|
15 |
+
output, (hidden, _) = self.lstm(embedded)
|
16 |
+
hidden_cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1) # concatenate last hidden states
|
17 |
+
x = self.fc1(hidden_cat)
|
18 |
+
x = self.relu(x)
|
19 |
+
out = self.fc2(x)
|
20 |
+
return out
|
models/rnn.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class RNNClassifier(nn.Module):
|
5 |
+
def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, padding_idx):
|
6 |
+
super(RNNClassifier, self).__init__()
|
7 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
|
8 |
+
self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)
|
9 |
+
self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2) # New hidden layer
|
10 |
+
self.relu = nn.ReLU()
|
11 |
+
self.fc2 = nn.Linear(hidden_dim // 2, output_dim)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
embedded = self.embedding(x) # [batch_size, seq_len, embed_dim]
|
15 |
+
output, hidden = self.rnn(embedded) # hidden: [1, batch_size, hidden_dim]
|
16 |
+
x = self.fc1(hidden.squeeze(0)) # [batch_size, hidden_dim//2]
|
17 |
+
x = self.relu(x)
|
18 |
+
out = self.fc2(x) # [batch_size, output_dim]
|
19 |
+
return out
|
models/transformer.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import AutoModel
|
4 |
+
|
5 |
+
class TransformerClassifier(nn.Module):
|
6 |
+
def __init__(self, model_name, output_dim):
|
7 |
+
super(TransformerClassifier, self).__init__()
|
8 |
+
self.transformer = AutoModel.from_pretrained(model_name)
|
9 |
+
# Freeze bottom 3 layers, unfreeze top layers
|
10 |
+
for name, param in self.transformer.named_parameters():
|
11 |
+
if "layer.0" in name or "layer.1" in name or "layer.2" in name:
|
12 |
+
param.requires_grad = False
|
13 |
+
self.fc = nn.Linear(self.transformer.config.hidden_size, output_dim)
|
14 |
+
|
15 |
+
def forward(self, input_ids, attention_mask):
|
16 |
+
outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
|
17 |
+
hidden_state = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]
|
18 |
+
pooled_output = hidden_state[:, 0] # Use CLS token output
|
19 |
+
out = self.fc(pooled_output)
|
20 |
+
return out
|
plots/class_distribution.png
ADDED
![]() |
plots/lstm_confusion_matrices.png
ADDED
![]() |
plots/lstm_loss_curve.png
ADDED
![]() |
plots/rnn_confusion_matrices.png
ADDED
![]() |
plots/rnn_loss_curve.png
ADDED
![]() |
plots/transformer_confusion_matrices.png
ADDED
![]() |
plots/transformer_loss_curve.png
ADDED
![]() |
pretrained_models/best_lstm.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ab822a8bc9c588f9e4b8dc7a2a85385528657b206b02d62d184217bb379ae76
|
3 |
+
size 4984493
|
pretrained_models/best_rnn.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0ef0a95f7b1655f46b26691395bbc94558c5346c09f7f90177dbf5c8cb578da0
|
3 |
+
size 3959329
|
pretrained_models/best_transformer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ae3d6c4db4de20f3e2251c67760fa91f6fe3f475135e45c3492114dfd11da0b
|
3 |
+
size 17564687
|
pretrained_models/vocab.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:276d70ee4c7d337f6d07514ee81043fa7c7475e815fb153a667400a45df99f2d
|
3 |
+
size 93532
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
datasets
|
4 |
+
scikit-learn
|
5 |
+
gradio
|
6 |
+
numpy==1.26.4
|
7 |
+
scipy
|
8 |
+
tqdm
|
9 |
+
streamlit
|
10 |
+
gradio
|
train.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
+
from collections import Counter
|
7 |
+
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
from utility import (
|
11 |
+
load_emotion_dataset,
|
12 |
+
encode_labels,
|
13 |
+
build_vocab,
|
14 |
+
collate_fn_rnn,
|
15 |
+
collate_fn_transformer
|
16 |
+
)
|
17 |
+
from models.rnn import RNNClassifier
|
18 |
+
from models.lstm import LSTMClassifier
|
19 |
+
from models.transformer import TransformerClassifier
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
|
24 |
+
def summarize_class_distribution(dataset, label_encoder):
|
25 |
+
labels = [example["label"] for example in dataset]
|
26 |
+
counter = Counter(labels)
|
27 |
+
print("\n🔍 Class distribution:")
|
28 |
+
for label_idx, count in sorted(counter.items()):
|
29 |
+
label_name = label_encoder.inverse_transform([label_idx])[0]
|
30 |
+
print(f"{label_name:>10}: {count}")
|
31 |
+
|
32 |
+
def plot_class_countplot(dataset, label_encoder):
|
33 |
+
labels = [example["label"] for example in dataset]
|
34 |
+
counts = Counter(labels)
|
35 |
+
label_display = [label_encoder.inverse_transform([i])[0] for i in sorted(counts.keys())]
|
36 |
+
values = [counts[i] for i in sorted(counts.keys())]
|
37 |
+
|
38 |
+
plt.figure(figsize=(8, 5))
|
39 |
+
sns.barplot(x=label_display, y=values)
|
40 |
+
plt.title("Emotion Class Distribution (Training Set)")
|
41 |
+
plt.xlabel("Emotion")
|
42 |
+
plt.ylabel("Count")
|
43 |
+
plt.tight_layout()
|
44 |
+
os.makedirs("plots", exist_ok=True)
|
45 |
+
plt.savefig("plots/class_distribution.png")
|
46 |
+
plt.close()
|
47 |
+
|
48 |
+
def plot_loss_curve(train_losses, test_losses, model_name):
|
49 |
+
plt.figure(figsize=(8, 4))
|
50 |
+
plt.plot(train_losses, label="Train Loss")
|
51 |
+
plt.plot(test_losses, label="Test Loss")
|
52 |
+
plt.xlabel("Epoch")
|
53 |
+
plt.ylabel("Loss")
|
54 |
+
plt.title(f"{model_name} Train vs Test Loss")
|
55 |
+
plt.legend()
|
56 |
+
os.makedirs("plots", exist_ok=True)
|
57 |
+
plt.savefig(f"plots/{model_name.lower()}_loss_curve.png")
|
58 |
+
plt.close()
|
59 |
+
|
60 |
+
def compute_test_loss(model, dataloader, criterion, model_type):
|
61 |
+
total_loss = 0
|
62 |
+
with torch.no_grad():
|
63 |
+
model.eval()
|
64 |
+
for batch in dataloader:
|
65 |
+
if isinstance(batch, tuple):
|
66 |
+
input_ids, labels = batch
|
67 |
+
attention_mask = None
|
68 |
+
else:
|
69 |
+
input_ids = batch["input_ids"]
|
70 |
+
attention_mask = batch.get("attention_mask", None)
|
71 |
+
labels = batch["labels"]
|
72 |
+
|
73 |
+
input_ids = input_ids.to(device)
|
74 |
+
labels = labels.to(device)
|
75 |
+
if attention_mask is not None:
|
76 |
+
attention_mask = attention_mask.to(device)
|
77 |
+
|
78 |
+
if model_type == "transformer":
|
79 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
80 |
+
else:
|
81 |
+
outputs = model(input_ids)
|
82 |
+
|
83 |
+
loss = criterion(outputs, labels)
|
84 |
+
total_loss += loss.item()
|
85 |
+
return total_loss / len(dataloader)
|
86 |
+
|
87 |
+
def train_model(model, train_loader, test_loader, optimizer, criterion, epochs, model_type="rnn"):
|
88 |
+
train_losses = []
|
89 |
+
test_losses = []
|
90 |
+
|
91 |
+
for epoch in range(epochs):
|
92 |
+
model.train()
|
93 |
+
start_time = time.time()
|
94 |
+
total_loss = 0
|
95 |
+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", ncols=100)
|
96 |
+
|
97 |
+
for batch in progress_bar:
|
98 |
+
optimizer.zero_grad()
|
99 |
+
|
100 |
+
if isinstance(batch, tuple):
|
101 |
+
input_ids, labels = batch
|
102 |
+
attention_mask = None
|
103 |
+
else:
|
104 |
+
input_ids = batch["input_ids"]
|
105 |
+
attention_mask = batch.get("attention_mask", None)
|
106 |
+
labels = batch["labels"]
|
107 |
+
|
108 |
+
input_ids = input_ids.to(device)
|
109 |
+
labels = labels.to(device)
|
110 |
+
if attention_mask is not None:
|
111 |
+
attention_mask = attention_mask.to(device)
|
112 |
+
|
113 |
+
if model_type == "transformer":
|
114 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
115 |
+
else:
|
116 |
+
outputs = model(input_ids)
|
117 |
+
|
118 |
+
loss = criterion(outputs, labels)
|
119 |
+
loss.backward()
|
120 |
+
optimizer.step()
|
121 |
+
|
122 |
+
total_loss += loss.item()
|
123 |
+
avg_loss = total_loss / len(train_loader)
|
124 |
+
progress_bar.set_postfix({"Avg Loss": f"{avg_loss:.4f}"})
|
125 |
+
|
126 |
+
test_loss = compute_test_loss(model, test_loader, criterion, model_type)
|
127 |
+
train_losses.append(avg_loss)
|
128 |
+
test_losses.append(test_loss)
|
129 |
+
|
130 |
+
print(f"✅ Epoch {epoch + 1} | Train: {avg_loss:.4f} | Test: {test_loss:.4f} | Time: {time.time() - start_time:.2f}s")
|
131 |
+
|
132 |
+
torch.cuda.empty_cache()
|
133 |
+
del model
|
134 |
+
return train_losses, test_losses
|
135 |
+
|
136 |
+
def evaluate_preds(model, dataloader, model_type="rnn"):
|
137 |
+
model.eval()
|
138 |
+
all_preds = []
|
139 |
+
all_labels = []
|
140 |
+
with torch.no_grad():
|
141 |
+
for batch in dataloader:
|
142 |
+
if isinstance(batch, tuple):
|
143 |
+
input_ids, labels = batch
|
144 |
+
attention_mask = None
|
145 |
+
else:
|
146 |
+
input_ids = batch["input_ids"]
|
147 |
+
attention_mask = batch.get("attention_mask", None)
|
148 |
+
labels = batch["labels"]
|
149 |
+
|
150 |
+
input_ids = input_ids.to(device)
|
151 |
+
labels = labels.to(device)
|
152 |
+
if attention_mask is not None:
|
153 |
+
attention_mask = attention_mask.to(device)
|
154 |
+
|
155 |
+
if model_type == "transformer":
|
156 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
157 |
+
else:
|
158 |
+
outputs = model(input_ids)
|
159 |
+
|
160 |
+
preds = torch.argmax(outputs, dim=1)
|
161 |
+
all_preds.extend(preds.cpu().tolist())
|
162 |
+
all_labels.extend(labels.cpu().tolist())
|
163 |
+
return all_labels, all_preds
|
164 |
+
|
165 |
+
def plot_confusion_matrices(y_true_train, y_pred_train, y_true_test, y_pred_test, labels, title, filename):
|
166 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
167 |
+
cm_train = confusion_matrix(y_true_train, y_pred_train)
|
168 |
+
cm_test = confusion_matrix(y_true_test, y_pred_test)
|
169 |
+
|
170 |
+
ConfusionMatrixDisplay(cm_train, display_labels=labels).plot(ax=axes[0], cmap='Blues', colorbar=False)
|
171 |
+
axes[0].set_title(f"{title} - Train")
|
172 |
+
|
173 |
+
ConfusionMatrixDisplay(cm_test, display_labels=labels).plot(ax=axes[1], cmap='Oranges', colorbar=False)
|
174 |
+
axes[1].set_title(f"{title} - Test")
|
175 |
+
|
176 |
+
plt.tight_layout()
|
177 |
+
os.makedirs("plots", exist_ok=True)
|
178 |
+
plt.savefig(f"plots/{filename}")
|
179 |
+
plt.close()
|
180 |
+
|
181 |
+
# Load and encode data
|
182 |
+
data = load_emotion_dataset("train")
|
183 |
+
train_data, label_encoder = encode_labels(data)
|
184 |
+
test_data, _ = encode_labels(load_emotion_dataset("test"))
|
185 |
+
labels = label_encoder.classes_
|
186 |
+
output_dim = len(labels)
|
187 |
+
padding_idx = 0
|
188 |
+
|
189 |
+
summarize_class_distribution(train_data, label_encoder)
|
190 |
+
plot_class_countplot(train_data, label_encoder)
|
191 |
+
|
192 |
+
# Build vocab
|
193 |
+
vocab = build_vocab(train_data)
|
194 |
+
|
195 |
+
model_name = "prajjwal1/bert-tiny"
|
196 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
197 |
+
|
198 |
+
# DataLoaders (no augmentation)
|
199 |
+
train_loader_rnn = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=lambda b: collate_fn_rnn(b, vocab, partial_prob=0.0))
|
200 |
+
test_loader_rnn = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=lambda b: collate_fn_rnn(b, vocab, partial_prob=0.0))
|
201 |
+
|
202 |
+
train_loader_tf = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=lambda b: collate_fn_transformer(b, tokenizer, partial_prob=0.0))
|
203 |
+
test_loader_tf = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=lambda b: collate_fn_transformer(b, tokenizer, partial_prob=0.0))
|
204 |
+
|
205 |
+
# Initialize and train models
|
206 |
+
rnn = RNNClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device)
|
207 |
+
lstm = LSTMClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device)
|
208 |
+
transformer = TransformerClassifier(model_name, output_dim).to(device)
|
209 |
+
|
210 |
+
criterion = torch.nn.CrossEntropyLoss()
|
211 |
+
|
212 |
+
# rnn_train_losses, rnn_test_losses = train_model(rnn, train_loader_rnn, test_loader_rnn, torch.optim.Adam(rnn.parameters(), lr=1e-4), criterion, epochs=50, model_type="rnn")
|
213 |
+
# torch.save(rnn.state_dict(), "pretrained_models/best_rnn.pt")
|
214 |
+
# plot_loss_curve(rnn_train_losses, rnn_test_losses, "RNN")
|
215 |
+
#
|
216 |
+
# lstm_train_losses, lstm_test_losses = train_model(lstm, train_loader_rnn, test_loader_rnn, torch.optim.Adam(lstm.parameters(), lr=1e-4), criterion, epochs=50, model_type="lstm")
|
217 |
+
# torch.save(lstm.state_dict(), "pretrained_models/best_lstm.pt")
|
218 |
+
# plot_loss_curve(lstm_train_losses, lstm_test_losses, "LSTM")
|
219 |
+
|
220 |
+
tf_train_losses, tf_test_losses = train_model(transformer, train_loader_tf, test_loader_tf, torch.optim.Adam(transformer.parameters(), lr=2e-5), criterion, epochs=50, model_type="transformer")
|
221 |
+
torch.save(transformer.state_dict(), "pretrained_models/best_transformer.pt")
|
222 |
+
plot_loss_curve(tf_train_losses, tf_test_losses, "Transformer")
|
223 |
+
|
224 |
+
# Evaluate and plot
|
225 |
+
model_paths = {
|
226 |
+
"RNN": "pretrained_models/best_rnn.pt",
|
227 |
+
"LSTM": "pretrained_models/best_lstm.pt",
|
228 |
+
"Transformer": "pretrained_models/best_transformer.pt"
|
229 |
+
}
|
230 |
+
|
231 |
+
for name in ["RNN", "LSTM", "Transformer"]:
|
232 |
+
if name == "RNN":
|
233 |
+
model = RNNClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device)
|
234 |
+
loader = train_loader_rnn
|
235 |
+
test_loader = test_loader_rnn
|
236 |
+
elif name == "LSTM":
|
237 |
+
model = LSTMClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device)
|
238 |
+
loader = train_loader_rnn
|
239 |
+
test_loader = test_loader_rnn
|
240 |
+
else:
|
241 |
+
model = TransformerClassifier(model_name, output_dim).to(device)
|
242 |
+
loader = train_loader_tf
|
243 |
+
test_loader = test_loader_tf
|
244 |
+
|
245 |
+
model.load_state_dict(torch.load(model_paths[name]))
|
246 |
+
model.eval()
|
247 |
+
|
248 |
+
y_train_true, y_train_pred = evaluate_preds(model, loader, model_type=name.lower())
|
249 |
+
y_test_true, y_test_pred = evaluate_preds(model, test_loader, model_type=name.lower())
|
250 |
+
|
251 |
+
plot_confusion_matrices(
|
252 |
+
y_train_true, y_train_pred, y_test_true, y_test_pred,
|
253 |
+
labels=labels,
|
254 |
+
title=name,
|
255 |
+
filename=f"{name.lower()}_confusion_matrices.png"
|
256 |
+
)
|
utility.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import torch
|
3 |
+
from collections import Counter
|
4 |
+
from datasets import load_dataset
|
5 |
+
from sklearn.preprocessing import LabelEncoder
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
import random
|
8 |
+
|
9 |
+
# ====== Dataset Loading ======
|
10 |
+
|
11 |
+
def load_emotion_dataset(split="train"):
|
12 |
+
return load_dataset("dair-ai/emotion", split=split)
|
13 |
+
|
14 |
+
def encode_labels(dataset):
|
15 |
+
le = LabelEncoder()
|
16 |
+
all_labels = [example["label"] for example in dataset]
|
17 |
+
le.fit(all_labels)
|
18 |
+
dataset = dataset.map(lambda x: {"label": le.transform([x["label"]])[0]})
|
19 |
+
return dataset, le
|
20 |
+
|
21 |
+
# ====== Tokenizer for RNN/LSTM ======
|
22 |
+
|
23 |
+
def simple_tokenizer(text):
|
24 |
+
text = text.lower()
|
25 |
+
text = re.sub(r"[^a-z0-9\s]", "", text) # Remove special characters
|
26 |
+
return text.split()
|
27 |
+
|
28 |
+
# ====== Vocab Builder for RNN/LSTM ======
|
29 |
+
|
30 |
+
def build_vocab(dataset, min_freq=2):
|
31 |
+
counter = Counter()
|
32 |
+
for example in dataset:
|
33 |
+
tokens = simple_tokenizer(example["text"])
|
34 |
+
counter.update(tokens)
|
35 |
+
|
36 |
+
vocab = {"<PAD>": 0, "<UNK>": 1}
|
37 |
+
idx = 2
|
38 |
+
for word, freq in counter.items():
|
39 |
+
if freq >= min_freq:
|
40 |
+
vocab[word] = idx
|
41 |
+
idx += 1
|
42 |
+
return vocab
|
43 |
+
|
44 |
+
# ====== Collate Function for RNN/LSTM ======
|
45 |
+
|
46 |
+
def collate_fn_rnn(batch, vocab, max_length=32, partial_prob=0.0):
|
47 |
+
texts = [item["text"] for item in batch]
|
48 |
+
labels = [item["label"] for item in batch]
|
49 |
+
|
50 |
+
all_input_ids = []
|
51 |
+
for text in texts:
|
52 |
+
tokens = simple_tokenizer(text)
|
53 |
+
|
54 |
+
# 🔥 Randomly truncate tokens with some probability
|
55 |
+
if random.random() < partial_prob and len(tokens) > 5:
|
56 |
+
# Keep between 30% to 70% of the tokens
|
57 |
+
cutoff = random.randint(int(len(tokens)*0.3), int(len(tokens)*0.7))
|
58 |
+
tokens = tokens[:cutoff]
|
59 |
+
|
60 |
+
ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
|
61 |
+
if len(ids) < max_length:
|
62 |
+
ids += [vocab["<PAD>"]] * (max_length - len(ids))
|
63 |
+
else:
|
64 |
+
ids = ids[:max_length]
|
65 |
+
all_input_ids.append(ids)
|
66 |
+
|
67 |
+
input_ids = torch.tensor(all_input_ids)
|
68 |
+
labels = torch.tensor(labels)
|
69 |
+
return input_ids, labels
|
70 |
+
|
71 |
+
# ====== Collate Function for Transformer ======
|
72 |
+
|
73 |
+
def collate_fn_transformer(batch, tokenizer, max_length=128, partial_prob=0.5):
|
74 |
+
import random
|
75 |
+
texts = []
|
76 |
+
labels = []
|
77 |
+
|
78 |
+
for item in batch:
|
79 |
+
text = item["text"]
|
80 |
+
tokens = text.split()
|
81 |
+
|
82 |
+
# 🔥 Random truncation
|
83 |
+
if random.random() < partial_prob and len(tokens) > 5:
|
84 |
+
cutoff = random.randint(int(len(tokens)*0.3), int(len(tokens)*0.7))
|
85 |
+
tokens = tokens[:cutoff]
|
86 |
+
text = " ".join(tokens)
|
87 |
+
|
88 |
+
texts.append(text)
|
89 |
+
labels.append(item["label"])
|
90 |
+
|
91 |
+
encoding = tokenizer(texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
|
92 |
+
encoding["labels"] = torch.tensor(labels)
|
93 |
+
return encoding
|
94 |
+
|