BiBiER / synthetic_utils /dia_tts_wrapper.py
farbverlauf's picture
gpu
960b1a0
import os
import time
import logging
import torch
import soundfile as sf
import numpy as np
from dia.model import Dia
class DiaTTSWrapper:
def __init__(self, model_name="nari-labs/Dia-1.6B", device="cuda", dtype="float16"):
self.device = device
self.sr = 44100
logging.info(f"[DiaTTS] Загрузка модели {model_name} на {device} (dtype={dtype})")
self.model = Dia.from_pretrained(
model_name,
device=device,
compute_dtype=dtype
)
def generate_audio_from_text(self, text: str, paralinguistic: str = "", max_duration: float = None) -> torch.Tensor:
try:
if paralinguistic:
clean = paralinguistic.strip("()").lower()
text = f"{text} ({clean})"
audio_np = self.model.generate(
text,
use_torch_compile=False,
verbose=False
)
wf = torch.from_numpy(audio_np).float().unsqueeze(0)
if max_duration:
max_samples = int(self.sr * max_duration)
wf = wf[:, :max_samples]
return wf
except Exception as e:
logging.error(f"[DiaTTS] Ошибка генерации аудио: {e}")
return torch.zeros(1, self.sr)
def generate_and_save_audio(
self,
text: str,
paralinguistic: str = "",
out_dir="tts_outputs",
filename_prefix="tts",
max_duration: float = None,
use_timestamp=True,
skip_if_exists=True,
max_trim_duration: float = None
) -> torch.Tensor:
os.makedirs(out_dir, exist_ok=True)
if use_timestamp:
timestr = time.strftime("%Y%m%d_%H%M%S")
filename = f"{filename_prefix}_{timestr}.wav"
else:
filename = f"{filename_prefix}.wav"
out_path = os.path.join(out_dir, filename)
if skip_if_exists and os.path.exists(out_path):
logging.info(f"[DiaTTS] ⏭️ Пропущено — уже существует: {out_path}")
return None
wf = self.generate_audio_from_text(text, paralinguistic, max_duration)
np_wf = wf.squeeze().cpu().numpy()
if max_trim_duration is not None:
max_len = int(self.sr * max_trim_duration)
if len(np_wf) > max_len:
logging.info(f"[DiaTTS] ✂️ Обрезка аудио до {max_trim_duration} сек.")
np_wf = np_wf[:max_len]
sf.write(out_path, np_wf, self.sr)
logging.info(f"[DiaTTS] 💾 Сохранено аудио: {out_path}")
return wf
def get_sample_rate(self):
return self.sr