Hunter-Pax commited on
Commit
e7a44ba
·
verified ·
1 Parent(s): a4973a5

Upload 18 files

Browse files
app_gradio.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer
5
+ import pickle
6
+
7
+ from models.rnn import RNNClassifier
8
+ from models.lstm import LSTMClassifier
9
+ from models.transformer import TransformerClassifier
10
+ from utility import simple_tokenizer
11
+
12
+ # =========================
13
+ # Load models and vocab
14
+ # =========================
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model_name = "prajjwal1/bert-tiny"
17
+
18
+ def load_vocab():
19
+ with open("pretrained_models/vocab.pkl", "rb") as f:
20
+ return pickle.load(f)
21
+
22
+ def load_models(vocab_size, output_dim=6, padding_idx=0):
23
+ rnn_model = RNNClassifier(vocab_size, 128, 128, output_dim, padding_idx)
24
+ rnn_model.load_state_dict(torch.load("pretrained_models/best_rnn.pt"))
25
+ rnn_model = rnn_model.to(device)
26
+ rnn_model.eval()
27
+
28
+ lstm_model = LSTMClassifier(vocab_size, 128, 128, output_dim, padding_idx)
29
+ lstm_model.load_state_dict(torch.load("pretrained_models/best_lstm.pt"))
30
+ lstm_model = lstm_model.to(device)
31
+ lstm_model.eval()
32
+
33
+ transformer_model = TransformerClassifier(model_name, output_dim)
34
+ transformer_model.load_state_dict(torch.load("pretrained_models/best_transformer.pt", map_location=device))
35
+ transformer_model = transformer_model.to(device)
36
+ transformer_model.eval()
37
+
38
+ return rnn_model, lstm_model, transformer_model
39
+
40
+
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ vocab = load_vocab()
43
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
44
+ rnn_model, lstm_model, transformer_model = load_models(len(vocab))
45
+
46
+ emotions = ["anger", "fear", "joy", "love", "sadness", "surprise"]
47
+
48
+ def predict(model, text, model_type, vocab, tokenizer=None, max_length=32):
49
+ if model_type in ["rnn", "lstm"]:
50
+ # Match collate_fn_rnn but with no random truncation
51
+ tokens = simple_tokenizer(text)
52
+ ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
53
+
54
+ if len(ids) < max_length:
55
+ ids += [vocab["<PAD>"]] * (max_length - len(ids))
56
+ else:
57
+ ids = ids[:max_length]
58
+
59
+ input_ids = torch.tensor([ids], dtype=torch.long).to(device)
60
+ outputs = model(input_ids)
61
+
62
+ else:
63
+ # Match collate_fn_transformer but with no partial_prob
64
+ encoding = tokenizer(
65
+ text,
66
+ padding="max_length",
67
+ truncation=True,
68
+ max_length=128,
69
+ return_tensors="pt"
70
+ )
71
+ input_ids = encoding["input_ids"].to(device)
72
+ attention_mask = encoding["attention_mask"].to(device)
73
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
74
+
75
+ probs = F.softmax(outputs, dim=-1)
76
+ return probs.squeeze().detach().cpu().numpy()
77
+
78
+ # =========================
79
+ # Gradio App
80
+ # =========================
81
+
82
+ def emotion_typeahead(text):
83
+ if len(text.strip()) <= 2:
84
+ return {}, {}, {}
85
+
86
+ rnn_probs = predict(rnn_model, text.strip(), "rnn", vocab)
87
+ lstm_probs = predict(lstm_model, text.strip(), "lstm", vocab)
88
+ transformer_probs = predict(transformer_model, text.strip(), "transformer", vocab, tokenizer)
89
+
90
+ rnn_dict = {emo: float(prob) for emo, prob in zip(emotions, rnn_probs)}
91
+ lstm_dict = {emo: float(prob) for emo, prob in zip(emotions, lstm_probs)}
92
+ transformer_dict = {emo: float(prob) for emo, prob in zip(emotions, transformer_probs)}
93
+
94
+ return rnn_dict, lstm_dict, transformer_dict
95
+
96
+ with gr.Blocks() as demo:
97
+ gr.Markdown("## 🎯 Emotion Typeahead Predictor (RNN, LSTM, Transformer)")
98
+
99
+ text_input = gr.Textbox(label="Type your sentence here...")
100
+
101
+ with gr.Row():
102
+ rnn_output = gr.Label(label="🧠 RNN Prediction")
103
+ lstm_output = gr.Label(label="🧠 LSTM Prediction")
104
+ transformer_output = gr.Label(label="🧠 Transformer Prediction")
105
+
106
+ text_input.change(emotion_typeahead, inputs=text_input, outputs=[rnn_output, lstm_output, transformer_output])
107
+
108
+ demo.launch()
models/lstm.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class LSTMClassifier(nn.Module):
5
+ def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, padding_idx):
6
+ super(LSTMClassifier, self).__init__()
7
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
8
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=1, dropout=0.3, batch_first=True, bidirectional=True)
9
+ self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)
10
+ self.relu = nn.ReLU()
11
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
12
+
13
+ def forward(self, x):
14
+ embedded = self.embedding(x)
15
+ output, (hidden, _) = self.lstm(embedded)
16
+ hidden_cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1) # concatenate last hidden states
17
+ x = self.fc1(hidden_cat)
18
+ x = self.relu(x)
19
+ out = self.fc2(x)
20
+ return out
models/rnn.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class RNNClassifier(nn.Module):
5
+ def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, padding_idx):
6
+ super(RNNClassifier, self).__init__()
7
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
8
+ self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)
9
+ self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2) # New hidden layer
10
+ self.relu = nn.ReLU()
11
+ self.fc2 = nn.Linear(hidden_dim // 2, output_dim)
12
+
13
+ def forward(self, x):
14
+ embedded = self.embedding(x) # [batch_size, seq_len, embed_dim]
15
+ output, hidden = self.rnn(embedded) # hidden: [1, batch_size, hidden_dim]
16
+ x = self.fc1(hidden.squeeze(0)) # [batch_size, hidden_dim//2]
17
+ x = self.relu(x)
18
+ out = self.fc2(x) # [batch_size, output_dim]
19
+ return out
models/transformer.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoModel
4
+
5
+ class TransformerClassifier(nn.Module):
6
+ def __init__(self, model_name, output_dim):
7
+ super(TransformerClassifier, self).__init__()
8
+ self.transformer = AutoModel.from_pretrained(model_name)
9
+ # Freeze bottom 3 layers, unfreeze top layers
10
+ for name, param in self.transformer.named_parameters():
11
+ if "layer.0" in name or "layer.1" in name or "layer.2" in name:
12
+ param.requires_grad = False
13
+ self.fc = nn.Linear(self.transformer.config.hidden_size, output_dim)
14
+
15
+ def forward(self, input_ids, attention_mask):
16
+ outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
17
+ hidden_state = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]
18
+ pooled_output = hidden_state[:, 0] # Use CLS token output
19
+ out = self.fc(pooled_output)
20
+ return out
plots/class_distribution.png ADDED
plots/lstm_confusion_matrices.png ADDED
plots/lstm_loss_curve.png ADDED
plots/rnn_confusion_matrices.png ADDED
plots/rnn_loss_curve.png ADDED
plots/transformer_confusion_matrices.png ADDED
plots/transformer_loss_curve.png ADDED
pretrained_models/best_lstm.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ab822a8bc9c588f9e4b8dc7a2a85385528657b206b02d62d184217bb379ae76
3
+ size 4984493
pretrained_models/best_rnn.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ef0a95f7b1655f46b26691395bbc94558c5346c09f7f90177dbf5c8cb578da0
3
+ size 3959329
pretrained_models/best_transformer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ae3d6c4db4de20f3e2251c67760fa91f6fe3f475135e45c3492114dfd11da0b
3
+ size 17564687
pretrained_models/vocab.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:276d70ee4c7d337f6d07514ee81043fa7c7475e815fb153a667400a45df99f2d
3
+ size 93532
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ datasets
4
+ scikit-learn
5
+ gradio
6
+ numpy==1.26.4
7
+ scipy
8
+ tqdm
9
+ streamlit
10
+ gradio
train.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ from collections import Counter
7
+ from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
8
+ from torch.utils.data import DataLoader
9
+ from transformers import AutoTokenizer
10
+ from utility import (
11
+ load_emotion_dataset,
12
+ encode_labels,
13
+ build_vocab,
14
+ collate_fn_rnn,
15
+ collate_fn_transformer
16
+ )
17
+ from models.rnn import RNNClassifier
18
+ from models.lstm import LSTMClassifier
19
+ from models.transformer import TransformerClassifier
20
+ from tqdm import tqdm
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+ def summarize_class_distribution(dataset, label_encoder):
25
+ labels = [example["label"] for example in dataset]
26
+ counter = Counter(labels)
27
+ print("\n🔍 Class distribution:")
28
+ for label_idx, count in sorted(counter.items()):
29
+ label_name = label_encoder.inverse_transform([label_idx])[0]
30
+ print(f"{label_name:>10}: {count}")
31
+
32
+ def plot_class_countplot(dataset, label_encoder):
33
+ labels = [example["label"] for example in dataset]
34
+ counts = Counter(labels)
35
+ label_display = [label_encoder.inverse_transform([i])[0] for i in sorted(counts.keys())]
36
+ values = [counts[i] for i in sorted(counts.keys())]
37
+
38
+ plt.figure(figsize=(8, 5))
39
+ sns.barplot(x=label_display, y=values)
40
+ plt.title("Emotion Class Distribution (Training Set)")
41
+ plt.xlabel("Emotion")
42
+ plt.ylabel("Count")
43
+ plt.tight_layout()
44
+ os.makedirs("plots", exist_ok=True)
45
+ plt.savefig("plots/class_distribution.png")
46
+ plt.close()
47
+
48
+ def plot_loss_curve(train_losses, test_losses, model_name):
49
+ plt.figure(figsize=(8, 4))
50
+ plt.plot(train_losses, label="Train Loss")
51
+ plt.plot(test_losses, label="Test Loss")
52
+ plt.xlabel("Epoch")
53
+ plt.ylabel("Loss")
54
+ plt.title(f"{model_name} Train vs Test Loss")
55
+ plt.legend()
56
+ os.makedirs("plots", exist_ok=True)
57
+ plt.savefig(f"plots/{model_name.lower()}_loss_curve.png")
58
+ plt.close()
59
+
60
+ def compute_test_loss(model, dataloader, criterion, model_type):
61
+ total_loss = 0
62
+ with torch.no_grad():
63
+ model.eval()
64
+ for batch in dataloader:
65
+ if isinstance(batch, tuple):
66
+ input_ids, labels = batch
67
+ attention_mask = None
68
+ else:
69
+ input_ids = batch["input_ids"]
70
+ attention_mask = batch.get("attention_mask", None)
71
+ labels = batch["labels"]
72
+
73
+ input_ids = input_ids.to(device)
74
+ labels = labels.to(device)
75
+ if attention_mask is not None:
76
+ attention_mask = attention_mask.to(device)
77
+
78
+ if model_type == "transformer":
79
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
80
+ else:
81
+ outputs = model(input_ids)
82
+
83
+ loss = criterion(outputs, labels)
84
+ total_loss += loss.item()
85
+ return total_loss / len(dataloader)
86
+
87
+ def train_model(model, train_loader, test_loader, optimizer, criterion, epochs, model_type="rnn"):
88
+ train_losses = []
89
+ test_losses = []
90
+
91
+ for epoch in range(epochs):
92
+ model.train()
93
+ start_time = time.time()
94
+ total_loss = 0
95
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", ncols=100)
96
+
97
+ for batch in progress_bar:
98
+ optimizer.zero_grad()
99
+
100
+ if isinstance(batch, tuple):
101
+ input_ids, labels = batch
102
+ attention_mask = None
103
+ else:
104
+ input_ids = batch["input_ids"]
105
+ attention_mask = batch.get("attention_mask", None)
106
+ labels = batch["labels"]
107
+
108
+ input_ids = input_ids.to(device)
109
+ labels = labels.to(device)
110
+ if attention_mask is not None:
111
+ attention_mask = attention_mask.to(device)
112
+
113
+ if model_type == "transformer":
114
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
115
+ else:
116
+ outputs = model(input_ids)
117
+
118
+ loss = criterion(outputs, labels)
119
+ loss.backward()
120
+ optimizer.step()
121
+
122
+ total_loss += loss.item()
123
+ avg_loss = total_loss / len(train_loader)
124
+ progress_bar.set_postfix({"Avg Loss": f"{avg_loss:.4f}"})
125
+
126
+ test_loss = compute_test_loss(model, test_loader, criterion, model_type)
127
+ train_losses.append(avg_loss)
128
+ test_losses.append(test_loss)
129
+
130
+ print(f"✅ Epoch {epoch + 1} | Train: {avg_loss:.4f} | Test: {test_loss:.4f} | Time: {time.time() - start_time:.2f}s")
131
+
132
+ torch.cuda.empty_cache()
133
+ del model
134
+ return train_losses, test_losses
135
+
136
+ def evaluate_preds(model, dataloader, model_type="rnn"):
137
+ model.eval()
138
+ all_preds = []
139
+ all_labels = []
140
+ with torch.no_grad():
141
+ for batch in dataloader:
142
+ if isinstance(batch, tuple):
143
+ input_ids, labels = batch
144
+ attention_mask = None
145
+ else:
146
+ input_ids = batch["input_ids"]
147
+ attention_mask = batch.get("attention_mask", None)
148
+ labels = batch["labels"]
149
+
150
+ input_ids = input_ids.to(device)
151
+ labels = labels.to(device)
152
+ if attention_mask is not None:
153
+ attention_mask = attention_mask.to(device)
154
+
155
+ if model_type == "transformer":
156
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
157
+ else:
158
+ outputs = model(input_ids)
159
+
160
+ preds = torch.argmax(outputs, dim=1)
161
+ all_preds.extend(preds.cpu().tolist())
162
+ all_labels.extend(labels.cpu().tolist())
163
+ return all_labels, all_preds
164
+
165
+ def plot_confusion_matrices(y_true_train, y_pred_train, y_true_test, y_pred_test, labels, title, filename):
166
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
167
+ cm_train = confusion_matrix(y_true_train, y_pred_train)
168
+ cm_test = confusion_matrix(y_true_test, y_pred_test)
169
+
170
+ ConfusionMatrixDisplay(cm_train, display_labels=labels).plot(ax=axes[0], cmap='Blues', colorbar=False)
171
+ axes[0].set_title(f"{title} - Train")
172
+
173
+ ConfusionMatrixDisplay(cm_test, display_labels=labels).plot(ax=axes[1], cmap='Oranges', colorbar=False)
174
+ axes[1].set_title(f"{title} - Test")
175
+
176
+ plt.tight_layout()
177
+ os.makedirs("plots", exist_ok=True)
178
+ plt.savefig(f"plots/{filename}")
179
+ plt.close()
180
+
181
+ # Load and encode data
182
+ data = load_emotion_dataset("train")
183
+ train_data, label_encoder = encode_labels(data)
184
+ test_data, _ = encode_labels(load_emotion_dataset("test"))
185
+ labels = label_encoder.classes_
186
+ output_dim = len(labels)
187
+ padding_idx = 0
188
+
189
+ summarize_class_distribution(train_data, label_encoder)
190
+ plot_class_countplot(train_data, label_encoder)
191
+
192
+ # Build vocab
193
+ vocab = build_vocab(train_data)
194
+
195
+ model_name = "prajjwal1/bert-tiny"
196
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
197
+
198
+ # DataLoaders (no augmentation)
199
+ train_loader_rnn = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=lambda b: collate_fn_rnn(b, vocab, partial_prob=0.0))
200
+ test_loader_rnn = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=lambda b: collate_fn_rnn(b, vocab, partial_prob=0.0))
201
+
202
+ train_loader_tf = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=lambda b: collate_fn_transformer(b, tokenizer, partial_prob=0.0))
203
+ test_loader_tf = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=lambda b: collate_fn_transformer(b, tokenizer, partial_prob=0.0))
204
+
205
+ # Initialize and train models
206
+ rnn = RNNClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device)
207
+ lstm = LSTMClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device)
208
+ transformer = TransformerClassifier(model_name, output_dim).to(device)
209
+
210
+ criterion = torch.nn.CrossEntropyLoss()
211
+
212
+ # rnn_train_losses, rnn_test_losses = train_model(rnn, train_loader_rnn, test_loader_rnn, torch.optim.Adam(rnn.parameters(), lr=1e-4), criterion, epochs=50, model_type="rnn")
213
+ # torch.save(rnn.state_dict(), "pretrained_models/best_rnn.pt")
214
+ # plot_loss_curve(rnn_train_losses, rnn_test_losses, "RNN")
215
+ #
216
+ # lstm_train_losses, lstm_test_losses = train_model(lstm, train_loader_rnn, test_loader_rnn, torch.optim.Adam(lstm.parameters(), lr=1e-4), criterion, epochs=50, model_type="lstm")
217
+ # torch.save(lstm.state_dict(), "pretrained_models/best_lstm.pt")
218
+ # plot_loss_curve(lstm_train_losses, lstm_test_losses, "LSTM")
219
+
220
+ tf_train_losses, tf_test_losses = train_model(transformer, train_loader_tf, test_loader_tf, torch.optim.Adam(transformer.parameters(), lr=2e-5), criterion, epochs=50, model_type="transformer")
221
+ torch.save(transformer.state_dict(), "pretrained_models/best_transformer.pt")
222
+ plot_loss_curve(tf_train_losses, tf_test_losses, "Transformer")
223
+
224
+ # Evaluate and plot
225
+ model_paths = {
226
+ "RNN": "pretrained_models/best_rnn.pt",
227
+ "LSTM": "pretrained_models/best_lstm.pt",
228
+ "Transformer": "pretrained_models/best_transformer.pt"
229
+ }
230
+
231
+ for name in ["RNN", "LSTM", "Transformer"]:
232
+ if name == "RNN":
233
+ model = RNNClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device)
234
+ loader = train_loader_rnn
235
+ test_loader = test_loader_rnn
236
+ elif name == "LSTM":
237
+ model = LSTMClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device)
238
+ loader = train_loader_rnn
239
+ test_loader = test_loader_rnn
240
+ else:
241
+ model = TransformerClassifier(model_name, output_dim).to(device)
242
+ loader = train_loader_tf
243
+ test_loader = test_loader_tf
244
+
245
+ model.load_state_dict(torch.load(model_paths[name]))
246
+ model.eval()
247
+
248
+ y_train_true, y_train_pred = evaluate_preds(model, loader, model_type=name.lower())
249
+ y_test_true, y_test_pred = evaluate_preds(model, test_loader, model_type=name.lower())
250
+
251
+ plot_confusion_matrices(
252
+ y_train_true, y_train_pred, y_test_true, y_test_pred,
253
+ labels=labels,
254
+ title=name,
255
+ filename=f"{name.lower()}_confusion_matrices.png"
256
+ )
utility.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ from collections import Counter
4
+ from datasets import load_dataset
5
+ from sklearn.preprocessing import LabelEncoder
6
+ from transformers import AutoTokenizer
7
+ import random
8
+
9
+ # ====== Dataset Loading ======
10
+
11
+ def load_emotion_dataset(split="train"):
12
+ return load_dataset("dair-ai/emotion", split=split)
13
+
14
+ def encode_labels(dataset):
15
+ le = LabelEncoder()
16
+ all_labels = [example["label"] for example in dataset]
17
+ le.fit(all_labels)
18
+ dataset = dataset.map(lambda x: {"label": le.transform([x["label"]])[0]})
19
+ return dataset, le
20
+
21
+ # ====== Tokenizer for RNN/LSTM ======
22
+
23
+ def simple_tokenizer(text):
24
+ text = text.lower()
25
+ text = re.sub(r"[^a-z0-9\s]", "", text) # Remove special characters
26
+ return text.split()
27
+
28
+ # ====== Vocab Builder for RNN/LSTM ======
29
+
30
+ def build_vocab(dataset, min_freq=2):
31
+ counter = Counter()
32
+ for example in dataset:
33
+ tokens = simple_tokenizer(example["text"])
34
+ counter.update(tokens)
35
+
36
+ vocab = {"<PAD>": 0, "<UNK>": 1}
37
+ idx = 2
38
+ for word, freq in counter.items():
39
+ if freq >= min_freq:
40
+ vocab[word] = idx
41
+ idx += 1
42
+ return vocab
43
+
44
+ # ====== Collate Function for RNN/LSTM ======
45
+
46
+ def collate_fn_rnn(batch, vocab, max_length=32, partial_prob=0.0):
47
+ texts = [item["text"] for item in batch]
48
+ labels = [item["label"] for item in batch]
49
+
50
+ all_input_ids = []
51
+ for text in texts:
52
+ tokens = simple_tokenizer(text)
53
+
54
+ # 🔥 Randomly truncate tokens with some probability
55
+ if random.random() < partial_prob and len(tokens) > 5:
56
+ # Keep between 30% to 70% of the tokens
57
+ cutoff = random.randint(int(len(tokens)*0.3), int(len(tokens)*0.7))
58
+ tokens = tokens[:cutoff]
59
+
60
+ ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
61
+ if len(ids) < max_length:
62
+ ids += [vocab["<PAD>"]] * (max_length - len(ids))
63
+ else:
64
+ ids = ids[:max_length]
65
+ all_input_ids.append(ids)
66
+
67
+ input_ids = torch.tensor(all_input_ids)
68
+ labels = torch.tensor(labels)
69
+ return input_ids, labels
70
+
71
+ # ====== Collate Function for Transformer ======
72
+
73
+ def collate_fn_transformer(batch, tokenizer, max_length=128, partial_prob=0.5):
74
+ import random
75
+ texts = []
76
+ labels = []
77
+
78
+ for item in batch:
79
+ text = item["text"]
80
+ tokens = text.split()
81
+
82
+ # 🔥 Random truncation
83
+ if random.random() < partial_prob and len(tokens) > 5:
84
+ cutoff = random.randint(int(len(tokens)*0.3), int(len(tokens)*0.7))
85
+ tokens = tokens[:cutoff]
86
+ text = " ".join(tokens)
87
+
88
+ texts.append(text)
89
+ labels.append(item["label"])
90
+
91
+ encoding = tokenizer(texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
92
+ encoding["labels"] = torch.tensor(labels)
93
+ return encoding
94
+