File size: 2,379 Bytes
960b1a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# parler_tts_wrapper.py

import torch
import soundfile as sf
import time
import os
import logging
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer

class ParlerTTS:
    def __init__(self, model_name="parler-tts/parler-tts-mini-v1", device="cuda"):
        self.device = device
        logging.info(f"[ParlerTTS] Загрузка модели {model_name} на {device} ...")

        self.model = ParlerTTSForConditionalGeneration.from_pretrained(model_name).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.sr = self.model.config.sampling_rate

    def generate_audio_from_text(self, text: str, description: str) -> torch.Tensor:
        """
        Генерирует аудио (без сохранения на диск).
        Возвращает PyTorch-тензор формы (1, num_samples).
        """
        input_ids = self.tokenizer(description, return_tensors="pt").input_ids.to(self.device)
        prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device)

        with torch.no_grad():
            generation = self.model.generate(
                input_ids=input_ids,
                prompt_input_ids=prompt_input_ids
            )

        audio_arr = generation.cpu().numpy().squeeze()  # (samples,)
        wf = torch.from_numpy(audio_arr).unsqueeze(0)    # -> (1, samples)
        return wf

    def generate_and_save_audio(self, text: str, description: str, out_dir="tts_outputs", filename_prefix="tts") -> torch.Tensor:
        """
        Генерирует аудио И сохраняет результат в WAV-файл (для отладки/проверки).
        Возвращает PyTorch-тензор (1, num_samples).
        """
        os.makedirs(out_dir, exist_ok=True)

        wf = self.generate_audio_from_text(text, description)
        np_wf = wf.squeeze().cpu().numpy()

        # Формируем имя файла
        timestr = time.strftime("%Y%m%d_%H%M%S")
        filename = f"{filename_prefix}_{timestr}.wav"
        out_path = os.path.join(out_dir, filename)

        # Сохраняем
        sf.write(out_path, np_wf, self.sr)
        logging.info(f"[ParlerTTS] Сохранено аудио: {out_path}")

        return wf

    def get_sample_rate(self):
        return self.sr