{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "bubcKFNzLDh_" }, "source": [ "# Pytorch ensemble" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "Uc1RGhnrk9GY", "outputId": "1b3dc026-c74a-46b1-f489-7d8538520300" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/400.9 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m400.9/400.9 kB\u001b[0m \u001b[31m12.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/247.0 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m247.0/247.0 kB\u001b[0m \u001b[31m18.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h\n", "Using full imbalanced dataset. Class distribution:\n", "sentiment\n", "negative 6620\n", "positive 2039\n", "neutral 1663\n", "Name: count, dtype: int64\n", "\n", "Extracting features from the full dataset...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 10322/10322 [03:03<00:00, 56.31it/s]\n", "[I 2025-08-20 03:48:21,149] A new study created in memory with name: no-name-3be815cb-86d8-41a9-abbe-29272bc3b1ab\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "--- PART 1: Starting Hyperparameter Search with Optuna ---\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[I 2025-08-20 03:48:47,544] Trial 0 finished with value: 0.7823361823361823 and parameters: {'conv1_filters': 128, 'conv2_filters': 256, 'lstm_units': 128, 'dropout': 0.30401902591485513, 'focal_gamma': 1.8327338604956367, 'lr': 0.0024542281841373665}. Best is trial 0 with value: 0.7823361823361823.\n", "[I 2025-08-20 03:49:11,879] Trial 1 finished with value: 0.7891737891737892 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.2674942423523225, 'focal_gamma': 2.141870013328642, 'lr': 0.00017647796313897198}. Best is trial 1 with value: 0.7891737891737892.\n", "[I 2025-08-20 03:49:47,252] Trial 2 finished with value: 0.7156695156695156 and parameters: {'conv1_filters': 128, 'conv2_filters': 256, 'lstm_units': 256, 'dropout': 0.3964252476312402, 'focal_gamma': 1.7487842282032928, 'lr': 0.004415654681232052}. Best is trial 1 with value: 0.7891737891737892.\n", "[I 2025-08-20 03:50:12,187] Trial 3 finished with value: 0.7612535612535613 and parameters: {'conv1_filters': 128, 'conv2_filters': 256, 'lstm_units': 128, 'dropout': 0.35892734283519545, 'focal_gamma': 2.6212169077280243, 'lr': 0.000977244386983535}. Best is trial 1 with value: 0.7891737891737892.\n", "[I 2025-08-20 03:50:48,087] Trial 4 finished with value: 0.6752136752136753 and parameters: {'conv1_filters': 128, 'conv2_filters': 256, 'lstm_units': 256, 'dropout': 0.4683721405647675, 'focal_gamma': 2.0491982601972865, 'lr': 0.007679164444896244}. Best is trial 1 with value: 0.7891737891737892.\n", "[I 2025-08-20 03:51:22,129] Trial 5 finished with value: 0.7196581196581197 and parameters: {'conv1_filters': 64, 'conv2_filters': 256, 'lstm_units': 256, 'dropout': 0.20045274365995666, 'focal_gamma': 2.1233580280504336, 'lr': 0.005151503536944851}. Best is trial 1 with value: 0.7891737891737892.\n", "[I 2025-08-20 03:51:46,297] Trial 6 finished with value: 0.7612535612535613 and parameters: {'conv1_filters': 64, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.49159292378579106, 'focal_gamma': 1.646030396134748, 'lr': 0.0010966146464538307}. Best is trial 1 with value: 0.7891737891737892.\n", "[I 2025-08-20 03:52:20,007] Trial 7 finished with value: 0.7213675213675214 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 256, 'dropout': 0.35014720698111357, 'focal_gamma': 2.1777018236226855, 'lr': 0.005144771288556656}. Best is trial 1 with value: 0.7891737891737892.\n", "[I 2025-08-20 03:52:44,140] Trial 8 finished with value: 0.7743589743589744 and parameters: {'conv1_filters': 64, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.366078317734969, 'focal_gamma': 2.5440747380614703, 'lr': 0.0006256215192712549}. Best is trial 1 with value: 0.7891737891737892.\n", "[I 2025-08-20 03:53:08,480] Trial 9 finished with value: 0.7031339031339031 and parameters: {'conv1_filters': 64, 'conv2_filters': 256, 'lstm_units': 128, 'dropout': 0.45243738898524233, 'focal_gamma': 2.3541212782046346, 'lr': 0.003328230689442395}. Best is trial 1 with value: 0.7891737891737892.\n", "[I 2025-08-20 03:53:32,847] Trial 10 finished with value: 0.7811965811965812 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.2597766006445424, 'focal_gamma': 2.706864616744472, 'lr': 0.00010341008379570829}. Best is trial 1 with value: 0.7891737891737892.\n", "[I 2025-08-20 03:53:57,126] Trial 11 finished with value: 0.8011396011396011 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.2812723168755197, 'focal_gamma': 1.8900379069937014, 'lr': 0.00012876302263545235}. Best is trial 11 with value: 0.8011396011396011.\n", "[I 2025-08-20 03:54:21,263] Trial 12 finished with value: 0.8022792022792022 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.2746132444767511, 'focal_gamma': 1.929945112033818, 'lr': 0.00012222292984434045}. Best is trial 12 with value: 0.8022792022792022.\n", "[I 2025-08-20 03:54:45,511] Trial 13 finished with value: 0.7766381766381767 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.29883369451981967, 'focal_gamma': 2.9720614082537438, 'lr': 0.00025578454924883457}. Best is trial 12 with value: 0.8022792022792022.\n", "[I 2025-08-20 03:55:09,757] Trial 14 finished with value: 0.8028490028490028 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.23981881296007343, 'focal_gamma': 1.5062287644707946, 'lr': 0.0002948170729512786}. Best is trial 14 with value: 0.8028490028490028.\n", "[I 2025-08-20 03:55:33,981] Trial 15 finished with value: 0.8011396011396011 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.2039802068679372, 'focal_gamma': 1.56873102929835, 'lr': 0.00035994555781908527}. Best is trial 14 with value: 0.8028490028490028.\n", "[I 2025-08-20 03:55:58,336] Trial 16 finished with value: 0.8222222222222222 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.23265518178485783, 'focal_gamma': 1.5481662017221482, 'lr': 0.0003782195384596158}. Best is trial 16 with value: 0.8222222222222222.\n", "[I 2025-08-20 03:56:22,575] Trial 17 finished with value: 0.815954415954416 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.23864997474568897, 'focal_gamma': 1.5018900290355484, 'lr': 0.00044924467393715685}. Best is trial 16 with value: 0.8222222222222222.\n", "[I 2025-08-20 03:56:56,695] Trial 18 finished with value: 0.7851851851851852 and parameters: {'conv1_filters': 64, 'conv2_filters': 128, 'lstm_units': 256, 'dropout': 0.23203022683613705, 'focal_gamma': 1.6970525474822304, 'lr': 0.000578099667453005}. Best is trial 16 with value: 0.8222222222222222.\n", "[I 2025-08-20 03:57:20,955] Trial 19 finished with value: 0.7675213675213676 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.3163597299919248, 'focal_gamma': 2.338994878464391, 'lr': 0.0018426772621860663}. Best is trial 16 with value: 0.8222222222222222.\n", "[I 2025-08-20 03:57:45,245] Trial 20 finished with value: 0.7943019943019943 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.4055309381361124, 'focal_gamma': 1.5314538939925675, 'lr': 0.0004686443740368918}. Best is trial 16 with value: 0.8222222222222222.\n", "[I 2025-08-20 03:58:09,421] Trial 21 finished with value: 0.8096866096866097 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.23661718346568342, 'focal_gamma': 1.5656764270829684, 'lr': 0.0002512738715943372}. Best is trial 16 with value: 0.8222222222222222.\n", "[I 2025-08-20 03:58:33,693] Trial 22 finished with value: 0.8022792022792022 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.2320740511674958, 'focal_gamma': 1.6772311814797796, 'lr': 0.0008979007608764508}. Best is trial 16 with value: 0.8222222222222222.\n", "[I 2025-08-20 03:58:57,915] Trial 23 finished with value: 0.7943019943019943 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.24325599462079886, 'focal_gamma': 1.7876687110895717, 'lr': 0.0002220460330250616}. Best is trial 16 with value: 0.8222222222222222.\n", "[I 2025-08-20 03:59:22,102] Trial 24 finished with value: 0.81994301994302 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.2173743713251179, 'focal_gamma': 1.6110715040133343, 'lr': 0.00042291701388497247}. Best is trial 16 with value: 0.8222222222222222.\n", "[I 2025-08-20 03:59:46,421] Trial 25 finished with value: 0.7914529914529914 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.32908379292613893, 'focal_gamma': 1.9009845200658604, 'lr': 0.0005049438759665062}. Best is trial 16 with value: 0.8222222222222222.\n", "[I 2025-08-20 04:00:20,200] Trial 26 finished with value: 0.7794871794871795 and parameters: {'conv1_filters': 64, 'conv2_filters': 128, 'lstm_units': 256, 'dropout': 0.21097954690924964, 'focal_gamma': 2.0030413285050472, 'lr': 0.0015449350224921235}. Best is trial 16 with value: 0.8222222222222222.\n", "[I 2025-08-20 04:00:44,449] Trial 27 finished with value: 0.807977207977208 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.21929761275181317, 'focal_gamma': 1.6252870525385432, 'lr': 0.00039661789054966075}. Best is trial 16 with value: 0.8222222222222222.\n", "[I 2025-08-20 04:01:08,733] Trial 28 finished with value: 0.8068376068376069 and parameters: {'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.2575253128229996, 'focal_gamma': 1.7872225106166337, 'lr': 0.0008149413119779976}. Best is trial 16 with value: 0.8222222222222222.\n", "[I 2025-08-20 04:01:33,499] Trial 29 finished with value: 0.7931623931623931 and parameters: {'conv1_filters': 128, 'conv2_filters': 256, 'lstm_units': 128, 'dropout': 0.29108951231374386, 'focal_gamma': 1.7910526724395877, 'lr': 0.0007207586672504403}. Best is trial 16 with value: 0.8222222222222222.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Optuna search complete. Best parameters found:\n", "{'conv1_filters': 128, 'conv2_filters': 128, 'lstm_units': 128, 'dropout': 0.23265518178485783, 'focal_gamma': 1.5481662017221482, 'lr': 0.0003782195384596158}\n", "\n", "--- PART 2: Training an ensemble of 5 models with the best parameters ---\n", "\n", "Training model 1/5...\n", "\n", "Training model 2/5...\n", "\n", "Training model 3/5...\n", "\n", "Training model 4/5...\n", "\n", "Training model 5/5...\n", "\n", "--- PART 3: Evaluating the ensemble of 5 models on the hold-out test set ---\n", "\n", "--- Final Ensemble Classification Report ---\n", " precision recall f1-score support\n", "\n", " negative 0.86 0.90 0.88 993\n", " neutral 0.73 0.78 0.75 250\n", " positive 0.75 0.58 0.66 306\n", "\n", " accuracy 0.82 1549\n", " macro avg 0.78 0.75 0.76 1549\n", "weighted avg 0.82 0.82 0.81 1549\n", "\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# =============================================================================\n", "# STEP 1: SETUP, LIBRARIES & DATA PREPARATION\n", "# =============================================================================\n", "!pip install optuna -q # Install Optuna\n", "import os\n", "import numpy as np\n", "import pandas as pd\n", "import librosa\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from tqdm import tqdm\n", "import kagglehub\n", "import optuna\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, Dataset\n", "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.metrics import classification_report, confusion_matrix\n", "\n", "# --- Setup Device (GPU/CPU) ---\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"\\nUsing device: {device}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "25208162" }, "source": [ "## Data Loading and Processing\n", "\n", "This section downloads and processes the RAVDESS and CREMA-D datasets. The audio files are loaded, and sentiment labels are extracted based on the filenames. The datasets are then combined and the full, imbalanced dataset is used for training with Focal Loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d51ad88a" }, "outputs": [], "source": [ "# --- Download & Process Datasets ---\n", "RAVDESS_PATH = kagglehub.dataset_download(\"uwrfkaggler/ravdess-emotional-speech-audio\")\n", "CREMA_D_PATH = kagglehub.dataset_download(\"ejlok1/cremad\")\n", "\n", "sentiment_map = {'happy': 'positive', 'surprised': 'positive', 'sad': 'negative', 'angry': 'negative', 'fearful': 'negative', 'disgust': 'negative', 'neutral': 'neutral', 'calm': 'neutral'}\n", "ravdess_emotion_map = {'01': 'neutral', '02': 'calm', '03': 'happy', '04': 'sad', '05': 'angry', '06': 'fearful', '07': 'disgust', '08': 'surprised'}\n", "ravdess_data = []\n", "for dirpath, _, filenames in os.walk(RAVDESS_PATH):\n", " for filename in filenames:\n", " if filename.endswith('.wav'):\n", " emotion_code = filename.split('-')[2]; emotion = ravdess_emotion_map.get(emotion_code); sentiment = sentiment_map.get(emotion)\n", " if sentiment: ravdess_data.append({\"filepath\": os.path.join(dirpath, filename), \"sentiment\": sentiment})\n", "ravdess_df = pd.DataFrame(ravdess_data)\n", "\n", "crema_emotion_map = {'HAP': 'happy', 'SAD': 'sad', 'ANG': 'angry', 'FEA': 'fearful', 'DIS': 'disgust', 'NEU': 'neutral'}\n", "crema_data = []\n", "crema_audio_path = os.path.join(CREMA_D_PATH, \"AudioWAV\")\n", "for filename in os.listdir(crema_audio_path):\n", " if filename.endswith('.wav'):\n", " emotion_code = filename.split('_')[2]; emotion = crema_emotion_map.get(emotion_code); sentiment = sentiment_map.get(emotion)\n", " if sentiment: crema_data.append({\"filepath\": os.path.join(crema_audio_path, filename), \"sentiment\": sentiment})\n", "crema_df = pd.DataFrame(crema_data)\n", "\n", "# --- Use the FULL, IMBALANCED dataset for Focal Loss ---\n", "combined_df = pd.concat([ravdess_df, crema_df], ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)\n", "print(\"\\nUsing full imbalanced dataset. Class distribution:\")\n", "print(combined_df['sentiment'].value_counts())" ] }, { "cell_type": "markdown", "metadata": { "id": "81862d87" }, "source": [ "## Feature Extraction and Dataset Class\n", "\n", "This section defines the `extract_features` function to extract MFCCs, delta MFCCs, and delta-delta MFCCs from the audio files. It also defines the `SpeechDataset` class to handle the features and labels, preparing the data for use with PyTorch DataLoaders." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fdbdfafa" }, "outputs": [], "source": [ "# --- Define Helper Classes and Functions ---\n", "def extract_features(file_path, n_mfcc=40, max_pad_len=216):\n", " try:\n", " audio, sr = librosa.load(file_path, res_type='kaiser_fast', sr=None)\n", " mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=n_mfcc)\n", " delta_mfccs = librosa.feature.delta(mfccs); delta2_mfccs = librosa.feature.delta(mfccs, order=2)\n", " all_features = np.concatenate((mfccs, delta_mfccs, delta2_mfccs))\n", " if all_features.shape[1] > max_pad_len: all_features = all_features[:, :max_pad_len]\n", " else: all_features = np.pad(all_features, ((0, 0), (0, max_pad_len - all_features.shape[1])), mode='constant')\n", " except Exception as e: return None\n", " return all_features\n", "\n", "class SpeechDataset(Dataset):\n", " def __init__(self, features, labels):\n", " self.features = [torch.tensor(f, dtype=torch.float32) for f in features]\n", " self.labels = torch.tensor(labels, dtype=torch.long)\n", " def __len__(self): return len(self.features)\n", " def __getitem__(self, idx): return self.features[idx], self.labels[idx]\n", "\n", "# --- Feature Extraction (run once) ---\n", "print(\"\\nExtracting features from the full dataset...\")\n", "X_features = [extract_features(fp) for fp in tqdm(combined_df['filepath'])]\n", "y_labels = combined_df['sentiment'].tolist()\n", "X_filtered = [x for x in X_features if x is not None]; y_filtered = [y for x, y in zip(X_features, y_labels) if x is not None]\n", "le = LabelEncoder(); y_encoded = le.fit_transform(y_filtered)\n", "\n", "# --- Final Data Split ---\n", "# We create a final hold-out test set ONCE.\n", "X_train_val, X_test, y_train_val, y_test = train_test_split(X_filtered, y_encoded, test_size=0.15, random_state=42, stratify=y_encoded)\n", "test_dataset = SpeechDataset(X_test, y_test)\n", "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "85540e16" }, "source": [ "## Model Definition and Hyperparameter Search with Optuna\n", "\n", "This section defines the `FocalLoss` class and the `Attention` module. It also defines the `define_model` function for Optuna's hyperparameter search and the `objective` function that Optuna will optimize. The Optuna study is created and optimized to find the best hyperparameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "909cc14d" }, "outputs": [], "source": [ "class FocalLoss(nn.Module):\n", " def __init__(self, alpha=1, gamma=2, reduction='mean'):\n", " super(FocalLoss, self).__init__(); self.alpha = alpha; self.gamma = gamma; self.reduction = reduction\n", " def forward(self, inputs, targets):\n", " ce_loss = F.cross_entropy(inputs, targets, reduction='none'); pt = torch.exp(-ce_loss)\n", " focal_loss = self.alpha * (1 - pt)**self.gamma * ce_loss\n", " return focal_loss.mean() if self.reduction == 'mean' else focal_loss.sum()\n", "\n", "class Attention(nn.Module):\n", " def __init__(self, hidden_size):\n", " super(Attention, self).__init__(); self.attention = nn.Linear(hidden_size, 1, bias=False)\n", " def forward(self, lstm_output):\n", " scores = self.attention(lstm_output).squeeze(2); weights = F.softmax(scores, dim=1).unsqueeze(1)\n", " return torch.bmm(weights, lstm_output).squeeze(1)\n", "\n", "def define_model(trial, input_shape, num_classes):\n", " # Suggest hyperparameters for the model architecture\n", " conv1_filters = trial.suggest_categorical(\"conv1_filters\", [64, 128])\n", " conv2_filters = trial.suggest_categorical(\"conv2_filters\", [128, 256])\n", " lstm_units = trial.suggest_categorical(\"lstm_units\", [128, 256])\n", " dropout_rate = trial.suggest_float(\"dropout\", 0.2, 0.5)\n", "\n", " class SpeechModel(nn.Module): # Define inside so it gets recreated with new params\n", " def __init__(self):\n", " super(SpeechModel, self).__init__()\n", " self.conv_layers = nn.Sequential(\n", " nn.Conv1d(input_shape[0], conv1_filters, 5, padding='same'), nn.BatchNorm1d(conv1_filters), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(dropout_rate),\n", " nn.Conv1d(conv1_filters, conv2_filters, 5, padding='same'), nn.BatchNorm1d(conv2_filters), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(dropout_rate))\n", " with torch.no_grad():\n", " dummy_output = self.conv_layers(torch.zeros(1, input_shape[0], input_shape[1]))\n", " lstm_input_size = dummy_output.shape[1]\n", " self.lstm = nn.LSTM(lstm_input_size, lstm_units, batch_first=True, bidirectional=True)\n", " self.attention = Attention(lstm_units * 2)\n", " self.classifier = nn.Sequential(nn.Linear(lstm_units * 2, 128), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(128, num_classes))\n", " def forward(self, x):\n", " x = self.conv_layers(x).permute(0, 2, 1); lstm_out, _ = self.lstm(x)\n", " return self.classifier(self.attention(lstm_out))\n", " return SpeechModel()\n", "\n", "def objective(trial):\n", " # --- Data for this trial ---\n", " X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.2, random_state=42, stratify=y_train_val)\n", " train_dataset = SpeechDataset(X_train, y_train); val_dataset = SpeechDataset(X_val, y_val)\n", " train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True); val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)\n", "\n", " # --- Model, Loss, Optimizer ---\n", " INPUT_SHAPE = (120, 216); NUM_CLASSES = len(le.classes_)\n", " model = define_model(trial, INPUT_SHAPE, NUM_CLASSES).to(device)\n", "\n", " focal_gamma = trial.suggest_float(\"focal_gamma\", 1.5, 3.0)\n", " criterion = FocalLoss(gamma=focal_gamma)\n", " lr = trial.suggest_float(\"lr\", 1e-4, 1e-2, log=True)\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n", "\n", " # --- Training Loop ---\n", " best_val_acc = 0\n", " for epoch in range(15): # A shorter run for hyperparameter search\n", " model.train()\n", " for features, labels in train_loader:\n", " features, labels = features.to(device), labels.to(device); optimizer.zero_grad()\n", " outputs = model(features); loss = criterion(outputs, labels); loss.backward(); optimizer.step()\n", "\n", " model.eval(); val_correct = 0\n", " with torch.no_grad():\n", " for features, labels in val_loader:\n", " features, labels = features.to(device), labels.to(device)\n", " val_correct += (torch.max(model(features).data, 1)[1] == labels).sum().item()\n", " val_acc = val_correct / len(val_dataset)\n", " if val_acc > best_val_acc: best_val_acc = val_acc\n", "\n", " return best_val_acc\n", "\n", "print(\"\\n--- PART 1: Starting Hyperparameter Search with Optuna ---\")\n", "study = optuna.create_study(direction=\"maximize\")\n", "study.optimize(objective, n_trials=30) # Increase trials for better search\n", "best_params = study.best_params\n", "print(\"\\nOptuna search complete. Best parameters found:\")\n", "print(best_params)" ] }, { "cell_type": "markdown", "metadata": { "id": "9584660f" }, "source": [ "## Ensemble Training\n", "\n", "This section trains an ensemble of models using the best hyperparameters found by Optuna. Each model in the ensemble is trained on a different train/validation split of the data to promote diversity." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cee415bd" }, "outputs": [], "source": [ "# =============================================================================\n", "# PART 2: TRAIN THE ENSEMBLE\n", "# =============================================================================\n", "N_ENSEMBLE = 5\n", "print(f\"\\n--- PART 2: Training an ensemble of {N_ENSEMBLE} models with the best parameters ---\")\n", "\n", "def create_final_model(params, input_shape, num_classes):\n", " # Recreate the model structure using the BEST params from Optuna\n", " class FinalModel(nn.Module):\n", " def __init__(self):\n", " super(FinalModel, self).__init__()\n", " self.conv_layers = nn.Sequential(\n", " nn.Conv1d(input_shape[0], params['conv1_filters'], 5, padding='same'), nn.BatchNorm1d(params['conv1_filters']), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(params['dropout']),\n", " nn.Conv1d(params['conv1_filters'], params['conv2_filters'], 5, padding='same'), nn.BatchNorm1d(params['conv2_filters']), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(params['dropout']))\n", " with torch.no_grad(): dummy_output = self.conv_layers(torch.zeros(1, input_shape[0], input_shape[1]))\n", " lstm_input_size = dummy_output.shape[1]\n", " self.lstm = nn.LSTM(lstm_input_size, params['lstm_units'], batch_first=True, bidirectional=True)\n", " self.attention = Attention(params['lstm_units'] * 2)\n", " self.classifier = nn.Sequential(nn.Linear(params['lstm_units'] * 2, 128), nn.ReLU(), nn.Dropout(params['dropout']), nn.Linear(128, num_classes))\n", " def forward(self, x):\n", " x = self.conv_layers(x).permute(0, 2, 1); lstm_out, _ = self.lstm(x)\n", " return self.classifier(self.attention(lstm_out))\n", " return FinalModel()\n", "\n", "for i in range(N_ENSEMBLE):\n", " print(f\"\\nTraining model {i+1}/{N_ENSEMBLE}...\")\n", " # Get a different train/val split for each model\n", " X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.2, random_state=i, stratify=y_train_val)\n", " train_dataset = SpeechDataset(X_train, y_train); val_dataset = SpeechDataset(X_val, y_val)\n", " train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True); val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)\n", "\n", " model = create_final_model(best_params, (120, 216), len(le.classes_)).to(device)\n", " criterion = FocalLoss(gamma=best_params['focal_gamma'])\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=best_params['lr'])\n", " scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)\n", "\n", " best_val_loss = float('inf'); patience_counter = 0\n", " for epoch in range(60): # Full training run\n", " model.train()\n", " for features, labels in train_loader:\n", " features, labels = features.to(device), labels.to(device); optimizer.zero_grad()\n", " outputs = model(features); loss = criterion(outputs, labels); loss.backward(); optimizer.step()\n", "\n", " model.eval(); val_loss = 0\n", " with torch.no_grad():\n", " for features, labels in val_loader:\n", " features, labels = features.to(device), labels.to(device); val_loss += criterion(model(features), labels).item()\n", "\n", " avg_val_loss = val_loss / len(val_loader)\n", " scheduler.step(avg_val_loss)\n", " if avg_val_loss < best_val_loss:\n", " best_val_loss = avg_val_loss; torch.save(model.state_dict(), f'ensemble_model_{i}.pth'); patience_counter = 0\n", " else:\n", " patience_counter += 1\n", " if patience_counter >= 10: break # Early stopping" ] }, { "cell_type": "markdown", "metadata": { "id": "089e88a0" }, "source": [ "## Ensemble Evaluation\n", "\n", "This section evaluates the trained ensemble of models on the hold-out test set. The predictions from each model are averaged (in terms of probabilities) to make the final prediction. The classification report and confusion matrix are then displayed to assess the ensemble's performance." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "71b503ed" }, "outputs": [], "source": [ "# =============================================================================\n", "# PART 3: EVALUATE THE ENSEMBLE\n", "# =============================================================================\n", "print(f\"\\n--- PART 3: Evaluating the ensemble of {N_ENSEMBLE} models on the hold-out test set ---\")\n", "models = []\n", "for i in range(N_ENSEMBLE):\n", " model = create_final_model(best_params, (120, 216), len(le.classes_)).to(device)\n", " model.load_state_dict(torch.load(f'ensemble_model_{i}.pth'))\n", " model.eval()\n", " models.append(model)\n", "\n", "all_preds, all_labels = [], []\n", "with torch.no_grad():\n", " for features, labels in test_loader:\n", " features = features.to(device)\n", " # Get predictions from all models in the ensemble\n", " all_model_probs = [F.softmax(model(features), dim=1) for model in models]\n", " # Average the probabilities\n", " mean_probs = torch.mean(torch.stack(all_model_probs), dim=0)\n", " # Get the final prediction\n", " _, predicted = torch.max(mean_probs.data, 1)\n", " all_preds.extend(predicted.cpu().numpy())\n", " all_labels.extend(labels.numpy())\n", "\n", "print(\"\\n--- Final Ensemble Classification Report ---\")\n", "print(classification_report(all_labels, all_preds, target_names=le.classes_))\n", "cm = confusion_matrix(all_labels, all_preds)\n", "plt.figure(figsize=(8, 6))\n", "sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=le.classes_, yticklabels=le.classes_)\n", "plt.title('Final Ensemble Confusion Matrix'); plt.ylabel('True Label'); plt.xlabel('Predicted Label')\n", "plt.show()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "08d5f1ac3e9e43d28bf0f2a1ca3780a9": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "099820472ef94622bc930c7dcc27392a": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "12ef45e353234cd187b89bd3b4adc885": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_8de1ae82920f4af996346673195773f2", "max": 159, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_ffb9e0b266b04194bb777818b49b7b1e", "value": 159 } }, "1622a1fc464742bdaeeec936fb81fd2c": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": "20px" } }, "1b77a27420aa4d7d90af158073b07e0c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "25ce37ab6aaf4bf6b0dcb920e2d3cb3e": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2d1043343d8d4d1f8eb0044e1fb0e48c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_ba121eca7b8f47559d54e7a3e02cc32c", "max": 380204696, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_8ab02e1b70cf46b6b2c2a38699d64de5", "value": 380204696 } }, "3719dfed7c0840e6bc152595a3c85487": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_5534f8d2d13b4f47b0dfacb3be042b68", "placeholder": "​", "style": "IPY_MODEL_099820472ef94622bc930c7dcc27392a", "value": " 380M/380M [00:06<00:00, 34.9MB/s]" } }, "3a4ece5788cb4e4fa02ed2e22c6c7416": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_880aab255d3c44798ec400644408b089", "placeholder": "​", "style": "IPY_MODEL_f8f99acab78d454b8505aa41c5b2fdb9", "value": "model.safetensors: 100%" } }, "3c66cc92b6dc443daa79dd82e564f407": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "40e143e654424e8f96a8acc62749b56c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_8a2b1e5cd1ed41208a04ce10bd37505d", "IPY_MODEL_12ef45e353234cd187b89bd3b4adc885", "IPY_MODEL_4934e2a2e84948d993219a887e6f932a" ], "layout": "IPY_MODEL_08d5f1ac3e9e43d28bf0f2a1ca3780a9" } }, "4934e2a2e84948d993219a887e6f932a": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_fbdc89bc7bca4e7b9af17480f1b51770", "placeholder": "​", "style": "IPY_MODEL_9051309263bf4dcab2a84a14085e010f", "value": " 159/159 [00:00<00:00, 4.26kB/s]" } }, "4a05f692345540a7ab688eae2ff4c635": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_3a4ece5788cb4e4fa02ed2e22c6c7416", "IPY_MODEL_2d1043343d8d4d1f8eb0044e1fb0e48c", "IPY_MODEL_a7a6f6ff7cac4c1b8dce48b4038af21f" ], "layout": "IPY_MODEL_d90a3d5595744e57b456a64dce5a6662" } }, "4ce71846fdcc4524ae55b33671370e81": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "5534f8d2d13b4f47b0dfacb3be042b68": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5d5dc262dd2a4ec282a1125b9dbaf63a": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5e10377b33264c4a9ac9727d8b9e8e44": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "678984ff6763443b9e57ad16e65e0c14": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "789e1107230e4431a3b8fc36ccc0f370": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_aea70cfcdeb94841bc1bc616cbbd1530", "max": 380267417, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_678984ff6763443b9e57ad16e65e0c14", "value": 380267417 } }, "7bcc87bdc1de4542be9816c23ef918d8": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_ddf902c6d15240f294215ca9ebcec5b9", "placeholder": "​", "style": "IPY_MODEL_3c66cc92b6dc443daa79dd82e564f407", "value": "pytorch_model.bin: 100%" } }, "82dc818a6f8049b08e9482b672cf3418": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "880aab255d3c44798ec400644408b089": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8a2b1e5cd1ed41208a04ce10bd37505d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_25ce37ab6aaf4bf6b0dcb920e2d3cb3e", "placeholder": "​", "style": "IPY_MODEL_5e10377b33264c4a9ac9727d8b9e8e44", "value": "preprocessor_config.json: 100%" } }, "8ab02e1b70cf46b6b2c2a38699d64de5": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "8de1ae82920f4af996346673195773f2": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9051309263bf4dcab2a84a14085e010f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "99248a3e649b4b4fb4a6693fc781a3ae": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_82dc818a6f8049b08e9482b672cf3418", "placeholder": "​", "style": "IPY_MODEL_e27a9962ac30459c8dae3d495c40ab8c", "value": "config.json: " } }, "a7a6f6ff7cac4c1b8dce48b4038af21f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c85639bd3d5240278f96912779017988", "placeholder": "​", "style": "IPY_MODEL_4ce71846fdcc4524ae55b33671370e81", "value": " 380M/380M [00:02<00:00, 167MB/s]" } }, "abc8a5b72f034606b3fe3b9e5d0ef8bb": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "aea70cfcdeb94841bc1bc616cbbd1530": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b63361a7078549919d77d03dbe797aa9": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_99248a3e649b4b4fb4a6693fc781a3ae", "IPY_MODEL_f77ea67f2d644e949477239def77cce3", "IPY_MODEL_d5b9e59652574ae39463a62e91acbc6e" ], "layout": "IPY_MODEL_ef93f9eb83cb461f90c8b5cb85fab024" } }, "ba121eca7b8f47559d54e7a3e02cc32c": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c85639bd3d5240278f96912779017988": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d5b9e59652574ae39463a62e91acbc6e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_abc8a5b72f034606b3fe3b9e5d0ef8bb", "placeholder": "​", "style": "IPY_MODEL_1b77a27420aa4d7d90af158073b07e0c", "value": " 1.84k/? [00:00<00:00, 33.1kB/s]" } }, "d90a3d5595744e57b456a64dce5a6662": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "dbfe5751c7a843cbac2f760232c268b2": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_7bcc87bdc1de4542be9816c23ef918d8", "IPY_MODEL_789e1107230e4431a3b8fc36ccc0f370", "IPY_MODEL_3719dfed7c0840e6bc152595a3c85487" ], "layout": "IPY_MODEL_5d5dc262dd2a4ec282a1125b9dbaf63a" } }, "ddf902c6d15240f294215ca9ebcec5b9": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e27a9962ac30459c8dae3d495c40ab8c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ef93f9eb83cb461f90c8b5cb85fab024": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f4f3e6103c844b89b42c71709ee180ba": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "f77ea67f2d644e949477239def77cce3": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_1622a1fc464742bdaeeec936fb81fd2c", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_f4f3e6103c844b89b42c71709ee180ba", "value": 1 } }, "f8f99acab78d454b8505aa41c5b2fdb9": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "fbdc89bc7bca4e7b9af17480f1b51770": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ffb9e0b266b04194bb777818b49b7b1e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } } } } }, "nbformat": 4, "nbformat_minor": 0 }