File size: 8,849 Bytes
164603c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import torch
from nemo.collections.tts.models import AudioCodecModel
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
import os


@dataclass
class Config:
    model_name: str = "nineninesix/lfm-nano-codec-tts-exp-4-large-61468-st"
    audiocodec_name: str = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps"
    device_map: str = "auto"
    tokeniser_length: int = 64400
    start_of_text: int = 1
    end_of_text: int = 2
    max_new_tokens: int = 2000
    temperature: float = .6
    top_p: float = .95
    repetition_penalty: float = 1.1


class NemoAudioPlayer:
    def __init__(self, config, text_tokenizer_name: str = None) -> None:
        self.conf = config
        print(f"Loading NeMo codec model: {self.conf.audiocodec_name}")
        
        # Load NeMo codec model
        self.nemo_codec_model = AudioCodecModel.from_pretrained(
            self.conf.audiocodec_name
        ).eval()
        
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Moving NeMo codec to device: {self.device}")
        self.nemo_codec_model.to(self.device)
        
        self.text_tokenizer_name = text_tokenizer_name
        if self.text_tokenizer_name:
            self.tokenizer = AutoTokenizer.from_pretrained(self.text_tokenizer_name)

        # Token configuration
        self.tokeniser_length = self.conf.tokeniser_length
        self.start_of_text = self.conf.start_of_text
        self.end_of_text = self.conf.end_of_text
        self.start_of_speech = self.tokeniser_length + 1
        self.end_of_speech = self.tokeniser_length + 2
        self.start_of_human = self.tokeniser_length + 3
        self.end_of_human = self.tokeniser_length + 4
        self.start_of_ai = self.tokeniser_length + 5
        self.end_of_ai = self.tokeniser_length + 6
        self.pad_token = self.tokeniser_length + 7
        self.audio_tokens_start = self.tokeniser_length + 10
        self.codebook_size = 4032

    def output_validation(self, out_ids):
        """Validate that output contains required speech tokens"""
        start_of_speech_flag = self.start_of_speech in out_ids
        end_of_speech_flag = self.end_of_speech in out_ids
        
        if not (start_of_speech_flag and end_of_speech_flag):
            raise ValueError('Special speech tokens not found in output!')
        
        print("Output validation passed - speech tokens found")

    def get_nano_codes(self, out_ids):
        """Extract nano codec tokens from model output"""
        try:
            start_a_idx = (out_ids == self.start_of_speech).nonzero(as_tuple=True)[0].item()
            end_a_idx = (out_ids == self.end_of_speech).nonzero(as_tuple=True)[0].item()
        except IndexError:
            raise ValueError('Speech start/end tokens not found!')
            
        if start_a_idx >= end_a_idx:
            raise ValueError('Invalid audio codes sequence!')

        audio_codes = out_ids[start_a_idx + 1: end_a_idx]
        
        if len(audio_codes) % 4:
            raise ValueError('Audio codes length must be multiple of 4!')
            
        audio_codes = audio_codes.reshape(-1, 4)
        
        # Decode audio codes
        audio_codes = audio_codes - torch.tensor([self.codebook_size * i for i in range(4)])
        audio_codes = audio_codes - self.audio_tokens_start
        
        if (audio_codes < 0).sum().item() > 0:
            raise ValueError('Invalid audio tokens detected!')

        audio_codes = audio_codes.T.unsqueeze(0)
        len_ = torch.tensor([audio_codes.shape[-1]])
        
        print(f"Extracted audio codes shape: {audio_codes.shape}")
        return audio_codes, len_

    def get_text(self, out_ids):
        """Extract text from model output"""
        try:
            start_t_idx = (out_ids == self.start_of_text).nonzero(as_tuple=True)[0].item()
            end_t_idx = (out_ids == self.end_of_text).nonzero(as_tuple=True)[0].item()
        except IndexError:
            raise ValueError('Text start/end tokens not found!')
            
        txt_tokens = out_ids[start_t_idx: end_t_idx + 1]
        text = self.tokenizer.decode(txt_tokens, skip_special_tokens=True)
        return text

    def get_waveform(self, out_ids):
        """Convert model output to audio waveform"""
        out_ids = out_ids.flatten()
        print("Starting waveform generation...")
        
        # Validate output
        self.output_validation(out_ids)
        
        # Extract audio codes
        audio_codes, len_ = self.get_nano_codes(out_ids)
        audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device)
        
        print("Decoding audio with NeMo codec...")
        with torch.inference_mode():
            reconstructed_audio, _ = self.nemo_codec_model.decode(
                tokens=audio_codes, 
                tokens_len=len_
            )
            output_audio = reconstructed_audio.cpu().detach().numpy().squeeze()

        print(f"Generated audio shape: {output_audio.shape}")
        
        if self.text_tokenizer_name:
            text = self.get_text(out_ids)
            return output_audio, text
        else:
            return output_audio, None


class KaniModel:
    def __init__(self, config, player: NemoAudioPlayer, token: str) -> None:
        self.conf = config
        self.player = player
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        print(f"Loading model: {self.conf.model_name}")
        print(f"Target device: {self.device}")
        
        # Load model with proper configuration
        self.model = AutoModelForCausalLM.from_pretrained(
            self.conf.model_name,
            torch_dtype=torch.bfloat16,
            device_map=self.conf.device_map,
            token=token,
            trust_remote_code=True  # May be needed for some models
        )

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.conf.model_name, 
            token=token,
            trust_remote_code=True
        )
        
        print(f"Model loaded successfully on device: {next(self.model.parameters()).device}")

    def get_input_ids(self, text_prompt: str) -> tuple[torch.tensor]:
        """Prepare input tokens for the model"""
        START_OF_HUMAN = self.player.start_of_human
        END_OF_TEXT = self.player.end_of_text
        END_OF_HUMAN = self.player.end_of_human

        # Tokenize input text
        input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids
        
        # Add special tokens
        start_token = torch.tensor([[START_OF_HUMAN]], dtype=torch.int64)
        end_tokens = torch.tensor([[END_OF_TEXT, END_OF_HUMAN]], dtype=torch.int64)
        
        # Concatenate tokens
        modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
        attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64)
        
        print(f"Input sequence length: {modified_input_ids.shape[1]}")
        return modified_input_ids, attention_mask

    def model_request(self, input_ids: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
        """Generate tokens using the model"""
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)

        print("Starting model generation...")
        print(f"Generation parameters: max_tokens={self.conf.max_new_tokens}, "
              f"temp={self.conf.temperature}, top_p={self.conf.top_p}")
        
        with torch.no_grad():
            generated_ids = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=self.conf.max_new_tokens,
                do_sample=True,
                temperature=self.conf.temperature,
                top_p=self.conf.top_p,
                repetition_penalty=self.conf.repetition_penalty,
                num_return_sequences=1,
                eos_token_id=self.player.end_of_speech,
                pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
            )
        
        print(f"Generated sequence length: {generated_ids.shape[1]}")
        return generated_ids.to('cpu')

    def run_model(self, text: str):
        """Complete pipeline: text -> tokens -> generation -> audio"""
        print(f"Processing text: '{text[:50]}{'...' if len(text) > 50 else ''}'")
        
        # Prepare input
        input_ids, attention_mask = self.get_input_ids(text)
        
        # Generate tokens
        model_output = self.model_request(input_ids, attention_mask)
        
        # Convert to audio
        audio, _ = self.player.get_waveform(model_output)
        
        print("Text-to-speech generation completed successfully!")
        return audio, text