File size: 5,445 Bytes
fc37b9e
 
 
 
 
 
 
 
 
 
 
19556fd
 
 
 
 
fc37b9e
 
 
 
015ffd7
fc37b9e
 
 
 
 
 
 
 
90ff39a
f87673e
fc37b9e
 
c529ff7
fc37b9e
 
 
 
 
 
7c6f92f
 
fc37b9e
 
 
 
 
 
 
 
 
 
 
 
19556fd
 
 
 
 
fc37b9e
 
 
 
 
 
 
 
19556fd
fc37b9e
 
 
 
 
19556fd
fc37b9e
 
 
 
c529ff7
fc37b9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# tts.py

from gruut import sentences
import re
import numpy as np
import onnxruntime as ort
from pathlib import Path
import json
import string
from IPython.display import Audio
import soundfile as sf
import logging

# Configure logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load models
lightspeech = ort.InferenceSession("./models/lightspeech_quant.onnx")
mbmelgan = ort.InferenceSession("./models/mbmelgan.onnx")
lightspeech_processor_config = Path("./models/processor.json")

with open(lightspeech_processor_config, "r") as f:
    processor = json.load(f)
    tokenizer = processor["symbol_to_id"]

class TTS:
    @staticmethod
    def generate(text: str) -> np.ndarray:
        sections = TTS.split_text(text)
        audio_sections = TTS.generate_speech_for_sections(sections)
        concatenated_audio = TTS.concatenate_audio_sections(audio_sections)
        return concatenated_audio

    @staticmethod
    def split_text(text: str) -> list:
        # Split the text into sentences based on punctuation marks
        sentences = re.split(r'(?<=[.!?])\s*', text)
        sections = []

        # for testing get upto first 3 sentences only
        sentences = sentences[:3]
        for sentence in sentences:
            # Split each sentence by commas for short pauses
            parts = re.split(r',\s*', sentence)
            for i, part in enumerate(parts):
                sections.append(part.strip())
                if i < len(parts) - 1:
                    sections.append('*')  # Short pause marker
            sections.append('**')  # Long pause marker after each sentence

        # Remove empty sections
        sections = [section for section in sections if section]

        # Trim last long pause marker
        if sections[-1] == '**':
            sections = sections[:-1]

        logger.info(f"Split text into sections: {sections}")
        return sections

    @staticmethod
    def generate_speech_for_sections(sections: list) -> list:
        audio_sections = []
        for section in sections:
            if section == '**':
                # Long pause
                pause_duration = 0.4
                sample_rate = 44100
                pause = np.zeros(int(pause_duration * sample_rate))
                audio_sections.append(pause)
            elif section == '*':
                # Short pause
                pause_duration = 0.2
                sample_rate = 44100
                pause = np.zeros(int(pause_duration * sample_rate))
                audio_sections.append(pause)
            else:
                mel_output, _ = TTS.text2mel(section)
                audio_array = TTS.mel2wav(mel_output)
                audio_sections.append(audio_array)
        return audio_sections

    @staticmethod
    def concatenate_audio_sections(audio_sections: list) -> np.ndarray:
        concatenated_audio = np.concatenate(audio_sections)
        return concatenated_audio



    @staticmethod
    def phonemize(word: str) -> str:
        ipa = []
        for words in sentences(word, lang="sw"):
            for word in words:
                if word.is_major_break or word.is_minor_break:
                    ipa += [word.text]
                    continue

                phonemes = word.phonemes[:]
                NG_GRAPHEME = "ng'"
                NG_PRENASALIZED_PHONEME = "ᵑg"
                NG_PHONEME = "ŋ"
                if NG_GRAPHEME in word.text:
                    ng_graphemes = re.findall(f"{NG_GRAPHEME}?", word.text)
                    ng_phonemes_idx = [i for i, p in enumerate(phonemes) if p == NG_PRENASALIZED_PHONEME]
                    assert len(ng_graphemes) == len(ng_phonemes_idx)
                    for i, g in zip(ng_phonemes_idx, ng_graphemes):
                        phonemes[i] = NG_PHONEME if g == NG_GRAPHEME else phonemes[i]

                ipa += phonemes
        return ipa

    @staticmethod
    def tokenize(phonemes):
        input_ids = []
        for phoneme in phonemes:
            if all(c in string.punctuation for c in phoneme):
                input_ids.append(tokenizer[phoneme])
            else:
                input_ids.append(tokenizer[f"@{phoneme}"])
        return input_ids

    @staticmethod
    def text2mel(text: str) -> tuple:
        phonemes = TTS.phonemize(text)
        input_ids = TTS.tokenize(phonemes)

        inputs = {
            "input_ids": np.array([input_ids], dtype=np.int32),
            "speaker_ids": np.array([0], dtype=np.int32),
            "speed_ratios": np.array([1.0], dtype=np.float32),
            "f0_ratios":  np.array([1.0], dtype=np.float32),
            "energy_ratios": np.array([1.0], dtype=np.float32),
        }

        mel_output, durations, _ = lightspeech.run(None, inputs)
        return mel_output, durations

    @staticmethod
    def mel2wav(mel_output: np.ndarray) -> np.ndarray:
        # Prepare input for vocoder model
        inputs = {
            "mels": mel_output,
        }

        # Run inference
        outputs = mbmelgan.run(None, inputs)
        audio_array = outputs[0][0, :, 0]

        return audio_array
    
    @staticmethod
    def synthesize(text: str) -> np.ndarray:
        mel_output, _ = TTS.text2mel(text)
        audio_array = TTS.mel2wav(mel_output)
        return audio_array
    
    @staticmethod
    def save_audio(audio_array: np.ndarray, path: str):
        sf.write(path, audio_array, 44100)