KaniTTS / util.py
Den Pavloff
multispeaker, multilang
eb18e14
raw
history blame
14.9 kB
import torch
import librosa
import requests
import time
from nemo.collections.tts.models import AudioCodecModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from omegaconf import OmegaConf, DictConfig
def load_config(config_path: str):
"""Load configuration from a YAML file using OmegaConf.
Args:
config_path (str): Path to the YAML configuration file.
Returns:
Any: The loaded OmegaConf DictConfig.
"""
resolved_path = os.path.abspath(config_path)
if not os.path.exists(resolved_path):
raise FileNotFoundError(f"Config file not found: {resolved_path}")
config = OmegaConf.load(resolved_path)
return config
class NemoAudioPlayer:
"""
High-level audio reconstruction helper built on NeMo Nano Codec.
This class converts discrete codec token sequences produced by the
language model into time-domain audio waveforms using
`nemo.collections.tts.models.AudioCodecModel`. It also optionally
handles extraction/decoding of text spans from the generated token
stream when a compatible text tokenizer is provided.
Parameters
----------
config : OmegaConf | DictConfig
Configuration block under `nemo_player` from `model_config.yaml`.
Expected fields:
- `audiocodec_name` (str): HuggingFace model id for NeMo codec
- `tokeniser_length` (int): Size of the base tokenizer vocabulary
- `start_of_text`, `end_of_text` (int): Special text token ids
text_tokenizer_name : str, optional
HF repo id or local path of the tokenizer used by the LLM. If
provided, the player can also extract and decode the text segment
embedded in the generated ids for debugging/inspection.
Notes
-----
- The class defines a fixed layout of special token ids derived from
`tokeniser_length`. Audio codes are expected to be arranged in 4
interleaved codebooks (q=4). See `get_nano_codes` for validation.
- Device selection is automatic (`cuda` if available else `cpu`).
Typical Usage
-------------
1) The model generates a sequence of token ids that contains both text
and audio sections delimited by special markers.
2) Call `get_waveform(model_output_ids)` to obtain a NumPy waveform
ready to be played or saved.
"""
def __init__(self, config, text_tokenizer_name: str = None) -> None:
self.conf = config
print(f"Loading NeMo codec model: {self.conf.audiocodec_name}")
# Load NeMo codec model
self.nemo_codec_model = AudioCodecModel.from_pretrained(
self.conf.audiocodec_name
).eval()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Moving NeMo codec to device: {self.device}")
self.nemo_codec_model.to(self.device)
self.text_tokenizer_name = text_tokenizer_name
if self.text_tokenizer_name:
self.tokenizer = AutoTokenizer.from_pretrained(self.text_tokenizer_name)
# Token configuration
self.tokeniser_length = self.conf.tokeniser_length
self.start_of_text = self.conf.start_of_text
self.end_of_text = self.conf.end_of_text
self.start_of_speech = self.tokeniser_length + 1
self.end_of_speech = self.tokeniser_length + 2
self.start_of_human = self.tokeniser_length + 3
self.end_of_human = self.tokeniser_length + 4
self.start_of_ai = self.tokeniser_length + 5
self.end_of_ai = self.tokeniser_length + 6
self.pad_token = self.tokeniser_length + 7
self.audio_tokens_start = self.tokeniser_length + 10
self.codebook_size = 4032
def output_validation(self, out_ids):
"""Validate that output contains required speech tokens"""
start_of_speech_flag = self.start_of_speech in out_ids
end_of_speech_flag = self.end_of_speech in out_ids
if not (start_of_speech_flag and end_of_speech_flag):
raise ValueError('Special speech tokens not found in output!')
def get_nano_codes(self, out_ids):
"""Extract nano codec tokens from model output"""
try:
start_a_idx = (out_ids == self.start_of_speech).nonzero(as_tuple=True)[0].item()
end_a_idx = (out_ids == self.end_of_speech).nonzero(as_tuple=True)[0].item()
except IndexError:
raise ValueError('Speech start/end tokens not found!')
if start_a_idx >= end_a_idx:
raise ValueError('Invalid audio codes sequence!')
audio_codes = out_ids[start_a_idx + 1: end_a_idx]
if len(audio_codes) % 4:
raise ValueError('Audio codes length must be multiple of 4!')
audio_codes = audio_codes.reshape(-1, 4)
# Decode audio codes
audio_codes = audio_codes - torch.tensor([self.codebook_size * i for i in range(4)])
audio_codes = audio_codes - self.audio_tokens_start
if (audio_codes < 0).sum().item() > 0:
raise ValueError('Invalid audio tokens detected!')
audio_codes = audio_codes.T.unsqueeze(0)
len_ = torch.tensor([audio_codes.shape[-1]])
return audio_codes, len_
def get_text(self, out_ids):
"""Extract text from model output"""
try:
start_t_idx = (out_ids == self.start_of_text).nonzero(as_tuple=True)[0].item()
end_t_idx = (out_ids == self.end_of_text).nonzero(as_tuple=True)[0].item()
except IndexError:
raise ValueError('Text start/end tokens not found!')
txt_tokens = out_ids[start_t_idx: end_t_idx + 1]
text = self.tokenizer.decode(txt_tokens, skip_special_tokens=True)
return text
def get_waveform(self, out_ids):
"""Convert model output to audio waveform"""
out_ids = out_ids.flatten()
# Validate output
self.output_validation(out_ids)
# Extract audio codes
audio_codes, len_ = self.get_nano_codes(out_ids)
audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device)
with torch.inference_mode():
reconstructed_audio, _ = self.nemo_codec_model.decode(
tokens=audio_codes,
tokens_len=len_
)
output_audio = reconstructed_audio.cpu().detach().numpy().squeeze()
if self.text_tokenizer_name:
text = self.get_text(out_ids)
return output_audio, text
else:
return output_audio, None
class KaniModel:
"""
Wrapper around a causal LLM that emits NeMo codec tokens for TTS.
Responsibilities
-----------------
- Load the LLM and tokenizer from HuggingFace with the provided
configuration (model id, device mapping, auth token, and
`trust_remote_code`).
- Prepare inputs by injecting conversation and modality control tokens
expected by the decoder (`START_OF_HUMAN`, `END_OF_TEXT`, etc.), and
optionally prefix the input with a speaker id tag.
- Perform generation with sampling parameters and return raw token ids.
- Delegate waveform reconstruction to `NemoAudioPlayer`.
Parameters
----------
config : OmegaConf | DictConfig
Model configuration block from `models[...]` in `model_config.yaml`.
Expected fields:
- `model_name` (str): HF repo id of the LLM
- `device_map` (str | dict): Device mapping strategy for HF
player : NemoAudioPlayer
Audio decoder that turns generated token ids into waveform.
token : str
HuggingFace access token (if the model requires authentication).
Key Methods
-----------
- `get_input_ids(text, speaker_id)`: builds the prompt with control
tokens and returns `(input_ids, attention_mask)` tensors.
- `model_request(...)`: runs `generate` with sampling controls.
- `run_model(...)`: end-to-end pipeline returning `(audio, text, report)`.
"""
def __init__(self, config, player: NemoAudioPlayer, token: str) -> None:
self.conf = config
self.player = player
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading model: {self.conf.model_name}")
print(f"Target device: {self.device}")
# Load model with proper configuration
self.model = AutoModelForCausalLM.from_pretrained(
self.conf.model_name,
dtype=torch.bfloat16,
device_map=self.conf.device_map,
token=token,
trust_remote_code=True # May be needed for some models
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.conf.model_name,
token=token,
trust_remote_code=True
)
print(f"Model loaded successfully on device: {next(self.model.parameters()).device}")
def get_input_ids(self, text_prompt: str, speaker_id:str) -> tuple[torch.tensor]:
"""Prepare input tokens for the model"""
START_OF_HUMAN = self.player.start_of_human
END_OF_TEXT = self.player.end_of_text
END_OF_HUMAN = self.player.end_of_human
# Tokenize input text
if speaker_id is not None:
input_ids = self.tokenizer(f"{speaker_id}: {text_prompt}", return_tensors="pt").input_ids
else:
input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids
# Add special tokens
start_token = torch.tensor([[START_OF_HUMAN]], dtype=torch.int64)
end_tokens = torch.tensor([[END_OF_TEXT, END_OF_HUMAN]], dtype=torch.int64)
# Concatenate tokens
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64)
return modified_input_ids, attention_mask
def model_request(
self,
input_ids: torch.tensor,
attention_mask: torch.tensor,
t:float,
top_p:float,
rp: float,
max_tok: int) -> torch.tensor:
"""Generate tokens using the model"""
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tok,
do_sample=True,
temperature=t,
top_p=top_p,
repetition_penalty=rp,
num_return_sequences=1,
eos_token_id=self.player.end_of_speech,
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
)
return generated_ids.to('cpu')
def time_report(self, point_1, point_2, point_3):
model_request = point_2 - point_1
player_time = point_3 - point_2
total_time = point_3 - point_1
report = f"SPEECH TOKENS: {model_request:.2f}\nCODEC: {player_time:.2f}\nTOTAL: {total_time:.2f}"
return report
def run_model(self, text: str, speaker_id:str, t: float, top_p: float, rp: float, max_tok: int):
"""Complete pipeline: text -> tokens -> generation -> audio"""
# Prepare input
input_ids, attention_mask = self.get_input_ids(text, speaker_id)
# Generate tokens
point_1 = time.time()
model_output = self.model_request(input_ids, attention_mask, t, top_p, rp, max_tok)
# Convert to audio
point_2 = time.time()
audio, _ = self.player.get_waveform(model_output)
point_3 = time.time()
return audio, text, self.time_report(point_1, point_2, point_3)
class InitModels:
"""
Lazy initializer that constructs a map of model name -> KaniModel.
Parameters
----------
models_configs : OmegaConf | DictConfig
The `models` section from `model_config.yaml` describing one or
more HF LLM checkpoints and their options (device map, speakers).
player : NemoAudioPlayer
Shared audio decoder instance reused across all models.
token_ : str
HuggingFace token passed to each `KaniModel` for loading.
Returns
-------
dict
When called, returns a dictionary `{model_name: KaniModel}`.
Notes
-----
- All models are loaded immediately in `__call__` so the UI can list
them and switch between them without extra latency.
"""
def __init__(self, models_configs:OmegaConf, player: NemoAudioPlayer, token_:str):
self.models_configs = models_configs
self.player = player
self.token_ = token_
def __call__(self):
models = {}
for model_name, config in self.models_configs.items():
print(f"Loading {model_name}...")
models[model_name] = KaniModel(config, self.player, self.token_)
print(f"{model_name} loaded!")
print("All models loaded!")
return models
class Examples:
"""
Adapter that converts YAML examples into Gradio `gr.Examples` rows.
Parameters
----------
exam_cfg : OmegaConf | DictConfig
Parsed contents of `examples.yaml`. Expected structure:
`examples: [ {text, speaker_id?, model, temperature?, top_p?,
repetition_penalty?, max_len?}, ... ]`.
Behavior
--------
- Produces a list-of-lists whose order must match the `inputs` order
used when constructing `gr.Examples` in `app.py`.
- Current order: `[text, model_dropdown, speaker_dropdown, temp,
top_p, rp, max_tok]`.
Why this exists
---------------
- Keeps format and defaults centralized, so changing the UI inputs
order only requires a single change here and in `app.py`.
"""
def __init__(self, exam_cfg: OmegaConf):
self.exam_cfg = exam_cfg
def __call__(self)->list[list]:
rows = []
for e in self.exam_cfg.examples:
text = e.get("text")
speaker_id = e.get("speaker_id")
model = e.get("model")
temperature = e.get("temperature", 1.4)
top_p = e.get("top_p", 0.95)
repetition_penalty = e.get("repetition_penalty", 1.1)
max_len = e.get("max_len", 1200)
# Order must match gr.Examples inputs: [text, model_dropdown, speaker_dropdown, temp, top_p, rp, max_tok]
rows.append([text, model, speaker_id, temperature, top_p, repetition_penalty, max_len])
return rows