# -*- coding: utf-8 -*- import os import random import logging import torch import torchaudio import whisper import numpy as np import pandas as pd from torch.utils.data import Dataset import pickle from tqdm import tqdm # from data_loading.feature_extractor import PretrainedAudioEmbeddingExtractor, PretrainedTextEmbeddingExtractor class DatasetMultiModalWithPretrainedExtractors(Dataset): """ Мультимодальный датасет для аудио, текста и эмоций (он‑the‑fly версия). При каждом вызове __getitem__: - Загружает WAV по video_name из CSV. - Для обучающей выборки (split="train"): Если аудио короче target_samples, проверяем, выбрали ли мы этот файл для склейки (по merge_probability). Если да – выполняется "chain merge": выбирается один или несколько дополнительных файлов того же класса, даже если один кандидат длиннее, и итоговое аудио затем обрезается до точной длины. - Если итоговое аудио всё ещё меньше target_samples, выполняется паддинг нулями. - Текст выбирается так: • Если аудио было merged (склеено) – вызывается Whisper для получения нового текста. • Если merge не происходило и CSV-текст не пуст – используется CSV-текст. • Если CSV-текст пустой – для train (или, при условии, для dev/test) вызывается Whisper. - Возвращает словарь { "audio": waveform, "label": label_vector, "text": text_final }. """ def __init__( self, csv_path, wav_dir, emotion_columns, config, split, audio_feature_extractor, text_feature_extractor, whisper_model, dataset_name ): """ :param csv_path: Путь к CSV-файлу (с колонками video_name, emotion_columns, возможно text). :param wav_dir: Папка с аудиофайлами (имя файла: video_name.wav). :param emotion_columns: Список колонок эмоций, например ["neutral", "happy", "sad", ...]. :param split: "train", "dev" или "test". :param audio_feature_extractor: Экстрактор аудио признаков :param text_feature_extractor: Экстрактор текстовых признаков :param sample_rate: Целевая частота дискретизации (например, 16000). :param wav_length: Целевая длина аудио в секундах. :param whisper_model: Mодель Whisper ("tiny", "base", "small", ...). :param max_text_tokens: (Не используется) – ограничение на число токенов. :param text_column: Название колонки с текстом в CSV. :param use_whisper_for_nontrain_if_no_text: Если True, для dev/test при отсутствии CSV-текста вызывается Whisper. :param whisper_device: "cuda" или "cpu" – устройство для модели Whisper. :param subset_size: Если > 0, используется только первые N записей из CSV (для отладки). :param merge_probability: Процент (0..1) от всего числа файлов, которые будут склеиваться, если они короче. :param dataset_name: Название корпуса """ super().__init__() self.split = split self.sample_rate = config.sample_rate self.target_samples = int(config.wav_length * self.sample_rate) self.emotion_columns = emotion_columns self.whisper_model = whisper_model self.text_column = config.text_column self.use_whisper_for_nontrain_if_no_text = config.use_whisper_for_nontrain_if_no_text self.whisper_device = config.whisper_device self.merge_probability = config.merge_probability self.audio_feature_extractor = audio_feature_extractor self.text_feature_extractor = text_feature_extractor self.subset_size = config.subset_size self.save_prepared_data = config.save_prepared_data self.seed = config.random_seed self.dataset_name = dataset_name self.save_feature_path = config.save_feature_path self.use_synthetic_data = config.use_synthetic_data self.synthetic_path = config.synthetic_path self.synthetic_ratio = config.synthetic_ratio # Загружаем CSV if not os.path.exists(csv_path): raise ValueError(f"Ошибка: файл CSV не найден: {csv_path}") df = pd.read_csv(csv_path) if self.subset_size > 0: df = df.head(self.subset_size) logging.info(f"[DatasetMultiModal] Используем только первые {len(df)} записей (subset_size={self.subset_size}).") #копия для сохранения текста Wisper self.original_df = df.copy() self.whisper_csv_update_log = [] # Проверяем наличие всех колонок эмоций missing = [c for c in emotion_columns if c not in df.columns] if missing: raise ValueError(f"В CSV отсутствуют необходимые колонки эмоций: {missing}") # Проверяем существование папки с аудио if not os.path.exists(wav_dir): raise ValueError(f"Ошибка: директория с аудио {wav_dir} не существует!") self.wav_dir = wav_dir # Собираем список строк: для каждой записи получаем путь к аудио, label и CSV-текст (если есть) self.rows = [] for i, rowi in df.iterrows(): audio_path = os.path.join(wav_dir, f"{rowi['video_name']}.wav") if not os.path.exists(audio_path): continue # Определяем доминирующую эмоцию (максимальное значение) # print(self.emotion_columns) emotion_values = rowi[self.emotion_columns].values.astype(float) max_idx = np.argmax(emotion_values) emotion_label = self.emotion_columns[max_idx] # Извлекаем текст из CSV (если есть) csv_text = "" if self.text_column in rowi and isinstance(rowi[self.text_column], str): csv_text = rowi[self.text_column] self.rows.append({ "audio_path": audio_path, "label": emotion_label, "csv_text": csv_text }) if self.use_synthetic_data and self.split == "train" and self.dataset_name.lower() == "meld": logging.info(f"🧪 Включена синтетика для датасета '{self.dataset_name}' — добавляем примеры из: {self.synthetic_path}") self._add_synthetic_data(self.synthetic_ratio) # Создаем карту для поиска файлов по эмоции self.audio_class_map = {entry["audio_path"]: entry["label"] for entry in self.rows} logging.info("📊 Анализ распределения файлов по эмоциям:") emotion_counts = {emotion: 0 for emotion in set(self.audio_class_map.values())} for path, emotion in self.audio_class_map.items(): emotion_counts[emotion] += 1 for emotion, count in emotion_counts.items(): logging.info(f"🎭 Эмоция '{emotion}': {count} файлов.") logging.info(f"[DatasetMultiModal] Сплит={split}, всего строк: {len(self.rows)}") # === Процентное семплирование === total_files = len(self.rows) num_to_merge = int(total_files * self.merge_probability) # <<< NEW: Кешируем длины (eq_len) для всех файлов >>> self.path_info = {} for row in self.rows: p = row["audio_path"] try: info = torchaudio.info(p) length = info.num_frames sr_ = info.sample_rate # переводим длину в "эквивалент self.sample_rate" if sr_ != self.sample_rate: ratio = sr_ / self.sample_rate eq_len = int(length / ratio) else: eq_len = length self.path_info[p] = eq_len except Exception as e: logging.warning(f"⚠️ Ошибка чтения {p}: {e}") self.path_info[p] = 0 # Если не смогли прочитать, ставим 0 # Определим, какие файлы "короткие" (могут нуждаться в склейке) - используем кэш вместо старого _is_too_short self.mergable_files = [ row["audio_path"] # вместо целого dict берём строку for row in self.rows if self._is_too_short_cached(row["audio_path"]) # <<< теперь тут используем новую функцию ] short_count = len(self.mergable_files) # Если коротких файлов больше нужного числа, выберем случайные. Иначе все короткие. if short_count > num_to_merge: self.files_to_merge = set(random.sample(self.mergable_files, num_to_merge)) else: self.files_to_merge = set(self.mergable_files) logging.info(f"🔗 Всего файлов: {total_files}, нужно склеить: {num_to_merge} ({self.merge_probability*100:.0f}%)") logging.info(f"🔗 Коротких файлов: {short_count}, выбрано для склейки: {len(self.files_to_merge)}") if self.save_prepared_data: self.meta = [] if self.use_synthetic_data: meta_filename = '{}_{}_seed_{}_subset_size_{}_audio_model_{}_feature_norm_{}_synthetic_true_pct_{}_pred.pickle'.format( self.dataset_name, self.split, config.audio_classifier_checkpoint[-4:-3], self.seed, self.subset_size, config.emb_normalize, int(self.synthetic_ratio * 100) ) else: meta_filename = '{}_{}_seed_{}_subset_size_{}_audio_model_{}_feature_norm_{}_merge_prob_{}_pred.pickle'.format( self.dataset_name, self.split, config.audio_classifier_checkpoint[-4:-3], self.seed, self.subset_size, config.emb_normalize, self.merge_probability ) pickle_path = os.path.join(self.save_feature_path, meta_filename) self.load_data(pickle_path) if not self.meta: self.prepare_data() os.makedirs(self.save_feature_path, exist_ok=True) self.save_data(pickle_path) def save_data(self, filename): with open(filename, 'wb') as handle: pickle.dump(self.meta, handle, protocol=pickle.HIGHEST_PROTOCOL) def load_data(self, filename): if os.path.exists(filename): with open(filename, 'rb') as handle: self.meta = pickle.load(handle) else: self.meta = [] def _is_too_short(self, audio_path): """ (Оригинальная) Проверяем, является ли файл короче target_samples. Использует torchaudio.info(audio_path). Но теперь этот метод не используется, поскольку мы кешируем длины. """ try: info = torchaudio.info(audio_path) length = info.num_frames sr_ = info.sample_rate # переводим длину в "эквивалент self.sample_rate" if sr_ != self.sample_rate: ratio = sr_ / self.sample_rate eq_len = int(length / ratio) else: eq_len = length return eq_len < self.target_samples except Exception as e: logging.warning(f"Ошибка _is_too_short({audio_path}): {e}") return False def _is_too_short_cached(self, audio_path): """ (Новая) Проверяем, является ли файл короче target_samples, используя закешированную длину в self.path_info. """ eq_len = self.path_info.get(audio_path, 0) return eq_len < self.target_samples def __len__(self): if self.save_prepared_data: return len(self.meta) else: return len(self.rows) def get_data(self, row): audio_path = row["audio_path"] label_name = row["label"] csv_text = row["csv_text"] # Преобразуем label в one-hot вектор label_vec = self.emotion_to_vector(label_name) # Шаг 1. Загружаем аудио waveform, sr = self.load_audio(audio_path) if waveform is None: return None orig_len = waveform.shape[1] logging.debug(f"Исходная длина {os.path.basename(audio_path)}: {orig_len/sr:.2f} сек") was_merged = False merged_texts = [csv_text] # Тексты исходного файла + добавленных # Шаг 2. Для train, если аудио короче target_samples, проверяем: # попал ли данный row в files_to_merge? if self.split == "train" and row["audio_path"] in self.files_to_merge: # chain merge current_length = orig_len used_candidates = set() while current_length < self.target_samples: needed = self.target_samples - current_length candidate = self.get_suitable_audio(label_name, exclude_path=audio_path, min_needed=needed, top_k=10) if candidate is None or candidate in used_candidates: break used_candidates.add(candidate) add_wf, add_sr = self.load_audio(candidate) if add_wf is None: break logging.debug(f"Склейка: добавляем {os.path.basename(candidate)} (необходимых сэмплов: {needed})") waveform = torch.cat((waveform, add_wf), dim=1) current_length = waveform.shape[1] was_merged = True # Получаем текст второго файла (если есть в CSV) add_csv_text = next((r["csv_text"] for r in self.rows if r["audio_path"] == candidate), "") merged_texts.append(add_csv_text) logging.debug(f"📜 Текст первого файла: {csv_text}") logging.debug(f"📜 Текст добавленного файла: {add_csv_text}") else: # Если файл не в списке "должны склеить" или сплит не train, пропускаем chain-merge logging.debug("Файл не выбран для склейки (или не train), пропускаем chain merge.") if was_merged: logging.debug("📝 Текст: аудио было merged – вызываем Whisper.") text_final = self.run_whisper(waveform) logging.debug(f"🆕 Whisper предсказал: {text_final}") merge_components = [os.path.splitext(os.path.basename(audio_path))[0]] merge_components += [os.path.splitext(os.path.basename(p))[0] for p in used_candidates] self.whisper_csv_update_log.append({ "video_name": os.path.splitext(os.path.basename(audio_path))[0], "text_new": text_final, "text_old": csv_text, "was_merged": True, "merge_components": merge_components }) else: if csv_text.strip(): logging.debug("Текст: используем CSV-текст (не пуст).") text_final = csv_text else: if self.split == "train" or self.use_whisper_for_nontrain_if_no_text: logging.debug("Текст: CSV пустой – вызываем Whisper.") text_final = self.run_whisper(waveform) else: logging.debug("Текст: CSV пустой и не вызываем Whisper для dev/test.") text_final = "" audio_pred, audion_emb = self.audio_feature_extractor.extract(waveform[0], self.sample_rate) text_pred, text_emb = self.text_feature_extractor.extract(text_final) return { "audio_path": os.path.basename(audio_path), "audio": audion_emb[0], "label": label_vec, "text": text_emb[0], "audio_pred": audio_pred[0], "text_pred": text_pred[0] } def prepare_data(self): """ Загружает и обрабатывает один элемент датасета, сохраняет эмбеддинги и обновлённый текст (если было склеено). """ for idx, row in enumerate(tqdm(self.rows)): curr_dict = self.get_data(row) if curr_dict is not None: self.meta.append(curr_dict) # === Сохраняем CSV с обновлёнными текстами (только если был merge) === if self.whisper_csv_update_log: df_log = pd.DataFrame(self.whisper_csv_update_log) # Копия исходного CSV df_out = self.original_df.copy() # Мержим по video_name df_out = df_out.merge(df_log, on="video_name", how="left") # Обновляем текст: заменяем только если Whisper сгенерировал df_out["text_final"] = df_out["text_new"].combine_first(df_out["text"]) df_out["text_old"] = df_out["text"] df_out["text"] = df_out["text_final"] df_out["was_merged"] = df_out["was_merged"].fillna(False).astype(bool) # Преобразуем merge_components в строку df_out["merge_components"] = df_out["merge_components"].apply( lambda x: ";".join(x) if isinstance(x, list) else "" ) # Чистим временные колонки df_out = df_out.drop(columns=["text_new", "text_final"]) # Сохраняем как CSV output_path = os.path.join(self.save_feature_path, f"{self.dataset_name}_{self.split}_merged_whisper_{self.merge_probability *100}.csv") os.makedirs(self.save_feature_path, exist_ok=True) df_out.to_csv(output_path, index=False, encoding="utf-8") logging.info(f"📄 Обновлённый merged CSV сохранён: {output_path}") def __getitem__(self, index): if self.save_prepared_data: return self.meta[index] else: return self.get_data(self.rows[index]) def load_audio(self, path): """ Загружает аудио по указанному пути и ресэмплирует его до self.sample_rate, если необходимо. """ if not os.path.exists(path): logging.warning(f"Файл отсутствует: {path}") return None, None try: wf, sr = torchaudio.load(path) if sr != self.sample_rate: resampler = torchaudio.transforms.Resample(sr, self.sample_rate) wf = resampler(wf) sr = self.sample_rate return wf, sr except Exception as e: logging.error(f"Ошибка загрузки {path}: {e}") return None, None def get_suitable_audio(self, label_name, exclude_path, min_needed, top_k=5): """ Ищет аудиофайл с той же эмоцией. 1) Если есть файлы >= min_needed, выбираем случайно из них. 2) Если таких нет, берём топ-K самых длинных, потом из них берём случайный. """ candidates = [p for p, lbl in self.audio_class_map.items() if lbl == label_name and p != exclude_path] logging.debug(f"🔍 Найдено {len(candidates)} кандидатов для эмоции '{label_name}'") # Сохраним: (eq_len, path) для всех кандидатов, но БЕЗ повторного чтения torchaudio.info all_info = [] for path in candidates: # <<< NEW: вместо info = torchaudio.info(path) ... eq_len = self.path_info.get(path, 0) # Получаем из кэша all_info.append((eq_len, path)) valid = [(l, p) for l, p in all_info if l >= min_needed] logging.debug(f"✅ Подходящих (>= {min_needed}): {len(valid)} (из {len(all_info)})") if valid: # Если есть идеальные — берём случайно из них random.shuffle(valid) chosen = random.choice(valid)[1] return chosen else: # 2) Если идеальных нет — берём топ-K по длине sorted_by_len = sorted(all_info, key=lambda x: x[0], reverse=True) top_k_list = sorted_by_len[:top_k] if not top_k_list: logging.debug("Нет доступных кандидатов вообще.") return None # вообще нет кандидатов random.shuffle(top_k_list) chosen = top_k_list[0][1] logging.info(f"Из топ-{top_k} выбран кандидат: {chosen}") return chosen def run_whisper(self, waveform): """ Вызывает Whisper на аудиосигнале и возвращает полный текст (без ограничения по количеству слов). """ arr = waveform.squeeze().cpu().numpy() try: with torch.no_grad(): result = self.whisper_model.transcribe(arr, fp16=False) text = result["text"].strip() return text except Exception as e: logging.error(f"Whisper ошибка: {e}") return "" def _add_synthetic_data(self, synthetic_ratio): """ Добавляет synthetic_ratio (0..1) от количества доступных синтетических файлов на каждую эмоцию. """ if not self.synthetic_path: logging.warning("⚠ Путь к синтетическим данным не указан.") return random.seed(self.seed) synth_csv_path = os.path.join(self.synthetic_path, "meld_s_train_labels.csv") synth_wav_dir = os.path.join(self.synthetic_path, "wavs") if not (os.path.exists(synth_csv_path) and os.path.exists(synth_wav_dir)): logging.warning("⚠ Синтетические данные не найдены.") return df_synth = pd.read_csv(synth_csv_path) rows_by_label = {emotion: [] for emotion in self.emotion_columns} for _, row in df_synth.iterrows(): audio_path = os.path.join(synth_wav_dir, f"{row['video_name']}.wav") if not os.path.exists(audio_path): continue emotion_values = row[self.emotion_columns].values.astype(float) max_idx = np.argmax(emotion_values) label = self.emotion_columns[max_idx] csv_text = row[self.text_column] if self.text_column in row and isinstance(row[self.text_column], str) else "" rows_by_label[label].append({ "audio_path": audio_path, "label": label, "csv_text": csv_text }) added = 0 for label in self.emotion_columns: candidates = rows_by_label[label] if not candidates: continue count_synth = int(len(candidates) * synthetic_ratio) if count_synth <= 0: continue selected = random.sample(candidates, count_synth) self.rows.extend(selected) added += len(selected) logging.info(f"➕ Добавлено {len(selected)} синтетических примеров для эмоции '{label}'") logging.info(f"📦 Всего добавлено {added} синтетических примеров из MELD_S") def emotion_to_vector(self, label_name): """ Преобразует название эмоции в one-hot вектор (torch.tensor). """ v = np.zeros(len(self.emotion_columns), dtype=np.float32) if label_name in self.emotion_columns: idx = self.emotion_columns.index(label_name) v[idx] = 1.0 return torch.tensor(v, dtype=torch.float32) class DatasetMultiModal(Dataset): """ Мультимодальный датасет для аудио, текста и эмоций (он‑the‑fly версия). При каждом вызове __getitem__: - Загружает WAV по video_name из CSV. - Для обучающей выборки (split="train"): Если аудио короче target_samples, проверяем, выбрали ли мы этот файл для склейки (по merge_probability). Если да – выполняется "chain merge": выбирается один или несколько дополнительных файлов того же класса, даже если один кандидат длиннее, и итоговое аудио затем обрезается до точной длины. - Если итоговое аудио всё ещё меньше target_samples, выполняется паддинг нулями. - Текст выбирается так: • Если аудио было merged (склеено) – вызывается Whisper для получения нового текста. • Если merge не происходило и CSV-текст не пуст – используется CSV-текст. • Если CSV-текст пустой – для train (или, при условии, для dev/test) вызывается Whisper. - Возвращает словарь { "audio": waveform, "label": label_vector, "text": text_final }. """ def __init__( self, csv_path, wav_dir, emotion_columns, split="train", sample_rate=16000, wav_length=4, whisper_model="tiny", text_column="text", use_whisper_for_nontrain_if_no_text=True, whisper_device="cuda", subset_size=0, merge_probability=1.0 # <-- Новый параметр: доля от ОБЩЕГО числа файлов ): """ :param csv_path: Путь к CSV-файлу (с колонками video_name, emotion_columns, возможно text). :param wav_dir: Папка с аудиофайлами (имя файла: video_name.wav). :param emotion_columns: Список колонок эмоций, например ["neutral", "happy", "sad", ...]. :param split: "train", "dev" или "test". :param sample_rate: Целевая частота дискретизации (например, 16000). :param wav_length: Целевая длина аудио в секундах. :param whisper_model: Название модели Whisper ("tiny", "base", "small", ...). :param max_text_tokens: (Не используется) – ограничение на число токенов. :param text_column: Название колонки с текстом в CSV. :param use_whisper_for_nontrain_if_no_text: Если True, для dev/test при отсутствии CSV-текста вызывается Whisper. :param whisper_device: "cuda" или "cpu" – устройство для модели Whisper. :param subset_size: Если > 0, используется только первые N записей из CSV (для отладки). :param merge_probability: Процент (0..1) от всего числа файлов, которые будут склеиваться, если они короче. """ super().__init__() self.split = split self.sample_rate = sample_rate self.target_samples = int(wav_length * sample_rate) self.emotion_columns = emotion_columns self.whisper_model_name = whisper_model self.text_column = text_column self.use_whisper_for_nontrain_if_no_text = use_whisper_for_nontrain_if_no_text self.whisper_device = whisper_device self.merge_probability = merge_probability # Загружаем CSV if not os.path.exists(csv_path): raise ValueError(f"Ошибка: файл CSV не найден: {csv_path}") df = pd.read_csv(csv_path) if subset_size > 0: df = df.head(subset_size) logging.info(f"[DatasetMultiModal] Используем только первые {len(df)} записей (subset_size={subset_size}).") # Проверяем наличие всех колонок эмоций missing = [c for c in emotion_columns if c not in df.columns] if missing: raise ValueError(f"В CSV отсутствуют необходимые колонки эмоций: {missing}") # Проверяем существование папки с аудио if not os.path.exists(wav_dir): raise ValueError(f"Ошибка: директория с аудио {wav_dir} не существует!") self.wav_dir = wav_dir # Собираем список строк: для каждой записи получаем путь к аудио, label и CSV-текст (если есть) self.rows = [] for i, rowi in df.iterrows(): audio_path = os.path.join(wav_dir, f"{rowi['video_name']}.wav") if not os.path.exists(audio_path): continue # Определяем доминирующую эмоцию (максимальное значение) emotion_values = rowi[self.emotion_columns].values.astype(float) max_idx = np.argmax(emotion_values) emotion_label = self.emotion_columns[max_idx] # Извлекаем текст из CSV (если есть) csv_text = "" if self.text_column in rowi and isinstance(rowi[self.text_column], str): csv_text = rowi[self.text_column] self.rows.append({ "audio_path": audio_path, "label": emotion_label, "csv_text": csv_text }) # Создаем карту для поиска файлов по эмоции self.audio_class_map = {entry["audio_path"]: entry["label"] for entry in self.rows} logging.info("📊 Анализ распределения файлов по эмоциям:") emotion_counts = {emotion: 0 for emotion in set(self.audio_class_map.values())} for path, emotion in self.audio_class_map.items(): emotion_counts[emotion] += 1 for emotion, count in emotion_counts.items(): logging.info(f"🎭 Эмоция '{emotion}': {count} файлов.") logging.info(f"[DatasetMultiModal] Сплит={split}, всего строк: {len(self.rows)}") # === Процентное семплирование === total_files = len(self.rows) num_to_merge = int(total_files * self.merge_probability) # <<< NEW: Кешируем длины (eq_len) для всех файлов >>> self.path_info = {} for row in self.rows: p = row["audio_path"] try: info = torchaudio.info(p) length = info.num_frames sr_ = info.sample_rate # переводим длину в "эквивалент self.sample_rate" if sr_ != self.sample_rate: ratio = sr_ / self.sample_rate eq_len = int(length / ratio) else: eq_len = length self.path_info[p] = eq_len except Exception as e: logging.warning(f"⚠️ Ошибка чтения {p}: {e}") self.path_info[p] = 0 # Если не смогли прочитать, ставим 0 # Определим, какие файлы "короткие" (могут нуждаться в склейке) - используем кэш вместо старого _is_too_short self.mergable_files = [ row["audio_path"] # вместо целого dict берём строку for row in self.rows if self._is_too_short_cached(row["audio_path"]) # <<< теперь тут используем новую функцию ] short_count = len(self.mergable_files) # Если коротких файлов больше нужного числа, выберем случайные. Иначе все короткие. if short_count > num_to_merge: self.files_to_merge = set(random.sample(self.mergable_files, num_to_merge)) else: self.files_to_merge = set(self.mergable_files) logging.info(f"🔗 Всего файлов: {total_files}, нужно склеить: {num_to_merge} ({self.merge_probability*100:.0f}%)") logging.info(f"🔗 Коротких файлов: {short_count}, выбрано для склейки: {len(self.files_to_merge)}") # Инициализируем Whisper-модель один раз logging.info(f"Инициализация Whisper: модель={whisper_model}, устройство={whisper_device}") self.whisper_model = whisper.load_model(whisper_model, device=whisper_device).eval() # print(f"📦 Whisper работает на устройстве: {self.whisper_model.device}") def _is_too_short(self, audio_path): """ (Оригинальная) Проверяем, является ли файл короче target_samples. Использует torchaudio.info(audio_path). Но теперь этот метод не используется, поскольку мы кешируем длины. """ try: info = torchaudio.info(audio_path) length = info.num_frames sr_ = info.sample_rate # переводим длину в "эквивалент self.sample_rate" if sr_ != self.sample_rate: ratio = sr_ / self.sample_rate eq_len = int(length / ratio) else: eq_len = length return eq_len < self.target_samples except Exception as e: logging.warning(f"Ошибка _is_too_short({audio_path}): {e}") return False def _is_too_short_cached(self, audio_path): """ (Новая) Проверяем, является ли файл короче target_samples, используя закешированную длину в self.path_info. """ eq_len = self.path_info.get(audio_path, 0) return eq_len < self.target_samples def __len__(self): return len(self.rows) def __getitem__(self, index): """ Загружает и обрабатывает один элемент датасета (он‑the‑fly). """ row = self.rows[index] audio_path = row["audio_path"] label_name = row["label"] csv_text = row["csv_text"] # Преобразуем label в one-hot вектор label_vec = self.emotion_to_vector(label_name) # Шаг 1. Загружаем аудио waveform, sr = self.load_audio(audio_path) if waveform is None: return None orig_len = waveform.shape[1] logging.debug(f"Исходная длина {os.path.basename(audio_path)}: {orig_len/sr:.2f} сек") was_merged = False merged_texts = [csv_text] # Тексты исходного файла + добавленных # Шаг 2. Для train, если аудио короче target_samples, проверяем: # попал ли данный row в files_to_merge? if self.split == "train" and row["audio_path"] in self.files_to_merge: # chain merge current_length = orig_len used_candidates = set() while current_length < self.target_samples: needed = self.target_samples - current_length candidate = self.get_suitable_audio(label_name, exclude_path=audio_path, min_needed=needed, top_k=10) if candidate is None or candidate in used_candidates: break used_candidates.add(candidate) add_wf, add_sr = self.load_audio(candidate) if add_wf is None: break logging.debug(f"Склейка: добавляем {os.path.basename(candidate)} (необходимых сэмплов: {needed})") waveform = torch.cat((waveform, add_wf), dim=1) current_length = waveform.shape[1] was_merged = True # Получаем текст второго файла (если есть в CSV) add_csv_text = next((r["csv_text"] for r in self.rows if r["audio_path"] == candidate), "") merged_texts.append(add_csv_text) logging.debug(f"📜 Текст первого файла: {csv_text}") logging.debug(f"📜 Текст добавленного файла: {add_csv_text}") else: # Если файл не в списке "должны склеить" или сплит не train, пропускаем chain-merge logging.debug("Файл не выбран для склейки (или не train), пропускаем chain merge.") # Шаг 3. Если итоговая длина меньше target_samples, паддинг нулями curr_len = waveform.shape[1] if curr_len < self.target_samples: pad_size = self.target_samples - curr_len logging.debug(f"Паддинг {os.path.basename(audio_path)}: +{pad_size} сэмплов") waveform = torch.nn.functional.pad(waveform, (0, pad_size)) # Шаг 4. Обрезаем аудио до target_samples (если вышло больше) waveform = waveform[:, :self.target_samples] logging.debug(f"Финальная длина {os.path.basename(audio_path)}: {waveform.shape[1]/sr:.2f} сек; was_merged={was_merged}") # Шаг 5. Получаем текст if was_merged: logging.debug("📝 Текст: аудио было merged – вызываем Whisper.") text_final = self.run_whisper(waveform) logging.debug(f"🆕 Whisper предсказал: {text_final}") else: if csv_text.strip(): logging.debug("Текст: используем CSV-текст (не пуст).") text_final = csv_text else: if self.split == "train" or self.use_whisper_for_nontrain_if_no_text: logging.debug("Текст: CSV пустой – вызываем Whisper.") text_final = self.run_whisper(waveform) else: logging.debug("Текст: CSV пустой и не вызываем Whisper для dev/test.") text_final = "" return { "audio_path": os.path.basename(audio_path), # new "audio": waveform, "label": label_vec, "text": text_final } def load_audio(self, path): """ Загружает аудио по указанному пути и ресэмплирует его до self.sample_rate, если необходимо. """ if not os.path.exists(path): logging.warning(f"Файл отсутствует: {path}") return None, None try: wf, sr = torchaudio.load(path) if sr != self.sample_rate: resampler = torchaudio.transforms.Resample(sr, self.sample_rate) wf = resampler(wf) sr = self.sample_rate return wf, sr except Exception as e: logging.error(f"Ошибка загрузки {path}: {e}") return None, None def get_suitable_audio(self, label_name, exclude_path, min_needed, top_k=5): """ Ищет аудиофайл с той же эмоцией. 1) Если есть файлы >= min_needed, выбираем случайно из них. 2) Если таких нет, берём топ-K самых длинных, потом из них берём случайный. """ candidates = [p for p, lbl in self.audio_class_map.items() if lbl == label_name and p != exclude_path] logging.debug(f"🔍 Найдено {len(candidates)} кандидатов для эмоции '{label_name}'") # Сохраним: (eq_len, path) для всех кандидатов, но БЕЗ повторного чтения torchaudio.info all_info = [] for path in candidates: # <<< NEW: вместо info = torchaudio.info(path) ... eq_len = self.path_info.get(path, 0) # Получаем из кэша all_info.append((eq_len, path)) # --- Ниже старый код, который был: # for path in candidates: # try: # info = torchaudio.info(path) # length = info.num_frames # sr_ = info.sample_rate # eq_len = int(length / (sr_ / self.sample_rate)) if sr_ != self.sample_rate else length # all_info.append((eq_len, path)) # except Exception as e: # logging.warning(f"⚠ Ошибка чтения {path}: {e}") # 1) Фильтруем только >= min_needed valid = [(l, p) for l, p in all_info if l >= min_needed] logging.debug(f"✅ Подходящих (>= {min_needed}): {len(valid)} (из {len(all_info)})") if valid: # Если есть идеальные — берём случайно из них random.shuffle(valid) chosen = random.choice(valid)[1] return chosen else: # 2) Если идеальных нет — берём топ-K по длине sorted_by_len = sorted(all_info, key=lambda x: x[0], reverse=True) top_k_list = sorted_by_len[:top_k] if not top_k_list: logging.debug("Нет доступных кандидатов вообще.") return None # вообще нет кандидатов random.shuffle(top_k_list) chosen = top_k_list[0][1] logging.info(f"Из топ-{top_k} выбран кандидат: {chosen}") return chosen def run_whisper(self, waveform): """ Вызывает Whisper на аудиосигнале и возвращает полный текст (без ограничения по количеству слов). """ arr = waveform.squeeze().cpu().numpy() try: result = self.whisper_model.transcribe(arr, fp16=False) text = result["text"].strip() return text except Exception as e: logging.error(f"Whisper ошибка: {e}") return "" def emotion_to_vector(self, label_name): """ Преобразует название эмоции в one-hot вектор (torch.tensor). """ v = np.zeros(len(self.emotion_columns), dtype=np.float32) if label_name in self.emotion_columns: idx = self.emotion_columns.index(label_name) v[idx] = 1.0 return torch.tensor(v, dtype=torch.float32)