Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
import os | |
import random | |
import logging | |
import torch | |
import torchaudio | |
import whisper | |
import numpy as np | |
import pandas as pd | |
from torch.utils.data import Dataset | |
import pickle | |
from tqdm import tqdm | |
# from data_loading.feature_extractor import PretrainedAudioEmbeddingExtractor, PretrainedTextEmbeddingExtractor | |
class DatasetMultiModalWithPretrainedExtractors(Dataset): | |
""" | |
Мультимодальный датасет для аудио, текста и эмоций (он‑the‑fly версия). | |
При каждом вызове __getitem__: | |
- Загружает WAV по video_name из CSV. | |
- Для обучающей выборки (split="train"): | |
Если аудио короче target_samples, проверяем, выбрали ли мы этот файл для склейки | |
(по merge_probability). Если да – выполняется "chain merge": | |
выбирается один или несколько дополнительных файлов того же класса, даже если один кандидат длиннее, | |
и итоговое аудио затем обрезается до точной длины. | |
- Если итоговое аудио всё ещё меньше target_samples, выполняется паддинг нулями. | |
- Текст выбирается так: | |
• Если аудио было merged (склеено) – вызывается Whisper для получения нового текста. | |
• Если merge не происходило и CSV-текст не пуст – используется CSV-текст. | |
• Если CSV-текст пустой – для train (или, при условии, для dev/test) вызывается Whisper. | |
- Возвращает словарь { "audio": waveform, "label": label_vector, "text": text_final }. | |
""" | |
def __init__( | |
self, | |
csv_path, | |
wav_dir, | |
emotion_columns, | |
config, | |
split, | |
audio_feature_extractor, | |
text_feature_extractor, | |
whisper_model, | |
dataset_name | |
): | |
""" | |
:param csv_path: Путь к CSV-файлу (с колонками video_name, emotion_columns, возможно text). | |
:param wav_dir: Папка с аудиофайлами (имя файла: video_name.wav). | |
:param emotion_columns: Список колонок эмоций, например ["neutral", "happy", "sad", ...]. | |
:param split: "train", "dev" или "test". | |
:param audio_feature_extractor: Экстрактор аудио признаков | |
:param text_feature_extractor: Экстрактор текстовых признаков | |
:param sample_rate: Целевая частота дискретизации (например, 16000). | |
:param wav_length: Целевая длина аудио в секундах. | |
:param whisper_model: Mодель Whisper ("tiny", "base", "small", ...). | |
:param max_text_tokens: (Не используется) – ограничение на число токенов. | |
:param text_column: Название колонки с текстом в CSV. | |
:param use_whisper_for_nontrain_if_no_text: Если True, для dev/test при отсутствии CSV-текста вызывается Whisper. | |
:param whisper_device: "cuda" или "cpu" – устройство для модели Whisper. | |
:param subset_size: Если > 0, используется только первые N записей из CSV (для отладки). | |
:param merge_probability: Процент (0..1) от всего числа файлов, которые будут склеиваться, если они короче. | |
:param dataset_name: Название корпуса | |
""" | |
super().__init__() | |
self.split = split | |
self.sample_rate = config.sample_rate | |
self.target_samples = int(config.wav_length * self.sample_rate) | |
self.emotion_columns = emotion_columns | |
self.whisper_model = whisper_model | |
self.text_column = config.text_column | |
self.use_whisper_for_nontrain_if_no_text = config.use_whisper_for_nontrain_if_no_text | |
self.whisper_device = config.whisper_device | |
self.merge_probability = config.merge_probability | |
self.audio_feature_extractor = audio_feature_extractor | |
self.text_feature_extractor = text_feature_extractor | |
self.subset_size = config.subset_size | |
self.save_prepared_data = config.save_prepared_data | |
self.seed = config.random_seed | |
self.dataset_name = dataset_name | |
self.save_feature_path = config.save_feature_path | |
self.use_synthetic_data = config.use_synthetic_data | |
self.synthetic_path = config.synthetic_path | |
self.synthetic_ratio = config.synthetic_ratio | |
# Загружаем CSV | |
if not os.path.exists(csv_path): | |
raise ValueError(f"Ошибка: файл CSV не найден: {csv_path}") | |
df = pd.read_csv(csv_path) | |
if self.subset_size > 0: | |
df = df.head(self.subset_size) | |
logging.info(f"[DatasetMultiModal] Используем только первые {len(df)} записей (subset_size={self.subset_size}).") | |
#копия для сохранения текста Wisper | |
self.original_df = df.copy() | |
self.whisper_csv_update_log = [] | |
# Проверяем наличие всех колонок эмоций | |
missing = [c for c in emotion_columns if c not in df.columns] | |
if missing: | |
raise ValueError(f"В CSV отсутствуют необходимые колонки эмоций: {missing}") | |
# Проверяем существование папки с аудио | |
if not os.path.exists(wav_dir): | |
raise ValueError(f"Ошибка: директория с аудио {wav_dir} не существует!") | |
self.wav_dir = wav_dir | |
# Собираем список строк: для каждой записи получаем путь к аудио, label и CSV-текст (если есть) | |
self.rows = [] | |
for i, rowi in df.iterrows(): | |
audio_path = os.path.join(wav_dir, f"{rowi['video_name']}.wav") | |
if not os.path.exists(audio_path): | |
continue | |
# Определяем доминирующую эмоцию (максимальное значение) | |
# print(self.emotion_columns) | |
emotion_values = rowi[self.emotion_columns].values.astype(float) | |
max_idx = np.argmax(emotion_values) | |
emotion_label = self.emotion_columns[max_idx] | |
# Извлекаем текст из CSV (если есть) | |
csv_text = "" | |
if self.text_column in rowi and isinstance(rowi[self.text_column], str): | |
csv_text = rowi[self.text_column] | |
self.rows.append({ | |
"audio_path": audio_path, | |
"label": emotion_label, | |
"csv_text": csv_text | |
}) | |
if self.use_synthetic_data and self.split == "train" and self.dataset_name.lower() == "meld": | |
logging.info(f"🧪 Включена синтетика для датасета '{self.dataset_name}' — добавляем примеры из: {self.synthetic_path}") | |
self._add_synthetic_data(self.synthetic_ratio) | |
# Создаем карту для поиска файлов по эмоции | |
self.audio_class_map = {entry["audio_path"]: entry["label"] for entry in self.rows} | |
logging.info("📊 Анализ распределения файлов по эмоциям:") | |
emotion_counts = {emotion: 0 for emotion in set(self.audio_class_map.values())} | |
for path, emotion in self.audio_class_map.items(): | |
emotion_counts[emotion] += 1 | |
for emotion, count in emotion_counts.items(): | |
logging.info(f"🎭 Эмоция '{emotion}': {count} файлов.") | |
logging.info(f"[DatasetMultiModal] Сплит={split}, всего строк: {len(self.rows)}") | |
# === Процентное семплирование === | |
total_files = len(self.rows) | |
num_to_merge = int(total_files * self.merge_probability) | |
# <<< NEW: Кешируем длины (eq_len) для всех файлов >>> | |
self.path_info = {} | |
for row in self.rows: | |
p = row["audio_path"] | |
try: | |
info = torchaudio.info(p) | |
length = info.num_frames | |
sr_ = info.sample_rate | |
# переводим длину в "эквивалент self.sample_rate" | |
if sr_ != self.sample_rate: | |
ratio = sr_ / self.sample_rate | |
eq_len = int(length / ratio) | |
else: | |
eq_len = length | |
self.path_info[p] = eq_len | |
except Exception as e: | |
logging.warning(f"⚠️ Ошибка чтения {p}: {e}") | |
self.path_info[p] = 0 # Если не смогли прочитать, ставим 0 | |
# Определим, какие файлы "короткие" (могут нуждаться в склейке) - используем кэш вместо старого _is_too_short | |
self.mergable_files = [ | |
row["audio_path"] # вместо целого dict берём строку | |
for row in self.rows | |
if self._is_too_short_cached(row["audio_path"]) # <<< теперь тут используем новую функцию | |
] | |
short_count = len(self.mergable_files) | |
# Если коротких файлов больше нужного числа, выберем случайные. Иначе все короткие. | |
if short_count > num_to_merge: | |
self.files_to_merge = set(random.sample(self.mergable_files, num_to_merge)) | |
else: | |
self.files_to_merge = set(self.mergable_files) | |
logging.info(f"🔗 Всего файлов: {total_files}, нужно склеить: {num_to_merge} ({self.merge_probability*100:.0f}%)") | |
logging.info(f"🔗 Коротких файлов: {short_count}, выбрано для склейки: {len(self.files_to_merge)}") | |
if self.save_prepared_data: | |
self.meta = [] | |
if self.use_synthetic_data: | |
meta_filename = '{}_{}_seed_{}_subset_size_{}_audio_model_{}_feature_norm_{}_synthetic_true_pct_{}_pred.pickle'.format( | |
self.dataset_name, | |
self.split, | |
config.audio_classifier_checkpoint[-4:-3], | |
self.seed, | |
self.subset_size, | |
config.emb_normalize, | |
int(self.synthetic_ratio * 100) | |
) | |
else: | |
meta_filename = '{}_{}_seed_{}_subset_size_{}_audio_model_{}_feature_norm_{}_merge_prob_{}_pred.pickle'.format( | |
self.dataset_name, | |
self.split, | |
config.audio_classifier_checkpoint[-4:-3], | |
self.seed, | |
self.subset_size, | |
config.emb_normalize, | |
self.merge_probability | |
) | |
pickle_path = os.path.join(self.save_feature_path, meta_filename) | |
self.load_data(pickle_path) | |
if not self.meta: | |
self.prepare_data() | |
os.makedirs(self.save_feature_path, exist_ok=True) | |
self.save_data(pickle_path) | |
def save_data(self, filename): | |
with open(filename, 'wb') as handle: | |
pickle.dump(self.meta, handle, protocol=pickle.HIGHEST_PROTOCOL) | |
def load_data(self, filename): | |
if os.path.exists(filename): | |
with open(filename, 'rb') as handle: | |
self.meta = pickle.load(handle) | |
else: | |
self.meta = [] | |
def _is_too_short(self, audio_path): | |
""" | |
(Оригинальная) Проверяем, является ли файл короче target_samples. | |
Использует torchaudio.info(audio_path). | |
Но теперь этот метод не используется, поскольку мы кешируем длины. | |
""" | |
try: | |
info = torchaudio.info(audio_path) | |
length = info.num_frames | |
sr_ = info.sample_rate | |
# переводим длину в "эквивалент self.sample_rate" | |
if sr_ != self.sample_rate: | |
ratio = sr_ / self.sample_rate | |
eq_len = int(length / ratio) | |
else: | |
eq_len = length | |
return eq_len < self.target_samples | |
except Exception as e: | |
logging.warning(f"Ошибка _is_too_short({audio_path}): {e}") | |
return False | |
def _is_too_short_cached(self, audio_path): | |
""" | |
(Новая) Проверяем, является ли файл короче target_samples, используя закешированную длину в self.path_info. | |
""" | |
eq_len = self.path_info.get(audio_path, 0) | |
return eq_len < self.target_samples | |
def __len__(self): | |
if self.save_prepared_data: | |
return len(self.meta) | |
else: | |
return len(self.rows) | |
def get_data(self, row): | |
audio_path = row["audio_path"] | |
label_name = row["label"] | |
csv_text = row["csv_text"] | |
# Преобразуем label в one-hot вектор | |
label_vec = self.emotion_to_vector(label_name) | |
# Шаг 1. Загружаем аудио | |
waveform, sr = self.load_audio(audio_path) | |
if waveform is None: | |
return None | |
orig_len = waveform.shape[1] | |
logging.debug(f"Исходная длина {os.path.basename(audio_path)}: {orig_len/sr:.2f} сек") | |
was_merged = False | |
merged_texts = [csv_text] # Тексты исходного файла + добавленных | |
# Шаг 2. Для train, если аудио короче target_samples, проверяем: | |
# попал ли данный row в files_to_merge? | |
if self.split == "train" and row["audio_path"] in self.files_to_merge: | |
# chain merge | |
current_length = orig_len | |
used_candidates = set() | |
while current_length < self.target_samples: | |
needed = self.target_samples - current_length | |
candidate = self.get_suitable_audio(label_name, exclude_path=audio_path, min_needed=needed, top_k=10) | |
if candidate is None or candidate in used_candidates: | |
break | |
used_candidates.add(candidate) | |
add_wf, add_sr = self.load_audio(candidate) | |
if add_wf is None: | |
break | |
logging.debug(f"Склейка: добавляем {os.path.basename(candidate)} (необходимых сэмплов: {needed})") | |
waveform = torch.cat((waveform, add_wf), dim=1) | |
current_length = waveform.shape[1] | |
was_merged = True | |
# Получаем текст второго файла (если есть в CSV) | |
add_csv_text = next((r["csv_text"] for r in self.rows if r["audio_path"] == candidate), "") | |
merged_texts.append(add_csv_text) | |
logging.debug(f"📜 Текст первого файла: {csv_text}") | |
logging.debug(f"📜 Текст добавленного файла: {add_csv_text}") | |
else: | |
# Если файл не в списке "должны склеить" или сплит не train, пропускаем chain-merge | |
logging.debug("Файл не выбран для склейки (или не train), пропускаем chain merge.") | |
if was_merged: | |
logging.debug("📝 Текст: аудио было merged – вызываем Whisper.") | |
text_final = self.run_whisper(waveform) | |
logging.debug(f"🆕 Whisper предсказал: {text_final}") | |
merge_components = [os.path.splitext(os.path.basename(audio_path))[0]] | |
merge_components += [os.path.splitext(os.path.basename(p))[0] for p in used_candidates] | |
self.whisper_csv_update_log.append({ | |
"video_name": os.path.splitext(os.path.basename(audio_path))[0], | |
"text_new": text_final, | |
"text_old": csv_text, | |
"was_merged": True, | |
"merge_components": merge_components | |
}) | |
else: | |
if csv_text.strip(): | |
logging.debug("Текст: используем CSV-текст (не пуст).") | |
text_final = csv_text | |
else: | |
if self.split == "train" or self.use_whisper_for_nontrain_if_no_text: | |
logging.debug("Текст: CSV пустой – вызываем Whisper.") | |
text_final = self.run_whisper(waveform) | |
else: | |
logging.debug("Текст: CSV пустой и не вызываем Whisper для dev/test.") | |
text_final = "" | |
audio_pred, audion_emb = self.audio_feature_extractor.extract(waveform[0], self.sample_rate) | |
text_pred, text_emb = self.text_feature_extractor.extract(text_final) | |
return { | |
"audio_path": os.path.basename(audio_path), | |
"audio": audion_emb[0], | |
"label": label_vec, | |
"text": text_emb[0], | |
"audio_pred": audio_pred[0], | |
"text_pred": text_pred[0] | |
} | |
def prepare_data(self): | |
""" | |
Загружает и обрабатывает один элемент датасета, | |
сохраняет эмбеддинги и обновлённый текст (если было склеено). | |
""" | |
for idx, row in enumerate(tqdm(self.rows)): | |
curr_dict = self.get_data(row) | |
if curr_dict is not None: | |
self.meta.append(curr_dict) | |
# === Сохраняем CSV с обновлёнными текстами (только если был merge) === | |
if self.whisper_csv_update_log: | |
df_log = pd.DataFrame(self.whisper_csv_update_log) | |
# Копия исходного CSV | |
df_out = self.original_df.copy() | |
# Мержим по video_name | |
df_out = df_out.merge(df_log, on="video_name", how="left") | |
# Обновляем текст: заменяем только если Whisper сгенерировал | |
df_out["text_final"] = df_out["text_new"].combine_first(df_out["text"]) | |
df_out["text_old"] = df_out["text"] | |
df_out["text"] = df_out["text_final"] | |
df_out["was_merged"] = df_out["was_merged"].fillna(False).astype(bool) | |
# Преобразуем merge_components в строку | |
df_out["merge_components"] = df_out["merge_components"].apply( | |
lambda x: ";".join(x) if isinstance(x, list) else "" | |
) | |
# Чистим временные колонки | |
df_out = df_out.drop(columns=["text_new", "text_final"]) | |
# Сохраняем как CSV | |
output_path = os.path.join(self.save_feature_path, f"{self.dataset_name}_{self.split}_merged_whisper_{self.merge_probability *100}.csv") | |
os.makedirs(self.save_feature_path, exist_ok=True) | |
df_out.to_csv(output_path, index=False, encoding="utf-8") | |
logging.info(f"📄 Обновлённый merged CSV сохранён: {output_path}") | |
def __getitem__(self, index): | |
if self.save_prepared_data: | |
return self.meta[index] | |
else: | |
return self.get_data(self.rows[index]) | |
def load_audio(self, path): | |
""" | |
Загружает аудио по указанному пути и ресэмплирует его до self.sample_rate, если необходимо. | |
""" | |
if not os.path.exists(path): | |
logging.warning(f"Файл отсутствует: {path}") | |
return None, None | |
try: | |
wf, sr = torchaudio.load(path) | |
if sr != self.sample_rate: | |
resampler = torchaudio.transforms.Resample(sr, self.sample_rate) | |
wf = resampler(wf) | |
sr = self.sample_rate | |
return wf, sr | |
except Exception as e: | |
logging.error(f"Ошибка загрузки {path}: {e}") | |
return None, None | |
def get_suitable_audio(self, label_name, exclude_path, min_needed, top_k=5): | |
""" | |
Ищет аудиофайл с той же эмоцией. | |
1) Если есть файлы >= min_needed, выбираем случайно из них. | |
2) Если таких нет, берём топ-K самых длинных, потом из них берём случайный. | |
""" | |
candidates = [p for p, lbl in self.audio_class_map.items() | |
if lbl == label_name and p != exclude_path] | |
logging.debug(f"🔍 Найдено {len(candidates)} кандидатов для эмоции '{label_name}'") | |
# Сохраним: (eq_len, path) для всех кандидатов, но БЕЗ повторного чтения torchaudio.info | |
all_info = [] | |
for path in candidates: | |
# <<< NEW: вместо info = torchaudio.info(path) ... | |
eq_len = self.path_info.get(path, 0) # Получаем из кэша | |
all_info.append((eq_len, path)) | |
valid = [(l, p) for l, p in all_info if l >= min_needed] | |
logging.debug(f"✅ Подходящих (>= {min_needed}): {len(valid)} (из {len(all_info)})") | |
if valid: | |
# Если есть идеальные — берём случайно из них | |
random.shuffle(valid) | |
chosen = random.choice(valid)[1] | |
return chosen | |
else: | |
# 2) Если идеальных нет — берём топ-K по длине | |
sorted_by_len = sorted(all_info, key=lambda x: x[0], reverse=True) | |
top_k_list = sorted_by_len[:top_k] | |
if not top_k_list: | |
logging.debug("Нет доступных кандидатов вообще.") | |
return None # вообще нет кандидатов | |
random.shuffle(top_k_list) | |
chosen = top_k_list[0][1] | |
logging.info(f"Из топ-{top_k} выбран кандидат: {chosen}") | |
return chosen | |
def run_whisper(self, waveform): | |
""" | |
Вызывает Whisper на аудиосигнале и возвращает полный текст (без ограничения по количеству слов). | |
""" | |
arr = waveform.squeeze().cpu().numpy() | |
try: | |
with torch.no_grad(): | |
result = self.whisper_model.transcribe(arr, fp16=False) | |
text = result["text"].strip() | |
return text | |
except Exception as e: | |
logging.error(f"Whisper ошибка: {e}") | |
return "" | |
def _add_synthetic_data(self, synthetic_ratio): | |
""" | |
Добавляет synthetic_ratio (0..1) от количества доступных синтетических файлов на каждую эмоцию. | |
""" | |
if not self.synthetic_path: | |
logging.warning("⚠ Путь к синтетическим данным не указан.") | |
return | |
random.seed(self.seed) | |
synth_csv_path = os.path.join(self.synthetic_path, "meld_s_train_labels.csv") | |
synth_wav_dir = os.path.join(self.synthetic_path, "wavs") | |
if not (os.path.exists(synth_csv_path) and os.path.exists(synth_wav_dir)): | |
logging.warning("⚠ Синтетические данные не найдены.") | |
return | |
df_synth = pd.read_csv(synth_csv_path) | |
rows_by_label = {emotion: [] for emotion in self.emotion_columns} | |
for _, row in df_synth.iterrows(): | |
audio_path = os.path.join(synth_wav_dir, f"{row['video_name']}.wav") | |
if not os.path.exists(audio_path): | |
continue | |
emotion_values = row[self.emotion_columns].values.astype(float) | |
max_idx = np.argmax(emotion_values) | |
label = self.emotion_columns[max_idx] | |
csv_text = row[self.text_column] if self.text_column in row and isinstance(row[self.text_column], str) else "" | |
rows_by_label[label].append({ | |
"audio_path": audio_path, | |
"label": label, | |
"csv_text": csv_text | |
}) | |
added = 0 | |
for label in self.emotion_columns: | |
candidates = rows_by_label[label] | |
if not candidates: | |
continue | |
count_synth = int(len(candidates) * synthetic_ratio) | |
if count_synth <= 0: | |
continue | |
selected = random.sample(candidates, count_synth) | |
self.rows.extend(selected) | |
added += len(selected) | |
logging.info(f"➕ Добавлено {len(selected)} синтетических примеров для эмоции '{label}'") | |
logging.info(f"📦 Всего добавлено {added} синтетических примеров из MELD_S") | |
def emotion_to_vector(self, label_name): | |
""" | |
Преобразует название эмоции в one-hot вектор (torch.tensor). | |
""" | |
v = np.zeros(len(self.emotion_columns), dtype=np.float32) | |
if label_name in self.emotion_columns: | |
idx = self.emotion_columns.index(label_name) | |
v[idx] = 1.0 | |
return torch.tensor(v, dtype=torch.float32) | |
class DatasetMultiModal(Dataset): | |
""" | |
Мультимодальный датасет для аудио, текста и эмоций (он‑the‑fly версия). | |
При каждом вызове __getitem__: | |
- Загружает WAV по video_name из CSV. | |
- Для обучающей выборки (split="train"): | |
Если аудио короче target_samples, проверяем, выбрали ли мы этот файл для склейки | |
(по merge_probability). Если да – выполняется "chain merge": | |
выбирается один или несколько дополнительных файлов того же класса, даже если один кандидат длиннее, | |
и итоговое аудио затем обрезается до точной длины. | |
- Если итоговое аудио всё ещё меньше target_samples, выполняется паддинг нулями. | |
- Текст выбирается так: | |
• Если аудио было merged (склеено) – вызывается Whisper для получения нового текста. | |
• Если merge не происходило и CSV-текст не пуст – используется CSV-текст. | |
• Если CSV-текст пустой – для train (или, при условии, для dev/test) вызывается Whisper. | |
- Возвращает словарь { "audio": waveform, "label": label_vector, "text": text_final }. | |
""" | |
def __init__( | |
self, | |
csv_path, | |
wav_dir, | |
emotion_columns, | |
split="train", | |
sample_rate=16000, | |
wav_length=4, | |
whisper_model="tiny", | |
text_column="text", | |
use_whisper_for_nontrain_if_no_text=True, | |
whisper_device="cuda", | |
subset_size=0, | |
merge_probability=1.0 # <-- Новый параметр: доля от ОБЩЕГО числа файлов | |
): | |
""" | |
:param csv_path: Путь к CSV-файлу (с колонками video_name, emotion_columns, возможно text). | |
:param wav_dir: Папка с аудиофайлами (имя файла: video_name.wav). | |
:param emotion_columns: Список колонок эмоций, например ["neutral", "happy", "sad", ...]. | |
:param split: "train", "dev" или "test". | |
:param sample_rate: Целевая частота дискретизации (например, 16000). | |
:param wav_length: Целевая длина аудио в секундах. | |
:param whisper_model: Название модели Whisper ("tiny", "base", "small", ...). | |
:param max_text_tokens: (Не используется) – ограничение на число токенов. | |
:param text_column: Название колонки с текстом в CSV. | |
:param use_whisper_for_nontrain_if_no_text: Если True, для dev/test при отсутствии CSV-текста вызывается Whisper. | |
:param whisper_device: "cuda" или "cpu" – устройство для модели Whisper. | |
:param subset_size: Если > 0, используется только первые N записей из CSV (для отладки). | |
:param merge_probability: Процент (0..1) от всего числа файлов, которые будут склеиваться, если они короче. | |
""" | |
super().__init__() | |
self.split = split | |
self.sample_rate = sample_rate | |
self.target_samples = int(wav_length * sample_rate) | |
self.emotion_columns = emotion_columns | |
self.whisper_model_name = whisper_model | |
self.text_column = text_column | |
self.use_whisper_for_nontrain_if_no_text = use_whisper_for_nontrain_if_no_text | |
self.whisper_device = whisper_device | |
self.merge_probability = merge_probability | |
# Загружаем CSV | |
if not os.path.exists(csv_path): | |
raise ValueError(f"Ошибка: файл CSV не найден: {csv_path}") | |
df = pd.read_csv(csv_path) | |
if subset_size > 0: | |
df = df.head(subset_size) | |
logging.info(f"[DatasetMultiModal] Используем только первые {len(df)} записей (subset_size={subset_size}).") | |
# Проверяем наличие всех колонок эмоций | |
missing = [c for c in emotion_columns if c not in df.columns] | |
if missing: | |
raise ValueError(f"В CSV отсутствуют необходимые колонки эмоций: {missing}") | |
# Проверяем существование папки с аудио | |
if not os.path.exists(wav_dir): | |
raise ValueError(f"Ошибка: директория с аудио {wav_dir} не существует!") | |
self.wav_dir = wav_dir | |
# Собираем список строк: для каждой записи получаем путь к аудио, label и CSV-текст (если есть) | |
self.rows = [] | |
for i, rowi in df.iterrows(): | |
audio_path = os.path.join(wav_dir, f"{rowi['video_name']}.wav") | |
if not os.path.exists(audio_path): | |
continue | |
# Определяем доминирующую эмоцию (максимальное значение) | |
emotion_values = rowi[self.emotion_columns].values.astype(float) | |
max_idx = np.argmax(emotion_values) | |
emotion_label = self.emotion_columns[max_idx] | |
# Извлекаем текст из CSV (если есть) | |
csv_text = "" | |
if self.text_column in rowi and isinstance(rowi[self.text_column], str): | |
csv_text = rowi[self.text_column] | |
self.rows.append({ | |
"audio_path": audio_path, | |
"label": emotion_label, | |
"csv_text": csv_text | |
}) | |
# Создаем карту для поиска файлов по эмоции | |
self.audio_class_map = {entry["audio_path"]: entry["label"] for entry in self.rows} | |
logging.info("📊 Анализ распределения файлов по эмоциям:") | |
emotion_counts = {emotion: 0 for emotion in set(self.audio_class_map.values())} | |
for path, emotion in self.audio_class_map.items(): | |
emotion_counts[emotion] += 1 | |
for emotion, count in emotion_counts.items(): | |
logging.info(f"🎭 Эмоция '{emotion}': {count} файлов.") | |
logging.info(f"[DatasetMultiModal] Сплит={split}, всего строк: {len(self.rows)}") | |
# === Процентное семплирование === | |
total_files = len(self.rows) | |
num_to_merge = int(total_files * self.merge_probability) | |
# <<< NEW: Кешируем длины (eq_len) для всех файлов >>> | |
self.path_info = {} | |
for row in self.rows: | |
p = row["audio_path"] | |
try: | |
info = torchaudio.info(p) | |
length = info.num_frames | |
sr_ = info.sample_rate | |
# переводим длину в "эквивалент self.sample_rate" | |
if sr_ != self.sample_rate: | |
ratio = sr_ / self.sample_rate | |
eq_len = int(length / ratio) | |
else: | |
eq_len = length | |
self.path_info[p] = eq_len | |
except Exception as e: | |
logging.warning(f"⚠️ Ошибка чтения {p}: {e}") | |
self.path_info[p] = 0 # Если не смогли прочитать, ставим 0 | |
# Определим, какие файлы "короткие" (могут нуждаться в склейке) - используем кэш вместо старого _is_too_short | |
self.mergable_files = [ | |
row["audio_path"] # вместо целого dict берём строку | |
for row in self.rows | |
if self._is_too_short_cached(row["audio_path"]) # <<< теперь тут используем новую функцию | |
] | |
short_count = len(self.mergable_files) | |
# Если коротких файлов больше нужного числа, выберем случайные. Иначе все короткие. | |
if short_count > num_to_merge: | |
self.files_to_merge = set(random.sample(self.mergable_files, num_to_merge)) | |
else: | |
self.files_to_merge = set(self.mergable_files) | |
logging.info(f"🔗 Всего файлов: {total_files}, нужно склеить: {num_to_merge} ({self.merge_probability*100:.0f}%)") | |
logging.info(f"🔗 Коротких файлов: {short_count}, выбрано для склейки: {len(self.files_to_merge)}") | |
# Инициализируем Whisper-модель один раз | |
logging.info(f"Инициализация Whisper: модель={whisper_model}, устройство={whisper_device}") | |
self.whisper_model = whisper.load_model(whisper_model, device=whisper_device).eval() | |
# print(f"📦 Whisper работает на устройстве: {self.whisper_model.device}") | |
def _is_too_short(self, audio_path): | |
""" | |
(Оригинальная) Проверяем, является ли файл короче target_samples. | |
Использует torchaudio.info(audio_path). | |
Но теперь этот метод не используется, поскольку мы кешируем длины. | |
""" | |
try: | |
info = torchaudio.info(audio_path) | |
length = info.num_frames | |
sr_ = info.sample_rate | |
# переводим длину в "эквивалент self.sample_rate" | |
if sr_ != self.sample_rate: | |
ratio = sr_ / self.sample_rate | |
eq_len = int(length / ratio) | |
else: | |
eq_len = length | |
return eq_len < self.target_samples | |
except Exception as e: | |
logging.warning(f"Ошибка _is_too_short({audio_path}): {e}") | |
return False | |
def _is_too_short_cached(self, audio_path): | |
""" | |
(Новая) Проверяем, является ли файл короче target_samples, используя закешированную длину в self.path_info. | |
""" | |
eq_len = self.path_info.get(audio_path, 0) | |
return eq_len < self.target_samples | |
def __len__(self): | |
return len(self.rows) | |
def __getitem__(self, index): | |
""" | |
Загружает и обрабатывает один элемент датасета (он‑the‑fly). | |
""" | |
row = self.rows[index] | |
audio_path = row["audio_path"] | |
label_name = row["label"] | |
csv_text = row["csv_text"] | |
# Преобразуем label в one-hot вектор | |
label_vec = self.emotion_to_vector(label_name) | |
# Шаг 1. Загружаем аудио | |
waveform, sr = self.load_audio(audio_path) | |
if waveform is None: | |
return None | |
orig_len = waveform.shape[1] | |
logging.debug(f"Исходная длина {os.path.basename(audio_path)}: {orig_len/sr:.2f} сек") | |
was_merged = False | |
merged_texts = [csv_text] # Тексты исходного файла + добавленных | |
# Шаг 2. Для train, если аудио короче target_samples, проверяем: | |
# попал ли данный row в files_to_merge? | |
if self.split == "train" and row["audio_path"] in self.files_to_merge: | |
# chain merge | |
current_length = orig_len | |
used_candidates = set() | |
while current_length < self.target_samples: | |
needed = self.target_samples - current_length | |
candidate = self.get_suitable_audio(label_name, exclude_path=audio_path, min_needed=needed, top_k=10) | |
if candidate is None or candidate in used_candidates: | |
break | |
used_candidates.add(candidate) | |
add_wf, add_sr = self.load_audio(candidate) | |
if add_wf is None: | |
break | |
logging.debug(f"Склейка: добавляем {os.path.basename(candidate)} (необходимых сэмплов: {needed})") | |
waveform = torch.cat((waveform, add_wf), dim=1) | |
current_length = waveform.shape[1] | |
was_merged = True | |
# Получаем текст второго файла (если есть в CSV) | |
add_csv_text = next((r["csv_text"] for r in self.rows if r["audio_path"] == candidate), "") | |
merged_texts.append(add_csv_text) | |
logging.debug(f"📜 Текст первого файла: {csv_text}") | |
logging.debug(f"📜 Текст добавленного файла: {add_csv_text}") | |
else: | |
# Если файл не в списке "должны склеить" или сплит не train, пропускаем chain-merge | |
logging.debug("Файл не выбран для склейки (или не train), пропускаем chain merge.") | |
# Шаг 3. Если итоговая длина меньше target_samples, паддинг нулями | |
curr_len = waveform.shape[1] | |
if curr_len < self.target_samples: | |
pad_size = self.target_samples - curr_len | |
logging.debug(f"Паддинг {os.path.basename(audio_path)}: +{pad_size} сэмплов") | |
waveform = torch.nn.functional.pad(waveform, (0, pad_size)) | |
# Шаг 4. Обрезаем аудио до target_samples (если вышло больше) | |
waveform = waveform[:, :self.target_samples] | |
logging.debug(f"Финальная длина {os.path.basename(audio_path)}: {waveform.shape[1]/sr:.2f} сек; was_merged={was_merged}") | |
# Шаг 5. Получаем текст | |
if was_merged: | |
logging.debug("📝 Текст: аудио было merged – вызываем Whisper.") | |
text_final = self.run_whisper(waveform) | |
logging.debug(f"🆕 Whisper предсказал: {text_final}") | |
else: | |
if csv_text.strip(): | |
logging.debug("Текст: используем CSV-текст (не пуст).") | |
text_final = csv_text | |
else: | |
if self.split == "train" or self.use_whisper_for_nontrain_if_no_text: | |
logging.debug("Текст: CSV пустой – вызываем Whisper.") | |
text_final = self.run_whisper(waveform) | |
else: | |
logging.debug("Текст: CSV пустой и не вызываем Whisper для dev/test.") | |
text_final = "" | |
return { | |
"audio_path": os.path.basename(audio_path), # new | |
"audio": waveform, | |
"label": label_vec, | |
"text": text_final | |
} | |
def load_audio(self, path): | |
""" | |
Загружает аудио по указанному пути и ресэмплирует его до self.sample_rate, если необходимо. | |
""" | |
if not os.path.exists(path): | |
logging.warning(f"Файл отсутствует: {path}") | |
return None, None | |
try: | |
wf, sr = torchaudio.load(path) | |
if sr != self.sample_rate: | |
resampler = torchaudio.transforms.Resample(sr, self.sample_rate) | |
wf = resampler(wf) | |
sr = self.sample_rate | |
return wf, sr | |
except Exception as e: | |
logging.error(f"Ошибка загрузки {path}: {e}") | |
return None, None | |
def get_suitable_audio(self, label_name, exclude_path, min_needed, top_k=5): | |
""" | |
Ищет аудиофайл с той же эмоцией. | |
1) Если есть файлы >= min_needed, выбираем случайно из них. | |
2) Если таких нет, берём топ-K самых длинных, потом из них берём случайный. | |
""" | |
candidates = [p for p, lbl in self.audio_class_map.items() | |
if lbl == label_name and p != exclude_path] | |
logging.debug(f"🔍 Найдено {len(candidates)} кандидатов для эмоции '{label_name}'") | |
# Сохраним: (eq_len, path) для всех кандидатов, но БЕЗ повторного чтения torchaudio.info | |
all_info = [] | |
for path in candidates: | |
# <<< NEW: вместо info = torchaudio.info(path) ... | |
eq_len = self.path_info.get(path, 0) # Получаем из кэша | |
all_info.append((eq_len, path)) | |
# --- Ниже старый код, который был: | |
# for path in candidates: | |
# try: | |
# info = torchaudio.info(path) | |
# length = info.num_frames | |
# sr_ = info.sample_rate | |
# eq_len = int(length / (sr_ / self.sample_rate)) if sr_ != self.sample_rate else length | |
# all_info.append((eq_len, path)) | |
# except Exception as e: | |
# logging.warning(f"⚠ Ошибка чтения {path}: {e}") | |
# 1) Фильтруем только >= min_needed | |
valid = [(l, p) for l, p in all_info if l >= min_needed] | |
logging.debug(f"✅ Подходящих (>= {min_needed}): {len(valid)} (из {len(all_info)})") | |
if valid: | |
# Если есть идеальные — берём случайно из них | |
random.shuffle(valid) | |
chosen = random.choice(valid)[1] | |
return chosen | |
else: | |
# 2) Если идеальных нет — берём топ-K по длине | |
sorted_by_len = sorted(all_info, key=lambda x: x[0], reverse=True) | |
top_k_list = sorted_by_len[:top_k] | |
if not top_k_list: | |
logging.debug("Нет доступных кандидатов вообще.") | |
return None # вообще нет кандидатов | |
random.shuffle(top_k_list) | |
chosen = top_k_list[0][1] | |
logging.info(f"Из топ-{top_k} выбран кандидат: {chosen}") | |
return chosen | |
def run_whisper(self, waveform): | |
""" | |
Вызывает Whisper на аудиосигнале и возвращает полный текст (без ограничения по количеству слов). | |
""" | |
arr = waveform.squeeze().cpu().numpy() | |
try: | |
result = self.whisper_model.transcribe(arr, fp16=False) | |
text = result["text"].strip() | |
return text | |
except Exception as e: | |
logging.error(f"Whisper ошибка: {e}") | |
return "" | |
def emotion_to_vector(self, label_name): | |
""" | |
Преобразует название эмоции в one-hot вектор (torch.tensor). | |
""" | |
v = np.zeros(len(self.emotion_columns), dtype=np.float32) | |
if label_name in self.emotion_columns: | |
idx = self.emotion_columns.index(label_name) | |
v[idx] = 1.0 | |
return torch.tensor(v, dtype=torch.float32) | |