BiBiER / data_loading /dataset_multimodal.py
farbverlauf's picture
gpu
960b1a0
# -*- coding: utf-8 -*-
import os
import random
import logging
import torch
import torchaudio
import whisper
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
import pickle
from tqdm import tqdm
# from data_loading.feature_extractor import PretrainedAudioEmbeddingExtractor, PretrainedTextEmbeddingExtractor
class DatasetMultiModalWithPretrainedExtractors(Dataset):
"""
Мультимодальный датасет для аудио, текста и эмоций (он‑the‑fly версия).
При каждом вызове __getitem__:
- Загружает WAV по video_name из CSV.
- Для обучающей выборки (split="train"):
Если аудио короче target_samples, проверяем, выбрали ли мы этот файл для склейки
(по merge_probability). Если да – выполняется "chain merge":
выбирается один или несколько дополнительных файлов того же класса, даже если один кандидат длиннее,
и итоговое аудио затем обрезается до точной длины.
- Если итоговое аудио всё ещё меньше target_samples, выполняется паддинг нулями.
- Текст выбирается так:
• Если аудио было merged (склеено) – вызывается Whisper для получения нового текста.
• Если merge не происходило и CSV-текст не пуст – используется CSV-текст.
• Если CSV-текст пустой – для train (или, при условии, для dev/test) вызывается Whisper.
- Возвращает словарь { "audio": waveform, "label": label_vector, "text": text_final }.
"""
def __init__(
self,
csv_path,
wav_dir,
emotion_columns,
config,
split,
audio_feature_extractor,
text_feature_extractor,
whisper_model,
dataset_name
):
"""
:param csv_path: Путь к CSV-файлу (с колонками video_name, emotion_columns, возможно text).
:param wav_dir: Папка с аудиофайлами (имя файла: video_name.wav).
:param emotion_columns: Список колонок эмоций, например ["neutral", "happy", "sad", ...].
:param split: "train", "dev" или "test".
:param audio_feature_extractor: Экстрактор аудио признаков
:param text_feature_extractor: Экстрактор текстовых признаков
:param sample_rate: Целевая частота дискретизации (например, 16000).
:param wav_length: Целевая длина аудио в секундах.
:param whisper_model: Mодель Whisper ("tiny", "base", "small", ...).
:param max_text_tokens: (Не используется) – ограничение на число токенов.
:param text_column: Название колонки с текстом в CSV.
:param use_whisper_for_nontrain_if_no_text: Если True, для dev/test при отсутствии CSV-текста вызывается Whisper.
:param whisper_device: "cuda" или "cpu" – устройство для модели Whisper.
:param subset_size: Если > 0, используется только первые N записей из CSV (для отладки).
:param merge_probability: Процент (0..1) от всего числа файлов, которые будут склеиваться, если они короче.
:param dataset_name: Название корпуса
"""
super().__init__()
self.split = split
self.sample_rate = config.sample_rate
self.target_samples = int(config.wav_length * self.sample_rate)
self.emotion_columns = emotion_columns
self.whisper_model = whisper_model
self.text_column = config.text_column
self.use_whisper_for_nontrain_if_no_text = config.use_whisper_for_nontrain_if_no_text
self.whisper_device = config.whisper_device
self.merge_probability = config.merge_probability
self.audio_feature_extractor = audio_feature_extractor
self.text_feature_extractor = text_feature_extractor
self.subset_size = config.subset_size
self.save_prepared_data = config.save_prepared_data
self.seed = config.random_seed
self.dataset_name = dataset_name
self.save_feature_path = config.save_feature_path
self.use_synthetic_data = config.use_synthetic_data
self.synthetic_path = config.synthetic_path
self.synthetic_ratio = config.synthetic_ratio
# Загружаем CSV
if not os.path.exists(csv_path):
raise ValueError(f"Ошибка: файл CSV не найден: {csv_path}")
df = pd.read_csv(csv_path)
if self.subset_size > 0:
df = df.head(self.subset_size)
logging.info(f"[DatasetMultiModal] Используем только первые {len(df)} записей (subset_size={self.subset_size}).")
#копия для сохранения текста Wisper
self.original_df = df.copy()
self.whisper_csv_update_log = []
# Проверяем наличие всех колонок эмоций
missing = [c for c in emotion_columns if c not in df.columns]
if missing:
raise ValueError(f"В CSV отсутствуют необходимые колонки эмоций: {missing}")
# Проверяем существование папки с аудио
if not os.path.exists(wav_dir):
raise ValueError(f"Ошибка: директория с аудио {wav_dir} не существует!")
self.wav_dir = wav_dir
# Собираем список строк: для каждой записи получаем путь к аудио, label и CSV-текст (если есть)
self.rows = []
for i, rowi in df.iterrows():
audio_path = os.path.join(wav_dir, f"{rowi['video_name']}.wav")
if not os.path.exists(audio_path):
continue
# Определяем доминирующую эмоцию (максимальное значение)
# print(self.emotion_columns)
emotion_values = rowi[self.emotion_columns].values.astype(float)
max_idx = np.argmax(emotion_values)
emotion_label = self.emotion_columns[max_idx]
# Извлекаем текст из CSV (если есть)
csv_text = ""
if self.text_column in rowi and isinstance(rowi[self.text_column], str):
csv_text = rowi[self.text_column]
self.rows.append({
"audio_path": audio_path,
"label": emotion_label,
"csv_text": csv_text
})
if self.use_synthetic_data and self.split == "train" and self.dataset_name.lower() == "meld":
logging.info(f"🧪 Включена синтетика для датасета '{self.dataset_name}' — добавляем примеры из: {self.synthetic_path}")
self._add_synthetic_data(self.synthetic_ratio)
# Создаем карту для поиска файлов по эмоции
self.audio_class_map = {entry["audio_path"]: entry["label"] for entry in self.rows}
logging.info("📊 Анализ распределения файлов по эмоциям:")
emotion_counts = {emotion: 0 for emotion in set(self.audio_class_map.values())}
for path, emotion in self.audio_class_map.items():
emotion_counts[emotion] += 1
for emotion, count in emotion_counts.items():
logging.info(f"🎭 Эмоция '{emotion}': {count} файлов.")
logging.info(f"[DatasetMultiModal] Сплит={split}, всего строк: {len(self.rows)}")
# === Процентное семплирование ===
total_files = len(self.rows)
num_to_merge = int(total_files * self.merge_probability)
# <<< NEW: Кешируем длины (eq_len) для всех файлов >>>
self.path_info = {}
for row in self.rows:
p = row["audio_path"]
try:
info = torchaudio.info(p)
length = info.num_frames
sr_ = info.sample_rate
# переводим длину в "эквивалент self.sample_rate"
if sr_ != self.sample_rate:
ratio = sr_ / self.sample_rate
eq_len = int(length / ratio)
else:
eq_len = length
self.path_info[p] = eq_len
except Exception as e:
logging.warning(f"⚠️ Ошибка чтения {p}: {e}")
self.path_info[p] = 0 # Если не смогли прочитать, ставим 0
# Определим, какие файлы "короткие" (могут нуждаться в склейке) - используем кэш вместо старого _is_too_short
self.mergable_files = [
row["audio_path"] # вместо целого dict берём строку
for row in self.rows
if self._is_too_short_cached(row["audio_path"]) # <<< теперь тут используем новую функцию
]
short_count = len(self.mergable_files)
# Если коротких файлов больше нужного числа, выберем случайные. Иначе все короткие.
if short_count > num_to_merge:
self.files_to_merge = set(random.sample(self.mergable_files, num_to_merge))
else:
self.files_to_merge = set(self.mergable_files)
logging.info(f"🔗 Всего файлов: {total_files}, нужно склеить: {num_to_merge} ({self.merge_probability*100:.0f}%)")
logging.info(f"🔗 Коротких файлов: {short_count}, выбрано для склейки: {len(self.files_to_merge)}")
if self.save_prepared_data:
self.meta = []
if self.use_synthetic_data:
meta_filename = '{}_{}_seed_{}_subset_size_{}_audio_model_{}_feature_norm_{}_synthetic_true_pct_{}_pred.pickle'.format(
self.dataset_name,
self.split,
config.audio_classifier_checkpoint[-4:-3],
self.seed,
self.subset_size,
config.emb_normalize,
int(self.synthetic_ratio * 100)
)
else:
meta_filename = '{}_{}_seed_{}_subset_size_{}_audio_model_{}_feature_norm_{}_merge_prob_{}_pred.pickle'.format(
self.dataset_name,
self.split,
config.audio_classifier_checkpoint[-4:-3],
self.seed,
self.subset_size,
config.emb_normalize,
self.merge_probability
)
pickle_path = os.path.join(self.save_feature_path, meta_filename)
self.load_data(pickle_path)
if not self.meta:
self.prepare_data()
os.makedirs(self.save_feature_path, exist_ok=True)
self.save_data(pickle_path)
def save_data(self, filename):
with open(filename, 'wb') as handle:
pickle.dump(self.meta, handle, protocol=pickle.HIGHEST_PROTOCOL)
def load_data(self, filename):
if os.path.exists(filename):
with open(filename, 'rb') as handle:
self.meta = pickle.load(handle)
else:
self.meta = []
def _is_too_short(self, audio_path):
"""
(Оригинальная) Проверяем, является ли файл короче target_samples.
Использует torchaudio.info(audio_path).
Но теперь этот метод не используется, поскольку мы кешируем длины.
"""
try:
info = torchaudio.info(audio_path)
length = info.num_frames
sr_ = info.sample_rate
# переводим длину в "эквивалент self.sample_rate"
if sr_ != self.sample_rate:
ratio = sr_ / self.sample_rate
eq_len = int(length / ratio)
else:
eq_len = length
return eq_len < self.target_samples
except Exception as e:
logging.warning(f"Ошибка _is_too_short({audio_path}): {e}")
return False
def _is_too_short_cached(self, audio_path):
"""
(Новая) Проверяем, является ли файл короче target_samples, используя закешированную длину в self.path_info.
"""
eq_len = self.path_info.get(audio_path, 0)
return eq_len < self.target_samples
def __len__(self):
if self.save_prepared_data:
return len(self.meta)
else:
return len(self.rows)
def get_data(self, row):
audio_path = row["audio_path"]
label_name = row["label"]
csv_text = row["csv_text"]
# Преобразуем label в one-hot вектор
label_vec = self.emotion_to_vector(label_name)
# Шаг 1. Загружаем аудио
waveform, sr = self.load_audio(audio_path)
if waveform is None:
return None
orig_len = waveform.shape[1]
logging.debug(f"Исходная длина {os.path.basename(audio_path)}: {orig_len/sr:.2f} сек")
was_merged = False
merged_texts = [csv_text] # Тексты исходного файла + добавленных
# Шаг 2. Для train, если аудио короче target_samples, проверяем:
# попал ли данный row в files_to_merge?
if self.split == "train" and row["audio_path"] in self.files_to_merge:
# chain merge
current_length = orig_len
used_candidates = set()
while current_length < self.target_samples:
needed = self.target_samples - current_length
candidate = self.get_suitable_audio(label_name, exclude_path=audio_path, min_needed=needed, top_k=10)
if candidate is None or candidate in used_candidates:
break
used_candidates.add(candidate)
add_wf, add_sr = self.load_audio(candidate)
if add_wf is None:
break
logging.debug(f"Склейка: добавляем {os.path.basename(candidate)} (необходимых сэмплов: {needed})")
waveform = torch.cat((waveform, add_wf), dim=1)
current_length = waveform.shape[1]
was_merged = True
# Получаем текст второго файла (если есть в CSV)
add_csv_text = next((r["csv_text"] for r in self.rows if r["audio_path"] == candidate), "")
merged_texts.append(add_csv_text)
logging.debug(f"📜 Текст первого файла: {csv_text}")
logging.debug(f"📜 Текст добавленного файла: {add_csv_text}")
else:
# Если файл не в списке "должны склеить" или сплит не train, пропускаем chain-merge
logging.debug("Файл не выбран для склейки (или не train), пропускаем chain merge.")
if was_merged:
logging.debug("📝 Текст: аудио было merged – вызываем Whisper.")
text_final = self.run_whisper(waveform)
logging.debug(f"🆕 Whisper предсказал: {text_final}")
merge_components = [os.path.splitext(os.path.basename(audio_path))[0]]
merge_components += [os.path.splitext(os.path.basename(p))[0] for p in used_candidates]
self.whisper_csv_update_log.append({
"video_name": os.path.splitext(os.path.basename(audio_path))[0],
"text_new": text_final,
"text_old": csv_text,
"was_merged": True,
"merge_components": merge_components
})
else:
if csv_text.strip():
logging.debug("Текст: используем CSV-текст (не пуст).")
text_final = csv_text
else:
if self.split == "train" or self.use_whisper_for_nontrain_if_no_text:
logging.debug("Текст: CSV пустой – вызываем Whisper.")
text_final = self.run_whisper(waveform)
else:
logging.debug("Текст: CSV пустой и не вызываем Whisper для dev/test.")
text_final = ""
audio_pred, audion_emb = self.audio_feature_extractor.extract(waveform[0], self.sample_rate)
text_pred, text_emb = self.text_feature_extractor.extract(text_final)
return {
"audio_path": os.path.basename(audio_path),
"audio": audion_emb[0],
"label": label_vec,
"text": text_emb[0],
"audio_pred": audio_pred[0],
"text_pred": text_pred[0]
}
def prepare_data(self):
"""
Загружает и обрабатывает один элемент датасета,
сохраняет эмбеддинги и обновлённый текст (если было склеено).
"""
for idx, row in enumerate(tqdm(self.rows)):
curr_dict = self.get_data(row)
if curr_dict is not None:
self.meta.append(curr_dict)
# === Сохраняем CSV с обновлёнными текстами (только если был merge) ===
if self.whisper_csv_update_log:
df_log = pd.DataFrame(self.whisper_csv_update_log)
# Копия исходного CSV
df_out = self.original_df.copy()
# Мержим по video_name
df_out = df_out.merge(df_log, on="video_name", how="left")
# Обновляем текст: заменяем только если Whisper сгенерировал
df_out["text_final"] = df_out["text_new"].combine_first(df_out["text"])
df_out["text_old"] = df_out["text"]
df_out["text"] = df_out["text_final"]
df_out["was_merged"] = df_out["was_merged"].fillna(False).astype(bool)
# Преобразуем merge_components в строку
df_out["merge_components"] = df_out["merge_components"].apply(
lambda x: ";".join(x) if isinstance(x, list) else ""
)
# Чистим временные колонки
df_out = df_out.drop(columns=["text_new", "text_final"])
# Сохраняем как CSV
output_path = os.path.join(self.save_feature_path, f"{self.dataset_name}_{self.split}_merged_whisper_{self.merge_probability *100}.csv")
os.makedirs(self.save_feature_path, exist_ok=True)
df_out.to_csv(output_path, index=False, encoding="utf-8")
logging.info(f"📄 Обновлённый merged CSV сохранён: {output_path}")
def __getitem__(self, index):
if self.save_prepared_data:
return self.meta[index]
else:
return self.get_data(self.rows[index])
def load_audio(self, path):
"""
Загружает аудио по указанному пути и ресэмплирует его до self.sample_rate, если необходимо.
"""
if not os.path.exists(path):
logging.warning(f"Файл отсутствует: {path}")
return None, None
try:
wf, sr = torchaudio.load(path)
if sr != self.sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
wf = resampler(wf)
sr = self.sample_rate
return wf, sr
except Exception as e:
logging.error(f"Ошибка загрузки {path}: {e}")
return None, None
def get_suitable_audio(self, label_name, exclude_path, min_needed, top_k=5):
"""
Ищет аудиофайл с той же эмоцией.
1) Если есть файлы >= min_needed, выбираем случайно из них.
2) Если таких нет, берём топ-K самых длинных, потом из них берём случайный.
"""
candidates = [p for p, lbl in self.audio_class_map.items()
if lbl == label_name and p != exclude_path]
logging.debug(f"🔍 Найдено {len(candidates)} кандидатов для эмоции '{label_name}'")
# Сохраним: (eq_len, path) для всех кандидатов, но БЕЗ повторного чтения torchaudio.info
all_info = []
for path in candidates:
# <<< NEW: вместо info = torchaudio.info(path) ...
eq_len = self.path_info.get(path, 0) # Получаем из кэша
all_info.append((eq_len, path))
valid = [(l, p) for l, p in all_info if l >= min_needed]
logging.debug(f"✅ Подходящих (>= {min_needed}): {len(valid)} (из {len(all_info)})")
if valid:
# Если есть идеальные — берём случайно из них
random.shuffle(valid)
chosen = random.choice(valid)[1]
return chosen
else:
# 2) Если идеальных нет — берём топ-K по длине
sorted_by_len = sorted(all_info, key=lambda x: x[0], reverse=True)
top_k_list = sorted_by_len[:top_k]
if not top_k_list:
logging.debug("Нет доступных кандидатов вообще.")
return None # вообще нет кандидатов
random.shuffle(top_k_list)
chosen = top_k_list[0][1]
logging.info(f"Из топ-{top_k} выбран кандидат: {chosen}")
return chosen
def run_whisper(self, waveform):
"""
Вызывает Whisper на аудиосигнале и возвращает полный текст (без ограничения по количеству слов).
"""
arr = waveform.squeeze().cpu().numpy()
try:
with torch.no_grad():
result = self.whisper_model.transcribe(arr, fp16=False)
text = result["text"].strip()
return text
except Exception as e:
logging.error(f"Whisper ошибка: {e}")
return ""
def _add_synthetic_data(self, synthetic_ratio):
"""
Добавляет synthetic_ratio (0..1) от количества доступных синтетических файлов на каждую эмоцию.
"""
if not self.synthetic_path:
logging.warning("⚠ Путь к синтетическим данным не указан.")
return
random.seed(self.seed)
synth_csv_path = os.path.join(self.synthetic_path, "meld_s_train_labels.csv")
synth_wav_dir = os.path.join(self.synthetic_path, "wavs")
if not (os.path.exists(synth_csv_path) and os.path.exists(synth_wav_dir)):
logging.warning("⚠ Синтетические данные не найдены.")
return
df_synth = pd.read_csv(synth_csv_path)
rows_by_label = {emotion: [] for emotion in self.emotion_columns}
for _, row in df_synth.iterrows():
audio_path = os.path.join(synth_wav_dir, f"{row['video_name']}.wav")
if not os.path.exists(audio_path):
continue
emotion_values = row[self.emotion_columns].values.astype(float)
max_idx = np.argmax(emotion_values)
label = self.emotion_columns[max_idx]
csv_text = row[self.text_column] if self.text_column in row and isinstance(row[self.text_column], str) else ""
rows_by_label[label].append({
"audio_path": audio_path,
"label": label,
"csv_text": csv_text
})
added = 0
for label in self.emotion_columns:
candidates = rows_by_label[label]
if not candidates:
continue
count_synth = int(len(candidates) * synthetic_ratio)
if count_synth <= 0:
continue
selected = random.sample(candidates, count_synth)
self.rows.extend(selected)
added += len(selected)
logging.info(f"➕ Добавлено {len(selected)} синтетических примеров для эмоции '{label}'")
logging.info(f"📦 Всего добавлено {added} синтетических примеров из MELD_S")
def emotion_to_vector(self, label_name):
"""
Преобразует название эмоции в one-hot вектор (torch.tensor).
"""
v = np.zeros(len(self.emotion_columns), dtype=np.float32)
if label_name in self.emotion_columns:
idx = self.emotion_columns.index(label_name)
v[idx] = 1.0
return torch.tensor(v, dtype=torch.float32)
class DatasetMultiModal(Dataset):
"""
Мультимодальный датасет для аудио, текста и эмоций (он‑the‑fly версия).
При каждом вызове __getitem__:
- Загружает WAV по video_name из CSV.
- Для обучающей выборки (split="train"):
Если аудио короче target_samples, проверяем, выбрали ли мы этот файл для склейки
(по merge_probability). Если да – выполняется "chain merge":
выбирается один или несколько дополнительных файлов того же класса, даже если один кандидат длиннее,
и итоговое аудио затем обрезается до точной длины.
- Если итоговое аудио всё ещё меньше target_samples, выполняется паддинг нулями.
- Текст выбирается так:
• Если аудио было merged (склеено) – вызывается Whisper для получения нового текста.
• Если merge не происходило и CSV-текст не пуст – используется CSV-текст.
• Если CSV-текст пустой – для train (или, при условии, для dev/test) вызывается Whisper.
- Возвращает словарь { "audio": waveform, "label": label_vector, "text": text_final }.
"""
def __init__(
self,
csv_path,
wav_dir,
emotion_columns,
split="train",
sample_rate=16000,
wav_length=4,
whisper_model="tiny",
text_column="text",
use_whisper_for_nontrain_if_no_text=True,
whisper_device="cuda",
subset_size=0,
merge_probability=1.0 # <-- Новый параметр: доля от ОБЩЕГО числа файлов
):
"""
:param csv_path: Путь к CSV-файлу (с колонками video_name, emotion_columns, возможно text).
:param wav_dir: Папка с аудиофайлами (имя файла: video_name.wav).
:param emotion_columns: Список колонок эмоций, например ["neutral", "happy", "sad", ...].
:param split: "train", "dev" или "test".
:param sample_rate: Целевая частота дискретизации (например, 16000).
:param wav_length: Целевая длина аудио в секундах.
:param whisper_model: Название модели Whisper ("tiny", "base", "small", ...).
:param max_text_tokens: (Не используется) – ограничение на число токенов.
:param text_column: Название колонки с текстом в CSV.
:param use_whisper_for_nontrain_if_no_text: Если True, для dev/test при отсутствии CSV-текста вызывается Whisper.
:param whisper_device: "cuda" или "cpu" – устройство для модели Whisper.
:param subset_size: Если > 0, используется только первые N записей из CSV (для отладки).
:param merge_probability: Процент (0..1) от всего числа файлов, которые будут склеиваться, если они короче.
"""
super().__init__()
self.split = split
self.sample_rate = sample_rate
self.target_samples = int(wav_length * sample_rate)
self.emotion_columns = emotion_columns
self.whisper_model_name = whisper_model
self.text_column = text_column
self.use_whisper_for_nontrain_if_no_text = use_whisper_for_nontrain_if_no_text
self.whisper_device = whisper_device
self.merge_probability = merge_probability
# Загружаем CSV
if not os.path.exists(csv_path):
raise ValueError(f"Ошибка: файл CSV не найден: {csv_path}")
df = pd.read_csv(csv_path)
if subset_size > 0:
df = df.head(subset_size)
logging.info(f"[DatasetMultiModal] Используем только первые {len(df)} записей (subset_size={subset_size}).")
# Проверяем наличие всех колонок эмоций
missing = [c for c in emotion_columns if c not in df.columns]
if missing:
raise ValueError(f"В CSV отсутствуют необходимые колонки эмоций: {missing}")
# Проверяем существование папки с аудио
if not os.path.exists(wav_dir):
raise ValueError(f"Ошибка: директория с аудио {wav_dir} не существует!")
self.wav_dir = wav_dir
# Собираем список строк: для каждой записи получаем путь к аудио, label и CSV-текст (если есть)
self.rows = []
for i, rowi in df.iterrows():
audio_path = os.path.join(wav_dir, f"{rowi['video_name']}.wav")
if not os.path.exists(audio_path):
continue
# Определяем доминирующую эмоцию (максимальное значение)
emotion_values = rowi[self.emotion_columns].values.astype(float)
max_idx = np.argmax(emotion_values)
emotion_label = self.emotion_columns[max_idx]
# Извлекаем текст из CSV (если есть)
csv_text = ""
if self.text_column in rowi and isinstance(rowi[self.text_column], str):
csv_text = rowi[self.text_column]
self.rows.append({
"audio_path": audio_path,
"label": emotion_label,
"csv_text": csv_text
})
# Создаем карту для поиска файлов по эмоции
self.audio_class_map = {entry["audio_path"]: entry["label"] for entry in self.rows}
logging.info("📊 Анализ распределения файлов по эмоциям:")
emotion_counts = {emotion: 0 for emotion in set(self.audio_class_map.values())}
for path, emotion in self.audio_class_map.items():
emotion_counts[emotion] += 1
for emotion, count in emotion_counts.items():
logging.info(f"🎭 Эмоция '{emotion}': {count} файлов.")
logging.info(f"[DatasetMultiModal] Сплит={split}, всего строк: {len(self.rows)}")
# === Процентное семплирование ===
total_files = len(self.rows)
num_to_merge = int(total_files * self.merge_probability)
# <<< NEW: Кешируем длины (eq_len) для всех файлов >>>
self.path_info = {}
for row in self.rows:
p = row["audio_path"]
try:
info = torchaudio.info(p)
length = info.num_frames
sr_ = info.sample_rate
# переводим длину в "эквивалент self.sample_rate"
if sr_ != self.sample_rate:
ratio = sr_ / self.sample_rate
eq_len = int(length / ratio)
else:
eq_len = length
self.path_info[p] = eq_len
except Exception as e:
logging.warning(f"⚠️ Ошибка чтения {p}: {e}")
self.path_info[p] = 0 # Если не смогли прочитать, ставим 0
# Определим, какие файлы "короткие" (могут нуждаться в склейке) - используем кэш вместо старого _is_too_short
self.mergable_files = [
row["audio_path"] # вместо целого dict берём строку
for row in self.rows
if self._is_too_short_cached(row["audio_path"]) # <<< теперь тут используем новую функцию
]
short_count = len(self.mergable_files)
# Если коротких файлов больше нужного числа, выберем случайные. Иначе все короткие.
if short_count > num_to_merge:
self.files_to_merge = set(random.sample(self.mergable_files, num_to_merge))
else:
self.files_to_merge = set(self.mergable_files)
logging.info(f"🔗 Всего файлов: {total_files}, нужно склеить: {num_to_merge} ({self.merge_probability*100:.0f}%)")
logging.info(f"🔗 Коротких файлов: {short_count}, выбрано для склейки: {len(self.files_to_merge)}")
# Инициализируем Whisper-модель один раз
logging.info(f"Инициализация Whisper: модель={whisper_model}, устройство={whisper_device}")
self.whisper_model = whisper.load_model(whisper_model, device=whisper_device).eval()
# print(f"📦 Whisper работает на устройстве: {self.whisper_model.device}")
def _is_too_short(self, audio_path):
"""
(Оригинальная) Проверяем, является ли файл короче target_samples.
Использует torchaudio.info(audio_path).
Но теперь этот метод не используется, поскольку мы кешируем длины.
"""
try:
info = torchaudio.info(audio_path)
length = info.num_frames
sr_ = info.sample_rate
# переводим длину в "эквивалент self.sample_rate"
if sr_ != self.sample_rate:
ratio = sr_ / self.sample_rate
eq_len = int(length / ratio)
else:
eq_len = length
return eq_len < self.target_samples
except Exception as e:
logging.warning(f"Ошибка _is_too_short({audio_path}): {e}")
return False
def _is_too_short_cached(self, audio_path):
"""
(Новая) Проверяем, является ли файл короче target_samples, используя закешированную длину в self.path_info.
"""
eq_len = self.path_info.get(audio_path, 0)
return eq_len < self.target_samples
def __len__(self):
return len(self.rows)
def __getitem__(self, index):
"""
Загружает и обрабатывает один элемент датасета (он‑the‑fly).
"""
row = self.rows[index]
audio_path = row["audio_path"]
label_name = row["label"]
csv_text = row["csv_text"]
# Преобразуем label в one-hot вектор
label_vec = self.emotion_to_vector(label_name)
# Шаг 1. Загружаем аудио
waveform, sr = self.load_audio(audio_path)
if waveform is None:
return None
orig_len = waveform.shape[1]
logging.debug(f"Исходная длина {os.path.basename(audio_path)}: {orig_len/sr:.2f} сек")
was_merged = False
merged_texts = [csv_text] # Тексты исходного файла + добавленных
# Шаг 2. Для train, если аудио короче target_samples, проверяем:
# попал ли данный row в files_to_merge?
if self.split == "train" and row["audio_path"] in self.files_to_merge:
# chain merge
current_length = orig_len
used_candidates = set()
while current_length < self.target_samples:
needed = self.target_samples - current_length
candidate = self.get_suitable_audio(label_name, exclude_path=audio_path, min_needed=needed, top_k=10)
if candidate is None or candidate in used_candidates:
break
used_candidates.add(candidate)
add_wf, add_sr = self.load_audio(candidate)
if add_wf is None:
break
logging.debug(f"Склейка: добавляем {os.path.basename(candidate)} (необходимых сэмплов: {needed})")
waveform = torch.cat((waveform, add_wf), dim=1)
current_length = waveform.shape[1]
was_merged = True
# Получаем текст второго файла (если есть в CSV)
add_csv_text = next((r["csv_text"] for r in self.rows if r["audio_path"] == candidate), "")
merged_texts.append(add_csv_text)
logging.debug(f"📜 Текст первого файла: {csv_text}")
logging.debug(f"📜 Текст добавленного файла: {add_csv_text}")
else:
# Если файл не в списке "должны склеить" или сплит не train, пропускаем chain-merge
logging.debug("Файл не выбран для склейки (или не train), пропускаем chain merge.")
# Шаг 3. Если итоговая длина меньше target_samples, паддинг нулями
curr_len = waveform.shape[1]
if curr_len < self.target_samples:
pad_size = self.target_samples - curr_len
logging.debug(f"Паддинг {os.path.basename(audio_path)}: +{pad_size} сэмплов")
waveform = torch.nn.functional.pad(waveform, (0, pad_size))
# Шаг 4. Обрезаем аудио до target_samples (если вышло больше)
waveform = waveform[:, :self.target_samples]
logging.debug(f"Финальная длина {os.path.basename(audio_path)}: {waveform.shape[1]/sr:.2f} сек; was_merged={was_merged}")
# Шаг 5. Получаем текст
if was_merged:
logging.debug("📝 Текст: аудио было merged – вызываем Whisper.")
text_final = self.run_whisper(waveform)
logging.debug(f"🆕 Whisper предсказал: {text_final}")
else:
if csv_text.strip():
logging.debug("Текст: используем CSV-текст (не пуст).")
text_final = csv_text
else:
if self.split == "train" or self.use_whisper_for_nontrain_if_no_text:
logging.debug("Текст: CSV пустой – вызываем Whisper.")
text_final = self.run_whisper(waveform)
else:
logging.debug("Текст: CSV пустой и не вызываем Whisper для dev/test.")
text_final = ""
return {
"audio_path": os.path.basename(audio_path), # new
"audio": waveform,
"label": label_vec,
"text": text_final
}
def load_audio(self, path):
"""
Загружает аудио по указанному пути и ресэмплирует его до self.sample_rate, если необходимо.
"""
if not os.path.exists(path):
logging.warning(f"Файл отсутствует: {path}")
return None, None
try:
wf, sr = torchaudio.load(path)
if sr != self.sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
wf = resampler(wf)
sr = self.sample_rate
return wf, sr
except Exception as e:
logging.error(f"Ошибка загрузки {path}: {e}")
return None, None
def get_suitable_audio(self, label_name, exclude_path, min_needed, top_k=5):
"""
Ищет аудиофайл с той же эмоцией.
1) Если есть файлы >= min_needed, выбираем случайно из них.
2) Если таких нет, берём топ-K самых длинных, потом из них берём случайный.
"""
candidates = [p for p, lbl in self.audio_class_map.items()
if lbl == label_name and p != exclude_path]
logging.debug(f"🔍 Найдено {len(candidates)} кандидатов для эмоции '{label_name}'")
# Сохраним: (eq_len, path) для всех кандидатов, но БЕЗ повторного чтения torchaudio.info
all_info = []
for path in candidates:
# <<< NEW: вместо info = torchaudio.info(path) ...
eq_len = self.path_info.get(path, 0) # Получаем из кэша
all_info.append((eq_len, path))
# --- Ниже старый код, который был:
# for path in candidates:
# try:
# info = torchaudio.info(path)
# length = info.num_frames
# sr_ = info.sample_rate
# eq_len = int(length / (sr_ / self.sample_rate)) if sr_ != self.sample_rate else length
# all_info.append((eq_len, path))
# except Exception as e:
# logging.warning(f"⚠ Ошибка чтения {path}: {e}")
# 1) Фильтруем только >= min_needed
valid = [(l, p) for l, p in all_info if l >= min_needed]
logging.debug(f"✅ Подходящих (>= {min_needed}): {len(valid)} (из {len(all_info)})")
if valid:
# Если есть идеальные — берём случайно из них
random.shuffle(valid)
chosen = random.choice(valid)[1]
return chosen
else:
# 2) Если идеальных нет — берём топ-K по длине
sorted_by_len = sorted(all_info, key=lambda x: x[0], reverse=True)
top_k_list = sorted_by_len[:top_k]
if not top_k_list:
logging.debug("Нет доступных кандидатов вообще.")
return None # вообще нет кандидатов
random.shuffle(top_k_list)
chosen = top_k_list[0][1]
logging.info(f"Из топ-{top_k} выбран кандидат: {chosen}")
return chosen
def run_whisper(self, waveform):
"""
Вызывает Whisper на аудиосигнале и возвращает полный текст (без ограничения по количеству слов).
"""
arr = waveform.squeeze().cpu().numpy()
try:
result = self.whisper_model.transcribe(arr, fp16=False)
text = result["text"].strip()
return text
except Exception as e:
logging.error(f"Whisper ошибка: {e}")
return ""
def emotion_to_vector(self, label_name):
"""
Преобразует название эмоции в one-hot вектор (torch.tensor).
"""
v = np.zeros(len(self.emotion_columns), dtype=np.float32)
if label_name in self.emotion_columns:
idx = self.emotion_columns.index(label_name)
v[idx] = 1.0
return torch.tensor(v, dtype=torch.float32)