File size: 5,445 Bytes
fc37b9e 19556fd fc37b9e 015ffd7 fc37b9e 90ff39a f87673e fc37b9e c529ff7 fc37b9e 7c6f92f fc37b9e 19556fd fc37b9e 19556fd fc37b9e 19556fd fc37b9e c529ff7 fc37b9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# tts.py
from gruut import sentences
import re
import numpy as np
import onnxruntime as ort
from pathlib import Path
import json
import string
from IPython.display import Audio
import soundfile as sf
import logging
# Configure logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load models
lightspeech = ort.InferenceSession("./models/lightspeech_quant.onnx")
mbmelgan = ort.InferenceSession("./models/mbmelgan.onnx")
lightspeech_processor_config = Path("./models/processor.json")
with open(lightspeech_processor_config, "r") as f:
processor = json.load(f)
tokenizer = processor["symbol_to_id"]
class TTS:
@staticmethod
def generate(text: str) -> np.ndarray:
sections = TTS.split_text(text)
audio_sections = TTS.generate_speech_for_sections(sections)
concatenated_audio = TTS.concatenate_audio_sections(audio_sections)
return concatenated_audio
@staticmethod
def split_text(text: str) -> list:
# Split the text into sentences based on punctuation marks
sentences = re.split(r'(?<=[.!?])\s*', text)
sections = []
# for testing get upto first 3 sentences only
sentences = sentences[:3]
for sentence in sentences:
# Split each sentence by commas for short pauses
parts = re.split(r',\s*', sentence)
for i, part in enumerate(parts):
sections.append(part.strip())
if i < len(parts) - 1:
sections.append('*') # Short pause marker
sections.append('**') # Long pause marker after each sentence
# Remove empty sections
sections = [section for section in sections if section]
# Trim last long pause marker
if sections[-1] == '**':
sections = sections[:-1]
logger.info(f"Split text into sections: {sections}")
return sections
@staticmethod
def generate_speech_for_sections(sections: list) -> list:
audio_sections = []
for section in sections:
if section == '**':
# Long pause
pause_duration = 0.4
sample_rate = 44100
pause = np.zeros(int(pause_duration * sample_rate))
audio_sections.append(pause)
elif section == '*':
# Short pause
pause_duration = 0.2
sample_rate = 44100
pause = np.zeros(int(pause_duration * sample_rate))
audio_sections.append(pause)
else:
mel_output, _ = TTS.text2mel(section)
audio_array = TTS.mel2wav(mel_output)
audio_sections.append(audio_array)
return audio_sections
@staticmethod
def concatenate_audio_sections(audio_sections: list) -> np.ndarray:
concatenated_audio = np.concatenate(audio_sections)
return concatenated_audio
@staticmethod
def phonemize(word: str) -> str:
ipa = []
for words in sentences(word, lang="sw"):
for word in words:
if word.is_major_break or word.is_minor_break:
ipa += [word.text]
continue
phonemes = word.phonemes[:]
NG_GRAPHEME = "ng'"
NG_PRENASALIZED_PHONEME = "ᵑg"
NG_PHONEME = "ŋ"
if NG_GRAPHEME in word.text:
ng_graphemes = re.findall(f"{NG_GRAPHEME}?", word.text)
ng_phonemes_idx = [i for i, p in enumerate(phonemes) if p == NG_PRENASALIZED_PHONEME]
assert len(ng_graphemes) == len(ng_phonemes_idx)
for i, g in zip(ng_phonemes_idx, ng_graphemes):
phonemes[i] = NG_PHONEME if g == NG_GRAPHEME else phonemes[i]
ipa += phonemes
return ipa
@staticmethod
def tokenize(phonemes):
input_ids = []
for phoneme in phonemes:
if all(c in string.punctuation for c in phoneme):
input_ids.append(tokenizer[phoneme])
else:
input_ids.append(tokenizer[f"@{phoneme}"])
return input_ids
@staticmethod
def text2mel(text: str) -> tuple:
phonemes = TTS.phonemize(text)
input_ids = TTS.tokenize(phonemes)
inputs = {
"input_ids": np.array([input_ids], dtype=np.int32),
"speaker_ids": np.array([0], dtype=np.int32),
"speed_ratios": np.array([1.0], dtype=np.float32),
"f0_ratios": np.array([1.0], dtype=np.float32),
"energy_ratios": np.array([1.0], dtype=np.float32),
}
mel_output, durations, _ = lightspeech.run(None, inputs)
return mel_output, durations
@staticmethod
def mel2wav(mel_output: np.ndarray) -> np.ndarray:
# Prepare input for vocoder model
inputs = {
"mels": mel_output,
}
# Run inference
outputs = mbmelgan.run(None, inputs)
audio_array = outputs[0][0, :, 0]
return audio_array
@staticmethod
def synthesize(text: str) -> np.ndarray:
mel_output, _ = TTS.text2mel(text)
audio_array = TTS.mel2wav(mel_output)
return audio_array
@staticmethod
def save_audio(audio_array: np.ndarray, path: str):
sf.write(path, audio_array, 44100)
|