Stardust-minus commited on
Commit
7a4bfa1
·
verified ·
1 Parent(s): b232d49

Upload folder using huggingface_hub

Browse files
app.py CHANGED
@@ -313,4 +313,4 @@ if __name__ == "__main__":
313
  inference_fct = get_inference_wrapper(inference_engine)
314
 
315
  app = build_app(inference_fct, args.theme)
316
- app.queue(api_open=True).launch(show_error=True, show_api=True)
 
313
  inference_fct = get_inference_wrapper(inference_engine)
314
 
315
  app = build_app(inference_fct, args.theme)
316
+ app.queue(api_open=True).launch(show_error=True, show_api=True, server_name="0.0.0.0", server_port=18888)
fish_speech/content_sequence.py CHANGED
@@ -271,7 +271,7 @@ class ContentSequence:
271
  self: "ContentSequence",
272
  tokenizer: FishTokenizer,
273
  num_codebooks: int,
274
- ) -> torch.Tensor:
275
  encoded = self.encode(tokenizer, add_shift=False)
276
  tokens = encoded.tokens
277
  values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
@@ -280,8 +280,9 @@ class ContentSequence:
280
  if (encoded.vq_parts is None or len(encoded.vq_parts) == 0) and (
281
  encoded.audio_parts is None or len(encoded.audio_parts) == 0
282
  ):
283
- return values
284
 
 
285
  if encoded.vq_parts is not None and len(encoded.vq_parts) > 0:
286
  vq_parts = encoded.vq_parts
287
  vq_parts = torch.cat(vq_parts, dim=1)
@@ -290,7 +291,11 @@ class ContentSequence:
290
  )
291
  values[1:, encoded.vq_mask_tokens] = vq_parts
292
 
293
- return values
 
 
 
 
294
 
295
  def visualize(
296
  self: "ContentSequence",
 
271
  self: "ContentSequence",
272
  tokenizer: FishTokenizer,
273
  num_codebooks: int,
274
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
275
  encoded = self.encode(tokenizer, add_shift=False)
276
  tokens = encoded.tokens
277
  values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
 
280
  if (encoded.vq_parts is None or len(encoded.vq_parts) == 0) and (
281
  encoded.audio_parts is None or len(encoded.audio_parts) == 0
282
  ):
283
+ return values, None, None
284
 
285
+ audio_parts = audio_masks = None
286
  if encoded.vq_parts is not None and len(encoded.vq_parts) > 0:
287
  vq_parts = encoded.vq_parts
288
  vq_parts = torch.cat(vq_parts, dim=1)
 
291
  )
292
  values[1:, encoded.vq_mask_tokens] = vq_parts
293
 
294
+ if encoded.audio_parts is not None and len(encoded.audio_parts) > 0:
295
+ audio_parts = torch.cat(encoded.audio_parts, dim=0)
296
+ audio_masks = encoded.audio_masks[None, :]
297
+
298
+ return values, audio_masks, audio_parts
299
 
300
  def visualize(
301
  self: "ContentSequence",
fish_speech/models/text2semantic/inference.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import queue
3
  import threading
4
  import time
 
5
  from contextlib import nullcontext
6
  from dataclasses import dataclass
7
  from pathlib import Path
@@ -35,6 +36,7 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
35
  from torch.nn.attention import SDPBackend, sdpa_kernel
36
 
37
  from fish_speech.models.text2semantic.llama import (
 
38
  DualARTransformer,
39
  NaiveTransformer,
40
  )
@@ -49,19 +51,19 @@ def multinomial_sample_one_no_sync(
49
 
50
  def logits_to_probs(
51
  logits,
 
 
 
52
  previous_tokens: Optional[torch.Tensor] = None,
53
- temperature: torch.Tensor = 1.0,
54
- top_p: torch.Tensor = 1.0,
55
- repetition_penalty: torch.Tensor = 1.0,
56
  ) -> torch.Tensor:
57
  # Apply repetition penalty
58
  if previous_tokens is not None:
59
  previous_tokens = previous_tokens.long()
60
- score = torch.gather(logits, dim=0, index=previous_tokens)
61
  score = torch.where(
62
  score < 0, score * repetition_penalty, score / repetition_penalty
63
  )
64
- logits.scatter_(dim=0, index=previous_tokens, src=score)
65
 
66
  # Apply top-p sampling
67
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
@@ -69,11 +71,10 @@ def logits_to_probs(
69
  sorted_indices_to_remove = cum_probs > top_p
70
  sorted_indices_to_remove[0] = False # keep at least one option
71
  indices_to_remove = sorted_indices_to_remove.scatter(
72
- dim=0, index=sorted_indices, src=sorted_indices_to_remove
73
  )
74
  logits = logits.masked_fill(indices_to_remove, -float("Inf"))
75
-
76
- logits = logits / max(temperature, 1e-5)
77
 
78
  probs = torch.nn.functional.softmax(logits, dim=-1)
79
  return probs
@@ -81,11 +82,17 @@ def logits_to_probs(
81
 
82
  def sample(
83
  logits,
 
 
 
84
  previous_tokens: Optional[torch.Tensor] = None,
85
- **sampling_kwargs,
86
  ) -> Tuple[torch.Tensor, torch.Tensor]:
87
  probs = logits_to_probs(
88
- logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
 
 
 
 
89
  )
90
  idx_next = multinomial_sample_one_no_sync(probs)
91
  return idx_next, probs
@@ -95,40 +102,35 @@ def decode_one_token_ar(
95
  model: DualARTransformer,
96
  x: torch.Tensor,
97
  input_pos: torch.Tensor,
 
 
 
 
 
98
  previous_tokens: torch.Tensor = None,
99
- **sampling_kwargs,
100
  ) -> torch.Tensor:
101
- """
102
- Generate one token using dual autoregressive transformer for text-to-speech.
103
-
104
- First generates semantic tokens, then generates acoustic codebook tokens sequentially.
105
-
106
- Args:
107
- x: Input token tensor (1, num_codebooks+1, seq_len)
108
- input_pos: Position indices for input tokens (seq_len,)
109
- temperature/top_p/repetition_penalty: Sampling parameters (1, 1)
110
- previous_tokens: Previous tokens for repetition penalty (1, num_codebooks+1, history_seq_len)
111
- audio_masks/audio_parts: Audio conditioning tensors (num_codebooks, seq_len)
112
-
113
- Returns:
114
- Generated tokens tensor (num_codebooks+1, 1) - one token per codebook
115
- """
116
- x = model.forward_generate(x, input_pos)
117
-
118
- sampling_kwargs_main = sampling_kwargs.copy()
119
 
120
  codebooks = [
121
  sample(
122
- x.logits,
 
 
 
123
  previous_tokens=(
124
- previous_tokens[0] if previous_tokens is not None else None
125
- ), # Disable repetition penalty for the token codebook
126
- **sampling_kwargs_main,
127
  )[0]
128
  ]
129
 
130
- hidden_states = x.hidden_states
131
-
132
  # Cleanup the cache
133
  for layer in model.fast_layers:
134
  layer.attention.kv_cache.k_cache.fill_(0)
@@ -146,22 +148,27 @@ def decode_one_token_ar(
146
  [codebook_idx], device=hidden_states.device, dtype=torch.long
147
  )
148
  logits = model.forward_generate_fast(hidden_states, input_pos)
149
- chunked_logits = logits[..., :1024]
 
 
 
150
  a = sample(
151
- chunked_logits,
 
 
 
152
  previous_tokens=(
153
  previous_tokens[codebook_idx + 1]
154
  if previous_tokens is not None
155
  else None
156
  ),
157
- **sampling_kwargs,
158
  )[0]
 
159
  hidden_states = model.fast_embeddings(a)
160
  codebooks.append(a)
161
 
162
- codebooks = torch.stack(codebooks, dim=0)
163
-
164
- return codebooks
165
 
166
 
167
  def decode_n_tokens(
@@ -169,24 +176,13 @@ def decode_n_tokens(
169
  cur_token: torch.Tensor,
170
  input_pos: torch.Tensor,
171
  num_new_tokens: int,
 
 
 
 
 
172
  decode_one_token=decode_one_token_ar,
173
- **sampling_kwargs,
174
  ):
175
- """
176
- Generate n tokens iteratively using the model.
177
-
178
- Args:
179
- model: The transformer model
180
- cur_token: Current token tensor of shape (1, num_codebooks+1, seq_len)
181
- input_pos: Current input position tensor
182
- num_new_tokens: Number of new tokens to generate
183
- semantic_ids: List of semantic token IDs
184
- decode_one_token: Function to decode one token
185
- **sampling_kwargs: Additional sampling parameters
186
-
187
- Returns:
188
- Generated tokens tensor of shape (num_codebooks+1, generated_len)
189
- """
190
  previous_tokens = torch.zeros(
191
  (model.config.num_codebooks + 1, model.config.max_seq_len),
192
  dtype=torch.int,
@@ -201,13 +197,19 @@ def decode_n_tokens(
201
  else:
202
  window = previous_tokens[:, i - win_size : i]
203
 
204
- with sdpa_kernel(SDPBackend.MATH):
 
 
205
  next_token = decode_one_token(
206
  model=model,
207
  x=cur_token,
208
  input_pos=input_pos,
209
  previous_tokens=window,
210
- **sampling_kwargs,
 
 
 
 
211
  ).clone()
212
 
213
  input_pos += 1
@@ -226,33 +228,31 @@ def decode_n_tokens(
226
  @torch.inference_mode()
227
  def generate(
228
  *,
229
- model: NaiveTransformer,
230
  prompt: torch.Tensor,
231
  max_new_tokens: int,
 
 
232
  decode_one_token=decode_one_token_ar,
 
233
  **sampling_kwargs,
234
- ) -> torch.Tensor:
235
  """
236
- Generate tokens from text prompt using the transformer model.
237
-
238
- Args:
239
- model: The transformer model for generation
240
- prompt: Input token tensor of shape (num_codebooks+1, seq_len)
241
- max_new_tokens: Maximum number of new tokens to generate
242
- decode_one_token: Function to decode one token at a time
243
- **sampling_kwargs: Additional sampling parameters (temperature, top_p, repetition_penalty)
244
-
245
- Returns:
246
- Generated sequence tensor of shape (num_codebooks+1, total_seq_len)
247
- where total_seq_len = original_seq_len + generated_tokens_len
248
  """
249
 
 
250
  T = prompt.size(1)
 
 
 
 
 
 
251
 
252
  if max_new_tokens:
253
  if T + max_new_tokens > model.config.max_seq_len:
254
  max_new_tokens = model.config.max_seq_len - T
255
- logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
256
 
257
  T_new = T + max_new_tokens
258
  else:
@@ -260,23 +260,40 @@ def generate(
260
  max_new_tokens = T_new - T
261
 
262
  device, dtype = prompt.device, prompt.dtype
 
 
 
 
 
 
263
 
264
  codebook_dim = 1 + model.config.num_codebooks
 
265
  empty = torch.empty(
266
  (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
267
  )
268
  empty[:, :T] = prompt
269
  seq = empty
270
- input_pos = torch.arange(0, T, device=device)
271
 
272
- # Use non-accelerated version for now, to avoid compilation overhead
 
 
 
 
 
 
 
273
  prefill_decode = decode_one_token_ar
274
 
275
  first_token = prefill_decode(
276
  model,
277
  prompt.view(1, codebook_dim, -1),
278
  input_pos,
279
- **sampling_kwargs,
 
 
 
 
280
  )
281
  seq[:, T : T + 1] = first_token
282
 
@@ -286,12 +303,15 @@ def generate(
286
  first_token.view(1, codebook_dim, -1),
287
  input_pos,
288
  max_new_tokens - 1,
 
 
 
 
 
289
  decode_one_token=decode_one_token,
290
- **sampling_kwargs,
291
  )
292
  seq = seq[:, : T + 1 + x.size(1)]
293
  seq[:, T + 1 :] = x
294
-
295
  return seq
296
 
297
 
@@ -303,17 +323,27 @@ def init_model(checkpoint_path, device, precision, compile=False):
303
 
304
  if isinstance(model, DualARTransformer):
305
  decode_one_token = decode_one_token_ar
 
306
  logger.info("Using DualARTransformer")
307
  else:
308
- raise ValueError("Model is not a DualARTransformer")
 
 
 
 
 
 
 
 
309
 
310
  if compile:
311
  logger.info("Compiling function...")
312
  decode_one_token = torch.compile(
313
  decode_one_token,
314
- fullgraph=True,
315
  backend="inductor" if torch.cuda.is_available() else "aot_eager",
316
  mode="reduce-overhead" if torch.cuda.is_available() else None,
 
317
  )
318
 
319
  return model.eval(), decode_one_token
@@ -362,27 +392,7 @@ def generate_long(
362
  tokenizer = model.tokenizer
363
  base_content_sequence = ContentSequence(modality="interleave")
364
 
365
- texts = split_text(text, chunk_length) if iterative_prompt else [text]
366
  max_length = model.config.max_seq_len
367
-
368
- # if use_prompt:
369
- # base_content_sequence.append(
370
- # [
371
- # TextPart(text=prompt_text[0]),
372
- # VQPart(codes=prompt_tokens[0]),
373
- # ],
374
- # add_end=True,
375
- # )
376
-
377
- # for text in texts:
378
- # content_sequence = ContentSequence(modality=None)
379
- # base_content_sequence.append(
380
- # [
381
- # TextPart(text=text),
382
- # ],
383
- # add_end=True,
384
- # )
385
-
386
  if use_prompt:
387
  for t, c in zip(prompt_text, prompt_tokens):
388
  base_content_sequence.append(
@@ -391,26 +401,24 @@ def generate_long(
391
  VQPart(codes=c),
392
  ],
393
  add_end=True,
 
394
  )
 
 
 
 
 
 
 
395
 
396
- encoded_prompts = base_content_sequence.encode_for_inference(
397
  tokenizer, num_codebooks=model.config.num_codebooks
398
  )
399
- if encoded_prompts.size(1) > max_length - 2048:
400
- raise ValueError(
401
- f"Prompt is too long: {encoded_prompts.size(1)} > {max_length - 2048}"
402
- )
403
 
404
- encoded = []
405
- for text in texts:
406
- content_sequence = ContentSequence(modality="text")
407
- content_sequence.append(TextPart(text=text))
408
- encoded.append(
409
- content_sequence.encode_for_inference(
410
- tokenizer, num_codebooks=model.config.num_codebooks
411
- )
412
- )
413
- logger.info(f"Encoded text: {text}")
414
 
415
  # Move temperature, top_p, repetition_penalty to device
416
  # This is important so that changing params doesn't trigger recompile
@@ -426,70 +434,53 @@ def generate_long(
426
 
427
  global_encoded = []
428
  seg_idx = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
- while seg_idx < len(encoded):
431
- logger.info(
432
- f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
433
- )
434
-
435
- seg = encoded[seg_idx]
436
- global_encoded.append(seg)
437
-
438
- if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
439
- cat_encoded = torch.cat(
440
- [encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
441
- )
442
- else:
443
- cat_encoded = torch.cat([encoded_prompts, seg], dim=1)
444
-
445
- cat_encoded = cat_encoded.to(device=device)
446
- prompt_length = cat_encoded.size(1)
447
-
448
- t0 = time.perf_counter()
449
- y = generate(
450
- model=model,
451
- prompt=cat_encoded,
452
- max_new_tokens=max_new_tokens,
453
- decode_one_token=decode_one_token,
454
- temperature=temperature,
455
- top_p=top_p,
456
- repetition_penalty=repetition_penalty,
457
- )
458
 
459
- if sample_idx == 0 and seg_idx == 0 and compile:
460
- logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
461
 
462
- if torch.cuda.is_available():
463
- torch.cuda.synchronize()
464
 
465
- t = time.perf_counter() - t0
 
 
 
 
 
466
 
467
- tokens_generated = y.size(1) - prompt_length
468
- tokens_sec = tokens_generated / t
469
- logger.info(
470
- f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
471
- )
472
  logger.info(
473
- f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
474
  )
475
 
476
- if torch.cuda.is_available():
477
- logger.info(
478
- f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
479
- )
480
-
481
- # Put the generated tokens
482
- # since there is <im_end>, we remove last token
483
- codes = y[1:, prompt_length:-1].clone()
484
- assert (codes >= 0).all(), f"Negative code found"
485
 
486
- decoded = y[:, prompt_length:].clone()
487
- # But for global encoding, we should keep the <im_end> token
488
 
489
- global_encoded.append(decoded.cpu())
490
- assert (codes >= 0).all(), f"Negative code found: {codes}"
491
- yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
492
- seg_idx += 1
493
 
494
  # This indicates the end of the current sample
495
  yield GenerateResponse(action="next")
@@ -544,6 +535,7 @@ def launch_thread_safe_queue(
544
  WrappedGenerateResponse(status="success", response=chunk)
545
  )
546
  except Exception as e:
 
547
  response_queue.put(WrappedGenerateResponse(status="error", response=e))
548
 
549
  threading.Thread(target=worker, daemon=True).start()
 
2
  import queue
3
  import threading
4
  import time
5
+ import traceback
6
  from contextlib import nullcontext
7
  from dataclasses import dataclass
8
  from pathlib import Path
 
36
  from torch.nn.attention import SDPBackend, sdpa_kernel
37
 
38
  from fish_speech.models.text2semantic.llama import (
39
+ BaseTransformer,
40
  DualARTransformer,
41
  NaiveTransformer,
42
  )
 
51
 
52
  def logits_to_probs(
53
  logits,
54
+ temperature: torch.Tensor,
55
+ top_p: torch.Tensor,
56
+ repetition_penalty: torch.Tensor,
57
  previous_tokens: Optional[torch.Tensor] = None,
 
 
 
58
  ) -> torch.Tensor:
59
  # Apply repetition penalty
60
  if previous_tokens is not None:
61
  previous_tokens = previous_tokens.long()
62
+ score = torch.gather(logits, dim=-1, index=previous_tokens)
63
  score = torch.where(
64
  score < 0, score * repetition_penalty, score / repetition_penalty
65
  )
66
+ logits.scatter_(dim=-1, index=previous_tokens, src=score)
67
 
68
  # Apply top-p sampling
69
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
 
71
  sorted_indices_to_remove = cum_probs > top_p
72
  sorted_indices_to_remove[0] = False # keep at least one option
73
  indices_to_remove = sorted_indices_to_remove.scatter(
74
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
75
  )
76
  logits = logits.masked_fill(indices_to_remove, -float("Inf"))
77
+ logits = logits / torch.clip(temperature, min=1e-5)
 
78
 
79
  probs = torch.nn.functional.softmax(logits, dim=-1)
80
  return probs
 
82
 
83
  def sample(
84
  logits,
85
+ temperature: torch.Tensor,
86
+ top_p: torch.Tensor,
87
+ repetition_penalty: torch.Tensor,
88
  previous_tokens: Optional[torch.Tensor] = None,
 
89
  ) -> Tuple[torch.Tensor, torch.Tensor]:
90
  probs = logits_to_probs(
91
+ logits=logits[0, -1],
92
+ temperature=temperature,
93
+ top_p=top_p,
94
+ repetition_penalty=repetition_penalty,
95
+ previous_tokens=previous_tokens,
96
  )
97
  idx_next = multinomial_sample_one_no_sync(probs)
98
  return idx_next, probs
 
102
  model: DualARTransformer,
103
  x: torch.Tensor,
104
  input_pos: torch.Tensor,
105
+ temperature: torch.Tensor,
106
+ top_p: torch.Tensor,
107
+ repetition_penalty: torch.Tensor,
108
+ audio_masks: torch.Tensor,
109
+ audio_parts: torch.Tensor,
110
  previous_tokens: torch.Tensor = None,
 
111
  ) -> torch.Tensor:
112
+ # print(x, torch.count_nonzero(vq_masks))
113
+ x = model.forward_generate(
114
+ x,
115
+ input_pos,
116
+ audio_masks=audio_masks,
117
+ audio_parts=audio_parts,
118
+ )
119
+ logits = x.logits # [:, -1:]
120
+ hidden_states = x.hidden_states # [:, -1:]
 
 
 
 
 
 
 
 
 
121
 
122
  codebooks = [
123
  sample(
124
+ logits,
125
+ temperature=temperature,
126
+ top_p=top_p,
127
+ repetition_penalty=repetition_penalty,
128
  previous_tokens=(
129
+ previous_tokens[:, 0] if previous_tokens is not None else None
130
+ ),
 
131
  )[0]
132
  ]
133
 
 
 
134
  # Cleanup the cache
135
  for layer in model.fast_layers:
136
  layer.attention.kv_cache.k_cache.fill_(0)
 
148
  [codebook_idx], device=hidden_states.device, dtype=torch.long
149
  )
150
  logits = model.forward_generate_fast(hidden_states, input_pos)
151
+
152
+ short_logits = logits[:, :, :1024]
153
+
154
+ # Convert logits to probs
155
  a = sample(
156
+ short_logits,
157
+ temperature=temperature,
158
+ top_p=top_p,
159
+ repetition_penalty=repetition_penalty,
160
  previous_tokens=(
161
  previous_tokens[codebook_idx + 1]
162
  if previous_tokens is not None
163
  else None
164
  ),
 
165
  )[0]
166
+
167
  hidden_states = model.fast_embeddings(a)
168
  codebooks.append(a)
169
 
170
+ codebooks = torch.stack(codebooks, dim=1)
171
+ return codebooks.T
 
172
 
173
 
174
  def decode_n_tokens(
 
176
  cur_token: torch.Tensor,
177
  input_pos: torch.Tensor,
178
  num_new_tokens: int,
179
+ temperature: torch.Tensor,
180
+ top_p: torch.Tensor,
181
+ repetition_penalty: torch.Tensor,
182
+ audio_masks: torch.Tensor,
183
+ audio_parts: torch.Tensor,
184
  decode_one_token=decode_one_token_ar,
 
185
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  previous_tokens = torch.zeros(
187
  (model.config.num_codebooks + 1, model.config.max_seq_len),
188
  dtype=torch.int,
 
197
  else:
198
  window = previous_tokens[:, i - win_size : i]
199
 
200
+ with sdpa_kernel(
201
+ SDPBackend.MATH
202
+ ): # Actually better for Inductor to codegen attention here
203
  next_token = decode_one_token(
204
  model=model,
205
  x=cur_token,
206
  input_pos=input_pos,
207
  previous_tokens=window,
208
+ temperature=temperature,
209
+ top_p=top_p,
210
+ repetition_penalty=repetition_penalty,
211
+ audio_masks=audio_masks,
212
+ audio_parts=audio_parts,
213
  ).clone()
214
 
215
  input_pos += 1
 
228
  @torch.inference_mode()
229
  def generate(
230
  *,
231
+ model: BaseTransformer,
232
  prompt: torch.Tensor,
233
  max_new_tokens: int,
234
+ audio_masks: torch.Tensor,
235
+ audio_parts: torch.Tensor,
236
  decode_one_token=decode_one_token_ar,
237
+ num_samples: int = 1,
238
  **sampling_kwargs,
239
+ ):
240
  """
241
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
 
 
 
 
 
 
 
 
 
 
 
242
  """
243
 
244
+ # create an empty tensor of the expected final shape and fill in the current tokens
245
  T = prompt.size(1)
246
+ prompt = prompt[None].repeat(num_samples, 1, 1)
247
+
248
+ if T >= model.config.max_seq_len:
249
+ raise ValueError(
250
+ f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
251
+ )
252
 
253
  if max_new_tokens:
254
  if T + max_new_tokens > model.config.max_seq_len:
255
  max_new_tokens = model.config.max_seq_len - T
 
256
 
257
  T_new = T + max_new_tokens
258
  else:
 
260
  max_new_tokens = T_new - T
261
 
262
  device, dtype = prompt.device, prompt.dtype
263
+ with torch.device(device):
264
+ model.setup_caches(
265
+ max_batch_size=num_samples,
266
+ max_seq_len=model.config.max_seq_len,
267
+ dtype=next(model.parameters()).dtype,
268
+ )
269
 
270
  codebook_dim = 1 + model.config.num_codebooks
271
+ input_pos = torch.arange(0, T, device=device)
272
  empty = torch.empty(
273
  (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
274
  )
275
  empty[:, :T] = prompt
276
  seq = empty
 
277
 
278
+ temperature = torch.tensor(
279
+ sampling_kwargs["temperature"], device=device, dtype=torch.bfloat16
280
+ )
281
+ top_p = torch.tensor(sampling_kwargs["top_p"], device=device, dtype=torch.bfloat16)
282
+ repetition_penalty = torch.tensor(
283
+ sampling_kwargs["repetition_penalty"], device=device, dtype=torch.bfloat16
284
+ )
285
+
286
  prefill_decode = decode_one_token_ar
287
 
288
  first_token = prefill_decode(
289
  model,
290
  prompt.view(1, codebook_dim, -1),
291
  input_pos,
292
+ temperature,
293
+ top_p,
294
+ repetition_penalty,
295
+ audio_masks,
296
+ audio_parts,
297
  )
298
  seq[:, T : T + 1] = first_token
299
 
 
303
  first_token.view(1, codebook_dim, -1),
304
  input_pos,
305
  max_new_tokens - 1,
306
+ temperature=temperature,
307
+ top_p=top_p,
308
+ repetition_penalty=repetition_penalty,
309
+ audio_masks=audio_masks,
310
+ audio_parts=audio_parts,
311
  decode_one_token=decode_one_token,
 
312
  )
313
  seq = seq[:, : T + 1 + x.size(1)]
314
  seq[:, T + 1 :] = x
 
315
  return seq
316
 
317
 
 
323
 
324
  if isinstance(model, DualARTransformer):
325
  decode_one_token = decode_one_token_ar
326
+ prefill_n_tokens = decode_one_token_ar
327
  logger.info("Using DualARTransformer")
328
  else:
329
+ raise ValueError("Unsupported model type")
330
+
331
+ # Initialize cache
332
+ with torch.device(device):
333
+ model.setup_caches(
334
+ max_batch_size=1,
335
+ max_seq_len=model.config.max_seq_len,
336
+ dtype=next(model.parameters()).dtype,
337
+ )
338
 
339
  if compile:
340
  logger.info("Compiling function...")
341
  decode_one_token = torch.compile(
342
  decode_one_token,
343
+ # mode="max-autotune-no-cudagraphs",
344
  backend="inductor" if torch.cuda.is_available() else "aot_eager",
345
  mode="reduce-overhead" if torch.cuda.is_available() else None,
346
+ fullgraph=True,
347
  )
348
 
349
  return model.eval(), decode_one_token
 
392
  tokenizer = model.tokenizer
393
  base_content_sequence = ContentSequence(modality="interleave")
394
 
 
395
  max_length = model.config.max_seq_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  if use_prompt:
397
  for t, c in zip(prompt_text, prompt_tokens):
398
  base_content_sequence.append(
 
401
  VQPart(codes=c),
402
  ],
403
  add_end=True,
404
+ speaker=0,
405
  )
406
+ base_content_sequence.append(
407
+ [
408
+ TextPart(text=text),
409
+ ],
410
+ add_end=False,
411
+ speaker=0,
412
+ )
413
 
414
+ encoded, audio_masks, audio_parts = base_content_sequence.encode_for_inference(
415
  tokenizer, num_codebooks=model.config.num_codebooks
416
  )
417
+ if encoded.size(1) > max_length - 2048:
418
+ raise ValueError(f"Prompt is too long: {encoded.size(1)} > {max_length - 2048}")
 
 
419
 
420
+ encoded = encoded.to(device=device)
421
+ logger.info(f"Encoded text: {text}")
 
 
 
 
 
 
 
 
422
 
423
  # Move temperature, top_p, repetition_penalty to device
424
  # This is important so that changing params doesn't trigger recompile
 
434
 
435
  global_encoded = []
436
  seg_idx = 0
437
+ prompt_length = encoded.size(1)
438
+
439
+ t0 = time.perf_counter()
440
+ y = generate(
441
+ model=model,
442
+ prompt=encoded,
443
+ max_new_tokens=max_new_tokens,
444
+ audio_masks=audio_masks,
445
+ audio_parts=audio_parts,
446
+ decode_one_token=decode_one_token,
447
+ temperature=temperature,
448
+ top_p=top_p,
449
+ repetition_penalty=repetition_penalty,
450
+ )
451
 
452
+ if sample_idx == 0 and seg_idx == 0 and compile:
453
+ logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
+ if torch.cuda.is_available():
456
+ torch.cuda.synchronize()
457
 
458
+ t = time.perf_counter() - t0
 
459
 
460
+ tokens_generated = y.size(1) - prompt_length
461
+ tokens_sec = tokens_generated / t
462
+ logger.info(
463
+ f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
464
+ )
465
+ logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
466
 
467
+ if torch.cuda.is_available():
 
 
 
 
468
  logger.info(
469
+ f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
470
  )
471
 
472
+ # Put the generated tokens
473
+ # since there is <im_end>, we remove last token
474
+ codes = y[1:, prompt_length:-1].clone()
475
+ assert (codes >= 0).all(), f"Negative code found"
 
 
 
 
 
476
 
477
+ decoded = y[:, prompt_length:].clone()
478
+ # But for global encoding, we should keep the <im_end> token
479
 
480
+ global_encoded.append(decoded.cpu())
481
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
482
+ yield GenerateResponse(action="sample", codes=codes, text=text)
483
+ seg_idx += 1
484
 
485
  # This indicates the end of the current sample
486
  yield GenerateResponse(action="next")
 
535
  WrappedGenerateResponse(status="success", response=chunk)
536
  )
537
  except Exception as e:
538
+ logger.error(traceback.format_exc())
539
  response_queue.put(WrappedGenerateResponse(status="error", response=e))
540
 
541
  threading.Thread(target=worker, daemon=True).start()
fish_speech/models/text2semantic/llama.py CHANGED
@@ -320,9 +320,45 @@ class BaseTransformer(nn.Module):
320
  self,
321
  inp: Tensor,
322
  input_pos: Optional[Tensor] = None,
 
 
323
  return_all: bool = False,
324
  ) -> BaseTransformerForwardResult:
325
- x = self.embed(inp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
  if input_pos is None:
328
  input_pos = torch.arange(inp.shape[-1], device=x.device)
@@ -595,69 +631,69 @@ class DualARTransformer(BaseTransformer):
595
  def forward(
596
  self,
597
  inp: Tensor,
 
598
  key_padding_mask: Optional[Tensor] = None,
 
 
 
 
 
599
  ) -> TransformerForwardResult:
600
- parent_result = super().forward(inp, key_padding_mask)
 
 
 
 
 
 
 
601
  token_logits = parent_result.logits
602
  x = parent_result.hidden_states
603
- x = self.fast_project_in(x)
604
 
605
  # Fast transformer
606
  fast_seq_len = self.config.num_codebooks
607
  fast_mask = self.causal_mask[
608
  None, None, :fast_seq_len, :fast_seq_len
609
  ] # (B, N, Q, K)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
 
611
- # Drop the last token and rotate left
612
- codebooks = inp[:, 1:-1, 1:]
613
- codebooks = F.pad(codebooks, (0, 1), value=0)
614
  codebook_embeddings = self.fast_embeddings(codebooks)
615
  x = torch.cat([x[:, None], codebook_embeddings], dim=1)
616
- b, s = x.size(0), x.size(2)
617
- x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
618
-
619
- # Remove padded part
620
- codebooks = rearrange(codebooks, "b n s -> (b s) n")
621
- codebook_mask = (codebooks == 0).all(dim=-1)
622
-
623
- if torch.all(codebook_mask):
624
- # If all codebooks are padded, we keep first 8 to make sure the model runs
625
- codebook_mask[:8] = False
626
-
627
- x_bs, x_len = x.size(0), x.size(1)
628
- x = x[~codebook_mask]
629
 
630
  for layer in self.fast_layers:
631
  if self.config.use_gradient_checkpointing and self.training:
632
- x = checkpoint(
633
- layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
634
- )
635
  else:
636
- x = layer(x, self.fast_freqs_cis, fast_mask)
637
 
638
  # unflatten the batch and num_codebooks
639
  fast_out = self.fast_norm(x)
640
  codebook_logits = self.fast_output(fast_out)
641
 
642
- # Re-pad the codebook_logits
643
- buffer = torch.zeros(
644
- x_bs,
645
- x_len,
646
- codebook_logits.size(-1),
647
- device=codebook_logits.device,
648
- dtype=codebook_logits.dtype,
649
- )
650
- buffer[~codebook_mask] = codebook_logits
651
- codebook_logits = buffer
652
-
653
  assert codebook_logits.shape[1] == self.config.num_codebooks
654
- codebook_logits = rearrange(
655
- codebook_logits,
656
- "(b s) n d -> b s n d",
657
- b=b,
658
- s=s,
659
- n=self.config.num_codebooks,
660
- )
661
 
662
  return TransformerForwardResult(
663
  token_logits=token_logits,
@@ -668,7 +704,7 @@ class DualARTransformer(BaseTransformer):
668
  self, x: Tensor, input_pos: Optional[Tensor] = None
669
  ) -> Tensor:
670
  # Fast transformer
671
- x = x.view(1, 1, -1)
672
 
673
  fast_mask = self.causal_mask[
674
  None, None, input_pos, : self.config.num_codebooks
@@ -688,9 +724,10 @@ class DualARTransformer(BaseTransformer):
688
  self,
689
  x: Tensor,
690
  input_pos: Optional[Tensor] = None,
691
- vq_masks: Optional[Tensor] = None,
 
692
  ) -> TransformerForwardResult:
693
- x = super().forward_generate(x, input_pos, vq_masks)
694
  x.hidden_states = self.fast_project_in(x.hidden_states)
695
  return x
696
 
 
320
  self,
321
  inp: Tensor,
322
  input_pos: Optional[Tensor] = None,
323
+ audio_masks: Optional[Tensor] = None,
324
+ audio_parts: Optional[Tensor] = None,
325
  return_all: bool = False,
326
  ) -> BaseTransformerForwardResult:
327
+ # This is used for generation, optimized for torch compile
328
+ # assert (
329
+ # self.max_seq_len != -1 and self.max_batch_size != -1
330
+ # ), "Please call setup_caches before forward_generate"
331
+
332
+ embeds = []
333
+ for i in range(self.config.num_codebooks):
334
+ emb = self.codebook_embeddings(
335
+ inp[:, i + 1] + i * self.config.codebook_size
336
+ )
337
+ embeds.append(emb)
338
+
339
+ vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
340
+
341
+ vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
342
+ inp[:, 0] <= self.tokenizer.semantic_end_id
343
+ )
344
+
345
+ vq_embeds_sum[~vq_masks] = 0
346
+ x = self.embeddings(inp[:, 0]) + vq_embeds_sum
347
+
348
+ if self.config.scale_codebook_embeddings:
349
+ # Expand vq_masks to match x's shape
350
+ vq_masks_expanded = vq_masks.unsqueeze(-1).expand_as(x)
351
+ x = torch.where(
352
+ vq_masks_expanded, x / math.sqrt(self.config.num_codebooks + 1), x
353
+ )
354
+
355
+ # Audio embeddings
356
+ if audio_parts is not None:
357
+ audio_embeds = self.audio_projector(audio_parts)
358
+ if self.config.scale_codebook_embeddings:
359
+ x[audio_masks] = audio_embeds / math.sqrt(2)
360
+ else:
361
+ x[audio_masks] = audio_embeds
362
 
363
  if input_pos is None:
364
  input_pos = torch.arange(inp.shape[-1], device=x.device)
 
631
  def forward(
632
  self,
633
  inp: Tensor,
634
+ labels: Optional[Tensor] = None,
635
  key_padding_mask: Optional[Tensor] = None,
636
+ vq_parts: Optional[Tensor] = None,
637
+ vq_masks: Optional[Tensor] = None,
638
+ vq_require_losses: Optional[Tensor] = None,
639
+ mel_parts: Optional[Tensor] = None,
640
+ mel_masks: Optional[Tensor] = None,
641
  ) -> TransformerForwardResult:
642
+ parent_result = super().forward(
643
+ inp=inp,
644
+ key_padding_mask=key_padding_mask,
645
+ vq_parts=vq_parts,
646
+ vq_masks=vq_masks,
647
+ mel_parts=mel_parts,
648
+ mel_masks=mel_masks,
649
+ )
650
  token_logits = parent_result.logits
651
  x = parent_result.hidden_states
 
652
 
653
  # Fast transformer
654
  fast_seq_len = self.config.num_codebooks
655
  fast_mask = self.causal_mask[
656
  None, None, :fast_seq_len, :fast_seq_len
657
  ] # (B, N, Q, K)
658
+ fast_freqs_cis = self.fast_freqs_cis[:fast_seq_len]
659
+
660
+ # Extract corresponding parts with labels
661
+ codebook_mask = labels == self.semantic_token_id
662
+ # This gives where input token is <|semantic|>
663
+ x = x[codebook_mask]
664
+
665
+ if x.shape[0] == 0:
666
+ # Use dummy input when no vq is required
667
+ x = torch.zeros(
668
+ (4, self.config.dim),
669
+ device=x.device,
670
+ dtype=x.dtype,
671
+ )
672
+ codebooks = torch.zeros(
673
+ (x.shape[0], self.config.num_codebooks - 1),
674
+ device=x.device,
675
+ dtype=torch.int,
676
+ )
677
+ else:
678
+ codebooks = vq_parts[..., :-1][vq_require_losses][
679
+ vq_masks[vq_require_losses]
680
+ ]
681
 
682
+ x = self.fast_project_in(x)
 
 
683
  codebook_embeddings = self.fast_embeddings(codebooks)
684
  x = torch.cat([x[:, None], codebook_embeddings], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
685
 
686
  for layer in self.fast_layers:
687
  if self.config.use_gradient_checkpointing and self.training:
688
+ x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
 
 
689
  else:
690
+ x = layer(x, fast_freqs_cis, fast_mask)
691
 
692
  # unflatten the batch and num_codebooks
693
  fast_out = self.fast_norm(x)
694
  codebook_logits = self.fast_output(fast_out)
695
 
 
 
 
 
 
 
 
 
 
 
 
696
  assert codebook_logits.shape[1] == self.config.num_codebooks
 
 
 
 
 
 
 
697
 
698
  return TransformerForwardResult(
699
  token_logits=token_logits,
 
704
  self, x: Tensor, input_pos: Optional[Tensor] = None
705
  ) -> Tensor:
706
  # Fast transformer
707
+ x = x.view(x.shape[0], 1, -1)
708
 
709
  fast_mask = self.causal_mask[
710
  None, None, input_pos, : self.config.num_codebooks
 
724
  self,
725
  x: Tensor,
726
  input_pos: Optional[Tensor] = None,
727
+ audio_masks: Optional[Tensor] = None,
728
+ audio_parts: Optional[Tensor] = None,
729
  ) -> TransformerForwardResult:
730
+ x = super().forward_generate(x, input_pos, audio_masks, audio_parts)
731
  x.hidden_states = self.fast_project_in(x.hidden_states)
732
  return x
733