Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	File size: 8,849 Bytes
			
			| 164603c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | import torch
from nemo.collections.tts.models import AudioCodecModel
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
@dataclass
class Config:
    model_name: str = "nineninesix/lfm-nano-codec-tts-exp-4-large-61468-st"
    audiocodec_name: str = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps"
    device_map: str = "auto"
    tokeniser_length: int = 64400
    start_of_text: int = 1
    end_of_text: int = 2
    max_new_tokens: int = 2000
    temperature: float = .6
    top_p: float = .95
    repetition_penalty: float = 1.1
class NemoAudioPlayer:
    def __init__(self, config, text_tokenizer_name: str = None) -> None:
        self.conf = config
        print(f"Loading NeMo codec model: {self.conf.audiocodec_name}")
        
        # Load NeMo codec model
        self.nemo_codec_model = AudioCodecModel.from_pretrained(
            self.conf.audiocodec_name
        ).eval()
        
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Moving NeMo codec to device: {self.device}")
        self.nemo_codec_model.to(self.device)
        
        self.text_tokenizer_name = text_tokenizer_name
        if self.text_tokenizer_name:
            self.tokenizer = AutoTokenizer.from_pretrained(self.text_tokenizer_name)
        # Token configuration
        self.tokeniser_length = self.conf.tokeniser_length
        self.start_of_text = self.conf.start_of_text
        self.end_of_text = self.conf.end_of_text
        self.start_of_speech = self.tokeniser_length + 1
        self.end_of_speech = self.tokeniser_length + 2
        self.start_of_human = self.tokeniser_length + 3
        self.end_of_human = self.tokeniser_length + 4
        self.start_of_ai = self.tokeniser_length + 5
        self.end_of_ai = self.tokeniser_length + 6
        self.pad_token = self.tokeniser_length + 7
        self.audio_tokens_start = self.tokeniser_length + 10
        self.codebook_size = 4032
    def output_validation(self, out_ids):
        """Validate that output contains required speech tokens"""
        start_of_speech_flag = self.start_of_speech in out_ids
        end_of_speech_flag = self.end_of_speech in out_ids
        
        if not (start_of_speech_flag and end_of_speech_flag):
            raise ValueError('Special speech tokens not found in output!')
        
        print("Output validation passed - speech tokens found")
    def get_nano_codes(self, out_ids):
        """Extract nano codec tokens from model output"""
        try:
            start_a_idx = (out_ids == self.start_of_speech).nonzero(as_tuple=True)[0].item()
            end_a_idx = (out_ids == self.end_of_speech).nonzero(as_tuple=True)[0].item()
        except IndexError:
            raise ValueError('Speech start/end tokens not found!')
            
        if start_a_idx >= end_a_idx:
            raise ValueError('Invalid audio codes sequence!')
        audio_codes = out_ids[start_a_idx + 1: end_a_idx]
        
        if len(audio_codes) % 4:
            raise ValueError('Audio codes length must be multiple of 4!')
            
        audio_codes = audio_codes.reshape(-1, 4)
        
        # Decode audio codes
        audio_codes = audio_codes - torch.tensor([self.codebook_size * i for i in range(4)])
        audio_codes = audio_codes - self.audio_tokens_start
        
        if (audio_codes < 0).sum().item() > 0:
            raise ValueError('Invalid audio tokens detected!')
        audio_codes = audio_codes.T.unsqueeze(0)
        len_ = torch.tensor([audio_codes.shape[-1]])
        
        print(f"Extracted audio codes shape: {audio_codes.shape}")
        return audio_codes, len_
    def get_text(self, out_ids):
        """Extract text from model output"""
        try:
            start_t_idx = (out_ids == self.start_of_text).nonzero(as_tuple=True)[0].item()
            end_t_idx = (out_ids == self.end_of_text).nonzero(as_tuple=True)[0].item()
        except IndexError:
            raise ValueError('Text start/end tokens not found!')
            
        txt_tokens = out_ids[start_t_idx: end_t_idx + 1]
        text = self.tokenizer.decode(txt_tokens, skip_special_tokens=True)
        return text
    def get_waveform(self, out_ids):
        """Convert model output to audio waveform"""
        out_ids = out_ids.flatten()
        print("Starting waveform generation...")
        
        # Validate output
        self.output_validation(out_ids)
        
        # Extract audio codes
        audio_codes, len_ = self.get_nano_codes(out_ids)
        audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device)
        
        print("Decoding audio with NeMo codec...")
        with torch.inference_mode():
            reconstructed_audio, _ = self.nemo_codec_model.decode(
                tokens=audio_codes, 
                tokens_len=len_
            )
            output_audio = reconstructed_audio.cpu().detach().numpy().squeeze()
        print(f"Generated audio shape: {output_audio.shape}")
        
        if self.text_tokenizer_name:
            text = self.get_text(out_ids)
            return output_audio, text
        else:
            return output_audio, None
class KaniModel:
    def __init__(self, config, player: NemoAudioPlayer, token: str) -> None:
        self.conf = config
        self.player = player
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        print(f"Loading model: {self.conf.model_name}")
        print(f"Target device: {self.device}")
        
        # Load model with proper configuration
        self.model = AutoModelForCausalLM.from_pretrained(
            self.conf.model_name,
            torch_dtype=torch.bfloat16,
            device_map=self.conf.device_map,
            token=token,
            trust_remote_code=True  # May be needed for some models
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.conf.model_name, 
            token=token,
            trust_remote_code=True
        )
        
        print(f"Model loaded successfully on device: {next(self.model.parameters()).device}")
    def get_input_ids(self, text_prompt: str) -> tuple[torch.tensor]:
        """Prepare input tokens for the model"""
        START_OF_HUMAN = self.player.start_of_human
        END_OF_TEXT = self.player.end_of_text
        END_OF_HUMAN = self.player.end_of_human
        # Tokenize input text
        input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids
        
        # Add special tokens
        start_token = torch.tensor([[START_OF_HUMAN]], dtype=torch.int64)
        end_tokens = torch.tensor([[END_OF_TEXT, END_OF_HUMAN]], dtype=torch.int64)
        
        # Concatenate tokens
        modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
        attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64)
        
        print(f"Input sequence length: {modified_input_ids.shape[1]}")
        return modified_input_ids, attention_mask
    def model_request(self, input_ids: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
        """Generate tokens using the model"""
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        print("Starting model generation...")
        print(f"Generation parameters: max_tokens={self.conf.max_new_tokens}, "
              f"temp={self.conf.temperature}, top_p={self.conf.top_p}")
        
        with torch.no_grad():
            generated_ids = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=self.conf.max_new_tokens,
                do_sample=True,
                temperature=self.conf.temperature,
                top_p=self.conf.top_p,
                repetition_penalty=self.conf.repetition_penalty,
                num_return_sequences=1,
                eos_token_id=self.player.end_of_speech,
                pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
            )
        
        print(f"Generated sequence length: {generated_ids.shape[1]}")
        return generated_ids.to('cpu')
    def run_model(self, text: str):
        """Complete pipeline: text -> tokens -> generation -> audio"""
        print(f"Processing text: '{text[:50]}{'...' if len(text) > 50 else ''}'")
        
        # Prepare input
        input_ids, attention_mask = self.get_input_ids(text)
        
        # Generate tokens
        model_output = self.model_request(input_ids, attention_mask)
        
        # Convert to audio
        audio, _ = self.player.get_waveform(model_output)
        
        print("Text-to-speech generation completed successfully!")
        return audio, text |