Spaces:
Running
Running
import os | |
import time | |
import torch | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from collections import Counter | |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay | |
from torch.utils.data import DataLoader | |
from transformers import AutoTokenizer | |
from utility import ( | |
load_emotion_dataset, | |
encode_labels, | |
build_vocab, | |
collate_fn_rnn, | |
collate_fn_transformer | |
) | |
from models.rnn import RNNClassifier | |
from models.lstm import LSTMClassifier | |
from models.transformer import TransformerClassifier | |
from tqdm import tqdm | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def summarize_class_distribution(dataset, label_encoder): | |
labels = [example["label"] for example in dataset] | |
counter = Counter(labels) | |
print("\nπ Class distribution:") | |
for label_idx, count in sorted(counter.items()): | |
label_name = label_encoder.inverse_transform([label_idx])[0] | |
print(f"{label_name:>10}: {count}") | |
def plot_class_countplot(dataset, label_encoder): | |
labels = [example["label"] for example in dataset] | |
counts = Counter(labels) | |
label_display = [label_encoder.inverse_transform([i])[0] for i in sorted(counts.keys())] | |
values = [counts[i] for i in sorted(counts.keys())] | |
plt.figure(figsize=(8, 5)) | |
sns.barplot(x=label_display, y=values) | |
plt.title("Emotion Class Distribution (Training Set)") | |
plt.xlabel("Emotion") | |
plt.ylabel("Count") | |
plt.tight_layout() | |
os.makedirs("plots", exist_ok=True) | |
plt.savefig("plots/class_distribution.png") | |
plt.close() | |
def plot_loss_curve(train_losses, test_losses, model_name): | |
plt.figure(figsize=(8, 4)) | |
plt.plot(train_losses, label="Train Loss") | |
plt.plot(test_losses, label="Test Loss") | |
plt.xlabel("Epoch") | |
plt.ylabel("Loss") | |
plt.title(f"{model_name} Train vs Test Loss") | |
plt.legend() | |
os.makedirs("plots", exist_ok=True) | |
plt.savefig(f"plots/{model_name.lower()}_loss_curve.png") | |
plt.close() | |
def compute_test_loss(model, dataloader, criterion, model_type): | |
total_loss = 0 | |
with torch.no_grad(): | |
model.eval() | |
for batch in dataloader: | |
if isinstance(batch, tuple): | |
input_ids, labels = batch | |
attention_mask = None | |
else: | |
input_ids = batch["input_ids"] | |
attention_mask = batch.get("attention_mask", None) | |
labels = batch["labels"] | |
input_ids = input_ids.to(device) | |
labels = labels.to(device) | |
if attention_mask is not None: | |
attention_mask = attention_mask.to(device) | |
if model_type == "transformer": | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
else: | |
outputs = model(input_ids) | |
loss = criterion(outputs, labels) | |
total_loss += loss.item() | |
return total_loss / len(dataloader) | |
def train_model(model, train_loader, test_loader, optimizer, criterion, epochs, model_type="rnn"): | |
train_losses = [] | |
test_losses = [] | |
for epoch in range(epochs): | |
model.train() | |
start_time = time.time() | |
total_loss = 0 | |
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", ncols=100) | |
for batch in progress_bar: | |
optimizer.zero_grad() | |
if isinstance(batch, tuple): | |
input_ids, labels = batch | |
attention_mask = None | |
else: | |
input_ids = batch["input_ids"] | |
attention_mask = batch.get("attention_mask", None) | |
labels = batch["labels"] | |
input_ids = input_ids.to(device) | |
labels = labels.to(device) | |
if attention_mask is not None: | |
attention_mask = attention_mask.to(device) | |
if model_type == "transformer": | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
else: | |
outputs = model(input_ids) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
avg_loss = total_loss / len(train_loader) | |
progress_bar.set_postfix({"Avg Loss": f"{avg_loss:.4f}"}) | |
test_loss = compute_test_loss(model, test_loader, criterion, model_type) | |
train_losses.append(avg_loss) | |
test_losses.append(test_loss) | |
print(f"β Epoch {epoch + 1} | Train: {avg_loss:.4f} | Test: {test_loss:.4f} | Time: {time.time() - start_time:.2f}s") | |
torch.cuda.empty_cache() | |
del model | |
return train_losses, test_losses | |
def evaluate_preds(model, dataloader, model_type="rnn"): | |
model.eval() | |
all_preds = [] | |
all_labels = [] | |
with torch.no_grad(): | |
for batch in dataloader: | |
if isinstance(batch, tuple): | |
input_ids, labels = batch | |
attention_mask = None | |
else: | |
input_ids = batch["input_ids"] | |
attention_mask = batch.get("attention_mask", None) | |
labels = batch["labels"] | |
input_ids = input_ids.to(device) | |
labels = labels.to(device) | |
if attention_mask is not None: | |
attention_mask = attention_mask.to(device) | |
if model_type == "transformer": | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
else: | |
outputs = model(input_ids) | |
preds = torch.argmax(outputs, dim=1) | |
all_preds.extend(preds.cpu().tolist()) | |
all_labels.extend(labels.cpu().tolist()) | |
return all_labels, all_preds | |
def plot_confusion_matrices(y_true_train, y_pred_train, y_true_test, y_pred_test, labels, title, filename): | |
fig, axes = plt.subplots(1, 2, figsize=(14, 6)) | |
cm_train = confusion_matrix(y_true_train, y_pred_train) | |
cm_test = confusion_matrix(y_true_test, y_pred_test) | |
ConfusionMatrixDisplay(cm_train, display_labels=labels).plot(ax=axes[0], cmap='Blues', colorbar=False) | |
axes[0].set_title(f"{title} - Train") | |
ConfusionMatrixDisplay(cm_test, display_labels=labels).plot(ax=axes[1], cmap='Oranges', colorbar=False) | |
axes[1].set_title(f"{title} - Test") | |
plt.tight_layout() | |
os.makedirs("plots", exist_ok=True) | |
plt.savefig(f"plots/{filename}") | |
plt.close() | |
# Load and encode data | |
data = load_emotion_dataset("train") | |
train_data, label_encoder = encode_labels(data) | |
test_data, _ = encode_labels(load_emotion_dataset("test")) | |
labels = label_encoder.classes_ | |
output_dim = len(labels) | |
padding_idx = 0 | |
summarize_class_distribution(train_data, label_encoder) | |
plot_class_countplot(train_data, label_encoder) | |
# Build vocab | |
vocab = build_vocab(train_data) | |
model_name = "prajjwal1/bert-tiny" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# DataLoaders (no augmentation) | |
train_loader_rnn = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=lambda b: collate_fn_rnn(b, vocab, partial_prob=0.0)) | |
test_loader_rnn = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=lambda b: collate_fn_rnn(b, vocab, partial_prob=0.0)) | |
train_loader_tf = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=lambda b: collate_fn_transformer(b, tokenizer, partial_prob=0.0)) | |
test_loader_tf = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=lambda b: collate_fn_transformer(b, tokenizer, partial_prob=0.0)) | |
# Initialize and train models | |
rnn = RNNClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device) | |
lstm = LSTMClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device) | |
transformer = TransformerClassifier(model_name, output_dim).to(device) | |
criterion = torch.nn.CrossEntropyLoss() | |
# rnn_train_losses, rnn_test_losses = train_model(rnn, train_loader_rnn, test_loader_rnn, torch.optim.Adam(rnn.parameters(), lr=1e-4), criterion, epochs=50, model_type="rnn") | |
# torch.save(rnn.state_dict(), "pretrained_models/best_rnn.pt") | |
# plot_loss_curve(rnn_train_losses, rnn_test_losses, "RNN") | |
# | |
# lstm_train_losses, lstm_test_losses = train_model(lstm, train_loader_rnn, test_loader_rnn, torch.optim.Adam(lstm.parameters(), lr=1e-4), criterion, epochs=50, model_type="lstm") | |
# torch.save(lstm.state_dict(), "pretrained_models/best_lstm.pt") | |
# plot_loss_curve(lstm_train_losses, lstm_test_losses, "LSTM") | |
tf_train_losses, tf_test_losses = train_model(transformer, train_loader_tf, test_loader_tf, torch.optim.Adam(transformer.parameters(), lr=2e-5), criterion, epochs=50, model_type="transformer") | |
torch.save(transformer.state_dict(), "pretrained_models/best_transformer.pt") | |
plot_loss_curve(tf_train_losses, tf_test_losses, "Transformer") | |
# Evaluate and plot | |
model_paths = { | |
"RNN": "pretrained_models/best_rnn.pt", | |
"LSTM": "pretrained_models/best_lstm.pt", | |
"Transformer": "pretrained_models/best_transformer.pt" | |
} | |
for name in ["RNN", "LSTM", "Transformer"]: | |
if name == "RNN": | |
model = RNNClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device) | |
loader = train_loader_rnn | |
test_loader = test_loader_rnn | |
elif name == "LSTM": | |
model = LSTMClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device) | |
loader = train_loader_rnn | |
test_loader = test_loader_rnn | |
else: | |
model = TransformerClassifier(model_name, output_dim).to(device) | |
loader = train_loader_tf | |
test_loader = test_loader_tf | |
model.load_state_dict(torch.load(model_paths[name])) | |
model.eval() | |
y_train_true, y_train_pred = evaluate_preds(model, loader, model_type=name.lower()) | |
y_test_true, y_test_pred = evaluate_preds(model, test_loader, model_type=name.lower()) | |
plot_confusion_matrices( | |
y_train_true, y_train_pred, y_test_true, y_test_pred, | |
labels=labels, | |
title=name, | |
filename=f"{name.lower()}_confusion_matrices.png" | |
) | |