import torch import librosa import requests import time from nemo.collections.tts.models import AudioCodecModel from transformers import AutoTokenizer, AutoModelForCausalLM import os from omegaconf import OmegaConf, DictConfig def load_config(config_path: str): """Load configuration from a YAML file using OmegaConf. Args: config_path (str): Path to the YAML configuration file. Returns: Any: The loaded OmegaConf DictConfig. """ resolved_path = os.path.abspath(config_path) if not os.path.exists(resolved_path): raise FileNotFoundError(f"Config file not found: {resolved_path}") config = OmegaConf.load(resolved_path) return config class NemoAudioPlayer: """ High-level audio reconstruction helper built on NeMo Nano Codec. This class converts discrete codec token sequences produced by the language model into time-domain audio waveforms using `nemo.collections.tts.models.AudioCodecModel`. It also optionally handles extraction/decoding of text spans from the generated token stream when a compatible text tokenizer is provided. Parameters ---------- config : OmegaConf | DictConfig Configuration block under `nemo_player` from `model_config.yaml`. Expected fields: - `audiocodec_name` (str): HuggingFace model id for NeMo codec - `tokeniser_length` (int): Size of the base tokenizer vocabulary - `start_of_text`, `end_of_text` (int): Special text token ids text_tokenizer_name : str, optional HF repo id or local path of the tokenizer used by the LLM. If provided, the player can also extract and decode the text segment embedded in the generated ids for debugging/inspection. Notes ----- - The class defines a fixed layout of special token ids derived from `tokeniser_length`. Audio codes are expected to be arranged in 4 interleaved codebooks (q=4). See `get_nano_codes` for validation. - Device selection is automatic (`cuda` if available else `cpu`). Typical Usage ------------- 1) The model generates a sequence of token ids that contains both text and audio sections delimited by special markers. 2) Call `get_waveform(model_output_ids)` to obtain a NumPy waveform ready to be played or saved. """ def __init__(self, config, text_tokenizer_name: str = None) -> None: self.conf = config print(f"Loading NeMo codec model: {self.conf.audiocodec_name}") # Load NeMo codec model self.nemo_codec_model = AudioCodecModel.from_pretrained( self.conf.audiocodec_name ).eval() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Moving NeMo codec to device: {self.device}") self.nemo_codec_model.to(self.device) self.text_tokenizer_name = text_tokenizer_name if self.text_tokenizer_name: self.tokenizer = AutoTokenizer.from_pretrained(self.text_tokenizer_name) # Token configuration self.tokeniser_length = self.conf.tokeniser_length self.start_of_text = self.conf.start_of_text self.end_of_text = self.conf.end_of_text self.start_of_speech = self.tokeniser_length + 1 self.end_of_speech = self.tokeniser_length + 2 self.start_of_human = self.tokeniser_length + 3 self.end_of_human = self.tokeniser_length + 4 self.start_of_ai = self.tokeniser_length + 5 self.end_of_ai = self.tokeniser_length + 6 self.pad_token = self.tokeniser_length + 7 self.audio_tokens_start = self.tokeniser_length + 10 self.codebook_size = 4032 def output_validation(self, out_ids): """Validate that output contains required speech tokens""" start_of_speech_flag = self.start_of_speech in out_ids end_of_speech_flag = self.end_of_speech in out_ids if not (start_of_speech_flag and end_of_speech_flag): raise ValueError('Special speech tokens not found in output!') def get_nano_codes(self, out_ids): """Extract nano codec tokens from model output""" try: start_a_idx = (out_ids == self.start_of_speech).nonzero(as_tuple=True)[0].item() end_a_idx = (out_ids == self.end_of_speech).nonzero(as_tuple=True)[0].item() except IndexError: raise ValueError('Speech start/end tokens not found!') if start_a_idx >= end_a_idx: raise ValueError('Invalid audio codes sequence!') audio_codes = out_ids[start_a_idx + 1: end_a_idx] if len(audio_codes) % 4: raise ValueError('Audio codes length must be multiple of 4!') audio_codes = audio_codes.reshape(-1, 4) # Decode audio codes audio_codes = audio_codes - torch.tensor([self.codebook_size * i for i in range(4)]) audio_codes = audio_codes - self.audio_tokens_start if (audio_codes < 0).sum().item() > 0: raise ValueError('Invalid audio tokens detected!') audio_codes = audio_codes.T.unsqueeze(0) len_ = torch.tensor([audio_codes.shape[-1]]) return audio_codes, len_ def get_text(self, out_ids): """Extract text from model output""" try: start_t_idx = (out_ids == self.start_of_text).nonzero(as_tuple=True)[0].item() end_t_idx = (out_ids == self.end_of_text).nonzero(as_tuple=True)[0].item() except IndexError: raise ValueError('Text start/end tokens not found!') txt_tokens = out_ids[start_t_idx: end_t_idx + 1] text = self.tokenizer.decode(txt_tokens, skip_special_tokens=True) return text def get_waveform(self, out_ids): """Convert model output to audio waveform""" out_ids = out_ids.flatten() # Validate output self.output_validation(out_ids) # Extract audio codes audio_codes, len_ = self.get_nano_codes(out_ids) audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device) with torch.inference_mode(): reconstructed_audio, _ = self.nemo_codec_model.decode( tokens=audio_codes, tokens_len=len_ ) output_audio = reconstructed_audio.cpu().detach().numpy().squeeze() if self.text_tokenizer_name: text = self.get_text(out_ids) return output_audio, text else: return output_audio, None class KaniModel: """ Wrapper around a causal LLM that emits NeMo codec tokens for TTS. Responsibilities ----------------- - Load the LLM and tokenizer from HuggingFace with the provided configuration (model id, device mapping, auth token, and `trust_remote_code`). - Prepare inputs by injecting conversation and modality control tokens expected by the decoder (`START_OF_HUMAN`, `END_OF_TEXT`, etc.), and optionally prefix the input with a speaker id tag. - Perform generation with sampling parameters and return raw token ids. - Delegate waveform reconstruction to `NemoAudioPlayer`. Parameters ---------- config : OmegaConf | DictConfig Model configuration block from `models[...]` in `model_config.yaml`. Expected fields: - `model_name` (str): HF repo id of the LLM - `device_map` (str | dict): Device mapping strategy for HF player : NemoAudioPlayer Audio decoder that turns generated token ids into waveform. token : str HuggingFace access token (if the model requires authentication). Key Methods ----------- - `get_input_ids(text, speaker_id)`: builds the prompt with control tokens and returns `(input_ids, attention_mask)` tensors. - `model_request(...)`: runs `generate` with sampling controls. - `run_model(...)`: end-to-end pipeline returning `(audio, text, report)`. """ def __init__(self, config, player: NemoAudioPlayer, token: str) -> None: self.conf = config self.player = player self.hf_token = token self.device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Loading model: {self.conf.model_name}") print(f"Target device: {self.device}") # Set HF_TOKEN in environment to avoid parameter passing issues if self.hf_token: os.environ['HF_TOKEN'] = self.hf_token # Load model with proper configuration # Don't pass token parameter - it will be read from HF_TOKEN env var self.model = AutoModelForCausalLM.from_pretrained( self.conf.model_name, dtype=torch.bfloat16, device_map=self.conf.device_map, trust_remote_code=True ) self.tokenizer = AutoTokenizer.from_pretrained( self.conf.model_name, trust_remote_code=True ) print(f"Model loaded successfully on device: {next(self.model.parameters()).device}") def get_input_ids(self, text_prompt: str, speaker_id:str) -> tuple[torch.tensor]: """Prepare input tokens for the model""" START_OF_HUMAN = self.player.start_of_human END_OF_TEXT = self.player.end_of_text END_OF_HUMAN = self.player.end_of_human # Tokenize input text if speaker_id is not None: input_ids = self.tokenizer(f"{speaker_id}: {text_prompt}", return_tensors="pt").input_ids else: input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids # Add special tokens start_token = torch.tensor([[START_OF_HUMAN]], dtype=torch.int64) end_tokens = torch.tensor([[END_OF_TEXT, END_OF_HUMAN]], dtype=torch.int64) # Concatenate tokens modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64) return modified_input_ids, attention_mask def model_request( self, input_ids: torch.tensor, attention_mask: torch.tensor, t:float, top_p:float, rp: float, max_tok: int) -> torch.tensor: """Generate tokens using the model""" input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) with torch.no_grad(): generated_ids = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_tok, do_sample=True, temperature=t, top_p=top_p, repetition_penalty=rp, num_return_sequences=1, eos_token_id=self.player.end_of_speech, pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id ) return generated_ids.to('cpu') def time_report(self, point_1, point_2, point_3): model_request = point_2 - point_1 player_time = point_3 - point_2 total_time = point_3 - point_1 report = f"SPEECH TOKENS: {model_request:.2f}\nCODEC: {player_time:.2f}\nTOTAL: {total_time:.2f}" return report def run_model(self, text: str, speaker_id:str, t: float, top_p: float, rp: float, max_tok: int): """Complete pipeline: text -> tokens -> generation -> audio""" # Prepare input input_ids, attention_mask = self.get_input_ids(text, speaker_id) # Generate tokens point_1 = time.time() model_output = self.model_request(input_ids, attention_mask, t, top_p, rp, max_tok) # Convert to audio point_2 = time.time() audio, _ = self.player.get_waveform(model_output) point_3 = time.time() return audio, text, self.time_report(point_1, point_2, point_3) class InitModels: """ Lazy initializer that constructs a map of model name -> KaniModel. Parameters ---------- models_configs : OmegaConf | DictConfig The `models` section from `model_config.yaml` describing one or more HF LLM checkpoints and their options (device map, speakers). player : NemoAudioPlayer Shared audio decoder instance reused across all models. token_ : str HuggingFace token passed to each `KaniModel` for loading. Returns ------- dict When called, returns a dictionary `{model_name: KaniModel}`. Notes ----- - All models are loaded immediately in `__call__` so the UI can list them and switch between them without extra latency. """ def __init__(self, models_configs:OmegaConf, player: NemoAudioPlayer, token_:str): self.models_configs = models_configs self.player = player self.token_ = token_ def __call__(self): models = {} for model_name, config in self.models_configs.items(): print(f"Loading {model_name}...") models[model_name] = KaniModel(config, self.player, self.token_) print(f"{model_name} loaded!") print("All models loaded!") return models class Examples: """ Adapter that converts YAML examples into Gradio `gr.Examples` rows. Parameters ---------- exam_cfg : OmegaConf | DictConfig Parsed contents of `examples.yaml`. Expected structure: `examples: [ {text, speaker_id?, model, temperature?, top_p?, repetition_penalty?, max_len?}, ... ]`. Behavior -------- - Produces a list-of-lists whose order must match the `inputs` order used when constructing `gr.Examples` in `app.py`. - Current order: `[text, model_dropdown, speaker_dropdown, temp, top_p, rp, max_tok]`. Why this exists --------------- - Keeps format and defaults centralized, so changing the UI inputs order only requires a single change here and in `app.py`. """ def __init__(self, exam_cfg: OmegaConf): self.exam_cfg = exam_cfg def __call__(self)->list[list]: rows = [] for e in self.exam_cfg.examples: text = e.get("text") speaker_id = e.get("speaker_id") model = e.get("model") temperature = e.get("temperature", 1.4) top_p = e.get("top_p", 0.95) repetition_penalty = e.get("repetition_penalty", 1.1) max_len = e.get("max_len", 1200) # Order must match gr.Examples inputs: [text, model_dropdown, speaker_dropdown, temp, top_p, rp, max_tok] rows.append([text, model, speaker_id, temperature, top_p, repetition_penalty, max_len]) return rows