{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "vRCda9pWYcZu" }, "source": [ "# Pytorch Implementation with Attention Layer" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "zzqBIRVxcO_t", "outputId": "c07d1938-6818-4bc6-8347-0d686de098b5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Using device: cuda\n", "\n", "Downloading datasets via kagglehub...\n", "All datasets downloaded.\n", "\n", "Dataset fused and balanced. Final class distribution:\n", "sentiment\n", "neutral 1663\n", "positive 1663\n", "negative 1663\n", "Name: count, dtype: int64\n", "\n", "Extracting features from the balanced dataset...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4989/4989 [01:09<00:00, 72.25it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Feature extraction and DataLoaders with augmentation created.\n", "\n", "Model with Attention created.\n", "\n", "Training the v3 model (Augmentation + Attention)...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1/80: 100%|██████████| 110/110 [00:03<00:00, 32.02it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1 | Train Loss: 1.0775 | Val Loss: 0.9739 | Train Acc: 0.3918 | Val Acc: 0.5094\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 2/80: 100%|██████████| 110/110 [00:02<00:00, 36.77it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2 | Train Loss: 0.9884 | Val Loss: 0.9089 | Train Acc: 0.4785 | Val Acc: 0.5602\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 3/80: 100%|██████████| 110/110 [00:03<00:00, 29.39it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3 | Train Loss: 0.9531 | Val Loss: 0.8522 | Train Acc: 0.5072 | Val Acc: 0.5829\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 4/80: 100%|██████████| 110/110 [00:03<00:00, 31.00it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4 | Train Loss: 0.9110 | Val Loss: 0.8120 | Train Acc: 0.5438 | Val Acc: 0.6230\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 5/80: 100%|██████████| 110/110 [00:02<00:00, 37.34it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5 | Train Loss: 0.8803 | Val Loss: 0.7971 | Train Acc: 0.5633 | Val Acc: 0.6096\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 6/80: 100%|██████████| 110/110 [00:03<00:00, 30.88it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6 | Train Loss: 0.8571 | Val Loss: 0.8141 | Train Acc: 0.5819 | Val Acc: 0.6203\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 7/80: 100%|██████████| 110/110 [00:03<00:00, 33.63it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7 | Train Loss: 0.8575 | Val Loss: 0.7452 | Train Acc: 0.5839 | Val Acc: 0.6644\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 8/80: 100%|██████████| 110/110 [00:02<00:00, 37.73it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8 | Train Loss: 0.8260 | Val Loss: 1.0207 | Train Acc: 0.6065 | Val Acc: 0.4706\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 9/80: 100%|██████████| 110/110 [00:02<00:00, 38.23it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9 | Train Loss: 0.8333 | Val Loss: 0.7718 | Train Acc: 0.6025 | Val Acc: 0.6524\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 10/80: 100%|██████████| 110/110 [00:03<00:00, 29.97it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10 | Train Loss: 0.8338 | Val Loss: 0.7468 | Train Acc: 0.5985 | Val Acc: 0.6324\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 11/80: 100%|██████████| 110/110 [00:02<00:00, 37.12it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 11 | Train Loss: 0.8018 | Val Loss: 0.7142 | Train Acc: 0.6231 | Val Acc: 0.6604\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 12/80: 100%|██████████| 110/110 [00:02<00:00, 37.23it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 12 | Train Loss: 0.7818 | Val Loss: 0.7175 | Train Acc: 0.6380 | Val Acc: 0.6791\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 13/80: 100%|██████████| 110/110 [00:02<00:00, 37.67it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 13 | Train Loss: 0.7669 | Val Loss: 0.7158 | Train Acc: 0.6466 | Val Acc: 0.6791\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 14/80: 100%|██████████| 110/110 [00:03<00:00, 29.35it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 14 | Train Loss: 0.7457 | Val Loss: 0.7036 | Train Acc: 0.6523 | Val Acc: 0.6751\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 15/80: 100%|██████████| 110/110 [00:02<00:00, 37.76it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 15 | Train Loss: 0.7342 | Val Loss: 0.7170 | Train Acc: 0.6661 | Val Acc: 0.6698\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 16/80: 100%|██████████| 110/110 [00:02<00:00, 37.92it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 16 | Train Loss: 0.7273 | Val Loss: 0.6677 | Train Acc: 0.6764 | Val Acc: 0.6872\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 17/80: 100%|██████████| 110/110 [00:03<00:00, 34.81it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 17 | Train Loss: 0.7175 | Val Loss: 0.6622 | Train Acc: 0.6767 | Val Acc: 0.6885\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 18/80: 100%|██████████| 110/110 [00:03<00:00, 32.46it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 18 | Train Loss: 0.7013 | Val Loss: 0.6525 | Train Acc: 0.6850 | Val Acc: 0.6979\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 19/80: 100%|██████████| 110/110 [00:02<00:00, 37.85it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 19 | Train Loss: 0.6755 | Val Loss: 0.7202 | Train Acc: 0.6953 | Val Acc: 0.6765\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 20/80: 100%|██████████| 110/110 [00:02<00:00, 37.70it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 20 | Train Loss: 0.6726 | Val Loss: 0.6369 | Train Acc: 0.7027 | Val Acc: 0.6992\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 21/80: 100%|██████████| 110/110 [00:03<00:00, 30.93it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 21 | Train Loss: 0.6574 | Val Loss: 0.6206 | Train Acc: 0.7148 | Val Acc: 0.7126\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 22/80: 100%|██████████| 110/110 [00:03<00:00, 34.72it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 22 | Train Loss: 0.6522 | Val Loss: 0.6574 | Train Acc: 0.7222 | Val Acc: 0.7219\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 23/80: 100%|██████████| 110/110 [00:02<00:00, 37.16it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 23 | Train Loss: 0.6612 | Val Loss: 0.6821 | Train Acc: 0.7182 | Val Acc: 0.7059\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 24/80: 100%|██████████| 110/110 [00:02<00:00, 37.63it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 24 | Train Loss: 0.6384 | Val Loss: 0.6364 | Train Acc: 0.7211 | Val Acc: 0.7406\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 25/80: 100%|██████████| 110/110 [00:03<00:00, 29.91it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 25 | Train Loss: 0.6289 | Val Loss: 0.6569 | Train Acc: 0.7285 | Val Acc: 0.7112\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 26/80: 100%|██████████| 110/110 [00:02<00:00, 37.44it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 26 | Train Loss: 0.5836 | Val Loss: 0.6210 | Train Acc: 0.7506 | Val Acc: 0.7433\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 27/80: 100%|██████████| 110/110 [00:02<00:00, 36.80it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 27 | Train Loss: 0.5548 | Val Loss: 0.6165 | Train Acc: 0.7617 | Val Acc: 0.7353\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 28/80: 100%|██████████| 110/110 [00:03<00:00, 36.42it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 28 | Train Loss: 0.5445 | Val Loss: 0.6026 | Train Acc: 0.7655 | Val Acc: 0.7340\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 29/80: 100%|██████████| 110/110 [00:03<00:00, 30.49it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 29 | Train Loss: 0.5244 | Val Loss: 0.6278 | Train Acc: 0.7738 | Val Acc: 0.7406\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 30/80: 100%|██████████| 110/110 [00:02<00:00, 37.83it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 30 | Train Loss: 0.5259 | Val Loss: 0.6247 | Train Acc: 0.7763 | Val Acc: 0.7567\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 31/80: 100%|██████████| 110/110 [00:02<00:00, 36.70it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 31 | Train Loss: 0.5153 | Val Loss: 0.6223 | Train Acc: 0.7849 | Val Acc: 0.7366\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 32/80: 100%|██████████| 110/110 [00:03<00:00, 33.96it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 32 | Train Loss: 0.4965 | Val Loss: 0.6087 | Train Acc: 0.7927 | Val Acc: 0.7687\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 33/80: 100%|██████████| 110/110 [00:03<00:00, 33.92it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 33 | Train Loss: 0.4631 | Val Loss: 0.6442 | Train Acc: 0.8047 | Val Acc: 0.7647\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 34/80: 100%|██████████| 110/110 [00:02<00:00, 37.67it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 34 | Train Loss: 0.4657 | Val Loss: 0.6116 | Train Acc: 0.8067 | Val Acc: 0.7687\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 35/80: 100%|██████████| 110/110 [00:02<00:00, 38.09it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 35 | Train Loss: 0.4447 | Val Loss: 0.6205 | Train Acc: 0.8202 | Val Acc: 0.7741\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 36/80: 100%|██████████| 110/110 [00:04<00:00, 25.29it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 36 | Train Loss: 0.4459 | Val Loss: 0.6515 | Train Acc: 0.8107 | Val Acc: 0.7567\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 37/80: 100%|██████████| 110/110 [00:03<00:00, 35.74it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 37 | Train Loss: 0.4193 | Val Loss: 0.6158 | Train Acc: 0.8276 | Val Acc: 0.7834\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 38/80: 100%|██████████| 110/110 [00:02<00:00, 37.16it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 38 | Train Loss: 0.4063 | Val Loss: 0.6454 | Train Acc: 0.8325 | Val Acc: 0.7714\n", "Early stopping triggered.\n", "\n", "Evaluating the final v3 model...\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "--- Final Classification Report on Fused Test Set (v3 Model) ---\n", " precision recall f1-score support\n", "\n", " negative 0.68 0.60 0.64 249\n", " neutral 0.74 0.89 0.81 250\n", " positive 0.75 0.68 0.71 250\n", "\n", " accuracy 0.72 749\n", " macro avg 0.72 0.72 0.72 749\n", "weighted avg 0.72 0.72 0.72 749\n", "\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# =============================================================================\n", "# SETUP & LIBRARIES\n", "# =============================================================================\n", "import os\n", "import numpy as np\n", "import pandas as pd\n", "import librosa\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from tqdm import tqdm\n", "import kagglehub\n", "\n", "# Import PyTorch essentials\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, Dataset\n", "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.metrics import classification_report, confusion_matrix\n", "\n", "# --- Setup Device (GPU/CPU) ---\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"\\nUsing device: {device}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "93e14618" }, "outputs": [], "source": [ "# =============================================================================\n", "# DOWNLOAD & PROCESS DATASETS\n", "# =============================================================================\n", "print(\"\\nDownloading datasets via kagglehub...\")\n", "RAVDESS_PATH = kagglehub.dataset_download(\"uwrfkaggler/ravdess-emotional-speech-audio\")\n", "CREMA_D_PATH = kagglehub.dataset_download(\"ejlok1/cremad\")\n", "print(\"All datasets downloaded.\")\n", "\n", "# --- Define processing functions and fuse datasets ---\n", "sentiment_map = {'happy': 'positive', 'surprised': 'positive', 'sad': 'negative', 'angry': 'negative', 'fearful': 'negative', 'disgust': 'negative', 'neutral': 'neutral', 'calm': 'neutral'}\n", "ravdess_emotion_map = {'01': 'neutral', '02': 'calm', '03': 'happy', '04': 'sad', '05': 'angry', '06': 'fearful', '07': 'disgust', '08': 'surprised'}\n", "ravdess_data = []\n", "for dirpath, _, filenames in os.walk(RAVDESS_PATH):\n", " for filename in filenames:\n", " if filename.endswith('.wav'):\n", " emotion_code = filename.split('-')[2]\n", " emotion = ravdess_emotion_map.get(emotion_code)\n", " sentiment = sentiment_map.get(emotion)\n", " if sentiment: ravdess_data.append({\"filepath\": os.path.join(dirpath, filename), \"sentiment\": sentiment})\n", "ravdess_df = pd.DataFrame(ravdess_data)\n", "\n", "crema_emotion_map = {'HAP': 'happy', 'SAD': 'sad', 'ANG': 'angry', 'FEA': 'fearful', 'DIS': 'disgust', 'NEU': 'neutral'}\n", "crema_data = []\n", "crema_audio_path = os.path.join(CREMA_D_PATH, \"AudioWAV\")\n", "for filename in os.listdir(crema_audio_path):\n", " if filename.endswith('.wav'):\n", " emotion_code = filename.split('_')[2]\n", " emotion = crema_emotion_map.get(emotion_code)\n", " sentiment = sentiment_map.get(emotion)\n", " if sentiment: crema_data.append({\"filepath\": os.path.join(crema_audio_path, filename), \"sentiment\": sentiment})\n", "crema_df = pd.DataFrame(crema_data)\n", "\n", "combined_df = pd.concat([ravdess_df, crema_df], ignore_index=True)\n", "min_class_size = combined_df['sentiment'].value_counts().min()\n", "df_list = [combined_df[combined_df['sentiment'] == s].sample(min_class_size, random_state=42) for s in combined_df['sentiment'].unique()]\n", "balanced_df = pd.concat(df_list).sample(frac=1, random_state=42).reset_index(drop=True)\n", "print(\"\\nDataset fused and balanced. Final class distribution:\")\n", "print(balanced_df['sentiment'].value_counts())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8dfb5a81" }, "outputs": [], "source": [ "# =============================================================================\n", "# STEP 3: FEATURE EXTRACTION & DATASET CLASS (WITH AUGMENTATION)\n", "# =============================================================================\n", "def extract_features(file_path, n_mfcc=40, max_pad_len=216):\n", " try:\n", " audio, sr = librosa.load(file_path, res_type='kaiser_fast', sr=None)\n", " mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=n_mfcc)\n", " delta_mfccs = librosa.feature.delta(mfccs)\n", " delta2_mfccs = librosa.feature.delta(mfccs, order=2)\n", " all_features = np.concatenate((mfccs, delta_mfccs, delta2_mfccs))\n", " if all_features.shape[1] > max_pad_len: all_features = all_features[:, :max_pad_len]\n", " else:\n", " pad_width = max_pad_len - all_features.shape[1]\n", " all_features = np.pad(all_features, pad_width=((0, 0), (0, pad_width)), mode='constant')\n", " except Exception as e: return None\n", " return all_features\n", "\n", "class SpeechDataset(Dataset):\n", " def __init__(self, features, labels, augmentation=False):\n", " # Store features as a list of numpy arrays for easier augmentation\n", " self.features_list = [np.array(f) for f in features]\n", " self.labels = torch.tensor(labels, dtype=torch.long)\n", " self.augmentation = augmentation\n", "\n", " def __len__(self):\n", " return len(self.features_list)\n", "\n", " def __getitem__(self, idx):\n", " feature = self.features_list[idx]\n", "\n", " if self.augmentation:\n", " # 50% chance of adding noise\n", " if np.random.rand() < 0.5:\n", " noise_amp = 0.005 * np.random.uniform() * np.amax(feature)\n", " feature = feature + noise_amp * np.random.normal(size=feature.shape)\n", "\n", " # 50% chance of simulating a pitch shift\n", " if np.random.rand() < 0.5:\n", " steps = np.random.randint(-2, 2)\n", " feature = np.roll(feature, steps, axis=0)\n", "\n", " return torch.tensor(feature, dtype=torch.float32), self.labels[idx]\n", "\n", "print(\"\\nExtracting features from the balanced dataset...\")\n", "X = [extract_features(fp) for fp in tqdm(balanced_df['filepath'])]\n", "# Filter out any None values that may have occurred during extraction\n", "y_list = [y for x, y in zip(X, balanced_df['sentiment']) if x is not None]\n", "X_list = [x for x in X if x is not None]\n", "\n", "le = LabelEncoder()\n", "y = le.fit_transform(y_list)\n", "\n", "# Split data and create DataLoaders\n", "X_train, X_temp, y_train, y_temp = train_test_split(X_list, y, test_size=0.3, random_state=42, stratify=y)\n", "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp)\n", "\n", "train_dataset = SpeechDataset(X_train, y_train, augmentation=True)\n", "val_dataset = SpeechDataset(X_val, y_val, augmentation=False)\n", "test_dataset = SpeechDataset(X_test, y_test, augmentation=False)\n", "\n", "BATCH_SIZE = 32\n", "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n", "val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)\n", "test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)\n", "print(\"\\nFeature extraction and DataLoaders with augmentation created.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e9cb6730" }, "outputs": [], "source": [ "# =============================================================================\n", "# STEP 4: MODEL ARCHITECTURE (WITH ATTENTION)\n", "# =============================================================================\n", "class Attention(nn.Module):\n", " def __init__(self, hidden_size):\n", " super(Attention, self).__init__()\n", " self.attention = nn.Linear(hidden_size, 1, bias=False)\n", "\n", " def forward(self, lstm_output):\n", " # (batch, seq_len, hidden_size) -> (batch, seq_len, 1)\n", " scores = self.attention(lstm_output)\n", " scores = scores.squeeze(2)\n", " # (batch, seq_len)\n", " weights = F.softmax(scores, dim=1)\n", " # (batch, 1, seq_len)\n", " weights = weights.unsqueeze(1)\n", " # (batch, 1, seq_len) x (batch, seq_len, hidden_size) -> (batch, 1, hidden_size)\n", " context = torch.bmm(weights, lstm_output)\n", " # (batch, hidden_size)\n", " return context.squeeze(1)\n", "\n", "class SpeechModel(nn.Module):\n", " def __init__(self, input_shape, num_classes):\n", " super(SpeechModel, self).__init__()\n", " self.conv_layers = nn.Sequential(\n", " nn.Conv1d(in_channels=input_shape[0], out_channels=128, kernel_size=5, padding='same'),\n", " nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(kernel_size=2), nn.Dropout(0.2),\n", " nn.Conv1d(in_channels=128, out_channels=256, kernel_size=5, padding='same'),\n", " nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2), nn.Dropout(0.2)\n", " )\n", " with torch.no_grad():\n", " dummy_input = torch.zeros(1, input_shape[0], input_shape[1])\n", " dummy_output = self.conv_layers(dummy_input)\n", " lstm_input_size = dummy_output.shape[1]\n", "\n", " self.lstm = nn.LSTM(input_size=lstm_input_size, hidden_size=256, batch_first=True, bidirectional=True)\n", " self.attention = Attention(hidden_size=256 * 2)\n", " self.classifier = nn.Sequential(\n", " nn.Linear(256 * 2, 128),\n", " nn.ReLU(), nn.Dropout(0.4),\n", " nn.Linear(128, num_classes)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.conv_layers(x)\n", " x = x.permute(0, 2, 1)\n", " lstm_out, _ = self.lstm(x)\n", " context_vector = self.attention(lstm_out)\n", " output = self.classifier(context_vector)\n", " return output\n", "\n", "# --- Instantiate model, loss, and optimizer ---\n", "INPUT_SHAPE = (120, 216) # (n_mfcc * 3, max_len)\n", "NUM_CLASSES = len(le.classes_)\n", "model = SpeechModel(input_shape=INPUT_SHAPE, num_classes=NUM_CLASSES).to(device)\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) # Using AdamW\n", "scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)\n", "print(\"\\nModel with Attention created.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "90e7efd0" }, "outputs": [], "source": [ "# =============================================================================\n", "# STEP 5: TRAIN THE MODEL\n", "# =============================================================================\n", "print(\"\\nTraining the v3 model (Augmentation + Attention)...\")\n", "EPOCHS = 80 # Increased epochs for this more complex setup\n", "patience = 10\n", "best_val_loss = float('inf')\n", "patience_counter = 0\n", "history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}\n", "\n", "for epoch in range(EPOCHS):\n", " model.train(); train_loss, train_correct = 0.0, 0\n", " for features, labels in tqdm(train_loader, desc=f\"Epoch {epoch+1}/{EPOCHS}\"):\n", " features, labels = features.to(device), labels.to(device)\n", " optimizer.zero_grad()\n", " outputs = model(features)\n", " loss = criterion(outputs, labels)\n", " loss.backward()\n", " optimizer.step()\n", " train_loss += loss.item() * features.size(0)\n", " train_correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()\n", "\n", " model.eval(); val_loss, val_correct = 0.0, 0\n", " with torch.no_grad():\n", " for features, labels in val_loader:\n", " features, labels = features.to(device), labels.to(device)\n", " outputs = model(features)\n", " loss = criterion(outputs, labels)\n", " val_loss += loss.item() * features.size(0)\n", " val_correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()\n", "\n", " avg_train_loss=train_loss/len(train_loader.dataset); avg_val_loss=val_loss/len(val_loader.dataset)\n", " train_acc=train_correct/len(train_loader.dataset); val_acc=val_correct/len(val_loader.dataset)\n", " history['train_loss'].append(avg_train_loss); history['val_loss'].append(avg_val_loss)\n", " history['train_acc'].append(train_acc); history['val_acc'].append(val_acc)\n", " print(f\"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}\")\n", "\n", " scheduler.step(avg_val_loss)\n", " if avg_val_loss < best_val_loss:\n", " best_val_loss = avg_val_loss; torch.save(model.state_dict(), 'best_fused_model_v3_attn.pth'); patience_counter = 0\n", " else:\n", " patience_counter += 1\n", " if patience_counter >= patience: print(\"Early stopping triggered.\"); break\n", "\n", "model.load_state_dict(torch.load('best_fused_model_v3_attn.pth'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "a841f14f" }, "outputs": [], "source": [ "# =============================================================================\n", "# STEP 6: EVALUATE THE FINAL MODEL\n", "# =============================================================================\n", "print(\"\\nEvaluating the final v3 model...\")\n", "plt.figure(figsize=(12, 4))\n", "plt.subplot(1, 2, 1); plt.plot(history['train_acc'], label='Train Acc'); plt.plot(history['val_acc'], label='Val Acc')\n", "plt.title('Accuracy'); plt.legend(); plt.grid(True)\n", "plt.subplot(1, 2, 2); plt.plot(history['train_loss'], label='Train Loss'); plt.plot(history['val_loss'], label='Val Loss')\n", "plt.title('Loss'); plt.legend(); plt.grid(True)\n", "plt.show()\n", "\n", "model.eval()\n", "all_preds, all_labels = [], []\n", "with torch.no_grad():\n", " for features, labels in test_loader:\n", " features, labels = features.to(device), labels.to(device)\n", " outputs = model(features)\n", " _, predicted = torch.max(outputs.data, 1)\n", " all_preds.extend(predicted.cpu().numpy())\n", " all_labels.extend(labels.cpu().numpy())\n", "\n", "print(\"\\n--- Final Classification Report on Fused Test Set (v3 Model) ---\")\n", "print(classification_report(all_labels, all_preds, target_names=le.classes_))\n", "cm = confusion_matrix(all_labels, all_preds)\n", "plt.figure(figsize=(8, 6))\n", "sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=le.classes_, yticklabels=le.classes_)\n", "plt.title('Confusion Matrix on Fused Test Set (v3 Model)'); plt.ylabel('True Label'); plt.xlabel('Predicted Label')\n", "plt.show()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "08d5f1ac3e9e43d28bf0f2a1ca3780a9": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "099820472ef94622bc930c7dcc27392a": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "12ef45e353234cd187b89bd3b4adc885": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_8de1ae82920f4af996346673195773f2", "max": 159, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_ffb9e0b266b04194bb777818b49b7b1e", "value": 159 } }, "1622a1fc464742bdaeeec936fb81fd2c": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": "20px" } }, "1b77a27420aa4d7d90af158073b07e0c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "25ce37ab6aaf4bf6b0dcb920e2d3cb3e": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2d1043343d8d4d1f8eb0044e1fb0e48c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_ba121eca7b8f47559d54e7a3e02cc32c", "max": 380204696, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_8ab02e1b70cf46b6b2c2a38699d64de5", "value": 380204696 } }, "3719dfed7c0840e6bc152595a3c85487": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_5534f8d2d13b4f47b0dfacb3be042b68", "placeholder": "​", "style": "IPY_MODEL_099820472ef94622bc930c7dcc27392a", "value": " 380M/380M [00:06<00:00, 34.9MB/s]" } }, "3a4ece5788cb4e4fa02ed2e22c6c7416": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_880aab255d3c44798ec400644408b089", "placeholder": "​", "style": "IPY_MODEL_f8f99acab78d454b8505aa41c5b2fdb9", "value": "model.safetensors: 100%" } }, "3c66cc92b6dc443daa79dd82e564f407": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "40e143e654424e8f96a8acc62749b56c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_8a2b1e5cd1ed41208a04ce10bd37505d", "IPY_MODEL_12ef45e353234cd187b89bd3b4adc885", "IPY_MODEL_4934e2a2e84948d993219a887e6f932a" ], "layout": "IPY_MODEL_08d5f1ac3e9e43d28bf0f2a1ca3780a9" } }, "4934e2a2e84948d993219a887e6f932a": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_fbdc89bc7bca4e7b9af17480f1b51770", "placeholder": "​", "style": "IPY_MODEL_9051309263bf4dcab2a84a14085e010f", "value": " 159/159 [00:00<00:00, 4.26kB/s]" } }, "4a05f692345540a7ab688eae2ff4c635": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_3a4ece5788cb4e4fa02ed2e22c6c7416", "IPY_MODEL_2d1043343d8d4d1f8eb0044e1fb0e48c", "IPY_MODEL_a7a6f6ff7cac4c1b8dce48b4038af21f" ], "layout": "IPY_MODEL_d90a3d5595744e57b456a64dce5a6662" } }, "4ce71846fdcc4524ae55b33671370e81": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "5534f8d2d13b4f47b0dfacb3be042b68": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5d5dc262dd2a4ec282a1125b9dbaf63a": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5e10377b33264c4a9ac9727d8b9e8e44": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "678984ff6763443b9e57ad16e65e0c14": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "789e1107230e4431a3b8fc36ccc0f370": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_aea70cfcdeb94841bc1bc616cbbd1530", "max": 380267417, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_678984ff6763443b9e57ad16e65e0c14", "value": 380267417 } }, "7bcc87bdc1de4542be9816c23ef918d8": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_ddf902c6d15240f294215ca9ebcec5b9", "placeholder": "​", "style": "IPY_MODEL_3c66cc92b6dc443daa79dd82e564f407", "value": "pytorch_model.bin: 100%" } }, "82dc818a6f8049b08e9482b672cf3418": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "880aab255d3c44798ec400644408b089": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8a2b1e5cd1ed41208a04ce10bd37505d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_25ce37ab6aaf4bf6b0dcb920e2d3cb3e", "placeholder": "​", "style": "IPY_MODEL_5e10377b33264c4a9ac9727d8b9e8e44", "value": "preprocessor_config.json: 100%" } }, "8ab02e1b70cf46b6b2c2a38699d64de5": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "8de1ae82920f4af996346673195773f2": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9051309263bf4dcab2a84a14085e010f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "99248a3e649b4b4fb4a6693fc781a3ae": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_82dc818a6f8049b08e9482b672cf3418", "placeholder": "​", "style": "IPY_MODEL_e27a9962ac30459c8dae3d495c40ab8c", "value": "config.json: " } }, "a7a6f6ff7cac4c1b8dce48b4038af21f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c85639bd3d5240278f96912779017988", "placeholder": "​", "style": "IPY_MODEL_4ce71846fdcc4524ae55b33671370e81", "value": " 380M/380M [00:02<00:00, 167MB/s]" } }, "abc8a5b72f034606b3fe3b9e5d0ef8bb": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "aea70cfcdeb94841bc1bc616cbbd1530": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b63361a7078549919d77d03dbe797aa9": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_99248a3e649b4b4fb4a6693fc781a3ae", "IPY_MODEL_f77ea67f2d644e949477239def77cce3", "IPY_MODEL_d5b9e59652574ae39463a62e91acbc6e" ], "layout": "IPY_MODEL_ef93f9eb83cb461f90c8b5cb85fab024" } }, "ba121eca7b8f47559d54e7a3e02cc32c": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c85639bd3d5240278f96912779017988": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d5b9e59652574ae39463a62e91acbc6e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_abc8a5b72f034606b3fe3b9e5d0ef8bb", "placeholder": "​", "style": "IPY_MODEL_1b77a27420aa4d7d90af158073b07e0c", "value": " 1.84k/? [00:00<00:00, 33.1kB/s]" } }, "d90a3d5595744e57b456a64dce5a6662": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "dbfe5751c7a843cbac2f760232c268b2": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_7bcc87bdc1de4542be9816c23ef918d8", "IPY_MODEL_789e1107230e4431a3b8fc36ccc0f370", "IPY_MODEL_3719dfed7c0840e6bc152595a3c85487" ], "layout": "IPY_MODEL_5d5dc262dd2a4ec282a1125b9dbaf63a" } }, "ddf902c6d15240f294215ca9ebcec5b9": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e27a9962ac30459c8dae3d495c40ab8c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ef93f9eb83cb461f90c8b5cb85fab024": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f4f3e6103c844b89b42c71709ee180ba": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "f77ea67f2d644e949477239def77cce3": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_1622a1fc464742bdaeeec936fb81fd2c", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_f4f3e6103c844b89b42c71709ee180ba", "value": 1 } }, "f8f99acab78d454b8505aa41c5b2fdb9": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "fbdc89bc7bca4e7b9af17480f1b51770": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ffb9e0b266b04194bb777818b49b7b1e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } } } } }, "nbformat": 4, "nbformat_minor": 0 }