Simonlob commited on
Commit
41423b2
·
verified ·
1 Parent(s): 26adaf1

Update util.py

Browse files
Files changed (1) hide show
  1. util.py +51 -22
util.py CHANGED
@@ -1,4 +1,6 @@
1
  import torch
 
 
2
  from nemo.collections.tts.models import AudioCodecModel
3
  from dataclasses import dataclass
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -59,7 +61,6 @@ class NemoAudioPlayer:
59
  if not (start_of_speech_flag and end_of_speech_flag):
60
  raise ValueError('Special speech tokens not found in output!')
61
 
62
- print("Output validation passed - speech tokens found")
63
 
64
  def get_nano_codes(self, out_ids):
65
  """Extract nano codec tokens from model output"""
@@ -88,8 +89,6 @@ class NemoAudioPlayer:
88
 
89
  audio_codes = audio_codes.T.unsqueeze(0)
90
  len_ = torch.tensor([audio_codes.shape[-1]])
91
-
92
- print(f"Extracted audio codes shape: {audio_codes.shape}")
93
  return audio_codes, len_
94
 
95
  def get_text(self, out_ids):
@@ -107,8 +106,7 @@ class NemoAudioPlayer:
107
  def get_waveform(self, out_ids):
108
  """Convert model output to audio waveform"""
109
  out_ids = out_ids.flatten()
110
- print("Starting waveform generation...")
111
-
112
  # Validate output
113
  self.output_validation(out_ids)
114
 
@@ -116,15 +114,12 @@ class NemoAudioPlayer:
116
  audio_codes, len_ = self.get_nano_codes(out_ids)
117
  audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device)
118
 
119
- print("Decoding audio with NeMo codec...")
120
  with torch.inference_mode():
121
  reconstructed_audio, _ = self.nemo_codec_model.decode(
122
  tokens=audio_codes,
123
  tokens_len=len_
124
  )
125
  output_audio = reconstructed_audio.cpu().detach().numpy().squeeze()
126
-
127
- print(f"Generated audio shape: {output_audio.shape}")
128
 
129
  if self.text_tokenizer_name:
130
  text = self.get_text(out_ids)
@@ -175,18 +170,12 @@ class KaniModel:
175
  # Concatenate tokens
176
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
177
  attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64)
178
-
179
- print(f"Input sequence length: {modified_input_ids.shape[1]}")
180
  return modified_input_ids, attention_mask
181
 
182
  def model_request(self, input_ids: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
183
  """Generate tokens using the model"""
184
  input_ids = input_ids.to(self.device)
185
  attention_mask = attention_mask.to(self.device)
186
-
187
- print("Starting model generation...")
188
- print(f"Generation parameters: max_tokens={self.conf.max_new_tokens}, "
189
- f"temp={self.conf.temperature}, top_p={self.conf.top_p}")
190
 
191
  with torch.no_grad():
192
  generated_ids = self.model.generate(
@@ -201,14 +190,10 @@ class KaniModel:
201
  eos_token_id=self.player.end_of_speech,
202
  pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
203
  )
204
-
205
- print(f"Generated sequence length: {generated_ids.shape[1]}")
206
  return generated_ids.to('cpu')
207
 
208
  def run_model(self, text: str):
209
- """Complete pipeline: text -> tokens -> generation -> audio"""
210
- print(f"Processing text: '{text[:50]}{'...' if len(text) > 50 else ''}'")
211
-
212
  # Prepare input
213
  input_ids, attention_mask = self.get_input_ids(text)
214
 
@@ -217,6 +202,50 @@ class KaniModel:
217
 
218
  # Convert to audio
219
  audio, _ = self.player.get_waveform(model_output)
220
-
221
- print("Text-to-speech generation completed successfully!")
222
- return audio, text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import librosa
3
+ import requests
4
  from nemo.collections.tts.models import AudioCodecModel
5
  from dataclasses import dataclass
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
61
  if not (start_of_speech_flag and end_of_speech_flag):
62
  raise ValueError('Special speech tokens not found in output!')
63
 
 
64
 
65
  def get_nano_codes(self, out_ids):
66
  """Extract nano codec tokens from model output"""
 
89
 
90
  audio_codes = audio_codes.T.unsqueeze(0)
91
  len_ = torch.tensor([audio_codes.shape[-1]])
 
 
92
  return audio_codes, len_
93
 
94
  def get_text(self, out_ids):
 
106
  def get_waveform(self, out_ids):
107
  """Convert model output to audio waveform"""
108
  out_ids = out_ids.flatten()
109
+
 
110
  # Validate output
111
  self.output_validation(out_ids)
112
 
 
114
  audio_codes, len_ = self.get_nano_codes(out_ids)
115
  audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device)
116
 
 
117
  with torch.inference_mode():
118
  reconstructed_audio, _ = self.nemo_codec_model.decode(
119
  tokens=audio_codes,
120
  tokens_len=len_
121
  )
122
  output_audio = reconstructed_audio.cpu().detach().numpy().squeeze()
 
 
123
 
124
  if self.text_tokenizer_name:
125
  text = self.get_text(out_ids)
 
170
  # Concatenate tokens
171
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
172
  attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64)
 
 
173
  return modified_input_ids, attention_mask
174
 
175
  def model_request(self, input_ids: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
176
  """Generate tokens using the model"""
177
  input_ids = input_ids.to(self.device)
178
  attention_mask = attention_mask.to(self.device)
 
 
 
 
179
 
180
  with torch.no_grad():
181
  generated_ids = self.model.generate(
 
190
  eos_token_id=self.player.end_of_speech,
191
  pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
192
  )
 
 
193
  return generated_ids.to('cpu')
194
 
195
  def run_model(self, text: str):
196
+ """Complete pipeline: text -> tokens -> generation -> audio"""
 
 
197
  # Prepare input
198
  input_ids, attention_mask = self.get_input_ids(text)
199
 
 
202
 
203
  # Convert to audio
204
  audio, _ = self.player.get_waveform(model_output)
205
+ return audio, text
206
+
207
+
208
+ class Demo:
209
+ def __init__(self):
210
+ self.audio_dir = './audio_examples'
211
+ os.makedirs(self.audio_dir, exist_ok=True)
212
+ self.sentences = [
213
+ "You make my days brighter, and my wildest dreams feel like reality. How do you do that?",
214
+ "Anyway, um, so, um, tell me, tell me all about her. I mean, what's she like? Is she really, you know, pretty?",
215
+ "Great, and just a couple quick questions so we can match you with the right buyer. Is your home address still 330 East Charleston Road?",
216
+ "No, that does not make you a failure. No, sweetie, no. It just, uh, it just means that you're having a tough time...",
217
+ "Oh, yeah. I mean did you want to get a quick snack together or maybe something before you go?",
218
+ "I-- Oh, I am such an idiot sometimes. I'm so sorry. Um, I-I don't know where my head's at.",
219
+ "Got it. $300,000. I can definitely help you get a very good price for your property by selecting a realtor.",
220
+ "Holy fu- Oh my God! Don't you understand how dangerous it is, huh?"
221
+ ]
222
+ self.urls = [
223
+ 'https://www.nineninesix.ai/examples/kani/1.wav',
224
+ 'https://www.nineninesix.ai/examples/kani/2.wav',
225
+ 'https://www.nineninesix.ai/examples/kani/5.wav',
226
+ 'https://www.nineninesix.ai/examples/kani/6.wav',
227
+ 'https://www.nineninesix.ai/examples/kani/3.wav',
228
+ 'https://www.nineninesix.ai/examples/kani/7.wav',
229
+ 'https://www.nineninesix.ai/examples/kani/4.wav',
230
+ 'https://www.nineninesix.ai/examples/kani/8.wav'
231
+ ]
232
+
233
+ def download_audio(self, url: str, filename: str):
234
+ filepath = os.path.join(self.audio_dir, filename)
235
+ if not os.path.exists(filepath):
236
+ r = requests.get(url)
237
+ r.raise_for_status()
238
+ with open(filepath, 'wb') as f:
239
+ f.write(r.content)
240
+ return filepath
241
+
242
+ def get_audio(self, filepath: str):
243
+ return librosa.load(filepath, sr=22050)
244
+
245
+ def __call__(self):
246
+ examples = {}
247
+ for idx, (sentence, url) in enumerate(zip(self.sentences, self.urls), start=1):
248
+ filename = f"{idx}.wav"
249
+ filepath = self.download_audio(url, filename)
250
+ examples[sentence] = self.get_audio(filepath)
251
+ return examples