SentiNet / train.py
Hunter-Pax's picture
Upload 18 files
e7a44ba verified
import os
import time
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from utility import (
load_emotion_dataset,
encode_labels,
build_vocab,
collate_fn_rnn,
collate_fn_transformer
)
from models.rnn import RNNClassifier
from models.lstm import LSTMClassifier
from models.transformer import TransformerClassifier
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def summarize_class_distribution(dataset, label_encoder):
labels = [example["label"] for example in dataset]
counter = Counter(labels)
print("\nπŸ” Class distribution:")
for label_idx, count in sorted(counter.items()):
label_name = label_encoder.inverse_transform([label_idx])[0]
print(f"{label_name:>10}: {count}")
def plot_class_countplot(dataset, label_encoder):
labels = [example["label"] for example in dataset]
counts = Counter(labels)
label_display = [label_encoder.inverse_transform([i])[0] for i in sorted(counts.keys())]
values = [counts[i] for i in sorted(counts.keys())]
plt.figure(figsize=(8, 5))
sns.barplot(x=label_display, y=values)
plt.title("Emotion Class Distribution (Training Set)")
plt.xlabel("Emotion")
plt.ylabel("Count")
plt.tight_layout()
os.makedirs("plots", exist_ok=True)
plt.savefig("plots/class_distribution.png")
plt.close()
def plot_loss_curve(train_losses, test_losses, model_name):
plt.figure(figsize=(8, 4))
plt.plot(train_losses, label="Train Loss")
plt.plot(test_losses, label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title(f"{model_name} Train vs Test Loss")
plt.legend()
os.makedirs("plots", exist_ok=True)
plt.savefig(f"plots/{model_name.lower()}_loss_curve.png")
plt.close()
def compute_test_loss(model, dataloader, criterion, model_type):
total_loss = 0
with torch.no_grad():
model.eval()
for batch in dataloader:
if isinstance(batch, tuple):
input_ids, labels = batch
attention_mask = None
else:
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask", None)
labels = batch["labels"]
input_ids = input_ids.to(device)
labels = labels.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
if model_type == "transformer":
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
else:
outputs = model(input_ids)
loss = criterion(outputs, labels)
total_loss += loss.item()
return total_loss / len(dataloader)
def train_model(model, train_loader, test_loader, optimizer, criterion, epochs, model_type="rnn"):
train_losses = []
test_losses = []
for epoch in range(epochs):
model.train()
start_time = time.time()
total_loss = 0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", ncols=100)
for batch in progress_bar:
optimizer.zero_grad()
if isinstance(batch, tuple):
input_ids, labels = batch
attention_mask = None
else:
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask", None)
labels = batch["labels"]
input_ids = input_ids.to(device)
labels = labels.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
if model_type == "transformer":
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
else:
outputs = model(input_ids)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
progress_bar.set_postfix({"Avg Loss": f"{avg_loss:.4f}"})
test_loss = compute_test_loss(model, test_loader, criterion, model_type)
train_losses.append(avg_loss)
test_losses.append(test_loss)
print(f"βœ… Epoch {epoch + 1} | Train: {avg_loss:.4f} | Test: {test_loss:.4f} | Time: {time.time() - start_time:.2f}s")
torch.cuda.empty_cache()
del model
return train_losses, test_losses
def evaluate_preds(model, dataloader, model_type="rnn"):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for batch in dataloader:
if isinstance(batch, tuple):
input_ids, labels = batch
attention_mask = None
else:
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask", None)
labels = batch["labels"]
input_ids = input_ids.to(device)
labels = labels.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
if model_type == "transformer":
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
else:
outputs = model(input_ids)
preds = torch.argmax(outputs, dim=1)
all_preds.extend(preds.cpu().tolist())
all_labels.extend(labels.cpu().tolist())
return all_labels, all_preds
def plot_confusion_matrices(y_true_train, y_pred_train, y_true_test, y_pred_test, labels, title, filename):
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
cm_train = confusion_matrix(y_true_train, y_pred_train)
cm_test = confusion_matrix(y_true_test, y_pred_test)
ConfusionMatrixDisplay(cm_train, display_labels=labels).plot(ax=axes[0], cmap='Blues', colorbar=False)
axes[0].set_title(f"{title} - Train")
ConfusionMatrixDisplay(cm_test, display_labels=labels).plot(ax=axes[1], cmap='Oranges', colorbar=False)
axes[1].set_title(f"{title} - Test")
plt.tight_layout()
os.makedirs("plots", exist_ok=True)
plt.savefig(f"plots/{filename}")
plt.close()
# Load and encode data
data = load_emotion_dataset("train")
train_data, label_encoder = encode_labels(data)
test_data, _ = encode_labels(load_emotion_dataset("test"))
labels = label_encoder.classes_
output_dim = len(labels)
padding_idx = 0
summarize_class_distribution(train_data, label_encoder)
plot_class_countplot(train_data, label_encoder)
# Build vocab
vocab = build_vocab(train_data)
model_name = "prajjwal1/bert-tiny"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# DataLoaders (no augmentation)
train_loader_rnn = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=lambda b: collate_fn_rnn(b, vocab, partial_prob=0.0))
test_loader_rnn = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=lambda b: collate_fn_rnn(b, vocab, partial_prob=0.0))
train_loader_tf = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=lambda b: collate_fn_transformer(b, tokenizer, partial_prob=0.0))
test_loader_tf = DataLoader(test_data, batch_size=64, shuffle=False, collate_fn=lambda b: collate_fn_transformer(b, tokenizer, partial_prob=0.0))
# Initialize and train models
rnn = RNNClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device)
lstm = LSTMClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device)
transformer = TransformerClassifier(model_name, output_dim).to(device)
criterion = torch.nn.CrossEntropyLoss()
# rnn_train_losses, rnn_test_losses = train_model(rnn, train_loader_rnn, test_loader_rnn, torch.optim.Adam(rnn.parameters(), lr=1e-4), criterion, epochs=50, model_type="rnn")
# torch.save(rnn.state_dict(), "pretrained_models/best_rnn.pt")
# plot_loss_curve(rnn_train_losses, rnn_test_losses, "RNN")
#
# lstm_train_losses, lstm_test_losses = train_model(lstm, train_loader_rnn, test_loader_rnn, torch.optim.Adam(lstm.parameters(), lr=1e-4), criterion, epochs=50, model_type="lstm")
# torch.save(lstm.state_dict(), "pretrained_models/best_lstm.pt")
# plot_loss_curve(lstm_train_losses, lstm_test_losses, "LSTM")
tf_train_losses, tf_test_losses = train_model(transformer, train_loader_tf, test_loader_tf, torch.optim.Adam(transformer.parameters(), lr=2e-5), criterion, epochs=50, model_type="transformer")
torch.save(transformer.state_dict(), "pretrained_models/best_transformer.pt")
plot_loss_curve(tf_train_losses, tf_test_losses, "Transformer")
# Evaluate and plot
model_paths = {
"RNN": "pretrained_models/best_rnn.pt",
"LSTM": "pretrained_models/best_lstm.pt",
"Transformer": "pretrained_models/best_transformer.pt"
}
for name in ["RNN", "LSTM", "Transformer"]:
if name == "RNN":
model = RNNClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device)
loader = train_loader_rnn
test_loader = test_loader_rnn
elif name == "LSTM":
model = LSTMClassifier(len(vocab), 128, 128, output_dim, padding_idx).to(device)
loader = train_loader_rnn
test_loader = test_loader_rnn
else:
model = TransformerClassifier(model_name, output_dim).to(device)
loader = train_loader_tf
test_loader = test_loader_tf
model.load_state_dict(torch.load(model_paths[name]))
model.eval()
y_train_true, y_train_pred = evaluate_preds(model, loader, model_type=name.lower())
y_test_true, y_test_pred = evaluate_preds(model, test_loader, model_type=name.lower())
plot_confusion_matrices(
y_train_true, y_train_pred, y_test_true, y_test_pred,
labels=labels,
title=name,
filename=f"{name.lower()}_confusion_matrices.png"
)