import torch import librosa import requests import time from nemo.collections.tts.models import AudioCodecModel from dataclasses import dataclass from transformers import AutoTokenizer, AutoModelForCausalLM import os @dataclass class Config: model_name: str = "nineninesix/lfm-nano-codec-tts-exp-4-large-61468-st" audiocodec_name: str = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps" device_map: str = "auto" tokeniser_length: int = 64400 start_of_text: int = 1 end_of_text: int = 2 max_new_tokens: int = 1200 temperature: float = .6 top_p: float = .95 repetition_penalty: float = 1.1 class NemoAudioPlayer: def __init__(self, config, text_tokenizer_name: str = None) -> None: self.conf = config print(f"Loading NeMo codec model: {self.conf.audiocodec_name}") # Load NeMo codec model self.nemo_codec_model = AudioCodecModel.from_pretrained( self.conf.audiocodec_name ).eval() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Moving NeMo codec to device: {self.device}") self.nemo_codec_model.to(self.device) self.text_tokenizer_name = text_tokenizer_name if self.text_tokenizer_name: self.tokenizer = AutoTokenizer.from_pretrained(self.text_tokenizer_name) # Token configuration self.tokeniser_length = self.conf.tokeniser_length self.start_of_text = self.conf.start_of_text self.end_of_text = self.conf.end_of_text self.start_of_speech = self.tokeniser_length + 1 self.end_of_speech = self.tokeniser_length + 2 self.start_of_human = self.tokeniser_length + 3 self.end_of_human = self.tokeniser_length + 4 self.start_of_ai = self.tokeniser_length + 5 self.end_of_ai = self.tokeniser_length + 6 self.pad_token = self.tokeniser_length + 7 self.audio_tokens_start = self.tokeniser_length + 10 self.codebook_size = 4032 def output_validation(self, out_ids): """Validate that output contains required speech tokens""" start_of_speech_flag = self.start_of_speech in out_ids end_of_speech_flag = self.end_of_speech in out_ids if not (start_of_speech_flag and end_of_speech_flag): raise ValueError('Special speech tokens not found in output!') def get_nano_codes(self, out_ids): """Extract nano codec tokens from model output""" try: start_a_idx = (out_ids == self.start_of_speech).nonzero(as_tuple=True)[0].item() end_a_idx = (out_ids == self.end_of_speech).nonzero(as_tuple=True)[0].item() except IndexError: raise ValueError('Speech start/end tokens not found!') if start_a_idx >= end_a_idx: raise ValueError('Invalid audio codes sequence!') audio_codes = out_ids[start_a_idx + 1: end_a_idx] if len(audio_codes) % 4: raise ValueError('Audio codes length must be multiple of 4!') audio_codes = audio_codes.reshape(-1, 4) # Decode audio codes audio_codes = audio_codes - torch.tensor([self.codebook_size * i for i in range(4)]) audio_codes = audio_codes - self.audio_tokens_start if (audio_codes < 0).sum().item() > 0: raise ValueError('Invalid audio tokens detected!') audio_codes = audio_codes.T.unsqueeze(0) len_ = torch.tensor([audio_codes.shape[-1]]) return audio_codes, len_ def get_text(self, out_ids): """Extract text from model output""" try: start_t_idx = (out_ids == self.start_of_text).nonzero(as_tuple=True)[0].item() end_t_idx = (out_ids == self.end_of_text).nonzero(as_tuple=True)[0].item() except IndexError: raise ValueError('Text start/end tokens not found!') txt_tokens = out_ids[start_t_idx: end_t_idx + 1] text = self.tokenizer.decode(txt_tokens, skip_special_tokens=True) return text def get_waveform(self, out_ids): """Convert model output to audio waveform""" out_ids = out_ids.flatten() # Validate output self.output_validation(out_ids) # Extract audio codes audio_codes, len_ = self.get_nano_codes(out_ids) audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device) with torch.inference_mode(): reconstructed_audio, _ = self.nemo_codec_model.decode( tokens=audio_codes, tokens_len=len_ ) output_audio = reconstructed_audio.cpu().detach().numpy().squeeze() if self.text_tokenizer_name: text = self.get_text(out_ids) return output_audio, text else: return output_audio, None class KaniModel: def __init__(self, config, player: NemoAudioPlayer, token: str) -> None: self.conf = config self.player = player self.device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Loading model: {self.conf.model_name}") print(f"Target device: {self.device}") # Load model with proper configuration self.model = AutoModelForCausalLM.from_pretrained( self.conf.model_name, torch_dtype=torch.bfloat16, device_map=self.conf.device_map, token=token, trust_remote_code=True # May be needed for some models ) self.tokenizer = AutoTokenizer.from_pretrained( self.conf.model_name, token=token, trust_remote_code=True ) print(f"Model loaded successfully on device: {next(self.model.parameters()).device}") def get_input_ids(self, text_prompt: str) -> tuple[torch.tensor]: """Prepare input tokens for the model""" START_OF_HUMAN = self.player.start_of_human END_OF_TEXT = self.player.end_of_text END_OF_HUMAN = self.player.end_of_human # Tokenize input text input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids # Add special tokens start_token = torch.tensor([[START_OF_HUMAN]], dtype=torch.int64) end_tokens = torch.tensor([[END_OF_TEXT, END_OF_HUMAN]], dtype=torch.int64) # Concatenate tokens modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64) return modified_input_ids, attention_mask def model_request(self, input_ids: torch.tensor, attention_mask: torch.tensor) -> torch.tensor: """Generate tokens using the model""" input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) with torch.no_grad(): generated_ids = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=self.conf.max_new_tokens, do_sample=True, temperature=self.conf.temperature, top_p=self.conf.top_p, repetition_penalty=self.conf.repetition_penalty, num_return_sequences=1, eos_token_id=self.player.end_of_speech, pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id ) return generated_ids.to('cpu') def time_report(self, point_1, point_2, point_3): model_request = point_2 - point_1 player_time = point_3 - point_2 total_time = point_3 - point_1 report = f"MODEL GENERATION: {model_request:.2f}\nNANO CODEC: {player_time:.2f}\nTOTAL: {total_time:.2f}" return report def run_model(self, text: str): """Complete pipeline: text -> tokens -> generation -> audio""" # Prepare input input_ids, attention_mask = self.get_input_ids(text) # Generate tokens point_1 = time.time() model_output = self.model_request(input_ids, attention_mask) # Convert to audio point_2 = time.time() audio, _ = self.player.get_waveform(model_output) point_3 = time.time() return audio, text, self.time_report(point_1, point_2, point_3) class Demo: def __init__(self): self.audio_dir = './audio_examples' os.makedirs(self.audio_dir, exist_ok=True) self.sentences = [ "You make my days brighter, and my wildest dreams feel like reality. How do you do that?", "Anyway, um, so, um, tell me, tell me all about her. I mean, what's she like? Is she really, you know, pretty?", "Great, and just a couple quick questions so we can match you with the right buyer. Is your home address still 330 East Charleston Road?", "No, that does not make you a failure. No, sweetie, no. It just, uh, it just means that you're having a tough time...", "Oh, yeah. I mean did you want to get a quick snack together or maybe something before you go?", "I-- Oh, I am such an idiot sometimes. I'm so sorry. Um, I-I don't know where my head's at.", "Got it. $300,000. I can definitely help you get a very good price for your property by selecting a realtor.", "Holy fu- Oh my God! Don't you understand how dangerous it is, huh?" ] self.urls = [ 'https://www.nineninesix.ai/examples/kani/1.wav', 'https://www.nineninesix.ai/examples/kani/2.wav', 'https://www.nineninesix.ai/examples/kani/5.wav', 'https://www.nineninesix.ai/examples/kani/6.wav', 'https://www.nineninesix.ai/examples/kani/3.wav', 'https://www.nineninesix.ai/examples/kani/7.wav', 'https://www.nineninesix.ai/examples/kani/4.wav', 'https://www.nineninesix.ai/examples/kani/8.wav' ] def download_audio(self, url: str, filename: str): filepath = os.path.join(self.audio_dir, filename) if not os.path.exists(filepath): r = requests.get(url) r.raise_for_status() with open(filepath, 'wb') as f: f.write(r.content) return filepath def get_audio(self, filepath: str): arr, _ = librosa.load(filepath, sr=22050) return arr def __call__(self): examples = {} for idx, (sentence, url) in enumerate(zip(self.sentences, self.urls), start=1): filename = f"{idx}.wav" filepath = self.download_audio(url, filename) examples[sentence] = self.get_audio(filepath) return examples