import streamlit as st import torch import torch.nn as nn import torch.optim as optim from torchtext.data.utils import get_tokenizer from torchtext.vocab import build_vocab_from_iterator from torchtext.datasets import IMDB from torch.utils.data import DataLoader, random_split import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import numpy as np from collections import Counter from torch.nn.utils.rnn import pad_sequence # Define the RNN model class RNN(nn.Module): def __init__(self, vocab_size, embed_size, hidden_size, output_size, n_layers, dropout): super(RNN, self).__init__() self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = nn.RNN(embed_size, hidden_size, n_layers, dropout=dropout, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.dropout(self.embedding(x)) h0 = torch.zeros(n_layers, x.size(0), hidden_size).to(device) out, _ = self.rnn(x, h0) out = self.fc(out[:, -1, :]) return out # Create a custom collate function to pad sequences def collate_batch(batch): texts, labels = zip(*batch) text_lengths = [len(text) for text in texts] texts_padded = pad_sequence(texts, batch_first=True, padding_value=vocab[""]) return texts_padded, torch.tensor(labels, dtype=torch.float), text_lengths # Function to load the data @st.cache_data def load_data(): tokenizer = get_tokenizer("basic_english") train_iter, test_iter = IMDB(split=('train', 'test')) def yield_tokens(data_iter): for _, text in data_iter: yield tokenizer(text) vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["", ""]) vocab.set_default_index(vocab[""]) # Define the text and label processing pipelines text_pipeline = lambda x: vocab(tokenizer(x)) label_pipeline = lambda x: 1 if x == 'pos' else 0 # Process the data into tensors def process_data(data_iter): texts, labels = [], [] for label, text in data_iter: texts.append(torch.tensor(text_pipeline(text), dtype=torch.long)) labels.append(label_pipeline(label)) return texts, torch.tensor(labels, dtype=torch.float) train_texts, train_labels = process_data(train_iter) test_texts, test_labels = process_data(test_iter) # Create DataLoaders train_dataset = list(zip(train_texts, train_labels)) test_dataset = list(zip(test_texts, test_labels)) train_size = int(0.8 * len(train_dataset)) valid_size = len(train_dataset) - train_size train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size]) BATCH_SIZE = 64 train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) return vocab, train_loader, valid_loader, test_loader # Function to train the network def train_network(net, iterator, optimizer, criterion, epochs): loss_values = [] for epoch in range(epochs): epoch_loss = 0 net.train() for texts, labels, _ in iterator: texts, labels = texts.to(device), labels.to(device) optimizer.zero_grad() predictions = net(texts).squeeze(1) loss = criterion(predictions, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss /= len(iterator) loss_values.append(epoch_loss) st.write(f'Epoch {epoch + 1}: loss {epoch_loss:.3f}') st.write('Finished Training') return loss_values # Function to evaluate the network def evaluate_network(net, iterator, criterion): epoch_loss = 0 correct = 0 total = 0 all_labels = [] all_predictions = [] net.eval() with torch.no_grad(): for texts, labels, _ in iterator: texts, labels = texts.to(device), labels.to(device) predictions = net(texts).squeeze(1) loss = criterion(predictions, labels) epoch_loss += loss.item() rounded_preds = torch.round(torch.sigmoid(predictions)) correct += (rounded_preds == labels).sum().item() total += len(labels) all_labels.extend(labels.cpu().numpy()) all_predictions.extend(rounded_preds.cpu().numpy()) accuracy = 100 * correct / total st.write(f'Loss: {epoch_loss / len(iterator):.4f}, Accuracy: {accuracy:.2f}%') return accuracy, all_labels, all_predictions # Load the data device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Display a loading message with some vertical space st.markdown("
Loading data...
", unsafe_allow_html=True) vocab, train_loader, valid_loader, test_loader = load_data() # Streamlit interface st.title("RNN for Text Classification on IMDb Dataset") st.write(""" This application demonstrates how to build and train a Recurrent Neural Network (RNN) for text classification using the IMDb dataset. You can adjust hyperparameters, visualize sample data, and see the model's performance. """) # Sidebar for input parameters st.sidebar.header('Model Hyperparameters') embed_size = st.sidebar.slider('Embedding Size', 50, 300, 100) hidden_size = st.sidebar.slider('Hidden Size', 50, 300, 256) n_layers = st.sidebar.slider('Number of RNN Layers', 1, 3, 2) dropout = st.sidebar.slider('Dropout', 0.0, 0.5, 0.2, step=0.1) learning_rate = st.sidebar.slider('Learning Rate', 0.001, 0.1, 0.01, step=0.001) epochs = st.sidebar.slider('Epochs', 1, 20, 5) # Create the network vocab_size = len(vocab) output_size = 1 net = RNN(vocab_size, embed_size, hidden_size, output_size, n_layers, dropout).to(device) criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adam(net.parameters(), lr=learning_rate) # Add vertical space st.write('\n' * 10) # Train the network if st.sidebar.button('Train Network'): loss_values = train_network(net, train_loader, optimizer, criterion, epochs) # Plot the loss values plt.figure(figsize=(10, 5)) plt.plot(range(1, epochs + 1), loss_values, marker='o') plt.title('Training Loss Over Epochs') plt.xlabel('Epoch') plt.ylabel('Loss') plt.grid(True) st.pyplot(plt) # Store the trained model in the session state st.session_state['trained_model'] = net # Test the network if 'trained_model' in st.session_state and st.sidebar.button('Test Network'): accuracy, all_labels, all_predictions = evaluate_network(st.session_state['trained_model'], test_loader, criterion) st.write(f'Test Accuracy: {accuracy:.2f}%') # Display results in a table st.write('Ground Truth vs Predicted') results = pd.DataFrame({ 'Ground Truth': all_labels, 'Predicted': all_predictions }) st.table(results.head(50)) # Display first 50 results for brevity # Visualize some test results def visualize_text_predictions(iterator, net): net.eval() samples = [] with torch.no_grad(): for texts, labels, _ in iterator: predictions = torch.round(torch.sigmoid(net(texts).squeeze(1))) samples.extend(zip(texts.cpu(), labels.cpu(), predictions.cpu())) if len(samples) >= 10: break return samples[:10] if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'): samples = visualize_text_predictions(test_loader, st.session_state['trained_model']) st.write('Ground Truth vs Predicted for Sample Texts') for i, (text, true_label, predicted) in enumerate(samples): st.write(f'Sample {i+1}') st.text(' '.join([vocab.get_itos()[token] for token in text])) st.write(f'Ground Truth: {true_label.item()}, Predicted: {predicted.item()}')