Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from nemo.collections.tts.models import AudioCodecModel | |
| from dataclasses import dataclass | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import os | |
| class Config: | |
| model_name: str = "nineninesix/lfm-nano-codec-tts-exp-4-large-61468-st" | |
| audiocodec_name: str = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps" | |
| device_map: str = "auto" | |
| tokeniser_length: int = 64400 | |
| start_of_text: int = 1 | |
| end_of_text: int = 2 | |
| max_new_tokens: int = 2000 | |
| temperature: float = .6 | |
| top_p: float = .95 | |
| repetition_penalty: float = 1.1 | |
| class NemoAudioPlayer: | |
| def __init__(self, config, text_tokenizer_name: str = None) -> None: | |
| self.conf = config | |
| print(f"Loading NeMo codec model: {self.conf.audiocodec_name}") | |
| # Load NeMo codec model | |
| self.nemo_codec_model = AudioCodecModel.from_pretrained( | |
| self.conf.audiocodec_name | |
| ).eval() | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"Moving NeMo codec to device: {self.device}") | |
| self.nemo_codec_model.to(self.device) | |
| self.text_tokenizer_name = text_tokenizer_name | |
| if self.text_tokenizer_name: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.text_tokenizer_name) | |
| # Token configuration | |
| self.tokeniser_length = self.conf.tokeniser_length | |
| self.start_of_text = self.conf.start_of_text | |
| self.end_of_text = self.conf.end_of_text | |
| self.start_of_speech = self.tokeniser_length + 1 | |
| self.end_of_speech = self.tokeniser_length + 2 | |
| self.start_of_human = self.tokeniser_length + 3 | |
| self.end_of_human = self.tokeniser_length + 4 | |
| self.start_of_ai = self.tokeniser_length + 5 | |
| self.end_of_ai = self.tokeniser_length + 6 | |
| self.pad_token = self.tokeniser_length + 7 | |
| self.audio_tokens_start = self.tokeniser_length + 10 | |
| self.codebook_size = 4032 | |
| def output_validation(self, out_ids): | |
| """Validate that output contains required speech tokens""" | |
| start_of_speech_flag = self.start_of_speech in out_ids | |
| end_of_speech_flag = self.end_of_speech in out_ids | |
| if not (start_of_speech_flag and end_of_speech_flag): | |
| raise ValueError('Special speech tokens not found in output!') | |
| print("Output validation passed - speech tokens found") | |
| def get_nano_codes(self, out_ids): | |
| """Extract nano codec tokens from model output""" | |
| try: | |
| start_a_idx = (out_ids == self.start_of_speech).nonzero(as_tuple=True)[0].item() | |
| end_a_idx = (out_ids == self.end_of_speech).nonzero(as_tuple=True)[0].item() | |
| except IndexError: | |
| raise ValueError('Speech start/end tokens not found!') | |
| if start_a_idx >= end_a_idx: | |
| raise ValueError('Invalid audio codes sequence!') | |
| audio_codes = out_ids[start_a_idx + 1: end_a_idx] | |
| if len(audio_codes) % 4: | |
| raise ValueError('Audio codes length must be multiple of 4!') | |
| audio_codes = audio_codes.reshape(-1, 4) | |
| # Decode audio codes | |
| audio_codes = audio_codes - torch.tensor([self.codebook_size * i for i in range(4)]) | |
| audio_codes = audio_codes - self.audio_tokens_start | |
| if (audio_codes < 0).sum().item() > 0: | |
| raise ValueError('Invalid audio tokens detected!') | |
| audio_codes = audio_codes.T.unsqueeze(0) | |
| len_ = torch.tensor([audio_codes.shape[-1]]) | |
| print(f"Extracted audio codes shape: {audio_codes.shape}") | |
| return audio_codes, len_ | |
| def get_text(self, out_ids): | |
| """Extract text from model output""" | |
| try: | |
| start_t_idx = (out_ids == self.start_of_text).nonzero(as_tuple=True)[0].item() | |
| end_t_idx = (out_ids == self.end_of_text).nonzero(as_tuple=True)[0].item() | |
| except IndexError: | |
| raise ValueError('Text start/end tokens not found!') | |
| txt_tokens = out_ids[start_t_idx: end_t_idx + 1] | |
| text = self.tokenizer.decode(txt_tokens, skip_special_tokens=True) | |
| return text | |
| def get_waveform(self, out_ids): | |
| """Convert model output to audio waveform""" | |
| out_ids = out_ids.flatten() | |
| print("Starting waveform generation...") | |
| # Validate output | |
| self.output_validation(out_ids) | |
| # Extract audio codes | |
| audio_codes, len_ = self.get_nano_codes(out_ids) | |
| audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device) | |
| print("Decoding audio with NeMo codec...") | |
| with torch.inference_mode(): | |
| reconstructed_audio, _ = self.nemo_codec_model.decode( | |
| tokens=audio_codes, | |
| tokens_len=len_ | |
| ) | |
| output_audio = reconstructed_audio.cpu().detach().numpy().squeeze() | |
| print(f"Generated audio shape: {output_audio.shape}") | |
| if self.text_tokenizer_name: | |
| text = self.get_text(out_ids) | |
| return output_audio, text | |
| else: | |
| return output_audio, None | |
| class KaniModel: | |
| def __init__(self, config, player: NemoAudioPlayer, token: str) -> None: | |
| self.conf = config | |
| self.player = player | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"Loading model: {self.conf.model_name}") | |
| print(f"Target device: {self.device}") | |
| # Load model with proper configuration | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.conf.model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map=self.conf.device_map, | |
| token=token, | |
| trust_remote_code=True # May be needed for some models | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.conf.model_name, | |
| token=token, | |
| trust_remote_code=True | |
| ) | |
| print(f"Model loaded successfully on device: {next(self.model.parameters()).device}") | |
| def get_input_ids(self, text_prompt: str) -> tuple[torch.tensor]: | |
| """Prepare input tokens for the model""" | |
| START_OF_HUMAN = self.player.start_of_human | |
| END_OF_TEXT = self.player.end_of_text | |
| END_OF_HUMAN = self.player.end_of_human | |
| # Tokenize input text | |
| input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids | |
| # Add special tokens | |
| start_token = torch.tensor([[START_OF_HUMAN]], dtype=torch.int64) | |
| end_tokens = torch.tensor([[END_OF_TEXT, END_OF_HUMAN]], dtype=torch.int64) | |
| # Concatenate tokens | |
| modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) | |
| attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64) | |
| print(f"Input sequence length: {modified_input_ids.shape[1]}") | |
| return modified_input_ids, attention_mask | |
| def model_request(self, input_ids: torch.tensor, attention_mask: torch.tensor) -> torch.tensor: | |
| """Generate tokens using the model""" | |
| input_ids = input_ids.to(self.device) | |
| attention_mask = attention_mask.to(self.device) | |
| print("Starting model generation...") | |
| print(f"Generation parameters: max_tokens={self.conf.max_new_tokens}, " | |
| f"temp={self.conf.temperature}, top_p={self.conf.top_p}") | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=self.conf.max_new_tokens, | |
| do_sample=True, | |
| temperature=self.conf.temperature, | |
| top_p=self.conf.top_p, | |
| repetition_penalty=self.conf.repetition_penalty, | |
| num_return_sequences=1, | |
| eos_token_id=self.player.end_of_speech, | |
| pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id | |
| ) | |
| print(f"Generated sequence length: {generated_ids.shape[1]}") | |
| return generated_ids.to('cpu') | |
| def run_model(self, text: str): | |
| """Complete pipeline: text -> tokens -> generation -> audio""" | |
| print(f"Processing text: '{text[:50]}{'...' if len(text) > 50 else ''}'") | |
| # Prepare input | |
| input_ids, attention_mask = self.get_input_ids(text) | |
| # Generate tokens | |
| model_output = self.model_request(input_ids, attention_mask) | |
| # Convert to audio | |
| audio, _ = self.player.get_waveform(model_output) | |
| print("Text-to-speech generation completed successfully!") | |
| return audio, text |