|
import torch |
|
|
|
torch.manual_seed(42) |
|
import json |
|
import re |
|
import unicodedata |
|
from types import SimpleNamespace |
|
import time |
|
import numpy as np |
|
import regex |
|
from scipy.io.wavfile import write |
|
from models import DurationNet, SynthesizerTrn |
|
import os |
|
import re |
|
|
|
from process import print_percent_done |
|
|
|
title = "LightSpeed: Vietnamese Male Voice TTS" |
|
description = "Vietnam Male Voice TTS." |
|
config_file = "config.json" |
|
duration_model_path = "vbx_duration_model.pth" |
|
lightspeed_model_path = "gen_619k.pth" |
|
phone_set_file = "vbx_phone_set.json" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
with open(config_file, "rb") as f: |
|
hps = json.load(f, object_hook=lambda x: SimpleNamespace(**x)) |
|
|
|
|
|
with open(phone_set_file, "r") as f: |
|
phone_set = json.load(f) |
|
|
|
assert phone_set[0][1:-1] == "SEP" |
|
assert "sil" in phone_set |
|
sil_idx = phone_set.index("sil") |
|
|
|
space_re = regex.compile(r"\s+") |
|
number_re = regex.compile("([0-9]+)") |
|
digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"] |
|
num_re = regex.compile(r"([0-9.,]*[0-9])") |
|
alphabet = "aàáảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoòóỏõọôồốổỗộơờớởỡợuùúủũụưừứửữựyỳýỷỹỵbcdđghklmnpqrstvx" |
|
keep_text_and_num_re = regex.compile(rf"[^\s{alphabet}.,0-9]") |
|
keep_text_re = regex.compile(rf"[^\s{alphabet}]") |
|
|
|
|
|
def read_number(num: str) -> str: |
|
if len(num) == 1: |
|
return digits[int(num)] |
|
elif len(num) == 2 and num.isdigit(): |
|
n = int(num) |
|
end = digits[n % 10] |
|
if n == 10: |
|
return "mười" |
|
if n % 10 == 5: |
|
end = "lăm" |
|
if n % 10 == 0: |
|
return digits[n // 10] + " mươi" |
|
elif n < 20: |
|
return "mười " + end |
|
else: |
|
if n % 10 == 1: |
|
end = "mốt" |
|
return digits[n // 10] + " mươi " + end |
|
elif len(num) == 3 and num.isdigit(): |
|
n = int(num) |
|
if n % 100 == 0: |
|
return digits[n // 100] + " trăm" |
|
elif num[1] == "0": |
|
return digits[n // 100] + " trăm lẻ " + digits[n % 100] |
|
else: |
|
return digits[n // 100] + " trăm " + read_number(num[1:]) |
|
elif len(num) >= 4 and len(num) <= 6 and num.isdigit(): |
|
n = int(num) |
|
n1 = n // 1000 |
|
return read_number(str(n1)) + " ngàn " + read_number(num[-3:]) |
|
elif "," in num: |
|
n1, n2 = num.split(",") |
|
return read_number(n1) + " phẩy " + read_number(n2) |
|
elif "." in num: |
|
parts = num.split(".") |
|
if len(parts) == 2: |
|
if parts[1] == "000": |
|
return read_number(parts[0]) + " ngàn" |
|
elif parts[1].startswith("00"): |
|
end = digits[int(parts[1][2:])] |
|
return read_number(parts[0]) + " ngàn lẻ " + end |
|
else: |
|
return read_number(parts[0]) + " ngàn " + read_number(parts[1]) |
|
elif len(parts) == 3: |
|
return ( |
|
read_number(parts[0]) |
|
+ " triệu " |
|
+ read_number(parts[1]) |
|
+ " ngàn " |
|
+ read_number(parts[2]) |
|
) |
|
return num |
|
|
|
|
|
def text_to_phone_idx(text): |
|
|
|
text = text.lower() |
|
|
|
text = unicodedata.normalize("NFKC", text) |
|
text = text.replace(".", " . ") |
|
text = text.replace(",", " , ") |
|
text = text.replace(";", " ; ") |
|
text = text.replace(":", " : ") |
|
text = text.replace("!", " ! ") |
|
text = text.replace("?", " ? ") |
|
text = text.replace("(", " ( ") |
|
|
|
text = num_re.sub(r" \1 ", text) |
|
words = text.split() |
|
words = [read_number(w) if num_re.fullmatch(w) else w for w in words] |
|
text = " ".join(words) |
|
|
|
|
|
text = re.sub(r"\s+", " ", text) |
|
|
|
text = text.strip() |
|
|
|
tokens = [] |
|
for c in text: |
|
|
|
if c in ":,.!?;(": |
|
tokens.append(sil_idx) |
|
elif c in phone_set: |
|
tokens.append(phone_set.index(c)) |
|
elif c == " ": |
|
|
|
tokens.append(0) |
|
if(len(tokens)==0): |
|
return tokens |
|
if tokens[0] != sil_idx: |
|
|
|
tokens = [sil_idx, 0] + tokens |
|
if tokens[-1] != sil_idx: |
|
tokens = tokens + [0, sil_idx] |
|
return tokens |
|
|
|
|
|
def text_to_speech(duration_net, generator, text): |
|
|
|
|
|
text = re.sub(r"(\d+):(\d+)", r"chương \1 câu \2", text) |
|
|
|
|
|
|
|
def capitalize_name(match): |
|
return match.group(0).replace("-", " ").title() |
|
|
|
|
|
text = re.sub(r"\b\w+(?:-\w+)+\b", capitalize_name, text) |
|
|
|
|
|
text = re.sub(r"(\d+)(\D+)", r"\1 \2", text) |
|
|
|
phone_idx = text_to_phone_idx(text) |
|
|
|
batch = { |
|
"phone_idx": np.array([phone_idx]), |
|
"phone_length": np.array([len(phone_idx)]), |
|
} |
|
|
|
|
|
phone_length = torch.from_numpy(batch["phone_length"].copy()).long().to(device) |
|
phone_idx = torch.from_numpy(batch["phone_idx"].copy()).long().to(device) |
|
with torch.inference_mode(): |
|
phone_duration = duration_net(phone_idx, phone_length)[:, :, 0] * 1000 |
|
phone_duration = torch.where( |
|
phone_idx == sil_idx, torch.clamp_min(phone_duration, 200), phone_duration |
|
) |
|
phone_duration = torch.where(phone_idx == 0, 0, phone_duration) |
|
|
|
|
|
end_time = torch.cumsum(phone_duration, dim=-1) |
|
start_time = end_time - phone_duration |
|
start_frame = start_time / 1000 * hps.data.sampling_rate / hps.data.hop_length |
|
end_frame = end_time / 1000 * hps.data.sampling_rate / hps.data.hop_length |
|
spec_length = end_frame.max(dim=-1).values |
|
pos = torch.arange(0, spec_length.item(), device=device) |
|
attn = torch.logical_and( |
|
pos[None, :, None] >= start_frame[:, None, :], |
|
pos[None, :, None] < end_frame[:, None, :], |
|
).float() |
|
with torch.inference_mode(): |
|
y_hat = generator.infer( |
|
phone_idx, phone_length, spec_length, attn, max_len=None, noise_scale=0.667 |
|
)[0] |
|
wave = y_hat[0, 0].data.cpu().numpy() |
|
return (wave * (2**15)).astype(np.int16) |
|
|
|
|
|
def load_models(): |
|
duration_net = DurationNet(hps.data.vocab_size, 64, 4).to(device) |
|
duration_net.load_state_dict(torch.load(duration_model_path, map_location=device)) |
|
duration_net = duration_net.eval() |
|
generator = SynthesizerTrn( |
|
hps.data.vocab_size, |
|
hps.data.filter_length // 2 + 1, |
|
hps.train.segment_size // hps.data.hop_length, |
|
**vars(hps.model), |
|
).to(device) |
|
del generator.enc_q |
|
ckpt = torch.load(lightspeed_model_path, map_location=device) |
|
params = {} |
|
for k, v in ckpt["net_g"].items(): |
|
k = k[7:] if k.startswith("module.") else k |
|
params[k] = v |
|
generator.load_state_dict(params, strict=False) |
|
del ckpt, params |
|
generator = generator.eval() |
|
return duration_net, generator |
|
|
|
|
|
def speak(text,filename): |
|
duration_net, generator = load_models() |
|
paragraphs = text.split("\n") |
|
clips = [] |
|
|
|
count = 0; |
|
for paragraph in paragraphs: |
|
paragraph = paragraph.strip(); |
|
|
|
|
|
paragraph = re.sub(r"[*#&^@\[\]{}]", "", paragraph) |
|
|
|
if paragraph == "": |
|
continue |
|
clips.append(text_to_speech(duration_net, generator, paragraph)) |
|
|
|
|
|
process = round(len(clips) / len(paragraphs) * 100) |
|
print_percent_done(process, 100, 50, 'Processing ' + filename) |
|
|
|
|
|
y = np.concatenate(clips) |
|
|
|
write('/kaggle/working/'+ filename+ str(time.time())+'.wav' ,hps.data.sampling_rate, y) |
|
return hps.data.sampling_rate, y |
|
|
|
dir = '/kaggle/working/vi-tts/books' |
|
|
|
for filename in os.listdir(dir): |
|
fs = open(dir + '/'+filename, "r") |
|
text = fs.read() |
|
speak(text,filename.split('.')[0]) |
|
fs.close() |
|
print('Saved: '+filename) |