import os import time import torch import matplotlib.pyplot as plt import seaborn as sns from collections import Counter from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay from torch.utils.data import DataLoader from transformers import AutoTokenizer from utility import ( load_emotion_dataset, encode_labels, build_vocab, collate_fn_rnn, collate_fn_transformer ) from models.rnn import RNNClassifier from models.lstm import LSTMClassifier from models.transformer import TransformerClassifier from tqdm import tqdm device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def summarize_class_distribution(dataset, label_encoder): labels = [example["label"] for example in dataset] counter = Counter(labels) print("\nšŸ” Class distribution:") for label_idx, count in sorted(counter.items()): label_name = label_encoder.inverse_transform([label_idx])[0] print(f"{label_name:>10}: {count}") def plot_class_countplot(dataset, label_encoder): labels = [example["label"] for example in dataset] counts = Counter(labels) label_display = [label_encoder.inverse_transform([i])[0] for i in sorted(counts.keys())] values = [counts[i] for i in sorted(counts.keys())] plt.figure(figsize=(8, 5)) sns.barplot(x=label_display, y=values) plt.title("Emotion Class Distribution (Training Set)") plt.xlabel("Emotion") plt.ylabel("Count") plt.tight_layout() os.makedirs("plots", exist_ok=True) plt.savefig("plots/class_distribution.png") plt.close() def plot_loss_curve(train_losses, test_losses, model_name): plt.figure(figsize=(8, 4)) plt.plot(train_losses, label="Train Loss") plt.plot(test_losses, label="Test Loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.title(f"{model_name} Train vs Test Loss") plt.legend() os.makedirs("plots", exist_ok=True) plt.savefig(f"plots/{model_name.lower()}_loss_curve.png") plt.close() def compute_test_loss(model, dataloader, criterion, model_type): total_loss = 0 with torch.no_grad(): model.eval() for batch in dataloader: if isinstance(batch, tuple): input_ids, labels = batch attention_mask = None else: input_ids = batch["input_ids"] attention_mask = batch.get("attention_mask", None) labels = batch["labels"] input_ids = input_ids.to(device) labels = labels.to(device) if attention_mask is not None: attention_mask = attention_mask.to(device) if model_type == "transformer": outputs = model(input_ids=input_ids, attention_mask=attention_mask) else: outputs = model(input_ids) loss = criterion(outputs, labels) total_loss += loss.item() return total_loss / len(dataloader) def train_model(model, train_loader, test_loader, optimizer, criterion, epochs, model_type="rnn"): train_losses = [] test_losses = [] for epoch in range(epochs): model.train() start_time = time.time() total_loss = 0 progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", ncols=100) for batch in progress_bar: optimizer.zero_grad() if isinstance(batch, tuple): input_ids, labels = batch attention_mask = None else: input_ids = batch["input_ids"] attention_mask = batch.get("attention_mask", None) labels = batch["labels"] input_ids = input_ids.to(device) labels = labels.to(device) if attention_mask is not None: attention_mask = attention_mask.to(device) if model_type == "transformer": outputs = model(input_ids=input_ids, attention_mask=attention_mask) else: outputs = model(input_ids) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() avg_loss = total_loss / len(train_loader) progress_bar.set_postfix({"Avg Loss": f"{avg_loss:.4f}"}) test_loss = compute_test_loss(model, test_loader, criterion, model_type) train_losses.append(avg_loss) test_losses.append(test_loss) print(f"āœ… Epoch {epoch + 1} | Train: {avg_loss:.4f} | Test: {test_loss:.4f} | Time: {time.time() - start_time:.2f}s") torch.cuda.empty_cache() del model return train_losses, test_losses def evaluate_preds(model, dataloader, model_type="rnn"): model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for batch in dataloader: if isinstance(batch, tuple): input_ids, labels = batch attention_mask = None else: input_ids = batch["input_ids"] attention_mask = batch.get("attention_mask", None) labels = batch["labels"] input_ids = input_ids.to(device) labels = labels.to(device) if attention_mask is not None: attention_mask = attention_mask.to(device) if model_type == "transformer": outputs = model(input_ids=input_ids, attention_mask=attention_mask) else: outputs = model(input_ids) preds = torch.argmax(outputs, dim=1) all_preds.extend(preds.cpu().tolist()) all_labels.extend(labels.cpu().tolist()) return all_labels, all_preds def plot_confusion_matrices(y_true_train, y_pred_train, y_true_test, y_pred_test, labels, title, filename): fig, axes = plt.subplots(1, 2, figsize=(14, 6)) cm_train = confusion_matrix(y_true_train, y_pred_train) cm_test = confusion_matrix(y_true_test, y_pred_test) ConfusionMatrixDisplay(cm_train, display_labels=labels).plot(ax=axes[0], cmap='Blues', colorbar=False) axes[0].set_title(f"{title} - Train") ConfusionMatrixDisplay(cm_test, display_labels=labels).plot(ax=axes[1], cmap='Oranges', colorbar=False) axes[1].set_title(f"{title} - Test") plt.tight_layout() os.makedirs("plots", exist_ok=True) plt.savefig(f"plots/{filename}") plt.close() # Load and encode data data = load_emotion_dataset("train") train_data, label_encoder = encode_labels(data) test_data, _ = encode_labels(load_emotion_dataset("test")) labels = label_encoder.classes_ output_dim = len(labels) padding_idx = 0 summarize_class_distribution(train_data, label_encoder) plot_class_countplot(train_data, label_encoder) # Build vocab vocab = build_vocab(train_data) model_name = "prajjwal1/bert-tiny" tokenizer = AutoTokenizer.from_pretrained(model_name) # DataLoaders (no augmentation) train_loader_rnn = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=lambda b: collate_fn_rnn(b, vocab, partial_prob=0.0)) test_loader_rnn = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=lambda b: collate_fn_rnn(b, vocab, partial_prob=0.0)) train_loader_tf = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=lambda b: collate_fn_transformer(b, tokenizer, partial_prob=0.0)) test_loader_tf = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=lambda b: collate_fn_transformer(b, tokenizer, partial_prob=0.0)) # Initialize and train models rnn = RNNClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device) lstm = LSTMClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device) transformer = TransformerClassifier(model_name, output_dim).to(device) criterion = torch.nn.CrossEntropyLoss() # rnn_train_losses, rnn_test_losses = train_model(rnn, train_loader_rnn, test_loader_rnn, torch.optim.Adam(rnn.parameters(), lr=1e-4), criterion, epochs=50, model_type="rnn") # torch.save(rnn.state_dict(), "pretrained_models/best_rnn.pt") # plot_loss_curve(rnn_train_losses, rnn_test_losses, "RNN") # # lstm_train_losses, lstm_test_losses = train_model(lstm, train_loader_rnn, test_loader_rnn, torch.optim.Adam(lstm.parameters(), lr=1e-4), criterion, epochs=50, model_type="lstm") # torch.save(lstm.state_dict(), "pretrained_models/best_lstm.pt") # plot_loss_curve(lstm_train_losses, lstm_test_losses, "LSTM") tf_train_losses, tf_test_losses = train_model(transformer, train_loader_tf, test_loader_tf, torch.optim.Adam(transformer.parameters(), lr=2e-5), criterion, epochs=50, model_type="transformer") torch.save(transformer.state_dict(), "pretrained_models/best_transformer.pt") plot_loss_curve(tf_train_losses, tf_test_losses, "Transformer") # Evaluate and plot model_paths = { "RNN": "pretrained_models/best_rnn.pt", "LSTM": "pretrained_models/best_lstm.pt", "Transformer": "pretrained_models/best_transformer.pt" } for name in ["RNN", "LSTM", "Transformer"]: if name == "RNN": model = RNNClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device) loader = train_loader_rnn test_loader = test_loader_rnn elif name == "LSTM": model = LSTMClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device) loader = train_loader_rnn test_loader = test_loader_rnn else: model = TransformerClassifier(model_name, output_dim).to(device) loader = train_loader_tf test_loader = test_loader_tf model.load_state_dict(torch.load(model_paths[name])) model.eval() y_train_true, y_train_pred = evaluate_preds(model, loader, model_type=name.lower()) y_test_true, y_test_pred = evaluate_preds(model, test_loader, model_type=name.lower()) plot_confusion_matrices( y_train_true, y_train_pred, y_test_true, y_test_pred, labels=labels, title=name, filename=f"{name.lower()}_confusion_matrices.png" )