File size: 4,780 Bytes
3215d8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
import sys
import io
import json
import tempfile
from TTS.utils.synthesizer import Synthesizer
import numpy as np
import triton_python_backend_utils as pb_utils
ENABLE_XLIT = True
INFERENCE_MODULE_DIR = "/home/app"
sys.path.insert(0, INFERENCE_MODULE_DIR)
from src.inference import TextToSpeechEngine
PWD = os.path.dirname(__file__)
class TritonPythonModel:
def initialize(self, args):
"""`initialize` is called only once when the model is being loaded.
Implementing `initialize` function is optional. This function allows
the model to intialize any state associated with this model.
Parameters
----------
args : dict
Both keys and values are strings. The dictionary keys and values are:
* model_config: A JSON string containing the model configuration
* model_instance_kind: A string containing model instance kind
* model_instance_device_id: A string containing model instance device ID
* model_repository: Model repository path
* model_version: Model version
* model_name: Model name
"""
# You must parse model_config. JSON string is not parsed here
self.model_config = model_config = json.loads(args['model_config'])
self.model_instance_device_id = json.loads(args['model_instance_device_id'])
# checkpoints_root_dir = os.path.join(PWD, "checkpoints")
checkpoints_root_dir = "/models/checkpoints"
checkpoint_folders = [ f.path for f in os.scandir(checkpoints_root_dir) if f.is_dir() ]
# The assumption is that, each folder name is language code
self.supported_speaker_ids = {"male", "female"}
self.supported_lang_codes = set()
self.models = {}
for checkpoint_folder in checkpoint_folders:
lang_code = os.path.basename(checkpoint_folder)
# Replace a few hardcoded paths in the config
tts_config_path = os.path.join(checkpoint_folder, "fastpitch/config.json")
tts_config = json.load(open(tts_config_path))
speakers_file = tts_config_path.replace("config.json", "speakers.pth")
tts_config["model_args"]["speakers_file"] = speakers_file
tts_config["speakers_file"] = speakers_file
# Write the config file to a temporary path so that we can pass it to the Synthesizer class
patched_tts_config_file = tempfile.NamedTemporaryFile(suffix=".json", mode='w', encoding='utf-8', delete=False)
patched_tts_config_file.write(json.dumps(tts_config))
patched_tts_config_file.close()
self.models[lang_code] = Synthesizer(
tts_checkpoint=os.path.join(checkpoint_folder, "fastpitch/best_model.pth"),
tts_config_path=patched_tts_config_file.name,
vocoder_checkpoint=os.path.join(checkpoint_folder, "hifigan/best_model.pth"),
vocoder_config=os.path.join(checkpoint_folder, "hifigan/config.json"),
use_cuda=True,
)
self.supported_lang_codes.add(lang_code)
os.unlink(patched_tts_config_file.name)
if "en+hi" in self.supported_lang_codes and "en" not in self.supported_lang_codes:
self.supported_lang_codes.add("en")
self.engine = TextToSpeechEngine(
self.models,
allow_transliteration=ENABLE_XLIT,
enable_denoiser=False,
)
def execute(self, requests):
responses = []
for request in requests:
input_texts = pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT").as_numpy()
speaker_ids = pb_utils.get_input_tensor_by_name(request, "INPUT_SPEAKER_ID").as_numpy()
lang_ids = pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID").as_numpy()
input_texts = [input_text.decode("utf-8", "ignore") for input_text in input_texts]
speaker_ids = [speaker_id.decode("utf-8", "ignore") for speaker_id in speaker_ids]
lang_ids = [lang_id.decode("utf-8", "ignore") for lang_id in lang_ids]
generated_audios = []
for input_text, speaker_id, lang_id in zip(input_texts, speaker_ids, lang_ids):
if lang_id in self.supported_lang_codes and speaker_id in self.supported_speaker_ids:
# generated_audio = self.engine.models[lang_id].tts(input_text, speaker_id)
generated_audio = self.engine.infer_from_text(input_text, lang=lang_id, speaker_name=speaker_id, transliterate_roman_to_native=ENABLE_XLIT)
else:
raise NotImplementedError("Language not supported")
# generated_audio = [0]
generated_audios.append(generated_audio)
out_tensor_0 = pb_utils.Tensor("OUTPUT_GENERATED_AUDIO",
np.array(generated_audios, dtype=np.float32))
inference_response = pb_utils.InferenceResponse(
output_tensors=[out_tensor_0])
responses.append(inference_response)
return responses
|