File size: 2,726 Bytes
960b1a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import time
import logging
import torch
import soundfile as sf
import numpy as np
from dia.model import Dia


class DiaTTSWrapper:
    def __init__(self, model_name="nari-labs/Dia-1.6B", device="cuda", dtype="float16"):
        self.device = device
        self.sr = 44100
        logging.info(f"[DiaTTS] Загрузка модели {model_name} на {device} (dtype={dtype})")
        self.model = Dia.from_pretrained(
            model_name,
            device=device,
            compute_dtype=dtype
        )

    def generate_audio_from_text(self, text: str, paralinguistic: str = "", max_duration: float = None) -> torch.Tensor:
        try:
            if paralinguistic:
                clean = paralinguistic.strip("()").lower()
                text = f"{text} ({clean})"
            audio_np = self.model.generate(
                text,
                use_torch_compile=False,
                verbose=False
            )
            wf = torch.from_numpy(audio_np).float().unsqueeze(0)
            if max_duration:
                max_samples = int(self.sr * max_duration)
                wf = wf[:, :max_samples]
            return wf
        except Exception as e:
            logging.error(f"[DiaTTS] Ошибка генерации аудио: {e}")
            return torch.zeros(1, self.sr)

    def generate_and_save_audio(
        self,
        text: str,
        paralinguistic: str = "",
        out_dir="tts_outputs",
        filename_prefix="tts",
        max_duration: float = None,
        use_timestamp=True,
        skip_if_exists=True,
        max_trim_duration: float = None
    ) -> torch.Tensor:
        os.makedirs(out_dir, exist_ok=True)
        if use_timestamp:
            timestr = time.strftime("%Y%m%d_%H%M%S")
            filename = f"{filename_prefix}_{timestr}.wav"
        else:
            filename = f"{filename_prefix}.wav"
        out_path = os.path.join(out_dir, filename)

        if skip_if_exists and os.path.exists(out_path):
            logging.info(f"[DiaTTS] ⏭️ Пропущено — уже существует: {out_path}")
            return None

        wf = self.generate_audio_from_text(text, paralinguistic, max_duration)
        np_wf = wf.squeeze().cpu().numpy()

        if max_trim_duration is not None:
            max_len = int(self.sr * max_trim_duration)
            if len(np_wf) > max_len:
                logging.info(f"[DiaTTS] ✂️ Обрезка аудио до {max_trim_duration} сек.")
                np_wf = np_wf[:max_len]

        sf.write(out_path, np_wf, self.sr)
        logging.info(f"[DiaTTS] 💾 Сохранено аудио: {out_path}")
        return wf

    def get_sample_rate(self):
        return self.sr