Stardust-minus commited on
Commit
a26769d
·
verified ·
1 Parent(s): 39d5a3b

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. README.md +1 -1
  3. app.py +73 -314
  4. examples/Arabic.wav +0 -0
  5. examples/English.wav +0 -0
  6. examples/French.wav +0 -0
  7. examples/German.wav +0 -0
  8. examples/Japanese.wav +2 -2
  9. examples/Korean.wav +2 -2
  10. examples/Nice English Ref.wav +2 -2
  11. examples/Spanish.wav +0 -0
  12. fish_speech/configs/base.yaml +87 -87
  13. fish_speech/configs/lora/r_8_alpha_16.yaml +4 -4
  14. fish_speech/configs/modded_dac_vq.yaml +50 -0
  15. fish_speech/configs/text2semantic_finetune.yaml +86 -83
  16. fish_speech/content_sequence.py +367 -0
  17. fish_speech/i18n/README.md +27 -27
  18. fish_speech/i18n/__init__.py +3 -3
  19. fish_speech/i18n/core.py +40 -40
  20. fish_speech/i18n/locale/en_US.json +123 -123
  21. fish_speech/i18n/locale/es_ES.json +123 -123
  22. fish_speech/i18n/locale/ja_JP.json +123 -123
  23. fish_speech/i18n/locale/ko_KR.json +123 -123
  24. fish_speech/i18n/locale/pt_BR.json +133 -133
  25. fish_speech/i18n/locale/zh_CN.json +123 -123
  26. fish_speech/i18n/scan.py +122 -122
  27. fish_speech/inference_engine/__init__.py +192 -0
  28. fish_speech/inference_engine/reference_loader.py +130 -0
  29. fish_speech/inference_engine/utils.py +29 -0
  30. fish_speech/inference_engine/vq_manager.py +59 -0
  31. fish_speech/models/dac/__init__.py +0 -0
  32. fish_speech/models/dac/inference.py +123 -0
  33. fish_speech/models/dac/modded_dac.py +1024 -0
  34. fish_speech/models/dac/rvq.py +403 -0
  35. fish_speech/models/text2semantic/inference.py +716 -0
  36. fish_speech/models/text2semantic/lit_module.py +202 -202
  37. fish_speech/models/text2semantic/llama.py +903 -887
  38. fish_speech/models/text2semantic/lora.py +92 -92
  39. fish_speech/text/__init__.py +4 -4
  40. fish_speech/text/clean.py +37 -37
  41. fish_speech/text/spliter.py +130 -130
  42. fish_speech/tokenizer.py +179 -152
  43. fish_speech/utils/__init__.py +24 -24
  44. fish_speech/utils/braceexpand.py +217 -217
  45. fish_speech/utils/context.py +13 -13
  46. fish_speech/utils/file.py +139 -16
  47. fish_speech/utils/instantiators.py +50 -50
  48. fish_speech/utils/logger.py +55 -55
  49. fish_speech/utils/logging_utils.py +48 -48
  50. fish_speech/utils/rich_utils.py +100 -100
.gitattributes CHANGED
@@ -36,3 +36,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  examples/Japanese.wav filter=lfs diff=lfs merge=lfs -text
37
  examples/Korean.wav filter=lfs diff=lfs merge=lfs -text
38
  examples/Nice[[:space:]]English[[:space:]]Ref.wav filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
36
  examples/Japanese.wav filter=lfs diff=lfs merge=lfs -text
37
  examples/Korean.wav filter=lfs diff=lfs merge=lfs -text
38
  examples/Nice[[:space:]]English[[:space:]]Ref.wav filter=lfs diff=lfs merge=lfs -text
39
+ examples/Arabic.wav filter=lfs diff=lfs merge=lfs -text
40
+ examples/English.wav filter=lfs diff=lfs merge=lfs -text
41
+ examples/French.wav filter=lfs diff=lfs merge=lfs -text
42
+ examples/German.wav filter=lfs diff=lfs merge=lfs -text
43
+ examples/Spanish.wav filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Fish Speech 1
3
  emoji: 🏆
4
  colorFrom: purple
5
  colorTo: gray
 
1
  ---
2
+ title: OpenAudio S1
3
  emoji: 🏆
4
  colorFrom: purple
5
  colorTo: gray
app.py CHANGED
@@ -1,86 +1,51 @@
1
  import os
2
  import queue
3
  from huggingface_hub import snapshot_download
4
- import hydra
5
  import numpy as np
6
  import wave
7
  import io
8
- import pyrootutils
9
  import gc
 
10
 
11
  # Download if not exists
12
  os.makedirs("checkpoints", exist_ok=True)
13
- snapshot_download(repo_id="fishaudio/fish-speech-1.5", local_dir="./checkpoints/fish-speech-1.5")
14
 
15
  print("All checkpoints downloaded")
16
 
17
  import html
18
  import os
19
- import threading
20
  from argparse import ArgumentParser
21
  from pathlib import Path
22
- from functools import partial
23
 
24
  import gradio as gr
25
- import librosa
26
  import torch
27
  import torchaudio
28
 
29
  torchaudio.set_audio_backend("soundfile")
30
 
31
  from loguru import logger
32
- from transformers import AutoTokenizer
33
-
34
  from fish_speech.i18n import i18n
35
- from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
36
- from fish_speech.utils import autocast_exclude_mps, set_seed
37
- from tools.api import decode_vq_tokens, encode_reference
38
- from tools.file import AUDIO_EXTENSIONS, list_files
39
- from tools.llama.generate import (
40
- GenerateRequest,
41
- GenerateResponse,
42
- WrappedGenerateResponse,
43
- launch_thread_safe_queue,
44
- )
45
- from tools.vqgan.inference import load_model as load_decoder_model
46
-
47
- from tools.schema import (
48
- GLOBAL_NUM_SAMPLES,
49
- ASRPackRequest,
50
- ServeASRRequest,
51
- ServeASRResponse,
52
- ServeASRSegment,
53
- ServeAudioPart,
54
- ServeForwardMessage,
55
- ServeMessage,
56
- ServeRequest,
57
- ServeResponse,
58
- ServeStreamDelta,
59
- ServeStreamResponse,
60
- ServeTextPart,
61
- ServeTimedASRResponse,
62
- ServeTTSRequest,
63
- ServeVQGANDecodeRequest,
64
- ServeVQGANDecodeResponse,
65
- ServeVQGANEncodeRequest,
66
- ServeVQGANEncodeResponse,
67
- ServeVQPart,
68
- ServeReferenceAudio
69
- )
70
  # Make einx happy
71
  os.environ["EINX_FILTER_TRACEBACK"] = "false"
72
 
73
 
74
- HEADER_MD = """# Fish Speech
75
 
76
- ## The demo in this space is version 1.5, Please check [Fish Audio](https://fish.audio) for the best model.
77
- ## 该 Demo 为 Fish Speech 1.5 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
78
 
79
- A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
80
- 由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GANLlama 的多语种语音合成.
81
 
82
- You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).
83
- 你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1.5) 找到模型.
84
 
85
  Related code and weights are released under CC BY-NC-SA 4.0 License.
86
  相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
@@ -88,8 +53,8 @@ Related code and weights are released under CC BY-NC-SA 4.0 License.
88
  We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
89
  我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
90
 
91
- The model running in this WebUI is Fish Speech V1.5 Medium.
92
- 在此 WebUI 中运行的模型是 Fish Speech V1.5 Medium.
93
  """
94
 
95
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
@@ -106,7 +71,6 @@ except ImportError:
106
 
107
  return wrapper
108
 
109
-
110
  def build_html_error_message(error):
111
  return f"""
112
  <div style="color: red;
@@ -115,109 +79,6 @@ def build_html_error_message(error):
115
  </div>
116
  """
117
 
118
-
119
- @GPU_DECORATOR
120
- @torch.inference_mode()
121
- def inference(req: ServeTTSRequest):
122
- try:
123
- # Parse reference audio aka prompt
124
- refs = req.references
125
-
126
- prompt_tokens = [
127
- encode_reference(
128
- decoder_model=decoder_model,
129
- reference_audio=ref.audio,
130
- enable_reference_audio=True,
131
- )
132
- for ref in refs
133
- ]
134
- prompt_texts = [ref.text for ref in refs]
135
-
136
- if req.seed is not None:
137
- set_seed(req.seed)
138
- logger.warning(f"set seed: {req.seed}")
139
-
140
- # LLAMA Inference
141
- request = dict(
142
- device=decoder_model.device,
143
- max_new_tokens=req.max_new_tokens,
144
- text=(
145
- req.text
146
- if not req.normalize
147
- else ChnNormedText(raw_text=req.text).normalize()
148
- ),
149
- top_p=req.top_p,
150
- repetition_penalty=req.repetition_penalty,
151
- temperature=req.temperature,
152
- compile=args.compile,
153
- iterative_prompt=req.chunk_length > 0,
154
- chunk_length=req.chunk_length,
155
- max_length=4096,
156
- prompt_tokens=prompt_tokens,
157
- prompt_text=prompt_texts,
158
- )
159
-
160
- response_queue = queue.Queue()
161
- llama_queue.put(
162
- GenerateRequest(
163
- request=request,
164
- response_queue=response_queue,
165
- )
166
- )
167
-
168
- segments = []
169
-
170
- while True:
171
- result: WrappedGenerateResponse = response_queue.get()
172
- if result.status == "error":
173
- yield None, None, build_html_error_message(result.response)
174
- break
175
-
176
- result: GenerateResponse = result.response
177
- if result.action == "next":
178
- break
179
-
180
- with autocast_exclude_mps(
181
- device_type=decoder_model.device.type, dtype=args.precision
182
- ):
183
- fake_audios = decode_vq_tokens(
184
- decoder_model=decoder_model,
185
- codes=result.codes,
186
- )
187
-
188
- fake_audios = fake_audios.float().cpu().numpy()
189
- segments.append(fake_audios)
190
-
191
- if len(segments) == 0:
192
- return (
193
- None,
194
- None,
195
- build_html_error_message(
196
- i18n("No audio generated, please check the input text.")
197
- ),
198
- )
199
-
200
- # No matter streaming or not, we need to return the final audio
201
- audio = np.concatenate(segments, axis=0)
202
- yield None, (decoder_model.spec_transform.sample_rate, audio), None
203
-
204
- if torch.cuda.is_available():
205
- torch.cuda.empty_cache()
206
- gc.collect()
207
-
208
- except Exception as e:
209
- er = "CUDA error: device-side assert triggered"
210
- if er in str(e):
211
- app.close()
212
- else:
213
- raise Exception(e)
214
-
215
- n_audios = 4
216
-
217
- global_audio_list = []
218
- global_error_list = []
219
-
220
-
221
  def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
222
  buffer = io.BytesIO()
223
 
@@ -230,13 +91,8 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
230
  buffer.close()
231
  return wav_header_bytes
232
 
233
- def normalize_text(user_input, use_normalization):
234
- if use_normalization:
235
- return ChnNormedText(raw_text=user_input).normalize()
236
- else:
237
- return user_input
238
 
239
- def build_app():
240
  with gr.Blocks(theme=gr.themes.Base()) as app:
241
  gr.Markdown(HEADER_MD)
242
 
@@ -245,7 +101,7 @@ def build_app():
245
  None,
246
  None,
247
  js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
248
- % args.theme,
249
  )
250
 
251
  # Inference
@@ -254,20 +110,6 @@ def build_app():
254
  text = gr.Textbox(
255
  label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
256
  )
257
- refined_text = gr.Textbox(
258
- label=i18n("Realtime Transform Text"),
259
- placeholder=i18n(
260
- "Normalization Result Preview (Currently Only Chinese)"
261
- ),
262
- lines=5,
263
- interactive=False,
264
- )
265
-
266
- with gr.Row():
267
- normalize = gr.Checkbox(
268
- label=i18n("Text Normalization"),
269
- value=False,
270
- )
271
 
272
  with gr.Row():
273
  with gr.Column():
@@ -275,45 +117,45 @@ def build_app():
275
  with gr.Row():
276
  chunk_length = gr.Slider(
277
  label=i18n("Iterative Prompt Length, 0 means off"),
278
- minimum=0,
279
- maximum=300,
280
- value=200,
281
  step=8,
282
  )
283
 
284
  max_new_tokens = gr.Slider(
285
  label=i18n(
286
- "Maximum tokens per batch"
287
  ),
288
- minimum=512,
289
  maximum=2048,
290
- value=1024,
291
- step=64,
292
  )
293
 
294
  with gr.Row():
295
  top_p = gr.Slider(
296
  label="Top-P",
297
- minimum=0.6,
298
- maximum=0.9,
299
- value=0.7,
300
  step=0.01,
301
  )
302
 
303
  repetition_penalty = gr.Slider(
304
  label=i18n("Repetition Penalty"),
305
  minimum=1,
306
- maximum=1.5,
307
- value=1.2,
308
  step=0.01,
309
  )
310
 
311
  with gr.Row():
312
  temperature = gr.Slider(
313
  label="Temperature",
314
- minimum=0.6,
315
- maximum=0.9,
316
- value=0.7,
317
  step=0.01,
318
  )
319
  seed = gr.Number(
@@ -326,24 +168,20 @@ def build_app():
326
  with gr.Row():
327
  gr.Markdown(
328
  i18n(
329
- "15 to 60 seconds of reference audio, useful for specifying speaker."
330
  )
331
  )
332
-
333
  with gr.Row():
334
- # Add dropdown for selecting example audio files
335
- example_audio_files = [f for f in os.listdir("examples") if f.endswith(".wav")]
336
- example_audio_dropdown = gr.Dropdown(
337
- label="Select Example Audio",
338
- choices=[""] + example_audio_files,
339
- value=""
340
  )
341
 
342
  with gr.Row():
343
  use_memory_cache = gr.Radio(
344
  label=i18n("Use Memory Cache"),
345
- choices=["never"],
346
- value="never",
347
  )
348
 
349
  with gr.Row():
@@ -351,7 +189,6 @@ def build_app():
351
  label=i18n("Reference Audio"),
352
  type="filepath",
353
  )
354
-
355
  with gr.Row():
356
  reference_text = gr.Textbox(
357
  label=i18n("Reference Text"),
@@ -377,101 +214,16 @@ def build_app():
377
  with gr.Row():
378
  with gr.Column(scale=3):
379
  generate = gr.Button(
380
- value="\U0001F3A7 " + i18n("Generate"), variant="primary"
 
381
  )
382
 
383
- text.input(
384
- fn=normalize_text, inputs=[text, normalize], outputs=[refined_text]
385
- )
386
-
387
- def inference_wrapper(
388
- text,
389
- normalize,
390
- reference_audio,
391
- reference_text,
392
- max_new_tokens,
393
- chunk_length,
394
- top_p,
395
- repetition_penalty,
396
- temperature,
397
- seed,
398
- use_memory_cache,
399
- ):
400
- print(
401
- "call inference wrapper",
402
- text,
403
- normalize,
404
- reference_audio,
405
- reference_text,
406
- max_new_tokens,
407
- chunk_length,
408
- top_p,
409
- repetition_penalty,
410
- temperature,
411
- seed,
412
- use_memory_cache
413
- )
414
-
415
- references = []
416
- if reference_audio:
417
- # 将文件路径转换为字节
418
- with open(reference_audio, 'rb') as audio_file:
419
- audio_bytes = audio_file.read()
420
-
421
- references = [
422
- ServeReferenceAudio(audio=audio_bytes, text=reference_text)
423
- ]
424
-
425
- req = ServeTTSRequest(
426
- text=text,
427
- normalize=normalize,
428
- reference_id=None,
429
- references=references,
430
- max_new_tokens=max_new_tokens,
431
- chunk_length=chunk_length,
432
- top_p=top_p,
433
- repetition_penalty=repetition_penalty,
434
- temperature=temperature,
435
- seed=int(seed) if seed else None,
436
- use_memory_cache=use_memory_cache,
437
- )
438
-
439
- for result in inference(req):
440
- if result[2]: # Error message
441
- return None, result[2]
442
- elif result[1]: # Audio data
443
- return result[1], None
444
-
445
- return None, i18n("No audio generated")
446
-
447
- def select_example_audio(audio_file):
448
- if audio_file:
449
- audio_path = os.path.join("examples", audio_file)
450
- lab_file = os.path.splitext(audio_file)[0] + ".lab"
451
- lab_path = os.path.join("examples", lab_file)
452
-
453
- if os.path.exists(lab_path):
454
- with open(lab_path, "r", encoding="utf-8") as f:
455
- lab_content = f.read().strip()
456
- else:
457
- lab_content = ""
458
-
459
- return audio_path, lab_content
460
- return None, ""
461
-
462
- # Connect the dropdown to update reference audio and text
463
- example_audio_dropdown.change(
464
- fn=select_example_audio,
465
- inputs=[example_audio_dropdown],
466
- outputs=[reference_audio, reference_text]
467
- )
468
-
469
  # Submit
470
  generate.click(
471
- inference_wrapper,
472
  [
473
- refined_text,
474
- normalize,
475
  reference_audio,
476
  reference_text,
477
  max_new_tokens,
@@ -488,26 +240,24 @@ def build_app():
488
 
489
  return app
490
 
491
-
492
-
493
  def parse_args():
494
  parser = ArgumentParser()
495
  parser.add_argument(
496
  "--llama-checkpoint-path",
497
  type=Path,
498
- default="checkpoints/fish-speech-1.5",
499
  )
500
  parser.add_argument(
501
  "--decoder-checkpoint-path",
502
  type=Path,
503
- default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
504
  )
505
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
506
  parser.add_argument("--device", type=str, default="cuda")
507
  parser.add_argument("--half", action="store_true")
508
  parser.add_argument("--compile", action="store_true",default=True)
509
  parser.add_argument("--max-gradio-length", type=int, default=0)
510
- parser.add_argument("--theme", type=str, default="light")
511
 
512
  return parser.parse_args()
513
 
@@ -533,25 +283,34 @@ if __name__ == "__main__":
533
 
534
  logger.info("Decoder model loaded, warming up...")
535
 
 
 
 
 
 
 
 
 
536
  # Dry run to check if the model is loaded correctly and avoid the first-time latency
537
  list(
538
- inference(
539
- ServeTTSRequest(
540
- text="Hello world.",
541
- references=[],
542
- reference_id=None,
543
- max_new_tokens=0,
544
- chunk_length=200,
545
- top_p=0.7,
546
- repetition_penalty=1.5,
547
- temperature=0.7,
548
- emotion=None,
549
- format="wav",
550
- )
551
  )
 
552
  )
553
 
554
  logger.info("Warming up done, launching the web UI...")
555
 
556
- app = build_app()
557
- app.queue(api_open=True).launch(show_error=True, show_api=True)
 
 
 
1
  import os
2
  import queue
3
  from huggingface_hub import snapshot_download
 
4
  import numpy as np
5
  import wave
6
  import io
 
7
  import gc
8
+ from typing import Callable
9
 
10
  # Download if not exists
11
  os.makedirs("checkpoints", exist_ok=True)
12
+ snapshot_download(repo_id="fishaudio/openaudio-s1-mini", local_dir="./checkpoints/openaudio-s1-mini")
13
 
14
  print("All checkpoints downloaded")
15
 
16
  import html
17
  import os
 
18
  from argparse import ArgumentParser
19
  from pathlib import Path
 
20
 
21
  import gradio as gr
 
22
  import torch
23
  import torchaudio
24
 
25
  torchaudio.set_audio_backend("soundfile")
26
 
27
  from loguru import logger
 
 
28
  from fish_speech.i18n import i18n
29
+ from fish_speech.inference_engine import TTSInferenceEngine
30
+ from fish_speech.models.dac.inference import load_model as load_decoder_model
31
+ from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
32
+ from tools.webui.inference import get_inference_wrapper
33
+ from fish_speech.utils.schema import ServeTTSRequest
34
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # Make einx happy
36
  os.environ["EINX_FILTER_TRACEBACK"] = "false"
37
 
38
 
39
+ HEADER_MD = """# OpenAudio S1
40
 
41
+ ## The demo in this space is OpenAudio S1, Please check [Fish Audio](https://fish.audio) for the best model.
42
+ ## 该 Demo 为 OpenAudio S1 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
43
 
44
+ A text-to-speech model based on DAC and Qwen3 developed by [Fish Audio](https://fish.audio).
45
+ 由 [Fish Audio](https://fish.audio) 研发的基于 DACQwen3 的多语种语音合成.
46
 
47
+ You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/openaudio-s1-mini).
48
+ 你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/openaudio-s1-mini) 找到模型.
49
 
50
  Related code and weights are released under CC BY-NC-SA 4.0 License.
51
  相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
 
53
  We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
54
  我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
55
 
56
+ The model running in this WebUI is OpenAudio S1 Mini.
57
+ 在此 WebUI 中运行的模型是 OpenAudio S1 Mini.
58
  """
59
 
60
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
 
71
 
72
  return wrapper
73
 
 
74
  def build_html_error_message(error):
75
  return f"""
76
  <div style="color: red;
 
79
  </div>
80
  """
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
83
  buffer = io.BytesIO()
84
 
 
91
  buffer.close()
92
  return wav_header_bytes
93
 
 
 
 
 
 
94
 
95
+ def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
96
  with gr.Blocks(theme=gr.themes.Base()) as app:
97
  gr.Markdown(HEADER_MD)
98
 
 
101
  None,
102
  None,
103
  js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
104
+ % theme,
105
  )
106
 
107
  # Inference
 
110
  text = gr.Textbox(
111
  label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
112
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  with gr.Row():
115
  with gr.Column():
 
117
  with gr.Row():
118
  chunk_length = gr.Slider(
119
  label=i18n("Iterative Prompt Length, 0 means off"),
120
+ minimum=100,
121
+ maximum=400,
122
+ value=300,
123
  step=8,
124
  )
125
 
126
  max_new_tokens = gr.Slider(
127
  label=i18n(
128
+ "Maximum tokens per batch, 0 means no limit"
129
  ),
130
+ minimum=0,
131
  maximum=2048,
132
+ value=0,
133
+ step=8,
134
  )
135
 
136
  with gr.Row():
137
  top_p = gr.Slider(
138
  label="Top-P",
139
+ minimum=0.7,
140
+ maximum=0.95,
141
+ value=0.8,
142
  step=0.01,
143
  )
144
 
145
  repetition_penalty = gr.Slider(
146
  label=i18n("Repetition Penalty"),
147
  minimum=1,
148
+ maximum=1.2,
149
+ value=1.1,
150
  step=0.01,
151
  )
152
 
153
  with gr.Row():
154
  temperature = gr.Slider(
155
  label="Temperature",
156
+ minimum=0.7,
157
+ maximum=1.0,
158
+ value=0.8,
159
  step=0.01,
160
  )
161
  seed = gr.Number(
 
168
  with gr.Row():
169
  gr.Markdown(
170
  i18n(
171
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
172
  )
173
  )
 
174
  with gr.Row():
175
+ reference_id = gr.Textbox(
176
+ label=i18n("Reference ID"),
177
+ placeholder="Leave empty to use uploaded references",
 
 
 
178
  )
179
 
180
  with gr.Row():
181
  use_memory_cache = gr.Radio(
182
  label=i18n("Use Memory Cache"),
183
+ choices=["on", "off"],
184
+ value="on",
185
  )
186
 
187
  with gr.Row():
 
189
  label=i18n("Reference Audio"),
190
  type="filepath",
191
  )
 
192
  with gr.Row():
193
  reference_text = gr.Textbox(
194
  label=i18n("Reference Text"),
 
214
  with gr.Row():
215
  with gr.Column(scale=3):
216
  generate = gr.Button(
217
+ value="\U0001f3a7 " + i18n("Generate"),
218
+ variant="primary",
219
  )
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  # Submit
222
  generate.click(
223
+ inference_fct,
224
  [
225
+ text,
226
+ reference_id,
227
  reference_audio,
228
  reference_text,
229
  max_new_tokens,
 
240
 
241
  return app
242
 
 
 
243
  def parse_args():
244
  parser = ArgumentParser()
245
  parser.add_argument(
246
  "--llama-checkpoint-path",
247
  type=Path,
248
+ default="checkpoints/openaudio-s1-mini",
249
  )
250
  parser.add_argument(
251
  "--decoder-checkpoint-path",
252
  type=Path,
253
+ default="checkpoints/openaudio-s1-mini/codec.pth",
254
  )
255
+ parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
256
  parser.add_argument("--device", type=str, default="cuda")
257
  parser.add_argument("--half", action="store_true")
258
  parser.add_argument("--compile", action="store_true",default=True)
259
  parser.add_argument("--max-gradio-length", type=int, default=0)
260
+ parser.add_argument("--theme", type=str, default="dark")
261
 
262
  return parser.parse_args()
263
 
 
283
 
284
  logger.info("Decoder model loaded, warming up...")
285
 
286
+ # Create the inference engine
287
+ inference_engine = TTSInferenceEngine(
288
+ llama_queue=llama_queue,
289
+ decoder_model=decoder_model,
290
+ compile=args.compile,
291
+ precision=args.precision,
292
+ )
293
+
294
  # Dry run to check if the model is loaded correctly and avoid the first-time latency
295
  list(
296
+ inference_engine.inference(
297
+ ServeTTSRequest(
298
+ text="Hello world.",
299
+ references=[],
300
+ reference_id=None,
301
+ max_new_tokens=1024,
302
+ chunk_length=200,
303
+ top_p=0.7,
304
+ repetition_penalty=1.5,
305
+ temperature=0.7,
306
+ format="wav",
 
 
307
  )
308
+ )
309
  )
310
 
311
  logger.info("Warming up done, launching the web UI...")
312
 
313
+ inference_fct = get_inference_wrapper(inference_engine)
314
+
315
+ app = build_app(inference_fct, args.theme)
316
+ app.queue(api_open=True).launch(show_error=True, show_api=True, server_name="0.0.0.0", server_port=18888)
examples/Arabic.wav CHANGED
Binary files a/examples/Arabic.wav and b/examples/Arabic.wav differ
 
examples/English.wav CHANGED
Binary files a/examples/English.wav and b/examples/English.wav differ
 
examples/French.wav CHANGED
Binary files a/examples/French.wav and b/examples/French.wav differ
 
examples/German.wav CHANGED
Binary files a/examples/German.wav and b/examples/German.wav differ
 
examples/Japanese.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a23cffeac70f42e1cc69e2a0505e4c1fda50884dd34c509128d432aaf44565e5
3
- size 1148682
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3034a38260884be854cb4a3f6cb648db85ebdeeb8cab74cfae2a578dc7aaedc2
3
+ size 132
examples/Korean.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c4234f119c741782e2c9c0ede4b5b864a560a355c28a23b2332e79420b69961a
3
- size 1632522
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5767663f0c26f4dc94f45227f385c2be568aac065272466915d65eaa64fdda0f
3
+ size 132
examples/Nice English Ref.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d00ad9768c62f9821fc01ecab3e02669581ca75c18af6549690e19ce90a09f53
3
- size 5254482
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b707de0cfc5d2eee59dcc3fea495603fe28d95ca64d8202bcdb31537d588782
3
+ size 132
examples/Spanish.wav CHANGED
Binary files a/examples/Spanish.wav and b/examples/Spanish.wav differ
 
fish_speech/configs/base.yaml CHANGED
@@ -1,87 +1,87 @@
1
- # Base configuration for training a model
2
- paths:
3
- run_dir: results/${project}
4
- ckpt_dir: ${paths.run_dir}/checkpoints
5
-
6
- hydra:
7
- run:
8
- dir: ${paths.run_dir}
9
-
10
- # Lightning Trainer
11
- trainer:
12
- _target_: lightning.pytorch.trainer.Trainer
13
-
14
- default_root_dir: ${paths.run_dir}
15
- accelerator: gpu
16
- num_nodes: 1
17
- devices: auto
18
- strategy:
19
- _target_: lightning.pytorch.strategies.DDPStrategy
20
- process_group_backend: nccl # This should be override when training on windows
21
-
22
- precision: bf16-mixed
23
-
24
- # disable validation by epoch end
25
- check_val_every_n_epoch: null
26
- val_check_interval: 5000
27
- max_steps: 100_000
28
-
29
- # Use torch.backends.cudnn.benchmark to speed up training
30
- benchmark: true
31
-
32
- # Callbacks
33
- callbacks:
34
- model_checkpoint:
35
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
36
- dirpath: ${paths.ckpt_dir}
37
- filename: "step_{step:09d}"
38
- save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
39
- save_top_k: 5 # save 5 latest checkpoints
40
- monitor: step # use step to monitor checkpoints
41
- mode: max # save the latest checkpoint with the highest global_step
42
- every_n_epochs: null # don't save checkpoints by epoch end
43
- every_n_train_steps: 5000 # save checkpoints every 5000 steps
44
- auto_insert_metric_name: false
45
-
46
- model_summary:
47
- _target_: lightning.pytorch.callbacks.ModelSummary
48
- max_depth: 2 # the maximum depth of layer nesting that the summary will include
49
-
50
- learning_rate_monitor:
51
- _target_: lightning.pytorch.callbacks.LearningRateMonitor
52
- logging_interval: step
53
- log_momentum: false
54
-
55
- grad_norm_monitor:
56
- _target_: fish_speech.callbacks.GradNormMonitor
57
- norm_type: 2
58
- logging_interval: step
59
-
60
- # Logger
61
- logger:
62
- tensorboard:
63
- _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
64
- save_dir: "${paths.run_dir}/tensorboard/"
65
- name: null
66
- log_graph: false
67
- default_hp_metric: true
68
- prefix: ""
69
-
70
- # wandb:
71
- # _target_: lightning.pytorch.loggers.wandb.WandbLogger
72
- # # name: "" # name of the run (normally generated by wandb)
73
- # save_dir: "${paths.run_dir}"
74
- # offline: False
75
- # id: null # pass correct id to resume experiment!
76
- # anonymous: null # enable anonymous logging
77
- # project: "fish-speech"
78
- # log_model: False # upload lightning ckpts
79
- # prefix: "" # a string to put at the beginning of metric keys
80
- # # entity: "" # set to name of your wandb team
81
- # group: ""
82
- # tags: ["vq", "hq", "finetune"]
83
- # job_type: ""
84
-
85
- # Loop
86
- train: true
87
- test: false
 
1
+ # Base configuration for training a model
2
+ paths:
3
+ run_dir: results/${project}
4
+ ckpt_dir: ${paths.run_dir}/checkpoints
5
+
6
+ hydra:
7
+ run:
8
+ dir: ${paths.run_dir}
9
+
10
+ # Lightning Trainer
11
+ trainer:
12
+ _target_: lightning.pytorch.trainer.Trainer
13
+
14
+ default_root_dir: ${paths.run_dir}
15
+ accelerator: gpu
16
+ num_nodes: 1
17
+ devices: auto
18
+ strategy:
19
+ _target_: lightning.pytorch.strategies.DDPStrategy
20
+ process_group_backend: nccl # This should be override when training on windows
21
+
22
+ precision: bf16-mixed
23
+
24
+ # disable validation by epoch end
25
+ check_val_every_n_epoch: null
26
+ val_check_interval: 5000
27
+ max_steps: 100_000
28
+
29
+ # Use torch.backends.cudnn.benchmark to speed up training
30
+ benchmark: true
31
+
32
+ # Callbacks
33
+ callbacks:
34
+ model_checkpoint:
35
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
36
+ dirpath: ${paths.ckpt_dir}
37
+ filename: "step_{step:09d}"
38
+ save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
39
+ save_top_k: 5 # save 5 latest checkpoints
40
+ monitor: step # use step to monitor checkpoints
41
+ mode: max # save the latest checkpoint with the highest global_step
42
+ every_n_epochs: null # don't save checkpoints by epoch end
43
+ every_n_train_steps: 5000 # save checkpoints every 5000 steps
44
+ auto_insert_metric_name: false
45
+
46
+ model_summary:
47
+ _target_: lightning.pytorch.callbacks.ModelSummary
48
+ max_depth: 2 # the maximum depth of layer nesting that the summary will include
49
+
50
+ learning_rate_monitor:
51
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
52
+ logging_interval: step
53
+ log_momentum: false
54
+
55
+ grad_norm_monitor:
56
+ _target_: fish_speech.callbacks.GradNormMonitor
57
+ norm_type: 2
58
+ logging_interval: step
59
+
60
+ # Logger
61
+ logger:
62
+ tensorboard:
63
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
64
+ save_dir: "${paths.run_dir}/tensorboard/"
65
+ name: null
66
+ log_graph: false
67
+ default_hp_metric: true
68
+ prefix: ""
69
+
70
+ # wandb:
71
+ # _target_: lightning.pytorch.loggers.wandb.WandbLogger
72
+ # # name: "" # name of the run (normally generated by wandb)
73
+ # save_dir: "${paths.run_dir}"
74
+ # offline: False
75
+ # id: null # pass correct id to resume experiment!
76
+ # anonymous: null # enable anonymous logging
77
+ # project: "fish-speech"
78
+ # log_model: False # upload lightning ckpts
79
+ # prefix: "" # a string to put at the beginning of metric keys
80
+ # # entity: "" # set to name of your wandb team
81
+ # group: ""
82
+ # tags: ["vq", "hq", "finetune"]
83
+ # job_type: ""
84
+
85
+ # Loop
86
+ train: true
87
+ test: false
fish_speech/configs/lora/r_8_alpha_16.yaml CHANGED
@@ -1,4 +1,4 @@
1
- _target_: fish_speech.models.text2semantic.lora.LoraConfig
2
- r: 8
3
- lora_alpha: 16
4
- lora_dropout: 0.01
 
1
+ _target_: fish_speech.models.text2semantic.lora.LoraConfig
2
+ r: 8
3
+ lora_alpha: 16
4
+ lora_dropout: 0.01
fish_speech/configs/modded_dac_vq.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: fish_speech.models.dac.modded_dac.DAC
2
+ # Model setup
3
+ sample_rate: 44100
4
+ encoder_dim: 64
5
+ encoder_rates: [2, 4, 8, 8]
6
+ decoder_dim: 1536
7
+ decoder_rates: [8, 8, 4, 2]
8
+ encoder_transformer_layers: [0, 0, 0, 4]
9
+ decoder_transformer_layers: [4, 0, 0, 0]
10
+ transformer_general_config:
11
+ _target_: fish_speech.models.dac.modded_dac.ModelArgs
12
+ _partial_: true
13
+ block_size: 16384
14
+ n_local_heads: -1
15
+ head_dim: 64
16
+ rope_base: 10000
17
+ norm_eps: 1e-5
18
+ dropout_rate: 0.1
19
+ attn_dropout_rate: 0.1
20
+ channels_first: true
21
+ # Quantization
22
+ quantizer:
23
+ _target_: fish_speech.models.dac.rvq.DownsampleResidualVectorQuantize
24
+ input_dim: 1024
25
+ n_codebooks: 9
26
+ codebook_size: 1024
27
+ codebook_dim: 8
28
+ quantizer_dropout: 0.5
29
+ downsample_factor: [2, 2]
30
+ post_module: &transformer_module
31
+ _target_: fish_speech.models.dac.modded_dac.WindowLimitedTransformer
32
+ causal: true
33
+ window_size: 128 # empirically this does not seem to matter
34
+ input_dim: 1024
35
+ config: &transformer_config
36
+ _target_: fish_speech.models.dac.modded_dac.ModelArgs
37
+ block_size: 4096
38
+ n_layer: 8
39
+ n_head: 16
40
+ dim: 1024
41
+ intermediate_size: 3072
42
+ n_local_heads: -1
43
+ head_dim: 64
44
+ rope_base: 10000
45
+ norm_eps: 1e-5
46
+ dropout_rate: 0.1
47
+ attn_dropout_rate: 0.1
48
+ channels_first: true
49
+ pre_module: *transformer_module
50
+ semantic_codebook_size: 4096
fish_speech/configs/text2semantic_finetune.yaml CHANGED
@@ -1,83 +1,86 @@
1
- defaults:
2
- - base
3
- - _self_
4
-
5
- project: text2semantic_finetune_dual_ar
6
- max_length: 4096
7
- pretrained_ckpt_path: checkpoints/fish-speech-1.4
8
-
9
- # Lightning Trainer
10
- trainer:
11
- accumulate_grad_batches: 1
12
- gradient_clip_val: 1.0
13
- gradient_clip_algorithm: "norm"
14
- max_steps: 1000
15
- precision: bf16-true
16
- limit_val_batches: 10
17
- val_check_interval: 100
18
-
19
- # Dataset Configuration
20
- tokenizer:
21
- _target_: transformers.AutoTokenizer.from_pretrained
22
- pretrained_model_name_or_path: ${pretrained_ckpt_path}
23
-
24
- # Dataset Configuration
25
- train_dataset:
26
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
27
- proto_files:
28
- - data/protos
29
- tokenizer: ${tokenizer}
30
- causal: true
31
- max_length: ${max_length}
32
- use_speaker: false
33
- interactive_prob: 0.7
34
-
35
- val_dataset:
36
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
37
- proto_files:
38
- - data/protos
39
- tokenizer: ${tokenizer}
40
- causal: true
41
- max_length: ${max_length}
42
- use_speaker: false
43
- interactive_prob: 0.7
44
-
45
- data:
46
- _target_: fish_speech.datasets.semantic.SemanticDataModule
47
- train_dataset: ${train_dataset}
48
- val_dataset: ${val_dataset}
49
- num_workers: 4
50
- batch_size: 8
51
- tokenizer: ${tokenizer}
52
- max_length: ${max_length}
53
-
54
- # Model Configuration
55
- model:
56
- _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
57
- model:
58
- _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
59
- path: ${pretrained_ckpt_path}
60
- load_weights: true
61
- max_length: ${max_length}
62
- lora_config: null
63
-
64
- optimizer:
65
- _target_: torch.optim.AdamW
66
- _partial_: true
67
- lr: 1e-4
68
- weight_decay: 0
69
- betas: [0.9, 0.95]
70
- eps: 1e-5
71
-
72
- lr_scheduler:
73
- _target_: torch.optim.lr_scheduler.LambdaLR
74
- _partial_: true
75
- lr_lambda:
76
- _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
77
- _partial_: true
78
- num_warmup_steps: 10
79
-
80
- # Callbacks
81
- callbacks:
82
- model_checkpoint:
83
- every_n_train_steps: ${trainer.val_check_interval}
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ project: text2semantic_finetune_dual_ar
6
+ max_length: 4096
7
+ pretrained_ckpt_path: checkpoints/openaudio-s1-mini
8
+
9
+ # Lightning Trainer
10
+ trainer:
11
+ accumulate_grad_batches: 1
12
+ gradient_clip_val: 1.0
13
+ gradient_clip_algorithm: "norm"
14
+ max_steps: 10000
15
+ precision: bf16-true
16
+ limit_val_batches: 10
17
+ val_check_interval: 100
18
+ # strategy:
19
+ # find_unused_parameters: true
20
+ # static_graph: true
21
+
22
+ # Dataset Configuration
23
+ tokenizer:
24
+ _target_: fish_speech.tokenizer.FishTokenizer
25
+ model_path: ${pretrained_ckpt_path}/tokenizer.tiktoken
26
+
27
+ # Dataset Configuration
28
+ train_dataset:
29
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
30
+ proto_files:
31
+ - data/protos
32
+ tokenizer: ${tokenizer}
33
+ causal: true
34
+ max_length: ${max_length}
35
+ use_speaker: false
36
+ interactive_prob: 0.7
37
+
38
+ val_dataset:
39
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
40
+ proto_files:
41
+ - data/protos
42
+ tokenizer: ${tokenizer}
43
+ causal: true
44
+ max_length: ${max_length}
45
+ use_speaker: false
46
+ interactive_prob: 0.7
47
+
48
+ data:
49
+ _target_: fish_speech.datasets.semantic.SemanticDataModule
50
+ train_dataset: ${train_dataset}
51
+ val_dataset: ${val_dataset}
52
+ num_workers: 4
53
+ batch_size: 4
54
+ tokenizer: ${tokenizer}
55
+ max_length: ${max_length}
56
+
57
+ # Model Configuration
58
+ model:
59
+ _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
60
+ model:
61
+ _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
62
+ path: ${pretrained_ckpt_path}
63
+ load_weights: true
64
+ max_length: ${max_length}
65
+ lora_config: null
66
+
67
+ optimizer:
68
+ _target_: torch.optim.AdamW
69
+ _partial_: true
70
+ lr: 1e-4
71
+ weight_decay: 0
72
+ betas: [0.9, 0.95]
73
+ eps: 1e-5
74
+
75
+ lr_scheduler:
76
+ _target_: torch.optim.lr_scheduler.LambdaLR
77
+ _partial_: true
78
+ lr_lambda:
79
+ _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
80
+ _partial_: true
81
+ num_warmup_steps: 10
82
+
83
+ # Callbacks
84
+ callbacks:
85
+ model_checkpoint:
86
+ every_n_train_steps: ${trainer.val_check_interval}
fish_speech/content_sequence.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Literal, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from fish_speech.tokenizer import (
8
+ IM_END_TOKEN,
9
+ MODALITY_TOKENS,
10
+ FishTokenizer,
11
+ )
12
+
13
+
14
+ def restore_ndarray(obj, to_tensor: bool = False):
15
+ if isinstance(obj, dict) and "__ndarray__" in obj:
16
+ obj = np.frombuffer(obj["data"], dtype=obj["dtype"]).reshape(obj["shape"])
17
+
18
+ if to_tensor and isinstance(obj, np.ndarray):
19
+ obj = torch.from_numpy(obj.copy())
20
+
21
+ return obj
22
+
23
+
24
+ @dataclass
25
+ class BasePart:
26
+ type: Literal["text", "vq", "audio"] | None = None
27
+ cal_loss: bool = False
28
+
29
+
30
+ @dataclass(kw_only=True)
31
+ class VQPart(BasePart):
32
+ type = "vq"
33
+ codes: torch.Tensor
34
+
35
+ def __post_init__(self: "VQPart"):
36
+ self.type = "vq"
37
+ self.codes = restore_ndarray(self.codes, to_tensor=True)
38
+
39
+
40
+ @dataclass(kw_only=True)
41
+ class TextPart(BasePart):
42
+ type = "text"
43
+ text: str | None = None
44
+ tokens: list[int] | None = None
45
+
46
+ def __post_init__(self: "TextPart"):
47
+ self.type = "text"
48
+ if self.text is None and self.tokens is None:
49
+ raise ValueError("Either text or tokens must be provided")
50
+
51
+
52
+ @dataclass(kw_only=True)
53
+ class AudioPart(BasePart):
54
+ type = "audio"
55
+ features: torch.Tensor
56
+
57
+ def __post_init__(self: "AudioPart"):
58
+ self.type = "audio"
59
+ self.features = restore_ndarray(self.features, to_tensor=True)
60
+
61
+
62
+ @dataclass(kw_only=True)
63
+ class EncodedMessage:
64
+ tokens: torch.Tensor
65
+ labels: torch.Tensor
66
+ vq_mask_tokens: torch.Tensor | None = None
67
+ vq_mask_labels: torch.Tensor | None = None
68
+ vq_parts: list[torch.Tensor]
69
+ vq_require_losses: torch.Tensor | None = None
70
+ audio_parts: list[torch.Tensor]
71
+ audio_masks: torch.Tensor | None = None
72
+ metadata: dict | None = None
73
+
74
+
75
+ @dataclass
76
+ class ContentSequence:
77
+ """
78
+ Flexible sequence of content parts that supports interleaved multimodal format.
79
+ Example format: <|interleave|><|speaker:1|> TEXT AUDIO <|im_end|><|speaker:2|> TEXT AUDIO <|im_end|>
80
+ """
81
+
82
+ parts: list[BasePart] = field(default_factory=list)
83
+ modality: Literal["text", "voice", "interleave"] | None = None
84
+ metadata: dict | None = None
85
+
86
+ def __init__(
87
+ self: "ContentSequence",
88
+ parts: list[BasePart | dict] | None = None,
89
+ modality: Literal["text", "voice", "interleave"] | None = None,
90
+ metadata: dict | None = None,
91
+ ):
92
+ self.modality = modality
93
+ self.metadata = metadata or {}
94
+
95
+ fixed_parts = []
96
+ for part in parts or []:
97
+ if isinstance(part, dict):
98
+ if part["type"] == "vq":
99
+ part = VQPart(**part)
100
+ elif part["type"] == "audio":
101
+ part = AudioPart(**part)
102
+ elif part["type"] == "text":
103
+ part = TextPart(**part)
104
+ else:
105
+ raise ValueError(f"Unsupported part type: {part['type']}")
106
+ fixed_parts.append(part)
107
+
108
+ self.parts = fixed_parts
109
+
110
+ # If modality is specified, add it at the beginning if it's not already there
111
+ if self.modality and not (
112
+ len(self.parts) > 0
113
+ and isinstance(self.parts[0], dict) is False
114
+ and isinstance(self.parts[0], TextPart)
115
+ and self.parts[0].text is not None
116
+ and self.parts[0].text.startswith(MODALITY_TOKENS[self.modality])
117
+ ):
118
+ modality_token = MODALITY_TOKENS[self.modality]
119
+ self.parts.insert(0, TextPart(text=modality_token))
120
+
121
+ def append(
122
+ self: "ContentSequence",
123
+ part_or_parts: Union[BasePart, List[BasePart]],
124
+ add_end: bool = False,
125
+ speaker: Union[str, int] | None = None,
126
+ ):
127
+ """
128
+ Append a part or list of parts to the sequence.
129
+
130
+ Args:
131
+ part_or_parts: A single part or list of parts to add
132
+ add_end: Whether to add the IM_END_TOKEN after these parts
133
+ speaker: Optional speaker identifier (name or ID) to add before the parts
134
+ """
135
+ # Convert single part to list
136
+ parts_to_add = (
137
+ [part_or_parts] if not isinstance(part_or_parts, list) else part_or_parts
138
+ )
139
+
140
+ # Add speaker token if specified
141
+ if speaker is not None:
142
+ speaker_token = f"<|speaker:{speaker}|>"
143
+ self.parts.append(TextPart(text=speaker_token))
144
+
145
+ # Add all the parts
146
+ self.parts.extend(parts_to_add)
147
+
148
+ # Add end token if requested
149
+ if add_end:
150
+ self.parts.append(
151
+ TextPart(text=IM_END_TOKEN, cal_loss=self.parts[-1].cal_loss)
152
+ )
153
+
154
+ def encode(
155
+ self: "ContentSequence",
156
+ tokenizer: FishTokenizer,
157
+ add_shift: bool = True,
158
+ ignore_loss_tokens: list[str] = [],
159
+ ) -> EncodedMessage:
160
+ """
161
+ Encode the sequence parts into tokens for the model.
162
+
163
+ Args:
164
+ tokenizer: The tokenizer to use
165
+ add_shift: Whether to shift tokens for next-token prediction
166
+ ignore_loss_tokens: List of token strings to ignore when calculating loss
167
+
168
+ Returns:
169
+ EncodedMessage with tensors ready for the model
170
+ """
171
+ all_tokens = []
172
+ all_labels = []
173
+
174
+ # Multi-modal elements
175
+ vq_parts = []
176
+ vq_masks = []
177
+ vq_require_losses = []
178
+
179
+ audio_parts = []
180
+ audio_masks = []
181
+
182
+ ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
183
+
184
+ for part in self.parts:
185
+ if isinstance(part, TextPart):
186
+ if part.tokens is None:
187
+ assert part.text is not None
188
+ tokens = tokenizer.encode(part.text)
189
+ else:
190
+ tokens = part.tokens
191
+
192
+ tokens = torch.tensor(tokens, dtype=torch.int)
193
+ elif isinstance(part, VQPart):
194
+ curr_codes = part.codes.clone().to(torch.int)
195
+ tokens = torch.tensor(
196
+ [
197
+ tokenizer.semantic_id_to_token_id[int(i.item())]
198
+ for i in curr_codes[0].int()
199
+ ],
200
+ dtype=torch.int,
201
+ )
202
+ vq_parts.append(curr_codes)
203
+ vq_require_losses.append(part.cal_loss)
204
+ else:
205
+ raise ValueError(f"Unsupported part type: {type(part)}")
206
+
207
+ all_tokens.append(tokens)
208
+
209
+ # Set masks for different part types
210
+ if isinstance(part, VQPart):
211
+ vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
212
+ audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
213
+ elif isinstance(part, AudioPart):
214
+ vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
215
+ audio_mask = torch.ones_like(tokens, dtype=torch.bool)
216
+ audio_mask[0] = False # Skip start token
217
+ audio_mask[-1] = False # Skip end token
218
+ audio_masks.append(audio_mask)
219
+ else:
220
+ vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
221
+ audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
222
+
223
+ # Set labels based on whether we want to calculate loss for this part
224
+ if part.cal_loss and not isinstance(part, AudioPart):
225
+ all_labels.append(tokens.clone())
226
+ else:
227
+ all_labels.append(torch.full_like(tokens, -100))
228
+
229
+ # Concatenate all tensors
230
+ tokens = torch.cat(all_tokens, dim=0)
231
+ labels = torch.cat(all_labels, dim=0)
232
+ vq_masks = torch.cat(vq_masks, dim=0)
233
+ audio_masks = torch.cat(audio_masks, dim=0)
234
+ vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
235
+
236
+ # Apply shift if needed for next-token prediction
237
+ vq_mask_tokens = vq_masks
238
+ vq_mask_labels = vq_masks
239
+
240
+ if add_shift:
241
+ tokens = tokens[:-1]
242
+ labels = labels[1:]
243
+ vq_masks = vq_masks[:-1]
244
+ vq_mask_tokens = vq_mask_tokens[:-1]
245
+ vq_mask_labels = vq_mask_labels[1:]
246
+ audio_masks = audio_masks[:-1]
247
+
248
+ # Ignore specified tokens
249
+ for i in ignore_loss_token_ids:
250
+ assert i != -100 and i is not None
251
+ labels[labels == i] = -100
252
+
253
+ assert tokens.dtype in [
254
+ torch.int,
255
+ torch.long,
256
+ ], f"Invalid dtype: {tokens.dtype}"
257
+
258
+ return EncodedMessage(
259
+ tokens=tokens,
260
+ labels=labels,
261
+ vq_parts=vq_parts,
262
+ vq_mask_tokens=vq_mask_tokens,
263
+ vq_mask_labels=vq_mask_labels,
264
+ vq_require_losses=vq_require_losses,
265
+ audio_parts=audio_parts,
266
+ audio_masks=audio_masks,
267
+ metadata=self.metadata,
268
+ )
269
+
270
+ def encode_for_inference(
271
+ self: "ContentSequence",
272
+ tokenizer: FishTokenizer,
273
+ num_codebooks: int,
274
+ ) -> torch.Tensor:
275
+ encoded = self.encode(tokenizer, add_shift=False)
276
+ tokens = encoded.tokens
277
+ values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
278
+ values[0] = tokens
279
+
280
+ if (encoded.vq_parts is None or len(encoded.vq_parts) == 0) and (
281
+ encoded.audio_parts is None or len(encoded.audio_parts) == 0
282
+ ):
283
+ return values
284
+
285
+ if encoded.vq_parts is not None and len(encoded.vq_parts) > 0:
286
+ vq_parts = encoded.vq_parts
287
+ vq_parts = torch.cat(vq_parts, dim=1)
288
+ values[0, encoded.vq_mask_tokens] = (
289
+ vq_parts[0] + tokenizer.semantic_begin_id
290
+ )
291
+ values[1:, encoded.vq_mask_tokens] = vq_parts
292
+
293
+ return values
294
+
295
+ def visualize(
296
+ self: "ContentSequence",
297
+ tokenizer: FishTokenizer,
298
+ ignore_loss_tokens: list[str] = [],
299
+ merge_semantic_tokens: bool = False,
300
+ ):
301
+ """
302
+ Visualize the encoded sequence with color-coded tokens.
303
+ Blue/cyan tokens contribute to loss, green tokens do not.
304
+ """
305
+ encoded = self.encode(
306
+ tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
307
+ )
308
+
309
+ # Colors for alternating tokens
310
+ colors = {
311
+ "blue": "\033[94m", # Light blue
312
+ "cyan": "\033[96m", # Cyan
313
+ "green": "\033[92m", # Light green
314
+ "dark_green": "\033[32m", # Dark green
315
+ }
316
+ blue_idx = 0
317
+ green_idx = 0
318
+
319
+ def print_in_blue(x):
320
+ nonlocal blue_idx
321
+ color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
322
+ print(f"{color}{x}\033[0m", end="")
323
+ blue_idx += 1
324
+
325
+ def print_in_green(x):
326
+ nonlocal green_idx
327
+ color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
328
+ print(f"{color}{x}\033[0m", end="")
329
+ green_idx += 1
330
+
331
+ def print_semantic_token(x, count):
332
+ val = f"[<|semantic|>x{count}]"
333
+ if x == -100:
334
+ print_in_green(val)
335
+ else:
336
+ print_in_blue(val)
337
+
338
+ count_semantic_tokens = 0
339
+ semantic_label = None
340
+
341
+ for tok, lab in zip(encoded.tokens, encoded.labels):
342
+ token_id = int(tok.item())
343
+
344
+ if merge_semantic_tokens:
345
+ if (
346
+ tokenizer.semantic_begin_id <= token_id <= tokenizer.semantic_end_id
347
+ and (semantic_label is None or semantic_label == lab)
348
+ ):
349
+ count_semantic_tokens += 1
350
+ semantic_label = lab
351
+ continue
352
+ elif count_semantic_tokens > 0:
353
+ print_semantic_token(semantic_label, count_semantic_tokens)
354
+ count_semantic_tokens = 0
355
+ semantic_label = None
356
+
357
+ val = tokenizer.decode([int(tok.item())])
358
+
359
+ if lab == -100:
360
+ print_in_green(val)
361
+ else:
362
+ print_in_blue(val)
363
+
364
+ if merge_semantic_tokens and count_semantic_tokens > 0:
365
+ print_semantic_token(semantic_label, count_semantic_tokens)
366
+
367
+ print()
fish_speech/i18n/README.md CHANGED
@@ -1,27 +1,27 @@
1
- ## i18n Folder Attribution
2
-
3
- The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
4
-
5
- ### fish_speech/i18n/core.py
6
-
7
- **Related code from RVC:**
8
- [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
9
-
10
- **Initial commit:**
11
- add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
12
-
13
- **Initial author:**
14
- [@L4Ph](https://github.com/L4Ph)
15
-
16
- ### fish_speech/i18n/scan.py
17
-
18
- **Related code from RVC:**
19
- [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
20
-
21
- **Initial commit:**
22
- File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
23
-
24
- **Initial author:**
25
- [@towzeur](https://github.com/towzeur)
26
-
27
- We appreciate the contributions of the RVC project and its authors.
 
1
+ ## i18n Folder Attribution
2
+
3
+ The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
4
+
5
+ ### fish_speech/i18n/core.py
6
+
7
+ **Related code from RVC:**
8
+ [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
9
+
10
+ **Initial commit:**
11
+ add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
12
+
13
+ **Initial author:**
14
+ [@L4Ph](https://github.com/L4Ph)
15
+
16
+ ### fish_speech/i18n/scan.py
17
+
18
+ **Related code from RVC:**
19
+ [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
20
+
21
+ **Initial commit:**
22
+ File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
23
+
24
+ **Initial author:**
25
+ [@towzeur](https://github.com/towzeur)
26
+
27
+ We appreciate the contributions of the RVC project and its authors.
fish_speech/i18n/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from .core import i18n
2
-
3
- __all__ = ["i18n"]
 
1
+ from .core import i18n
2
+
3
+ __all__ = ["i18n"]
fish_speech/i18n/core.py CHANGED
@@ -1,40 +1,40 @@
1
- import json
2
- import locale
3
- from pathlib import Path
4
-
5
- I18N_FILE_PATH = Path(__file__).parent / "locale"
6
- DEFAULT_LANGUAGE = "en_US"
7
-
8
-
9
- def load_language_list(language):
10
- with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
11
- language_list = json.load(f)
12
-
13
- return language_list
14
-
15
-
16
- class I18nAuto:
17
- def __init__(self):
18
- i18n_file = Path(".locale")
19
-
20
- if i18n_file.exists():
21
- with open(i18n_file, "r", encoding="utf-8") as f:
22
- language = f.read().strip()
23
- else:
24
- # getlocale can't identify the system's language ((None, None))
25
- language = locale.getdefaultlocale()[0]
26
-
27
- if (I18N_FILE_PATH / f"{language}.json").exists() is False:
28
- language = DEFAULT_LANGUAGE
29
-
30
- self.language = language
31
- self.language_map = load_language_list(language)
32
-
33
- def __call__(self, key):
34
- return self.language_map.get(key, key)
35
-
36
- def __repr__(self):
37
- return "Use Language: " + self.language
38
-
39
-
40
- i18n = I18nAuto()
 
1
+ import json
2
+ import locale
3
+ from pathlib import Path
4
+
5
+ I18N_FILE_PATH = Path(__file__).parent / "locale"
6
+ DEFAULT_LANGUAGE = "en_US"
7
+
8
+
9
+ def load_language_list(language):
10
+ with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
11
+ language_list = json.load(f)
12
+
13
+ return language_list
14
+
15
+
16
+ class I18nAuto:
17
+ def __init__(self):
18
+ i18n_file = Path(".locale")
19
+
20
+ if i18n_file.exists():
21
+ with open(i18n_file, "r", encoding="utf-8") as f:
22
+ language = f.read().strip()
23
+ else:
24
+ # getlocale can't identify the system's language ((None, None))
25
+ language = locale.getdefaultlocale()[0]
26
+
27
+ if (I18N_FILE_PATH / f"{language}.json").exists() is False:
28
+ language = DEFAULT_LANGUAGE
29
+
30
+ self.language = language
31
+ self.language_map = load_language_list(language)
32
+
33
+ def __call__(self, key):
34
+ return self.language_map.get(key, key)
35
+
36
+ def __repr__(self):
37
+ return "Use Language: " + self.language
38
+
39
+
40
+ i18n = I18nAuto()
fish_speech/i18n/locale/en_US.json CHANGED
@@ -1,123 +1,123 @@
1
- {
2
- "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
3
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
4
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
5
- "Accumulate Gradient Batches": "Accumulate Gradient Batches",
6
- "Add to Processing Area": "Add to Processing Area",
7
- "Added path successfully!": "Added path successfully!",
8
- "Advanced Config": "Advanced Config",
9
- "Base LLAMA Model": "Base LLAMA Model",
10
- "Batch Inference": "Batch Inference",
11
- "Batch Size": "Batch Size",
12
- "Changing with the Model Path": "Changing with the Model Path",
13
- "Chinese": "Chinese",
14
- "Compile Model": "Compile Model",
15
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
16
- "Copy": "Copy",
17
- "Data Preprocessing": "Data Preprocessing",
18
- "Data Preprocessing Path": "Data Preprocessing Path",
19
- "Data Source": "Data Source",
20
- "Decoder Model Config": "Decoder Model Config",
21
- "Decoder Model Path": "Decoder Model Path",
22
- "Disabled": "Disabled",
23
- "Enable Reference Audio": "Enable Reference Audio",
24
- "English": "English",
25
- "Error Message": "Error Message",
26
- "File Preprocessing": "File Preprocessing",
27
- "Generate": "Generate",
28
- "Generated Audio": "Generated Audio",
29
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
30
- "Infer interface is closed": "Infer interface is closed",
31
- "Inference Configuration": "Inference Configuration",
32
- "Inference Server Configuration": "Inference Server Configuration",
33
- "Inference Server Error": "Inference Server Error",
34
- "Inferring interface is launched at {}": "Inferring interface is launched at {}",
35
- "Initial Learning Rate": "Initial Learning Rate",
36
- "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
37
- "Input Text": "Input Text",
38
- "Invalid path: {}": "Invalid path: {}",
39
- "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
40
- "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
41
- "Japanese": "Japanese",
42
- "LLAMA Configuration": "LLAMA Configuration",
43
- "LLAMA Model Config": "LLAMA Model Config",
44
- "LLAMA Model Path": "LLAMA Model Path",
45
- "Labeling Device": "Labeling Device",
46
- "LoRA Model to be merged": "LoRA Model to be merged",
47
- "Maximum Audio Duration": "Maximum Audio Duration",
48
- "Maximum Length per Sample": "Maximum Length per Sample",
49
- "Maximum Training Steps": "Maximum Training Steps",
50
- "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
51
- "Merge": "Merge",
52
- "Merge LoRA": "Merge LoRA",
53
- "Merge successfully": "Merge successfully",
54
- "Minimum Audio Duration": "Minimum Audio Duration",
55
- "Model Output Path": "Model Output Path",
56
- "Model Size": "Model Size",
57
- "Move": "Move",
58
- "Move files successfully": "Move files successfully",
59
- "No audio generated, please check the input text.": "No audio generated, please check the input text.",
60
- "No selected options": "No selected options",
61
- "Number of Workers": "Number of Workers",
62
- "Open Inference Server": "Open Inference Server",
63
- "Open Labeler WebUI": "Open Labeler WebUI",
64
- "Open Tensorboard": "Open Tensorboard",
65
- "Opened labeler in browser": "Opened labeler in browser",
66
- "Optional Label Language": "Optional Label Language",
67
- "Optional online ver": "Optional online ver",
68
- "Output Path": "Output Path",
69
- "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
70
- "Precision": "Precision",
71
- "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
72
- "Put your text here.": "Put your text here.",
73
- "Reference Audio": "Reference Audio",
74
- "Reference Text": "Reference Text",
75
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
76
- "Remove Selected Data": "Remove Selected Data",
77
- "Removed path successfully!": "Removed path successfully!",
78
- "Repetition Penalty": "Repetition Penalty",
79
- "Save model every n steps": "Save model every n steps",
80
- "Select LLAMA ckpt": "Select LLAMA ckpt",
81
- "Select VITS ckpt": "Select VITS ckpt",
82
- "Select VQGAN ckpt": "Select VQGAN ckpt",
83
- "Select source file processing method": "Select source file processing method",
84
- "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
85
- "Selected: {}": "Selected: {}",
86
- "Speaker": "Speaker",
87
- "Speaker is identified by the folder name": "Speaker is identified by the folder name",
88
- "Start Training": "Start Training",
89
- "Streaming Audio": "Streaming Audio",
90
- "Streaming Generate": "Streaming Generate",
91
- "Tensorboard Host": "Tensorboard Host",
92
- "Tensorboard Log Path": "Tensorboard Log Path",
93
- "Tensorboard Port": "Tensorboard Port",
94
- "Tensorboard interface is closed": "Tensorboard interface is closed",
95
- "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
96
- "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
97
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
98
- "Training Configuration": "Training Configuration",
99
- "Training Error": "Training Error",
100
- "Training stopped": "Training stopped",
101
- "Type name of the speaker": "Type name of the speaker",
102
- "Type the path or select from the dropdown": "Type the path or select from the dropdown",
103
- "Use LoRA": "Use LoRA",
104
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
105
- "Use filelist": "Use filelist",
106
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
107
- "VITS Configuration": "VITS Configuration",
108
- "VQGAN Configuration": "VQGAN Configuration",
109
- "Validation Batch Size": "Validation Batch Size",
110
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
111
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
112
- "WebUI Host": "WebUI Host",
113
- "WebUI Port": "WebUI Port",
114
- "Whisper Model": "Whisper Model",
115
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
116
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
117
- "latest": "latest",
118
- "new": "new",
119
- "Realtime Transform Text": "Realtime Transform Text",
120
- "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
121
- "Text Normalization": "Text Normalization",
122
- "Select Example Audio": "Select Example Audio"
123
- }
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
5
+ "Accumulate Gradient Batches": "Accumulate Gradient Batches",
6
+ "Add to Processing Area": "Add to Processing Area",
7
+ "Added path successfully!": "Added path successfully!",
8
+ "Advanced Config": "Advanced Config",
9
+ "Base LLAMA Model": "Base LLAMA Model",
10
+ "Batch Inference": "Batch Inference",
11
+ "Batch Size": "Batch Size",
12
+ "Changing with the Model Path": "Changing with the Model Path",
13
+ "Chinese": "Chinese",
14
+ "Compile Model": "Compile Model",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
16
+ "Copy": "Copy",
17
+ "Data Preprocessing": "Data Preprocessing",
18
+ "Data Preprocessing Path": "Data Preprocessing Path",
19
+ "Data Source": "Data Source",
20
+ "Decoder Model Config": "Decoder Model Config",
21
+ "Decoder Model Path": "Decoder Model Path",
22
+ "Disabled": "Disabled",
23
+ "Enable Reference Audio": "Enable Reference Audio",
24
+ "English": "English",
25
+ "Error Message": "Error Message",
26
+ "File Preprocessing": "File Preprocessing",
27
+ "Generate": "Generate",
28
+ "Generated Audio": "Generated Audio",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
30
+ "Infer interface is closed": "Infer interface is closed",
31
+ "Inference Configuration": "Inference Configuration",
32
+ "Inference Server Configuration": "Inference Server Configuration",
33
+ "Inference Server Error": "Inference Server Error",
34
+ "Inferring interface is launched at {}": "Inferring interface is launched at {}",
35
+ "Initial Learning Rate": "Initial Learning Rate",
36
+ "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
37
+ "Input Text": "Input Text",
38
+ "Invalid path: {}": "Invalid path: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
40
+ "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
41
+ "Japanese": "Japanese",
42
+ "LLAMA Configuration": "LLAMA Configuration",
43
+ "LLAMA Model Config": "LLAMA Model Config",
44
+ "LLAMA Model Path": "LLAMA Model Path",
45
+ "Labeling Device": "Labeling Device",
46
+ "LoRA Model to be merged": "LoRA Model to be merged",
47
+ "Maximum Audio Duration": "Maximum Audio Duration",
48
+ "Maximum Length per Sample": "Maximum Length per Sample",
49
+ "Maximum Training Steps": "Maximum Training Steps",
50
+ "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
51
+ "Merge": "Merge",
52
+ "Merge LoRA": "Merge LoRA",
53
+ "Merge successfully": "Merge successfully",
54
+ "Minimum Audio Duration": "Minimum Audio Duration",
55
+ "Model Output Path": "Model Output Path",
56
+ "Model Size": "Model Size",
57
+ "Move": "Move",
58
+ "Move files successfully": "Move files successfully",
59
+ "No audio generated, please check the input text.": "No audio generated, please check the input text.",
60
+ "No selected options": "No selected options",
61
+ "Number of Workers": "Number of Workers",
62
+ "Open Inference Server": "Open Inference Server",
63
+ "Open Labeler WebUI": "Open Labeler WebUI",
64
+ "Open Tensorboard": "Open Tensorboard",
65
+ "Opened labeler in browser": "Opened labeler in browser",
66
+ "Optional Label Language": "Optional Label Language",
67
+ "Optional online ver": "Optional online ver",
68
+ "Output Path": "Output Path",
69
+ "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
70
+ "Precision": "Precision",
71
+ "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
72
+ "Put your text here.": "Put your text here.",
73
+ "Reference Audio": "Reference Audio",
74
+ "Reference Text": "Reference Text",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
76
+ "Remove Selected Data": "Remove Selected Data",
77
+ "Removed path successfully!": "Removed path successfully!",
78
+ "Repetition Penalty": "Repetition Penalty",
79
+ "Save model every n steps": "Save model every n steps",
80
+ "Select LLAMA ckpt": "Select LLAMA ckpt",
81
+ "Select VITS ckpt": "Select VITS ckpt",
82
+ "Select VQGAN ckpt": "Select VQGAN ckpt",
83
+ "Select source file processing method": "Select source file processing method",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
85
+ "Selected: {}": "Selected: {}",
86
+ "Speaker": "Speaker",
87
+ "Speaker is identified by the folder name": "Speaker is identified by the folder name",
88
+ "Start Training": "Start Training",
89
+ "Streaming Audio": "Streaming Audio",
90
+ "Streaming Generate": "Streaming Generate",
91
+ "Tensorboard Host": "Tensorboard Host",
92
+ "Tensorboard Log Path": "Tensorboard Log Path",
93
+ "Tensorboard Port": "Tensorboard Port",
94
+ "Tensorboard interface is closed": "Tensorboard interface is closed",
95
+ "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
96
+ "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
98
+ "Training Configuration": "Training Configuration",
99
+ "Training Error": "Training Error",
100
+ "Training stopped": "Training stopped",
101
+ "Type name of the speaker": "Type name of the speaker",
102
+ "Type the path or select from the dropdown": "Type the path or select from the dropdown",
103
+ "Use LoRA": "Use LoRA",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
105
+ "Use filelist": "Use filelist",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
107
+ "VITS Configuration": "VITS Configuration",
108
+ "VQGAN Configuration": "VQGAN Configuration",
109
+ "Validation Batch Size": "Validation Batch Size",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
112
+ "WebUI Host": "WebUI Host",
113
+ "WebUI Port": "WebUI Port",
114
+ "Whisper Model": "Whisper Model",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
117
+ "latest": "latest",
118
+ "new": "new",
119
+ "Realtime Transform Text": "Realtime Transform Text",
120
+ "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
121
+ "Text Normalization": "Text Normalization",
122
+ "Select Example Audio": "Select Example Audio"
123
+ }
fish_speech/i18n/locale/es_ES.json CHANGED
@@ -1,123 +1,123 @@
1
- {
2
- "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
3
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
4
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
5
- "Accumulate Gradient Batches": "Acumular lotes de gradientes",
6
- "Add to Processing Area": "Agregar al Área de Procesamiento",
7
- "Added path successfully!": "¡Ruta agregada exitosamente!",
8
- "Advanced Config": "Configuración Avanzada",
9
- "Base LLAMA Model": "Modelo Base LLAMA",
10
- "Batch Inference": "Inferencia por Lote",
11
- "Batch Size": "Tamaño del Lote",
12
- "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
13
- "Chinese": "Chino",
14
- "Compile Model": "Compilar Modelo",
15
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
16
- "Copy": "Copiar",
17
- "Data Preprocessing": "Preprocesamiento de Datos",
18
- "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
19
- "Data Source": "Fuente de Datos",
20
- "Decoder Model Config": "Configuración del modelo decodificador",
21
- "Decoder Model Path": "Ruta del modelo decodificador",
22
- "Disabled": "Desactivado",
23
- "Enable Reference Audio": "Habilitar Audio de Referencia",
24
- "English": "Inglés",
25
- "Error Message": "Mensaje de Error",
26
- "File Preprocessing": "Preprocesamiento de Archivos",
27
- "Generate": "Generar",
28
- "Generated Audio": "Audio Generado",
29
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
30
- "Infer interface is closed": "La interfaz de inferencia está cerrada",
31
- "Inference Configuration": "Configuración de Inferencia",
32
- "Inference Server Configuration": "Configuración del Servidor de Inferencia",
33
- "Inference Server Error": "Error del Servidor de Inferencia",
34
- "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
35
- "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
36
- "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
37
- "Input Text": "Texto de Entrada",
38
- "Invalid path: {}": "Ruta inválida: {}",
39
- "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
40
- "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
41
- "Japanese": "Japonés",
42
- "LLAMA Configuration": "Configuración de LLAMA",
43
- "LLAMA Model Config": "Configuración del Modelo LLAMA",
44
- "LLAMA Model Path": "Ruta del Modelo LLAMA",
45
- "Labeling Device": "Dispositivo de Etiquetado",
46
- "LoRA Model to be merged": "Modelo LoRA a fusionar",
47
- "Maximum Audio Duration": "Duración máxima de audio",
48
- "Maximum Length per Sample": "Longitud Máxima por Muestra",
49
- "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
50
- "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
51
- "Merge": "Fusionar",
52
- "Merge LoRA": "Fusionar LoRA",
53
- "Merge successfully": "Fusionado exitosamente",
54
- "Minimum Audio Duration": "Duración mínima de audio",
55
- "Model Output Path": "Ruta de Salida del Modelo",
56
- "Model Size": "Tamaño del Modelo",
57
- "Move": "Mover",
58
- "Move files successfully": "Archivos movidos exitosamente",
59
- "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
60
- "No selected options": "No hay opciones seleccionadas",
61
- "Number of Workers": "Número de Trabajadores",
62
- "Open Inference Server": "Abrir Servidor de Inferencia",
63
- "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
64
- "Open Tensorboard": "Abrir Tensorboard",
65
- "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
66
- "Optional Label Language": "Idioma de Etiquetado Opcional",
67
- "Optional online ver": "Ver en línea opcional",
68
- "Output Path": "Ruta de Salida",
69
- "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
70
- "Precision": "Precisión",
71
- "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
72
- "Put your text here.": "Ponga su texto aquí.",
73
- "Reference Audio": "Audio de Referencia",
74
- "Reference Text": "Texto de Referencia",
75
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
76
- "Remove Selected Data": "Eliminar Datos Seleccionados",
77
- "Removed path successfully!": "¡Ruta eliminada exitosamente!",
78
- "Repetition Penalty": "Penalización por Repetición",
79
- "Save model every n steps": "Guardar modelo cada n pasos",
80
- "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
81
- "Select VITS ckpt": "Seleccionar punto de control VITS",
82
- "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
83
- "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
84
- "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
85
- "Selected: {}": "Seleccionado: {}",
86
- "Speaker": "Hablante",
87
- "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
88
- "Start Training": "Iniciar Entrenamiento",
89
- "Streaming Audio": "transmisión de audio",
90
- "Streaming Generate": "síntesis en flujo",
91
- "Tensorboard Host": "Host de Tensorboard",
92
- "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
93
- "Tensorboard Port": "Puerto de Tensorboard",
94
- "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
95
- "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
96
- "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
97
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
98
- "Training Configuration": "Configuración de Entrenamiento",
99
- "Training Error": "Error de Entrenamiento",
100
- "Training stopped": "Entrenamiento detenido",
101
- "Type name of the speaker": "Escriba el nombre del hablante",
102
- "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
103
- "Use LoRA": "Usar LoRA",
104
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
105
- "Use filelist": "Usar lista de archivos",
106
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
107
- "VITS Configuration": "Configuración de VITS",
108
- "VQGAN Configuration": "Configuración de VQGAN",
109
- "Validation Batch Size": "Tamaño del Lote de Validación",
110
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
111
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
112
- "WebUI Host": "Host de WebUI",
113
- "WebUI Port": "Puerto de WebUI",
114
- "Whisper Model": "Modelo Whisper",
115
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
116
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
117
- "latest": "más reciente",
118
- "new": "nuevo",
119
- "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
120
- "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
121
- "Text Normalization": "Normalización de Texto",
122
- "Select Example Audio": "Selecionar áudio de exemplo"
123
- }
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
5
+ "Accumulate Gradient Batches": "Acumular lotes de gradientes",
6
+ "Add to Processing Area": "Agregar al Área de Procesamiento",
7
+ "Added path successfully!": "¡Ruta agregada exitosamente!",
8
+ "Advanced Config": "Configuración Avanzada",
9
+ "Base LLAMA Model": "Modelo Base LLAMA",
10
+ "Batch Inference": "Inferencia por Lote",
11
+ "Batch Size": "Tamaño del Lote",
12
+ "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
13
+ "Chinese": "Chino",
14
+ "Compile Model": "Compilar Modelo",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
16
+ "Copy": "Copiar",
17
+ "Data Preprocessing": "Preprocesamiento de Datos",
18
+ "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
19
+ "Data Source": "Fuente de Datos",
20
+ "Decoder Model Config": "Configuración del modelo decodificador",
21
+ "Decoder Model Path": "Ruta del modelo decodificador",
22
+ "Disabled": "Desactivado",
23
+ "Enable Reference Audio": "Habilitar Audio de Referencia",
24
+ "English": "Inglés",
25
+ "Error Message": "Mensaje de Error",
26
+ "File Preprocessing": "Preprocesamiento de Archivos",
27
+ "Generate": "Generar",
28
+ "Generated Audio": "Audio Generado",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
30
+ "Infer interface is closed": "La interfaz de inferencia está cerrada",
31
+ "Inference Configuration": "Configuración de Inferencia",
32
+ "Inference Server Configuration": "Configuración del Servidor de Inferencia",
33
+ "Inference Server Error": "Error del Servidor de Inferencia",
34
+ "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
35
+ "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
36
+ "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
37
+ "Input Text": "Texto de Entrada",
38
+ "Invalid path: {}": "Ruta inválida: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
40
+ "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
41
+ "Japanese": "Japonés",
42
+ "LLAMA Configuration": "Configuración de LLAMA",
43
+ "LLAMA Model Config": "Configuración del Modelo LLAMA",
44
+ "LLAMA Model Path": "Ruta del Modelo LLAMA",
45
+ "Labeling Device": "Dispositivo de Etiquetado",
46
+ "LoRA Model to be merged": "Modelo LoRA a fusionar",
47
+ "Maximum Audio Duration": "Duración máxima de audio",
48
+ "Maximum Length per Sample": "Longitud Máxima por Muestra",
49
+ "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
50
+ "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
51
+ "Merge": "Fusionar",
52
+ "Merge LoRA": "Fusionar LoRA",
53
+ "Merge successfully": "Fusionado exitosamente",
54
+ "Minimum Audio Duration": "Duración mínima de audio",
55
+ "Model Output Path": "Ruta de Salida del Modelo",
56
+ "Model Size": "Tamaño del Modelo",
57
+ "Move": "Mover",
58
+ "Move files successfully": "Archivos movidos exitosamente",
59
+ "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
60
+ "No selected options": "No hay opciones seleccionadas",
61
+ "Number of Workers": "Número de Trabajadores",
62
+ "Open Inference Server": "Abrir Servidor de Inferencia",
63
+ "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
64
+ "Open Tensorboard": "Abrir Tensorboard",
65
+ "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
66
+ "Optional Label Language": "Idioma de Etiquetado Opcional",
67
+ "Optional online ver": "Ver en línea opcional",
68
+ "Output Path": "Ruta de Salida",
69
+ "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
70
+ "Precision": "Precisión",
71
+ "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
72
+ "Put your text here.": "Ponga su texto aquí.",
73
+ "Reference Audio": "Audio de Referencia",
74
+ "Reference Text": "Texto de Referencia",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
76
+ "Remove Selected Data": "Eliminar Datos Seleccionados",
77
+ "Removed path successfully!": "¡Ruta eliminada exitosamente!",
78
+ "Repetition Penalty": "Penalización por Repetición",
79
+ "Save model every n steps": "Guardar modelo cada n pasos",
80
+ "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
81
+ "Select VITS ckpt": "Seleccionar punto de control VITS",
82
+ "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
83
+ "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
85
+ "Selected: {}": "Seleccionado: {}",
86
+ "Speaker": "Hablante",
87
+ "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
88
+ "Start Training": "Iniciar Entrenamiento",
89
+ "Streaming Audio": "transmisión de audio",
90
+ "Streaming Generate": "síntesis en flujo",
91
+ "Tensorboard Host": "Host de Tensorboard",
92
+ "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
93
+ "Tensorboard Port": "Puerto de Tensorboard",
94
+ "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
95
+ "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
96
+ "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
98
+ "Training Configuration": "Configuración de Entrenamiento",
99
+ "Training Error": "Error de Entrenamiento",
100
+ "Training stopped": "Entrenamiento detenido",
101
+ "Type name of the speaker": "Escriba el nombre del hablante",
102
+ "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
103
+ "Use LoRA": "Usar LoRA",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
105
+ "Use filelist": "Usar lista de archivos",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
107
+ "VITS Configuration": "Configuración de VITS",
108
+ "VQGAN Configuration": "Configuración de VQGAN",
109
+ "Validation Batch Size": "Tamaño del Lote de Validación",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
112
+ "WebUI Host": "Host de WebUI",
113
+ "WebUI Port": "Puerto de WebUI",
114
+ "Whisper Model": "Modelo Whisper",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
117
+ "latest": "más reciente",
118
+ "new": "nuevo",
119
+ "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
120
+ "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
121
+ "Text Normalization": "Normalización de Texto",
122
+ "Select Example Audio": "Selecionar áudio de exemplo"
123
+ }
fish_speech/i18n/locale/ja_JP.json CHANGED
@@ -1,123 +1,123 @@
1
- {
2
- "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
3
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
4
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
5
- "Accumulate Gradient Batches": "勾配バッチの累積",
6
- "Add to Processing Area": "処理エリアに追加",
7
- "Added path successfully!": "パスの追加に成功しました!",
8
- "Advanced Config": "詳細設定",
9
- "Base LLAMA Model": "基本LLAMAモデル",
10
- "Batch Inference": "バッチ推論",
11
- "Batch Size": "バッチサイズ",
12
- "Changing with the Model Path": "モデルのパスに伴って変化する",
13
- "Chinese": "中国語",
14
- "Compile Model": "モデルのコンパイル",
15
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
16
- "Copy": "コピー",
17
- "Data Preprocessing": "データ前処理",
18
- "Data Preprocessing Path": "データ前処理パス",
19
- "Data Source": "データソース",
20
- "Decoder Model Config": "デコーダーモデルの構成",
21
- "Decoder Model Path": "デコーダーモデルのパス",
22
- "Disabled": "無効",
23
- "Enable Reference Audio": "リファレンスオーディオを有効にする",
24
- "English": "英語",
25
- "Error Message": "エラーメッセージ",
26
- "File Preprocessing": "文書前处理",
27
- "Generate": "生成",
28
- "Generated Audio": "生成されたオーディオ",
29
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
30
- "Infer interface is closed": "推論インターフェースが閉じられています",
31
- "Inference Configuration": "推論設定",
32
- "Inference Server Configuration": "推論サーバー設定",
33
- "Inference Server Error": "推論サーバーエラー",
34
- "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
35
- "Initial Learning Rate": "初期学習率",
36
- "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
37
- "Input Text": "入力テキスト",
38
- "Invalid path: {}": "無効なパス: {}",
39
- "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
40
- "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
41
- "Japanese": "日本語",
42
- "LLAMA Configuration": "LLAMA設定",
43
- "LLAMA Model Config": "LLAMAモデル設定",
44
- "LLAMA Model Path": "LLAMAモデルパス",
45
- "Labeling Device": "ラベリングデバイス",
46
- "LoRA Model to be merged": "マージするLoRAモデル",
47
- "Maximum Audio Duration": "最大オーディオの長さ",
48
- "Maximum Length per Sample": "サンプルあたりの最大長",
49
- "Maximum Training Steps": "最大トレーニングステップ数",
50
- "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
51
- "Merge": "マージ",
52
- "Merge LoRA": "LoRAのマージ",
53
- "Merge successfully": "マージに成功しました",
54
- "Minimum Audio Duration": "最小オーディオの長さ",
55
- "Model Output Path": "モデル出力パス",
56
- "Model Size": "モデルサイズ",
57
- "Move": "移動",
58
- "Move files successfully": "ファイルの移動に成功しました",
59
- "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
60
- "No selected options": "選択されたオプションはありません",
61
- "Number of Workers": "ワーカー数",
62
- "Open Inference Server": "推論サーバーを開く",
63
- "Open Labeler WebUI": "ラベラーWebUIを開く",
64
- "Open Tensorboard": "Tensorboardを開く",
65
- "Opened labeler in browser": "ブラウザでラベラーを開きました",
66
- "Optional Label Language": "オプションのラベル言語",
67
- "Optional online ver": "オプションのオンラインバージョン",
68
- "Output Path": "出力パス",
69
- "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
70
- "Precision": "精度",
71
- "Probability of applying Speaker Condition": "話者条件を適用する確率",
72
- "Put your text here.": "ここにテキスト���入力してください。",
73
- "Reference Audio": "リファレンスオーディオ",
74
- "Reference Text": "リファレンステキスト",
75
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
76
- "Remove Selected Data": "選択したデータを削除",
77
- "Removed path successfully!": "パスの削除に成功しました!",
78
- "Repetition Penalty": "反復ペナルティ",
79
- "Save model every n steps": "nステップごとにモデルを保存",
80
- "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
81
- "Select VITS ckpt": "VITS チェックポイントを選択",
82
- "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
83
- "Select source file processing method": "ソースファイルの処理方法を選択",
84
- "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
85
- "Selected: {}": "選択済み: {}",
86
- "Speaker": "話者",
87
- "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
88
- "Start Training": "トレーニング開始",
89
- "Streaming Audio": "ストリーミングオーディオ",
90
- "Streaming Generate": "ストリーミング合成",
91
- "Tensorboard Host": "Tensorboardホスト",
92
- "Tensorboard Log Path": "Tensorboardログパス",
93
- "Tensorboard Port": "Tensorboardポート",
94
- "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
95
- "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
96
- "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
97
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
98
- "Training Configuration": "トレーニング設定",
99
- "Training Error": "トレーニングエラー",
100
- "Training stopped": "トレーニングが停止しました",
101
- "Type name of the speaker": "話者の名前を入力",
102
- "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
103
- "Use LoRA": "LoRAを使用",
104
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
105
- "Use filelist": "ファイルリストを使用",
106
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
107
- "VITS Configuration": "VITS の構成",
108
- "VQGAN Configuration": "VQGAN の構成",
109
- "Validation Batch Size": "検証バッチサイズ",
110
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
111
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
112
- "WebUI Host": "WebUIホスト",
113
- "WebUI Port": "WebUIポート",
114
- "Whisper Model": "Whisperモデル",
115
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
116
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
117
- "latest": "最新",
118
- "new": "新規",
119
- "Realtime Transform Text": "リアルタイム変換テキスト",
120
- "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
121
- "Text Normalization": "テキスト正規化",
122
- "Select Example Audio": "サンプル音声を選択"
123
- }
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成��デル。",
5
+ "Accumulate Gradient Batches": "勾配バッチの累積",
6
+ "Add to Processing Area": "処理エリアに追加",
7
+ "Added path successfully!": "パスの追加に成功しました!",
8
+ "Advanced Config": "詳細設定",
9
+ "Base LLAMA Model": "基本LLAMAモデル",
10
+ "Batch Inference": "バッチ推論",
11
+ "Batch Size": "バッチサイズ",
12
+ "Changing with the Model Path": "モデルのパスに伴って変化する",
13
+ "Chinese": "中国語",
14
+ "Compile Model": "モデルのコンパイル",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
16
+ "Copy": "コピー",
17
+ "Data Preprocessing": "データ前処理",
18
+ "Data Preprocessing Path": "データ前処理パス",
19
+ "Data Source": "データソース",
20
+ "Decoder Model Config": "デコーダーモデルの構成",
21
+ "Decoder Model Path": "デコーダーモデルのパス",
22
+ "Disabled": "無効",
23
+ "Enable Reference Audio": "リファレンスオーディオを有効にする",
24
+ "English": "英語",
25
+ "Error Message": "エラーメッセージ",
26
+ "File Preprocessing": "文書前处理",
27
+ "Generate": "生成",
28
+ "Generated Audio": "生成されたオーディオ",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
30
+ "Infer interface is closed": "推論インターフェースが閉じられています",
31
+ "Inference Configuration": "推論設定",
32
+ "Inference Server Configuration": "推論サーバー設定",
33
+ "Inference Server Error": "推論サーバーエラー",
34
+ "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
35
+ "Initial Learning Rate": "初期学習率",
36
+ "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
37
+ "Input Text": "入力テキスト",
38
+ "Invalid path: {}": "無効なパス: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
40
+ "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
41
+ "Japanese": "日本語",
42
+ "LLAMA Configuration": "LLAMA設定",
43
+ "LLAMA Model Config": "LLAMAモデル設定",
44
+ "LLAMA Model Path": "LLAMAモデルパス",
45
+ "Labeling Device": "ラベリングデバイス",
46
+ "LoRA Model to be merged": "マージするLoRAモデル",
47
+ "Maximum Audio Duration": "最大オーディオの長さ",
48
+ "Maximum Length per Sample": "サンプルあたりの最大長",
49
+ "Maximum Training Steps": "最大トレーニングステップ数",
50
+ "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
51
+ "Merge": "マージ",
52
+ "Merge LoRA": "LoRAのマージ",
53
+ "Merge successfully": "マージに成功しました",
54
+ "Minimum Audio Duration": "最小オーディオの長さ",
55
+ "Model Output Path": "モデル出力パス",
56
+ "Model Size": "モデルサイズ",
57
+ "Move": "移動",
58
+ "Move files successfully": "ファイルの移動に成功しました",
59
+ "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
60
+ "No selected options": "選択されたオプションはありません",
61
+ "Number of Workers": "ワーカー数",
62
+ "Open Inference Server": "推論サーバーを開く",
63
+ "Open Labeler WebUI": "ラベラーWebUIを開く",
64
+ "Open Tensorboard": "Tensorboardを開く",
65
+ "Opened labeler in browser": "ブラウザでラベラーを開きました",
66
+ "Optional Label Language": "オプションのラベル言語",
67
+ "Optional online ver": "オプションのオンラインバージョン",
68
+ "Output Path": "出力パス",
69
+ "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
70
+ "Precision": "精度",
71
+ "Probability of applying Speaker Condition": "話者条件を適用する確率",
72
+ "Put your text here.": "ここにテキストを入力してください。",
73
+ "Reference Audio": "リファレンスオーディオ",
74
+ "Reference Text": "リファレンステキスト",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
76
+ "Remove Selected Data": "選択したデータを削除",
77
+ "Removed path successfully!": "パスの削除に成功しました!",
78
+ "Repetition Penalty": "反復ペナルティ",
79
+ "Save model every n steps": "nステップごとにモデルを保存",
80
+ "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
81
+ "Select VITS ckpt": "VITS チェックポイントを選択",
82
+ "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
83
+ "Select source file processing method": "ソースファイルの処理方法を選択",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
85
+ "Selected: {}": "選択済み: {}",
86
+ "Speaker": "話者",
87
+ "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
88
+ "Start Training": "トレーニング開始",
89
+ "Streaming Audio": "ストリーミングオーディオ",
90
+ "Streaming Generate": "ストリーミング合成",
91
+ "Tensorboard Host": "Tensorboardホスト",
92
+ "Tensorboard Log Path": "Tensorboardログパス",
93
+ "Tensorboard Port": "Tensorboardポート",
94
+ "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
95
+ "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
96
+ "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
98
+ "Training Configuration": "トレーニング設定",
99
+ "Training Error": "トレーニングエラー",
100
+ "Training stopped": "トレーニングが停止しました",
101
+ "Type name of the speaker": "話者の名前を入力",
102
+ "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
103
+ "Use LoRA": "LoRAを使用",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
105
+ "Use filelist": "ファイルリストを使用",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
107
+ "VITS Configuration": "VITS の構成",
108
+ "VQGAN Configuration": "VQGAN の構成",
109
+ "Validation Batch Size": "検証バッチサイズ",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
112
+ "WebUI Host": "WebUIホスト",
113
+ "WebUI Port": "WebUIポート",
114
+ "Whisper Model": "Whisperモデル",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
117
+ "latest": "最新",
118
+ "new": "新規",
119
+ "Realtime Transform Text": "リアルタイム変換テキスト",
120
+ "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
121
+ "Text Normalization": "テキスト正規化",
122
+ "Select Example Audio": "サンプル音声を選択"
123
+ }
fish_speech/i18n/locale/ko_KR.json CHANGED
@@ -1,123 +1,123 @@
1
- {
2
- "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
3
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
4
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
5
- "Accumulate Gradient Batches": "그라디언트 배치 누적",
6
- "Add to Processing Area": "처리 영역에 추가",
7
- "Added path successfully!": "경로가 성공적으로 추가되었습니다!",
8
- "Advanced Config": "고급 설정",
9
- "Base LLAMA Model": "기본 LLAMA 모델",
10
- "Batch Inference": "배치 추론",
11
- "Batch Size": "배치 크기",
12
- "Changing with the Model Path": "모델 경로에 따라 변경 중",
13
- "Chinese": "중국어",
14
- "Compile Model": "모델 컴파일",
15
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
16
- "Copy": "복사",
17
- "Data Preprocessing": "데이터 전처리",
18
- "Data Preprocessing Path": "데이터 전처리 경로",
19
- "Data Source": "데이터 소스",
20
- "Decoder Model Config": "디코더 모델 설정",
21
- "Decoder Model Path": "디코더 모델 경로",
22
- "Disabled": "비활성화 됨",
23
- "Enable Reference Audio": "참고 음성 활성화",
24
- "English": "영어",
25
- "Error Message": "오류 메시지",
26
- "File Preprocessing": "파일 전처리",
27
- "Generate": "생성",
28
- "Generated Audio": "생성된 오디오",
29
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
30
- "Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
31
- "Inference Configuration": "추론 설정",
32
- "Inference Server Configuration": "추론 서버 설정",
33
- "Inference Server Error": "추론 서버 오류",
34
- "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
35
- "Initial Learning Rate": "초기 학습률",
36
- "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
37
- "Input Text": "입력 텍스트",
38
- "Invalid path: {}": "유효하지 않은 경로: {}",
39
- "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
40
- "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
41
- "Japanese": "일본어",
42
- "LLAMA Configuration": "LLAMA 설정",
43
- "LLAMA Model Config": "LLAMA 모델 설정",
44
- "LLAMA Model Path": "LLAMA 모델 경로",
45
- "Labeling Device": "라벨링 장치",
46
- "LoRA Model to be merged": "병합할 LoRA 모델",
47
- "Maximum Audio Duration": "최대 오디오 길이",
48
- "Maximum Length per Sample": "샘플당 최대 길이",
49
- "Maximum Training Steps": "최대 학습 단계",
50
- "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
51
- "Merge": "병합",
52
- "Merge LoRA": "LoRA 병합",
53
- "Merge successfully": "성공적으로 병합 되었습니다.",
54
- "Minimum Audio Duration": "최소 오디오 길이",
55
- "Model Output Path": "모델 출력 경로",
56
- "Model Size": "모델 크기",
57
- "Move": "이동",
58
- "Move files successfully": "파일이 성공적으로 이동되었습니다.",
59
- "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
60
- "No selected options": "옵션이 선택되지 않았습니다.",
61
- "Number of Workers": "작업자 수",
62
- "Open Inference Server": "추론 서버 열기",
63
- "Open Labeler WebUI": "라벨러 WebUI 열기",
64
- "Open Tensorboard": "Tensorboard 열기",
65
- "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
66
- "Optional Label Language": "선택적 라벨 언어",
67
- "Optional online ver": "온라인 버전 선택",
68
- "Output Path": "출력 경로",
69
- "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
70
- "Precision": "정밀도",
71
- "Probability of applying Speaker Condition": "화자 조건 적용 확률",
72
- "Put your text here.": "여기에 텍스트를 입력하세요.",
73
- "Reference Audio": "참고 오디오",
74
- "Reference Text": "참고 텍스트",
75
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
76
- "Remove Selected Data": "선택한 데이터 제거",
77
- "Removed path successfully!": "경로가 성공적으로 제거되었습니다!",
78
- "Repetition Penalty": "반복 패널티",
79
- "Save model every n steps": "n 단계마다 모델 저장",
80
- "Select LLAMA ckpt": "LLAMA ckpt 선택",
81
- "Select VITS ckpt": "VITS ckpt 선택",
82
- "Select VQGAN ckpt": "VQGAN ckpt 선택",
83
- "Select source file processing method": "소스 파일 처리 방법 선택",
84
- "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
85
- "Selected: {}": "선택됨: {}",
86
- "Speaker": "화자",
87
- "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
88
- "Start Training": "학습 시작",
89
- "Streaming Audio": "스트리밍 오디오",
90
- "Streaming Generate": "스트리밍 생성",
91
- "Tensorboard Host": "Tensorboard 호스트",
92
- "Tensorboard Log Path": "Tensorboard 로그 경로",
93
- "Tensorboard Port": "Tensorboard 포트",
94
- "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
95
- "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
96
- "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
97
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
98
- "Training Configuration": "학습 설정",
99
- "Training Error": "학습 오류",
100
- "Training stopped": "학습이 중지되었습니다.",
101
- "Type name of the speaker": "화자의 이름을 입력하세요.",
102
- "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
103
- "Use LoRA": "LoRA 사용",
104
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
105
- "Use filelist": "파일 목록 사용",
106
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
107
- "VITS Configuration": "VITS 설정",
108
- "VQGAN Configuration": "VQGAN 설정",
109
- "Validation Batch Size": "검증 배치 크기",
110
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
111
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
112
- "WebUI Host": "WebUI 호스트",
113
- "WebUI Port": "WebUI 포트",
114
- "Whisper Model": "Whisper 모델",
115
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
116
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
117
- "latest": "최신",
118
- "new": "새로운",
119
- "Realtime Transform Text": "실시간 텍스트 변환",
120
- "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
121
- "Text Normalization": "텍스트 정규화",
122
- "Select Example Audio": "예시 오디오 선택"
123
- }
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
5
+ "Accumulate Gradient Batches": "그라디언트 배치 누적",
6
+ "Add to Processing Area": "처리 영역에 추가",
7
+ "Added path successfully!": "경로가 성공적으로 추가되었습니다!",
8
+ "Advanced Config": "고급 설정",
9
+ "Base LLAMA Model": "기본 LLAMA 모델",
10
+ "Batch Inference": "배치 추론",
11
+ "Batch Size": "배치 크기",
12
+ "Changing with the Model Path": "모델 경로에 따라 변경 중",
13
+ "Chinese": "중국어",
14
+ "Compile Model": "모델 컴파일",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
16
+ "Copy": "복사",
17
+ "Data Preprocessing": "데이터 전처리",
18
+ "Data Preprocessing Path": "데이터 전처리 경로",
19
+ "Data Source": "데이터 소스",
20
+ "Decoder Model Config": "디코더 모델 설정",
21
+ "Decoder Model Path": "디코더 모델 경로",
22
+ "Disabled": "비활성화 됨",
23
+ "Enable Reference Audio": "참고 음성 활성화",
24
+ "English": "영어",
25
+ "Error Message": "오류 메시지",
26
+ "File Preprocessing": "파일 전처리",
27
+ "Generate": "생성",
28
+ "Generated Audio": "생성된 오디오",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
30
+ "Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
31
+ "Inference Configuration": "추론 설정",
32
+ "Inference Server Configuration": "추론 서버 설정",
33
+ "Inference Server Error": "추론 서버 오류",
34
+ "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
35
+ "Initial Learning Rate": "초기 학습률",
36
+ "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
37
+ "Input Text": "입력 텍스트",
38
+ "Invalid path: {}": "유효하지 않은 경로: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
40
+ "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
41
+ "Japanese": "일본어",
42
+ "LLAMA Configuration": "LLAMA 설정",
43
+ "LLAMA Model Config": "LLAMA 모델 설정",
44
+ "LLAMA Model Path": "LLAMA 모델 경로",
45
+ "Labeling Device": "라벨링 장치",
46
+ "LoRA Model to be merged": "병합할 LoRA 모델",
47
+ "Maximum Audio Duration": "최대 오디오 길이",
48
+ "Maximum Length per Sample": "샘플당 최대 길이",
49
+ "Maximum Training Steps": "최대 학습 단계",
50
+ "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
51
+ "Merge": "병합",
52
+ "Merge LoRA": "LoRA 병합",
53
+ "Merge successfully": "성공적으로 병합 되었습니다.",
54
+ "Minimum Audio Duration": "최소 오디오 길이",
55
+ "Model Output Path": "모델 출력 경로",
56
+ "Model Size": "모델 크기",
57
+ "Move": "이동",
58
+ "Move files successfully": "파일이 성공적으로 이동되었습니다.",
59
+ "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
60
+ "No selected options": "옵션이 선택되지 않았습니다.",
61
+ "Number of Workers": "작업자 수",
62
+ "Open Inference Server": "추론 서버 열기",
63
+ "Open Labeler WebUI": "라벨러 WebUI 열기",
64
+ "Open Tensorboard": "Tensorboard 열기",
65
+ "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
66
+ "Optional Label Language": "선택적 라벨 언어",
67
+ "Optional online ver": "온라인 버전 선택",
68
+ "Output Path": "출력 경로",
69
+ "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
70
+ "Precision": "정밀도",
71
+ "Probability of applying Speaker Condition": "화자 조건 적용 확률",
72
+ "Put your text here.": "여기에 텍스트를 입력하세요.",
73
+ "Reference Audio": "참고 오디오",
74
+ "Reference Text": "참고 텍스트",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
76
+ "Remove Selected Data": "선택한 데이터 제거",
77
+ "Removed path successfully!": "경로가 성공적으로 제거되었습니다!",
78
+ "Repetition Penalty": "반복 패널티",
79
+ "Save model every n steps": "n 단계마다 모델 저장",
80
+ "Select LLAMA ckpt": "LLAMA ckpt 선택",
81
+ "Select VITS ckpt": "VITS ckpt 선택",
82
+ "Select VQGAN ckpt": "VQGAN ckpt 선택",
83
+ "Select source file processing method": "소스 파일 처리 방법 선택",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
85
+ "Selected: {}": "선택됨: {}",
86
+ "Speaker": "화자",
87
+ "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
88
+ "Start Training": "학습 시작",
89
+ "Streaming Audio": "스트리밍 오디오",
90
+ "Streaming Generate": "스트리밍 생성",
91
+ "Tensorboard Host": "Tensorboard 호스트",
92
+ "Tensorboard Log Path": "Tensorboard 로그 경로",
93
+ "Tensorboard Port": "Tensorboard 포트",
94
+ "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
95
+ "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
96
+ "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
98
+ "Training Configuration": "학습 설정",
99
+ "Training Error": "학습 오류",
100
+ "Training stopped": "학습이 중지되었습니다.",
101
+ "Type name of the speaker": "화자의 이름을 입력하세요.",
102
+ "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
103
+ "Use LoRA": "LoRA 사용",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
105
+ "Use filelist": "파일 목록 사용",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
107
+ "VITS Configuration": "VITS 설정",
108
+ "VQGAN Configuration": "VQGAN 설정",
109
+ "Validation Batch Size": "검증 배치 크기",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
112
+ "WebUI Host": "WebUI 호스트",
113
+ "WebUI Port": "WebUI 포트",
114
+ "Whisper Model": "Whisper 모델",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
117
+ "latest": "최신",
118
+ "new": "새로운",
119
+ "Realtime Transform Text": "실시간 텍스트 변환",
120
+ "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
121
+ "Text Normalization": "텍스트 정규화",
122
+ "Select Example Audio": "예시 오디오 선택"
123
+ }
fish_speech/i18n/locale/pt_BR.json CHANGED
@@ -1,133 +1,133 @@
1
- {
2
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
3
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
4
- "Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
5
- "Add to Processing Area": "Adicionar à Área de Processamento",
6
- "Added path successfully!": "Caminho adicionado com sucesso!",
7
- "Advanced Config": "Configuração Avançada",
8
- "Base LLAMA Model": "Modelo LLAMA Base",
9
- "Batch Inference": "Inferência em Lote",
10
- "Batch Size": "Tamanho do Lote",
11
- "Changing with the Model Path": "Alterando com o Caminho do Modelo",
12
-
13
- "Compile Model": "Compilar Modelo",
14
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
15
- "Copy": "Copiar",
16
- "Data Preprocessing": "Pré-processamento de Dados",
17
- "Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
18
- "Data Source": "Fonte de Dados",
19
- "Decoder Model Config": "Configuração do Modelo Decodificador",
20
- "Decoder Model Path": "Caminho do Modelo Decodificador",
21
- "Disabled": "Desativado",
22
- "Enable Initial Prompt": "Habilitar Prompt Inicial",
23
- "Enable Reference Audio": "Habilitar Áudio de Referência",
24
- "English": "Inglês",
25
- "Japanese": "Japonês",
26
- "Chinese": "Chinês",
27
- "Portuguese": "Português",
28
- "Spanish": "Espanhol",
29
- "Error Message": "Mensagem de Erro",
30
- "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
31
- "File Preprocessing": "Pré-processamento de Arquivos",
32
- "Generate": "Gerar",
33
- "Generated Audio": "Áudio Gerado",
34
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
35
- "Infer interface is closed": "A interface de inferência foi fechada",
36
- "Inference Configuration": "Configuração de Inferência",
37
- "Inference Server Configuration": "Configuração do Servidor de Inferência",
38
- "Inference Server Error": "Erro do Servidor de Inferência",
39
- "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
40
- "Initial Learning Rate": "Taxa de Aprendizagem Inicial",
41
- "Initial Prompt": "Prompt Inicial",
42
- "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
43
- "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
44
- "Input Text": "Texto de Entrada",
45
- "Invalid path: {}": "Caminho inválido: {}",
46
- "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
47
- "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
48
- "LLAMA Configuration": "Configuração do LLAMA",
49
- "LLAMA Model Config": "Configuração do Modelo LLAMA",
50
- "LLAMA Model Path": "Caminho do Modelo LLAMA",
51
- "Labeling Device": "Dispositivo de Rotulagem",
52
- "LoRA Model to be merged": "Modelo LoRA para mesclagem",
53
- "Maximum Length per Sample": "Comprimento Máximo por Amostra",
54
- "Maximum Training Steps": "Etapas Máximas de Treinamento",
55
- "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
56
- "Merge": "Mesclar",
57
- "Merge LoRA": "Mesclar LoRA",
58
- "Merge successfully": "Mesclado com sucesso",
59
- "Model Output Path": "Caminho de Saída do Modelo",
60
- "Model Quantization": "Quantização do Modelo",
61
- "Model Size": "Tamanho do Modelo",
62
- "Move": "Mover",
63
- "Move files successfully": "Arquivos movidos com sucesso",
64
- "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
65
- "No selected options": "Nenhuma opção selecionada",
66
- "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
67
- "Number of Workers": "Número de Processos",
68
- "Open Inference Server": "Abrir Servidor de Inferência",
69
- "Open Labeler WebUI": "Abrir WebUI de Rotulagem",
70
- "Open Tensorboard": "Abrir Tensorboard",
71
- "Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
72
- "Optional Label Language": "Idioma do Rótulo (Opcional)",
73
- "Optional online ver": "Versão online (opcional)",
74
- "Output Path": "Caminho de Saída",
75
- "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
76
- "Post-quantification Precision": "Precisão Pós-quantização",
77
- "Precision": "Precisão",
78
- "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
79
- "Put your text here.": "Insira seu texto aqui.",
80
- "Quantify": "Quantizar",
81
- "Quantify successfully": "Quantizado com sucesso",
82
- "Realtime Transform Text": "Transformar Texto em Tempo Real",
83
- "Reference Audio": "Áudio de Referência",
84
- "Reference Text": "Texto de Referência",
85
- "warning": "Aviso",
86
- "Pre-processing begins...": "O pré-processamento começou!",
87
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
88
- "Remove Selected Data": "Remover Dados Selecionados",
89
- "Removed path successfully!": "Caminho removido com sucesso!",
90
- "Repetition Penalty": "Penalidade de Repetição",
91
- "Save model every n steps": "Salvar modelo a cada n etapas",
92
- "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
93
- "Select source file processing method": "Escolha como processar o arquivo de origem",
94
- "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
95
- "Selected: {}": "Selecionado: {}",
96
- "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
97
- "Start Training": "Iniciar Treinamento",
98
- "Streaming Audio": "Áudio em Streaming",
99
- "Streaming Generate": "Geração em Streaming",
100
- "Tensorboard Host": "Host do Tensorboard",
101
- "Tensorboard Log Path": "Caminho de Log do Tensorboard",
102
- "Tensorboard Port": "Porta do Tensorboard",
103
- "Tensorboard interface is closed": "A interface do Tensorboard está fechada",
104
- "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
105
- "Text Normalization": "Normalização de Texto",
106
- "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
107
- "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
108
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
109
- "Training Configuration": "Configuração de Treinamento",
110
- "Training Error": "Erro de Treinamento",
111
- "Training stopped": "Treinamento interrompido!",
112
- "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
113
- "Use LoRA": "Usar LoRA",
114
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
115
- "Use filelist": "Usar lista de arquivos",
116
- "VQGAN Configuration": "Configuração do VQGAN",
117
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
118
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
119
- "WebUI Host": "Host da WebUI",
120
- "WebUI Port": "Porta da WebUI",
121
- "Whisper Model": "Modelo Whisper",
122
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
123
- "auto": "automático",
124
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
125
- "latest": "mais recente",
126
- "new": "novo",
127
- "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
128
- "You don't need to train this model!": "Não é necessário treinar este modelo!",
129
- "Yes": "Sim",
130
- "No": "Não",
131
- "version:": "versão:",
132
- "author:": "autor:"
133
- }
 
1
+ {
2
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
3
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
4
+ "Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
5
+ "Add to Processing Area": "Adicionar à Área de Processamento",
6
+ "Added path successfully!": "Caminho adicionado com sucesso!",
7
+ "Advanced Config": "Configuração Avançada",
8
+ "Base LLAMA Model": "Modelo LLAMA Base",
9
+ "Batch Inference": "Inferência em Lote",
10
+ "Batch Size": "Tamanho do Lote",
11
+ "Changing with the Model Path": "Alterando com o Caminho do Modelo",
12
+
13
+ "Compile Model": "Compilar Modelo",
14
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
15
+ "Copy": "Copiar",
16
+ "Data Preprocessing": "Pré-processamento de Dados",
17
+ "Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
18
+ "Data Source": "Fonte de Dados",
19
+ "Decoder Model Config": "Configuração do Modelo Decodificador",
20
+ "Decoder Model Path": "Caminho do Modelo Decodificador",
21
+ "Disabled": "Desativado",
22
+ "Enable Initial Prompt": "Habilitar Prompt Inicial",
23
+ "Enable Reference Audio": "Habilitar Áudio de Referência",
24
+ "English": "Inglês",
25
+ "Japanese": "Japonês",
26
+ "Chinese": "Chinês",
27
+ "Portuguese": "Português",
28
+ "Spanish": "Espanhol",
29
+ "Error Message": "Mensagem de Erro",
30
+ "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
31
+ "File Preprocessing": "Pré-processamento de Arquivos",
32
+ "Generate": "Gerar",
33
+ "Generated Audio": "Áudio Gerado",
34
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
35
+ "Infer interface is closed": "A interface de inferência foi fechada",
36
+ "Inference Configuration": "Configuração de Inferência",
37
+ "Inference Server Configuration": "Configuração do Servidor de Inferência",
38
+ "Inference Server Error": "Erro do Servidor de Inferência",
39
+ "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
40
+ "Initial Learning Rate": "Taxa de Aprendizagem Inicial",
41
+ "Initial Prompt": "Prompt Inicial",
42
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
43
+ "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
44
+ "Input Text": "Texto de Entrada",
45
+ "Invalid path: {}": "Caminho inválido: {}",
46
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
47
+ "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
48
+ "LLAMA Configuration": "Configuração do LLAMA",
49
+ "LLAMA Model Config": "Configuração do Modelo LLAMA",
50
+ "LLAMA Model Path": "Caminho do Modelo LLAMA",
51
+ "Labeling Device": "Dispositivo de Rotulagem",
52
+ "LoRA Model to be merged": "Modelo LoRA para mesclagem",
53
+ "Maximum Length per Sample": "Comprimento Máximo por Amostra",
54
+ "Maximum Training Steps": "Etapas Máximas de Treinamento",
55
+ "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
56
+ "Merge": "Mesclar",
57
+ "Merge LoRA": "Mesclar LoRA",
58
+ "Merge successfully": "Mesclado com sucesso",
59
+ "Model Output Path": "Caminho de Saída do Modelo",
60
+ "Model Quantization": "Quantização do Modelo",
61
+ "Model Size": "Tamanho do Modelo",
62
+ "Move": "Mover",
63
+ "Move files successfully": "Arquivos movidos com sucesso",
64
+ "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
65
+ "No selected options": "Nenhuma opção selecionada",
66
+ "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
67
+ "Number of Workers": "Número de Processos",
68
+ "Open Inference Server": "Abrir Servidor de Inferência",
69
+ "Open Labeler WebUI": "Abrir WebUI de Rotulagem",
70
+ "Open Tensorboard": "Abrir Tensorboard",
71
+ "Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
72
+ "Optional Label Language": "Idioma do Rótulo (Opcional)",
73
+ "Optional online ver": "Versão online (opcional)",
74
+ "Output Path": "Caminho de Saída",
75
+ "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
76
+ "Post-quantification Precision": "Precisão Pós-quantização",
77
+ "Precision": "Precisão",
78
+ "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
79
+ "Put your text here.": "Insira seu texto aqui.",
80
+ "Quantify": "Quantizar",
81
+ "Quantify successfully": "Quantizado com sucesso",
82
+ "Realtime Transform Text": "Transformar Texto em Tempo Real",
83
+ "Reference Audio": "Áudio de Referência",
84
+ "Reference Text": "Texto de Referência",
85
+ "warning": "Aviso",
86
+ "Pre-processing begins...": "O pré-processamento começou!",
87
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
88
+ "Remove Selected Data": "Remover Dados Selecionados",
89
+ "Removed path successfully!": "Caminho removido com sucesso!",
90
+ "Repetition Penalty": "Penalidade de Repetição",
91
+ "Save model every n steps": "Salvar modelo a cada n etapas",
92
+ "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
93
+ "Select source file processing method": "Escolha como processar o arquivo de origem",
94
+ "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
95
+ "Selected: {}": "Selecionado: {}",
96
+ "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
97
+ "Start Training": "Iniciar Treinamento",
98
+ "Streaming Audio": "Áudio em Streaming",
99
+ "Streaming Generate": "Geração em Streaming",
100
+ "Tensorboard Host": "Host do Tensorboard",
101
+ "Tensorboard Log Path": "Caminho de Log do Tensorboard",
102
+ "Tensorboard Port": "Porta do Tensorboard",
103
+ "Tensorboard interface is closed": "A interface do Tensorboard está fechada",
104
+ "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
105
+ "Text Normalization": "Normalização de Texto",
106
+ "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
107
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
108
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
109
+ "Training Configuration": "Configuração de Treinamento",
110
+ "Training Error": "Erro de Treinamento",
111
+ "Training stopped": "Treinamento interrompido!",
112
+ "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
113
+ "Use LoRA": "Usar LoRA",
114
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
115
+ "Use filelist": "Usar lista de arquivos",
116
+ "VQGAN Configuration": "Configuração do VQGAN",
117
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
118
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
119
+ "WebUI Host": "Host da WebUI",
120
+ "WebUI Port": "Porta da WebUI",
121
+ "Whisper Model": "Modelo Whisper",
122
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
123
+ "auto": "automático",
124
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
125
+ "latest": "mais recente",
126
+ "new": "novo",
127
+ "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
128
+ "You don't need to train this model!": "Não é necessário treinar este modelo!",
129
+ "Yes": "Sim",
130
+ "No": "Não",
131
+ "version:": "versão:",
132
+ "author:": "autor:"
133
+ }
fish_speech/i18n/locale/zh_CN.json CHANGED
@@ -1,123 +1,123 @@
1
- {
2
- "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
3
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
4
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
5
- "Accumulate Gradient Batches": "梯度累积批次",
6
- "Add to Processing Area": "加入处理区",
7
- "Added path successfully!": "添加路径成功!",
8
- "Advanced Config": "高级参数",
9
- "Base LLAMA Model": "基础 LLAMA 模型",
10
- "Batch Inference": "批量推理",
11
- "Batch Size": "批次大小",
12
- "Changing with the Model Path": "随模型路径变化",
13
- "Chinese": "中文",
14
- "Compile Model": "编译模型",
15
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
16
- "Copy": "复制",
17
- "Data Preprocessing": "数据预处理",
18
- "Data Preprocessing Path": "数据预处理路径",
19
- "Data Source": "数据源",
20
- "Decoder Model Config": "解码器模型配置",
21
- "Decoder Model Path": "解码器模型路径",
22
- "Disabled": "禁用",
23
- "Enable Reference Audio": "启用参考音频",
24
- "English": "英文",
25
- "Error Message": "错误信息",
26
- "File Preprocessing": "文件预处理",
27
- "Generate": "生成",
28
- "Generated Audio": "音频",
29
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
30
- "Infer interface is closed": "推理界面已关闭",
31
- "Inference Configuration": "推理配置",
32
- "Inference Server Configuration": "推理服务器配置",
33
- "Inference Server Error": "推理服务器错误",
34
- "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
35
- "Initial Learning Rate": "初始学习率",
36
- "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
37
- "Input Text": "输入文本",
38
- "Invalid path: {}": "无效路径: {}",
39
- "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
40
- "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
41
- "Japanese": "日文",
42
- "LLAMA Configuration": "LLAMA 配置",
43
- "LLAMA Model Config": "LLAMA 模型配置",
44
- "LLAMA Model Path": "LLAMA 模型路径",
45
- "Labeling Device": "标注加速设备",
46
- "LoRA Model to be merged": "要合并的 LoRA 模型",
47
- "Maximum Audio Duration": "最大音频时长",
48
- "Maximum Length per Sample": "每个样本的最大长度",
49
- "Maximum Training Steps": "最大训练步数",
50
- "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
51
- "Merge": "合并",
52
- "Merge LoRA": "合并 LoRA",
53
- "Merge successfully": "合并成功",
54
- "Minimum Audio Duration": "最小音频时长",
55
- "Model Output Path": "模型输出路径",
56
- "Model Size": "模型规模",
57
- "Move": "移动",
58
- "Move files successfully": "移动文件成功",
59
- "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
60
- "No selected options": "没有选择的选项",
61
- "Number of Workers": "数据加载进程数",
62
- "Open Inference Server": "打开推理服务器",
63
- "Open Labeler WebUI": "打开标注工具",
64
- "Open Tensorboard": "打开 Tensorboard",
65
- "Opened labeler in browser": "在浏览器中打开标注工具",
66
- "Optional Label Language": "[可选] 标注语言",
67
- "Optional online ver": "[可选] 使用在线版",
68
- "Output Path": "输出路径",
69
- "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
70
- "Precision": "精度",
71
- "Probability of applying Speaker Condition": "应用说话人条件的概率",
72
- "Put your text here.": "在此处输入文本.",
73
- "Reference Audio": "参考音频",
74
- "Reference Text": "参考文本",
75
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
76
- "Remove Selected Data": "移除选中数据",
77
- "Removed path successfully!": "移除路径成功!",
78
- "Repetition Penalty": "重复惩罚",
79
- "Save model every n steps": "每 n 步保存模型",
80
- "Select LLAMA ckpt": "选择 LLAMA 检查点",
81
- "Select VITS ckpt": "选择 VITS 检查点",
82
- "Select VQGAN ckpt": "选择 VQGAN 检查点",
83
- "Select source file processing method": "选择源文件处理方法",
84
- "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
85
- "Selected: {}": "已选择: {}",
86
- "Speaker": "说话人",
87
- "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
88
- "Start Training": "开始训练",
89
- "Streaming Audio": "流式音频",
90
- "Streaming Generate": "流式合成",
91
- "Tensorboard Host": "Tensorboard 监听地址",
92
- "Tensorboard Log Path": "Tensorboard 日志路径",
93
- "Tensorboard Port": "Tensorboard 端口",
94
- "Tensorboard interface is closed": "Tensorboard 界面已关闭",
95
- "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
96
- "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
97
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
98
- "Training Configuration": "训练配置",
99
- "Training Error": "训练错误",
100
- "Training stopped": "训练已停止",
101
- "Type name of the speaker": "输入说话人的名称",
102
- "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
103
- "Use LoRA": "使用 LoRA",
104
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
105
- "Use filelist": "使用文件列表",
106
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
107
- "VITS Configuration": "VITS 配置",
108
- "VQGAN Configuration": "VQGAN 配置",
109
- "Validation Batch Size": "验证批次大小",
110
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
111
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
112
- "WebUI Host": "WebUI 监听地址",
113
- "WebUI Port": "WebUI 端口",
114
- "Whisper Model": "Whisper 模型",
115
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
116
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
117
- "latest": "最近的检查点",
118
- "new": "创建新的检查点",
119
- "Realtime Transform Text": "实时规范化文本",
120
- "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
121
- "Text Normalization": "文本规范化",
122
- "Select Example Audio": "选择参考音频"
123
- }
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
5
+ "Accumulate Gradient Batches": "梯度累积批次",
6
+ "Add to Processing Area": "加入处理区",
7
+ "Added path successfully!": "添加路径成功!",
8
+ "Advanced Config": "高级参数",
9
+ "Base LLAMA Model": "基础 LLAMA 模型",
10
+ "Batch Inference": "批量推理",
11
+ "Batch Size": "批次大小",
12
+ "Changing with the Model Path": "随模型路径变化",
13
+ "Chinese": "中文",
14
+ "Compile Model": "编译模型",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
16
+ "Copy": "复制",
17
+ "Data Preprocessing": "数据预处理",
18
+ "Data Preprocessing Path": "数据预处理路径",
19
+ "Data Source": "数据源",
20
+ "Decoder Model Config": "解码器模型配置",
21
+ "Decoder Model Path": "解码器模型路径",
22
+ "Disabled": "禁用",
23
+ "Enable Reference Audio": "启用参考音频",
24
+ "English": "英文",
25
+ "Error Message": "错误信息",
26
+ "File Preprocessing": "文件预处理",
27
+ "Generate": "生成",
28
+ "Generated Audio": "音频",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
30
+ "Infer interface is closed": "推理界面已关闭",
31
+ "Inference Configuration": "推理配置",
32
+ "Inference Server Configuration": "推理服务器配置",
33
+ "Inference Server Error": "推理服务器错误",
34
+ "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
35
+ "Initial Learning Rate": "初始学习率",
36
+ "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
37
+ "Input Text": "输入文本",
38
+ "Invalid path: {}": "无效���径: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
40
+ "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
41
+ "Japanese": "日文",
42
+ "LLAMA Configuration": "LLAMA 配置",
43
+ "LLAMA Model Config": "LLAMA 模型配置",
44
+ "LLAMA Model Path": "LLAMA 模型路径",
45
+ "Labeling Device": "标注加速设备",
46
+ "LoRA Model to be merged": "要合并的 LoRA 模型",
47
+ "Maximum Audio Duration": "最大音频时长",
48
+ "Maximum Length per Sample": "每个样本的最大长度",
49
+ "Maximum Training Steps": "最大训练步数",
50
+ "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
51
+ "Merge": "合并",
52
+ "Merge LoRA": "合并 LoRA",
53
+ "Merge successfully": "合并成功",
54
+ "Minimum Audio Duration": "最小音频时长",
55
+ "Model Output Path": "模型输出路径",
56
+ "Model Size": "模型规模",
57
+ "Move": "移动",
58
+ "Move files successfully": "移动文件成功",
59
+ "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
60
+ "No selected options": "没有选择的选项",
61
+ "Number of Workers": "数据加载进程数",
62
+ "Open Inference Server": "打开推理服务器",
63
+ "Open Labeler WebUI": "打开标注工具",
64
+ "Open Tensorboard": "打开 Tensorboard",
65
+ "Opened labeler in browser": "在浏览器中打开标注工具",
66
+ "Optional Label Language": "[可选] 标注语言",
67
+ "Optional online ver": "[可选] 使用在线版",
68
+ "Output Path": "输出路径",
69
+ "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
70
+ "Precision": "精度",
71
+ "Probability of applying Speaker Condition": "应用说话人条件的概率",
72
+ "Put your text here.": "在此处输入文本.",
73
+ "Reference Audio": "参考音频",
74
+ "Reference Text": "参考文本",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
76
+ "Remove Selected Data": "移除选中数据",
77
+ "Removed path successfully!": "移除路径成功!",
78
+ "Repetition Penalty": "重复惩罚",
79
+ "Save model every n steps": "每 n 步保存模型",
80
+ "Select LLAMA ckpt": "选择 LLAMA 检查点",
81
+ "Select VITS ckpt": "选择 VITS 检查点",
82
+ "Select VQGAN ckpt": "选择 VQGAN 检查点",
83
+ "Select source file processing method": "选择源文件处理方法",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
85
+ "Selected: {}": "已选择: {}",
86
+ "Speaker": "说话人",
87
+ "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
88
+ "Start Training": "开始训练",
89
+ "Streaming Audio": "流式音频",
90
+ "Streaming Generate": "流式合成",
91
+ "Tensorboard Host": "Tensorboard 监听地址",
92
+ "Tensorboard Log Path": "Tensorboard 日志路径",
93
+ "Tensorboard Port": "Tensorboard 端口",
94
+ "Tensorboard interface is closed": "Tensorboard 界面已关闭",
95
+ "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
96
+ "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
98
+ "Training Configuration": "训练配置",
99
+ "Training Error": "训练错误",
100
+ "Training stopped": "训练已停止",
101
+ "Type name of the speaker": "输入说话人的名称",
102
+ "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
103
+ "Use LoRA": "使用 LoRA",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
105
+ "Use filelist": "使用文件列表",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
107
+ "VITS Configuration": "VITS 配置",
108
+ "VQGAN Configuration": "VQGAN 配置",
109
+ "Validation Batch Size": "验证批次大小",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
112
+ "WebUI Host": "WebUI 监听地址",
113
+ "WebUI Port": "WebUI 端口",
114
+ "Whisper Model": "Whisper 模型",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
117
+ "latest": "最近的检查点",
118
+ "new": "创建新的检查点",
119
+ "Realtime Transform Text": "实时规范化文本",
120
+ "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
121
+ "Text Normalization": "文本规范化",
122
+ "Select Example Audio": "选择参考音频"
123
+ }
fish_speech/i18n/scan.py CHANGED
@@ -1,122 +1,122 @@
1
- import ast
2
- import glob
3
- import json
4
- from collections import OrderedDict
5
- from pathlib import Path
6
-
7
- from loguru import logger
8
-
9
- from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
10
-
11
-
12
- def extract_i18n_strings(node):
13
- i18n_strings = []
14
-
15
- if (
16
- isinstance(node, ast.Call)
17
- and isinstance(node.func, ast.Name)
18
- and node.func.id == "i18n"
19
- ):
20
- for arg in node.args:
21
- if isinstance(arg, ast.Str):
22
- i18n_strings.append(arg.s)
23
-
24
- for child_node in ast.iter_child_nodes(node):
25
- i18n_strings.extend(extract_i18n_strings(child_node))
26
-
27
- return i18n_strings
28
-
29
-
30
- # scan the directory for all .py files (recursively)
31
- # for each file, parse the code into an AST
32
- # for each AST, extract the i18n strings
33
-
34
- strings = []
35
- folders = ["fish_speech", "tools"]
36
- # for filename in glob.iglob("**/*.py", recursive=True):
37
- for folder in folders:
38
- for f in Path(folder).rglob("*.py"):
39
- code = f.read_text(encoding="utf-8")
40
- if "i18n(" in code:
41
- tree = ast.parse(code)
42
- i18n_strings = extract_i18n_strings(tree)
43
- logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
44
- strings.extend(i18n_strings)
45
-
46
- code_keys = set(strings)
47
- logger.info(f"Total unique: {len(code_keys)}")
48
-
49
-
50
- standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
51
- with open(standard_file, "r", encoding="utf-8") as f:
52
- standard_data = json.load(f, object_pairs_hook=OrderedDict)
53
- standard_keys = set(standard_data.keys())
54
-
55
- # Define the standard file name
56
- unused_keys = standard_keys - code_keys
57
- logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
58
- for unused_key in unused_keys:
59
- logger.info(f"\t{unused_key}")
60
-
61
- missing_keys = code_keys - standard_keys
62
- logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
63
- for missing_key in missing_keys:
64
- logger.info(f"\t{missing_key}")
65
-
66
- code_keys_dict = OrderedDict()
67
- for s in strings:
68
- code_keys_dict[s] = s
69
-
70
- # write back
71
- with open(standard_file, "w", encoding="utf-8") as f:
72
- json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
73
- f.write("\n")
74
-
75
- logger.info(f"Updated {standard_file}")
76
-
77
-
78
- # Define the standard file name
79
- standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
80
-
81
- # Find all JSON files in the directory
82
- dir_path = I18N_FILE_PATH
83
- languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
84
-
85
- # Load the standard file
86
- with open(standard_file, "r", encoding="utf-8") as f:
87
- standard_data = json.load(f, object_pairs_hook=OrderedDict)
88
-
89
- # Loop through each language file
90
- for lang_file in languages:
91
- # Load the language file
92
- with open(lang_file, "r", encoding="utf-8") as f:
93
- lang_data = json.load(f, object_pairs_hook=OrderedDict)
94
-
95
- # Find the difference between the language file and the standard file
96
- diff = set(standard_data.keys()) - set(lang_data.keys())
97
-
98
- miss = set(lang_data.keys()) - set(standard_data.keys())
99
-
100
- # Add any missing keys to the language file
101
- for key in diff:
102
- lang_data[key] = "#!" + key
103
- logger.info(f"Added missing key: {key} to {lang_file}")
104
-
105
- # Del any extra keys to the language file
106
- for key in miss:
107
- del lang_data[key]
108
- logger.info(f"Del extra key: {key} from {lang_file}")
109
-
110
- # Sort the keys of the language file to match the order of the standard file
111
- lang_data = OrderedDict(
112
- sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
113
- )
114
-
115
- # Save the updated language file
116
- with open(lang_file, "w", encoding="utf-8") as f:
117
- json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
118
- f.write("\n")
119
-
120
- logger.info(f"Updated {lang_file}")
121
-
122
- logger.info("Done")
 
1
+ import ast
2
+ import glob
3
+ import json
4
+ from collections import OrderedDict
5
+ from pathlib import Path
6
+
7
+ from loguru import logger
8
+
9
+ from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
10
+
11
+
12
+ def extract_i18n_strings(node):
13
+ i18n_strings = []
14
+
15
+ if (
16
+ isinstance(node, ast.Call)
17
+ and isinstance(node.func, ast.Name)
18
+ and node.func.id == "i18n"
19
+ ):
20
+ for arg in node.args:
21
+ if isinstance(arg, ast.Str):
22
+ i18n_strings.append(arg.s)
23
+
24
+ for child_node in ast.iter_child_nodes(node):
25
+ i18n_strings.extend(extract_i18n_strings(child_node))
26
+
27
+ return i18n_strings
28
+
29
+
30
+ # scan the directory for all .py files (recursively)
31
+ # for each file, parse the code into an AST
32
+ # for each AST, extract the i18n strings
33
+
34
+ strings = []
35
+ folders = ["fish_speech", "tools"]
36
+ # for filename in glob.iglob("**/*.py", recursive=True):
37
+ for folder in folders:
38
+ for f in Path(folder).rglob("*.py"):
39
+ code = f.read_text(encoding="utf-8")
40
+ if "i18n(" in code:
41
+ tree = ast.parse(code)
42
+ i18n_strings = extract_i18n_strings(tree)
43
+ logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
44
+ strings.extend(i18n_strings)
45
+
46
+ code_keys = set(strings)
47
+ logger.info(f"Total unique: {len(code_keys)}")
48
+
49
+
50
+ standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
51
+ with open(standard_file, "r", encoding="utf-8") as f:
52
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
53
+ standard_keys = set(standard_data.keys())
54
+
55
+ # Define the standard file name
56
+ unused_keys = standard_keys - code_keys
57
+ logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
58
+ for unused_key in unused_keys:
59
+ logger.info(f"\t{unused_key}")
60
+
61
+ missing_keys = code_keys - standard_keys
62
+ logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
63
+ for missing_key in missing_keys:
64
+ logger.info(f"\t{missing_key}")
65
+
66
+ code_keys_dict = OrderedDict()
67
+ for s in strings:
68
+ code_keys_dict[s] = s
69
+
70
+ # write back
71
+ with open(standard_file, "w", encoding="utf-8") as f:
72
+ json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
73
+ f.write("\n")
74
+
75
+ logger.info(f"Updated {standard_file}")
76
+
77
+
78
+ # Define the standard file name
79
+ standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
80
+
81
+ # Find all JSON files in the directory
82
+ dir_path = I18N_FILE_PATH
83
+ languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
84
+
85
+ # Load the standard file
86
+ with open(standard_file, "r", encoding="utf-8") as f:
87
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
88
+
89
+ # Loop through each language file
90
+ for lang_file in languages:
91
+ # Load the language file
92
+ with open(lang_file, "r", encoding="utf-8") as f:
93
+ lang_data = json.load(f, object_pairs_hook=OrderedDict)
94
+
95
+ # Find the difference between the language file and the standard file
96
+ diff = set(standard_data.keys()) - set(lang_data.keys())
97
+
98
+ miss = set(lang_data.keys()) - set(standard_data.keys())
99
+
100
+ # Add any missing keys to the language file
101
+ for key in diff:
102
+ lang_data[key] = "#!" + key
103
+ logger.info(f"Added missing key: {key} to {lang_file}")
104
+
105
+ # Del any extra keys to the language file
106
+ for key in miss:
107
+ del lang_data[key]
108
+ logger.info(f"Del extra key: {key} from {lang_file}")
109
+
110
+ # Sort the keys of the language file to match the order of the standard file
111
+ lang_data = OrderedDict(
112
+ sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
113
+ )
114
+
115
+ # Save the updated language file
116
+ with open(lang_file, "w", encoding="utf-8") as f:
117
+ json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
118
+ f.write("\n")
119
+
120
+ logger.info(f"Updated {lang_file}")
121
+
122
+ logger.info("Done")
fish_speech/inference_engine/__init__.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import queue
3
+ from typing import Generator
4
+
5
+ import numpy as np
6
+ import torch
7
+ from loguru import logger
8
+
9
+ from fish_speech.inference_engine.reference_loader import ReferenceLoader
10
+ from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header
11
+ from fish_speech.inference_engine.vq_manager import VQManager
12
+ from fish_speech.models.dac.modded_dac import DAC
13
+ from fish_speech.models.text2semantic.inference import (
14
+ GenerateRequest,
15
+ GenerateResponse,
16
+ WrappedGenerateResponse,
17
+ )
18
+ from fish_speech.utils import autocast_exclude_mps, set_seed
19
+ from fish_speech.utils.schema import ServeTTSRequest
20
+
21
+
22
+ class TTSInferenceEngine(ReferenceLoader, VQManager):
23
+
24
+ def __init__(
25
+ self,
26
+ llama_queue: queue.Queue,
27
+ decoder_model: DAC,
28
+ precision: torch.dtype,
29
+ compile: bool,
30
+ ) -> None:
31
+
32
+ super().__init__()
33
+
34
+ self.llama_queue = llama_queue
35
+ self.decoder_model = decoder_model
36
+ self.precision = precision
37
+ self.compile = compile
38
+
39
+ @torch.inference_mode()
40
+ def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
41
+ """
42
+ Main inference function:
43
+ - Loads the reference audio and text.
44
+ - Calls the LLAMA model for inference.
45
+ - Decodes the VQ tokens to audio.
46
+ """
47
+
48
+ ref_id: str | None = req.reference_id
49
+ prompt_tokens, prompt_texts = [], []
50
+ # Load the reference audio and text based on id or hash
51
+ if ref_id is not None:
52
+ prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
53
+
54
+ elif req.references:
55
+ prompt_tokens, prompt_texts = self.load_by_hash(
56
+ req.references, req.use_memory_cache
57
+ )
58
+
59
+ # Set the random seed if provided
60
+ if req.seed is not None:
61
+ set_seed(req.seed)
62
+ logger.warning(f"set seed: {req.seed}")
63
+
64
+ # Get the symbolic tokens from the LLAMA model
65
+ response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
66
+
67
+ # Get the sample rate from the decoder model
68
+ if hasattr(self.decoder_model, "spec_transform"):
69
+ sample_rate = self.decoder_model.spec_transform.sample_rate
70
+ else:
71
+ sample_rate = self.decoder_model.sample_rate
72
+
73
+ # If streaming, send the header
74
+ if req.streaming:
75
+ yield InferenceResult(
76
+ code="header",
77
+ audio=(
78
+ sample_rate,
79
+ np.array(wav_chunk_header(sample_rate=sample_rate)),
80
+ ),
81
+ error=None,
82
+ )
83
+
84
+ segments = []
85
+
86
+ while True:
87
+ # Get the response from the LLAMA model
88
+ wrapped_result: WrappedGenerateResponse = response_queue.get()
89
+ if wrapped_result.status == "error":
90
+ yield InferenceResult(
91
+ code="error",
92
+ audio=None,
93
+ error=(
94
+ wrapped_result.response
95
+ if isinstance(wrapped_result.response, Exception)
96
+ else Exception("Unknown error")
97
+ ),
98
+ )
99
+ break
100
+
101
+ # Check the response type
102
+ if not isinstance(wrapped_result.response, GenerateResponse):
103
+ raise TypeError(
104
+ "Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
105
+ )
106
+
107
+ result: GenerateResponse = wrapped_result.response
108
+ if result.action != "next":
109
+ segment = self.get_audio_segment(result)
110
+
111
+ if req.streaming: # Used only by the API server
112
+ yield InferenceResult(
113
+ code="segment",
114
+ audio=(sample_rate, segment),
115
+ error=None,
116
+ )
117
+ segments.append(segment)
118
+ else:
119
+ break
120
+
121
+ # Clean up the memory
122
+ if torch.cuda.is_available():
123
+ torch.cuda.empty_cache()
124
+ gc.collect()
125
+
126
+ # Edge case: no audio generated
127
+ if len(segments) == 0:
128
+ yield InferenceResult(
129
+ code="error",
130
+ audio=None,
131
+ error=RuntimeError("No audio generated, please check the input text."),
132
+ )
133
+ else:
134
+ # Streaming or not, return the final audio
135
+ audio = np.concatenate(segments, axis=0)
136
+ yield InferenceResult(
137
+ code="final",
138
+ audio=(sample_rate, audio),
139
+ error=None,
140
+ )
141
+
142
+ return None
143
+
144
+ def send_Llama_request(
145
+ self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
146
+ ) -> queue.Queue:
147
+ """
148
+ Send a request to the LLAMA model to generate the symbolic tokens.
149
+ """
150
+
151
+ # Prepare the request
152
+ request = dict(
153
+ device=self.decoder_model.device,
154
+ max_new_tokens=req.max_new_tokens,
155
+ text=req.text,
156
+ top_p=req.top_p,
157
+ repetition_penalty=req.repetition_penalty,
158
+ temperature=req.temperature,
159
+ compile=self.compile,
160
+ iterative_prompt=req.chunk_length > 0,
161
+ chunk_length=req.chunk_length,
162
+ prompt_tokens=prompt_tokens,
163
+ prompt_text=prompt_texts,
164
+ )
165
+
166
+ # Create a queue to get the response
167
+ response_queue = queue.Queue()
168
+
169
+ # Send the request to the LLAMA model
170
+ self.llama_queue.put(
171
+ GenerateRequest(
172
+ request=request,
173
+ response_queue=response_queue,
174
+ )
175
+ )
176
+
177
+ return response_queue
178
+
179
+ def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
180
+ """
181
+ Decode the VQ tokens to audio.
182
+ """
183
+
184
+ # Don't use autocast on MPS devices
185
+ with autocast_exclude_mps(
186
+ device_type=self.decoder_model.device.type, dtype=self.precision
187
+ ):
188
+ # Decode the symbolic tokens to audio
189
+ segment = self.decode_vq_tokens(codes=result.codes)
190
+
191
+ # Convert the audio to numpy
192
+ return segment.float().cpu().numpy()
fish_speech/inference_engine/reference_loader.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from hashlib import sha256
3
+ from pathlib import Path
4
+ from typing import Callable, Literal, Tuple
5
+
6
+ import torch
7
+ import torchaudio
8
+ from loguru import logger
9
+
10
+ from fish_speech.models.dac.modded_dac import DAC
11
+ from fish_speech.utils.file import (
12
+ AUDIO_EXTENSIONS,
13
+ audio_to_bytes,
14
+ list_files,
15
+ read_ref_text,
16
+ )
17
+ from fish_speech.utils.schema import ServeReferenceAudio
18
+
19
+
20
+ class ReferenceLoader:
21
+
22
+ def __init__(self) -> None:
23
+ """
24
+ Component of the TTSInferenceEngine class.
25
+ Loads and manages the cache for the reference audio and text.
26
+ """
27
+ self.ref_by_id: dict = {}
28
+ self.ref_by_hash: dict = {}
29
+
30
+ # Make Pylance happy (attribut/method not defined...)
31
+ self.decoder_model: DAC
32
+ self.encode_reference: Callable
33
+
34
+ # Define the torchaudio backend
35
+ backends = torchaudio.list_audio_backends()
36
+ if "ffmpeg" in backends:
37
+ self.backend = "ffmpeg"
38
+ else:
39
+ self.backend = "soundfile"
40
+
41
+ def load_by_id(
42
+ self,
43
+ id: str,
44
+ use_cache: Literal["on", "off"],
45
+ ) -> Tuple:
46
+
47
+ # Load the references audio and text by id
48
+ ref_folder = Path("references") / id
49
+ ref_folder.mkdir(parents=True, exist_ok=True)
50
+ ref_audios = list_files(
51
+ ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
52
+ )
53
+
54
+ if use_cache == "off" or id not in self.ref_by_id:
55
+ # If the references are not already loaded, encode them
56
+ prompt_tokens = [
57
+ self.encode_reference(
58
+ # decoder_model=self.decoder_model,
59
+ reference_audio=audio_to_bytes(str(ref_audio)),
60
+ enable_reference_audio=True,
61
+ )
62
+ for ref_audio in ref_audios
63
+ ]
64
+ prompt_texts = [
65
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
66
+ for ref_audio in ref_audios
67
+ ]
68
+ self.ref_by_id[id] = (prompt_tokens, prompt_texts)
69
+
70
+ else:
71
+ # Reuse already encoded references
72
+ logger.info("Use same references")
73
+ prompt_tokens, prompt_texts = self.ref_by_id[id]
74
+
75
+ return prompt_tokens, prompt_texts
76
+
77
+ def load_by_hash(
78
+ self,
79
+ references: list[ServeReferenceAudio],
80
+ use_cache: Literal["on", "off"],
81
+ ) -> Tuple:
82
+
83
+ # Load the references audio and text by hash
84
+ audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
85
+
86
+ cache_used = False
87
+ prompt_tokens, prompt_texts = [], []
88
+ for i, ref in enumerate(references):
89
+ if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
90
+ # If the references are not already loaded, encode them
91
+ prompt_tokens.append(
92
+ self.encode_reference(
93
+ reference_audio=ref.audio,
94
+ enable_reference_audio=True,
95
+ )
96
+ )
97
+ prompt_texts.append(ref.text)
98
+ self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
99
+
100
+ else:
101
+ # Reuse already encoded references
102
+ prompt_tokens, prompt_texts = self.ref_by_hash[audio_hashes[i]]
103
+ cache_used = True
104
+
105
+ if cache_used:
106
+ logger.info("Use same references")
107
+
108
+ return prompt_tokens, prompt_texts
109
+
110
+ def load_audio(self, reference_audio, sr):
111
+ """
112
+ Load the audio data from a file or bytes.
113
+ """
114
+ if len(reference_audio) > 255 or not Path(reference_audio).exists():
115
+ audio_data = reference_audio
116
+ reference_audio = io.BytesIO(audio_data)
117
+
118
+ waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
119
+
120
+ if waveform.shape[0] > 1:
121
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
122
+
123
+ if original_sr != sr:
124
+ resampler = torchaudio.transforms.Resample(
125
+ orig_freq=original_sr, new_freq=sr
126
+ )
127
+ waveform = resampler(waveform)
128
+
129
+ audio = waveform.squeeze().numpy()
130
+ return audio
fish_speech/inference_engine/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import wave
3
+ from dataclasses import dataclass
4
+ from typing import Literal, Optional, Tuple
5
+
6
+ import numpy as np
7
+
8
+
9
+ @dataclass
10
+ class InferenceResult:
11
+ code: Literal["header", "segment", "error", "final"]
12
+ audio: Optional[Tuple[int, np.ndarray]]
13
+ error: Optional[Exception]
14
+
15
+
16
+ def wav_chunk_header(
17
+ sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
18
+ ) -> bytes:
19
+ buffer = io.BytesIO()
20
+
21
+ with wave.open(buffer, "wb") as wav_file:
22
+ wav_file.setnchannels(channels)
23
+ wav_file.setsampwidth(bit_depth // 8)
24
+ wav_file.setframerate(sample_rate)
25
+
26
+ wav_header_bytes = buffer.getvalue()
27
+ buffer.close()
28
+
29
+ return wav_header_bytes
fish_speech/inference_engine/vq_manager.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ from loguru import logger
5
+
6
+ from fish_speech.models.dac.modded_dac import DAC
7
+
8
+
9
+ class VQManager:
10
+
11
+ def __init__(self):
12
+ # Make Pylance happy (attribut/method not defined...)
13
+ self.decoder_model: DAC
14
+ self.load_audio: Callable
15
+
16
+ def decode_vq_tokens(self, codes):
17
+ feature_lengths = torch.tensor(
18
+ [codes.shape[1]], device=self.decoder_model.device
19
+ )
20
+ logger.info(f"VQ features: {codes.shape}")
21
+
22
+ if isinstance(self.decoder_model, DAC):
23
+ return self.decoder_model.decode(
24
+ indices=codes[None],
25
+ feature_lengths=feature_lengths,
26
+ )[0].squeeze()
27
+
28
+ raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
29
+
30
+ def encode_reference(self, reference_audio, enable_reference_audio):
31
+ if enable_reference_audio and reference_audio is not None:
32
+ # Load audios, and prepare basic info here
33
+ if hasattr(self.decoder_model, "spec_transform"):
34
+ sample_rate = self.decoder_model.spec_transform.sample_rate
35
+ else:
36
+ sample_rate = self.decoder_model.sample_rate
37
+ reference_audio_content = self.load_audio(reference_audio, sample_rate)
38
+
39
+ audios = torch.from_numpy(reference_audio_content).to(
40
+ self.decoder_model.device
41
+ )[None, None, :]
42
+ audio_lengths = torch.tensor(
43
+ [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
44
+ )
45
+ logger.info(
46
+ f"Loaded audio with {audios.shape[2] / sample_rate:.2f} seconds"
47
+ )
48
+
49
+ # VQ Encoder
50
+ if isinstance(self.decoder_model, DAC):
51
+ prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
52
+ logger.info(f"Encoded prompt: {prompt_tokens.shape}")
53
+ else:
54
+ raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
55
+ else:
56
+ prompt_tokens = None
57
+ logger.info("No reference audio provided")
58
+
59
+ return prompt_tokens
fish_speech/models/dac/__init__.py ADDED
File without changes
fish_speech/models/dac/inference.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import click
4
+ import hydra
5
+ import numpy as np
6
+ import pyrootutils
7
+ import soundfile as sf
8
+ import torch
9
+ import torchaudio
10
+ from hydra import compose, initialize
11
+ from hydra.utils import instantiate
12
+ from loguru import logger
13
+ from omegaconf import OmegaConf
14
+
15
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
16
+
17
+ from fish_speech.utils.file import AUDIO_EXTENSIONS
18
+
19
+ # register eval resolver
20
+ OmegaConf.register_new_resolver("eval", eval)
21
+
22
+
23
+ def load_model(config_name, checkpoint_path, device="cuda"):
24
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
25
+ with initialize(version_base="1.3", config_path="../../configs"):
26
+ cfg = compose(config_name=config_name)
27
+
28
+ model = instantiate(cfg)
29
+ state_dict = torch.load(
30
+ checkpoint_path, map_location=device, mmap=True, weights_only=True
31
+ )
32
+ if "state_dict" in state_dict:
33
+ state_dict = state_dict["state_dict"]
34
+
35
+ if any("generator" in k for k in state_dict):
36
+ state_dict = {
37
+ k.replace("generator.", ""): v
38
+ for k, v in state_dict.items()
39
+ if "generator." in k
40
+ }
41
+
42
+ result = model.load_state_dict(state_dict, strict=False, assign=True)
43
+ model.eval()
44
+ model.to(device)
45
+
46
+ logger.info(f"Loaded model: {result}")
47
+ return model
48
+
49
+
50
+ @torch.no_grad()
51
+ @click.command()
52
+ @click.option(
53
+ "--input-path",
54
+ "-i",
55
+ default="test.wav",
56
+ type=click.Path(exists=True, path_type=Path),
57
+ )
58
+ @click.option(
59
+ "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
60
+ )
61
+ @click.option("--config-name", default="modded_dac_vq")
62
+ @click.option(
63
+ "--checkpoint-path",
64
+ default="checkpoints/openaudio-s1-mini/codec.pth",
65
+ )
66
+ @click.option(
67
+ "--device",
68
+ "-d",
69
+ default="cuda",
70
+ )
71
+ def main(input_path, output_path, config_name, checkpoint_path, device):
72
+ model = load_model(config_name, checkpoint_path, device=device)
73
+
74
+ if input_path.suffix in AUDIO_EXTENSIONS:
75
+ logger.info(f"Processing in-place reconstruction of {input_path}")
76
+
77
+ # Load audio
78
+ audio, sr = torchaudio.load(str(input_path))
79
+ if audio.shape[0] > 1:
80
+ audio = audio.mean(0, keepdim=True)
81
+ audio = torchaudio.functional.resample(audio, sr, model.sample_rate)
82
+
83
+ audios = audio[None].to(device)
84
+ logger.info(
85
+ f"Loaded audio with {audios.shape[2] / model.sample_rate:.2f} seconds"
86
+ )
87
+
88
+ # VQ Encoder
89
+ audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
90
+ indices, indices_lens = model.encode(audios, audio_lengths)
91
+
92
+ if indices.ndim == 3:
93
+ indices = indices[0]
94
+
95
+ logger.info(f"Generated indices of shape {indices.shape}")
96
+
97
+ # Save indices
98
+ np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
99
+ elif input_path.suffix == ".npy":
100
+ logger.info(f"Processing precomputed indices from {input_path}")
101
+ indices = np.load(input_path)
102
+ indices = torch.from_numpy(indices).to(device).long()
103
+ assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
104
+ indices_lens = torch.tensor([indices.shape[1]], device=device, dtype=torch.long)
105
+ else:
106
+ raise ValueError(f"Unknown input type: {input_path}")
107
+
108
+ # Restore
109
+ fake_audios, audio_lengths = model.decode(indices, indices_lens)
110
+ audio_time = fake_audios.shape[-1] / model.sample_rate
111
+
112
+ logger.info(
113
+ f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
114
+ )
115
+
116
+ # Save audio
117
+ fake_audio = fake_audios[0, 0].float().cpu().numpy()
118
+ sf.write(output_path, fake_audio, model.sample_rate)
119
+ logger.info(f"Saved audio to {output_path}")
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()
fish_speech/models/dac/modded_dac.py ADDED
@@ -0,0 +1,1024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing as tp
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Union
5
+
6
+ import hydra
7
+ import librosa
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import torch
11
+ from audiotools import AudioSignal
12
+ from audiotools.ml import BaseModel
13
+ from dac.model.base import CodecMixin
14
+ from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
15
+ from omegaconf import OmegaConf
16
+ from torch import Tensor, nn
17
+ from torch.nn import functional as F
18
+ from torch.nn.utils.parametrizations import weight_norm
19
+ from torch.nn.utils.parametrize import remove_parametrizations
20
+
21
+
22
+ @dataclass
23
+ class VQResult:
24
+ z: torch.Tensor
25
+ codes: torch.Tensor
26
+ latents: torch.Tensor
27
+ codebook_loss: torch.Tensor
28
+ commitment_loss: torch.Tensor
29
+ semantic_distill_z: torch.Tensor | None = None
30
+
31
+
32
+ def find_multiple(n: int, k: int) -> int:
33
+ if n % k == 0:
34
+ return n
35
+ return n + k - (n % k)
36
+
37
+
38
+ @dataclass
39
+ class ModelArgs:
40
+ block_size: int = 2048
41
+ n_layer: int = 8
42
+ n_head: int = 8
43
+ dim: int = 512
44
+ intermediate_size: int = 1536
45
+ n_local_heads: int = -1
46
+ head_dim: int = 64
47
+ rope_base: float = 10000
48
+ norm_eps: float = 1e-5
49
+ dropout_rate: float = 0.1
50
+ attn_dropout_rate: float = 0.1
51
+ channels_first: bool = True # to be compatible with conv1d input/output
52
+ pos_embed_type: str = "rope" # can be "rope" or "conformer"
53
+ max_relative_position: int = 128 # for conformer-style relative position embedding
54
+
55
+ def __post_init__(self):
56
+ if self.n_local_heads == -1:
57
+ self.n_local_heads = self.n_head
58
+ if self.intermediate_size is None:
59
+ hidden_dim = 4 * self.dim
60
+ n_hidden = int(2 * hidden_dim / 3)
61
+ self.intermediate_size = find_multiple(n_hidden, 256)
62
+ assert self.pos_embed_type in [
63
+ "rope",
64
+ "conformer",
65
+ ], "pos_embed_type must be either 'rope' or 'conformer'"
66
+
67
+
68
+ class KVCache(nn.Module):
69
+ def __init__(
70
+ self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
71
+ ):
72
+ super().__init__()
73
+ cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
74
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
75
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
76
+
77
+ def update(self, input_pos, k_val, v_val):
78
+ # input_pos: [S], k_val: [B, H, S, D]
79
+ assert input_pos.shape[0] == k_val.shape[2]
80
+
81
+ k_out = self.k_cache
82
+ v_out = self.v_cache
83
+ k_out[:, :, input_pos] = k_val
84
+ v_out[:, :, input_pos] = v_val
85
+
86
+ return (
87
+ k_out[:, :, : input_pos.max() + 1, :],
88
+ v_out[:, :, : input_pos.max() + 1, :],
89
+ )
90
+
91
+ def clear_cache(self, prompt_len):
92
+ self.k_cache[:, :, prompt_len:, :].fill_(0)
93
+ self.v_cache[:, :, prompt_len:, :].fill_(0)
94
+
95
+
96
+ class Transformer(nn.Module):
97
+ def __init__(self, config: ModelArgs) -> None:
98
+ super().__init__()
99
+ self.config = config
100
+
101
+ self.layers = nn.ModuleList(
102
+ TransformerBlock(config) for _ in range(config.n_layer)
103
+ )
104
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
105
+
106
+ # Only compute RoPE frequencies if using RoPE
107
+ if config.pos_embed_type == "rope":
108
+ freqs_cis = precompute_freqs_cis(
109
+ self.config.block_size, self.config.head_dim, self.config.rope_base
110
+ )
111
+ self.register_buffer("freqs_cis", freqs_cis)
112
+ else:
113
+ self.register_buffer("freqs_cis", None)
114
+
115
+ causal_mask = torch.tril(
116
+ torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool)
117
+ )
118
+ self.register_buffer("causal_mask", causal_mask)
119
+
120
+ self.max_batch_size = -1
121
+ self.max_seq_length = -1
122
+ self.use_kv_cache = False
123
+
124
+ def setup_caches(self, max_batch_size, max_seq_length):
125
+ """
126
+ This method will only be called during inference when using KV cache.
127
+ """
128
+ head_dim = self.config.dim // self.config.n_head
129
+ max_seq_length = find_multiple(max_seq_length, 8)
130
+ self.max_seq_length = max_seq_length
131
+ self.max_batch_size = max_batch_size
132
+ dtype = self.norm.weight.dtype
133
+ device = self.norm.weight.device
134
+
135
+ for b in self.layers:
136
+ b.attention.kv_cache = KVCache(
137
+ max_batch_size,
138
+ max_seq_length,
139
+ self.config.n_local_heads,
140
+ head_dim,
141
+ dtype,
142
+ ).to(device)
143
+
144
+ self.use_kv_cache = True
145
+
146
+ def forward(
147
+ self,
148
+ x: Tensor,
149
+ input_pos: Optional[Tensor] = None,
150
+ mask: Optional[Tensor] = None,
151
+ ) -> Tensor:
152
+ if self.config.pos_embed_type == "rope":
153
+ assert (
154
+ self.freqs_cis is not None
155
+ ), "RoPE frequencies must be initialized for RoPE positional embedding"
156
+ freqs_cis = self.freqs_cis[input_pos]
157
+ else:
158
+ freqs_cis = None
159
+
160
+ if mask is None: # in case of non-causal model
161
+ if not self.training and self.use_kv_cache:
162
+ mask = self.causal_mask[None, None, input_pos]
163
+ mask = mask[..., : input_pos.max() + 1]
164
+ else:
165
+ mask = self.causal_mask[None, None, input_pos]
166
+ mask = mask[..., input_pos]
167
+
168
+ for i, layer in enumerate(self.layers):
169
+ x = layer(x, input_pos, freqs_cis, mask)
170
+ x = self.norm(x)
171
+ return x
172
+
173
+
174
+ class TransformerBlock(nn.Module):
175
+ def __init__(self, config: ModelArgs) -> None:
176
+ super().__init__()
177
+ self.attention = Attention(config)
178
+ self.feed_forward = FeedForward(config)
179
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
180
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
181
+ self.attention_layer_scale = LayerScale(config.dim, inplace=True)
182
+ self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
183
+
184
+ def forward(
185
+ self,
186
+ x: Tensor,
187
+ input_pos: Tensor,
188
+ freqs_cis: Tensor,
189
+ mask: Tensor,
190
+ ) -> Tensor:
191
+ h = x + self.attention_layer_scale(
192
+ self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
193
+ )
194
+ out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
195
+ return out
196
+
197
+
198
+ class Attention(nn.Module):
199
+ def __init__(self, config: ModelArgs):
200
+ super().__init__()
201
+ assert config.dim % config.n_head == 0
202
+
203
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
204
+ # key, query, value projections for all heads, but in a batch
205
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
206
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
207
+ self.kv_cache = None
208
+
209
+ self.n_head = config.n_head
210
+ self.head_dim = config.head_dim
211
+ self.n_local_heads = config.n_local_heads
212
+ self.dim = config.dim
213
+ self.attn_dropout_rate = config.attn_dropout_rate
214
+ self.pos_embed_type = config.pos_embed_type
215
+
216
+ # Add relative position embedding for conformer-style
217
+ if self.pos_embed_type == "conformer":
218
+ self.max_relative_position = config.max_relative_position
219
+ num_pos_embeddings = 2 * config.max_relative_position + 1
220
+ self.rel_pos_embeddings = nn.Parameter(
221
+ torch.zeros(num_pos_embeddings, self.head_dim)
222
+ )
223
+ nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
224
+
225
+ def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
226
+ # q: [B, H, S, D]
227
+ # Returns: [B, H, S, S]
228
+ positions = torch.arange(seqlen, device=q.device)
229
+ relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
230
+ relative_positions = torch.clamp(
231
+ relative_positions + self.max_relative_position,
232
+ 0,
233
+ 2 * self.max_relative_position,
234
+ )
235
+ rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
236
+
237
+ # Compute attention scores with relative position embeddings
238
+ q = q.transpose(1, 2) # [B, S, H, D]
239
+ rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
240
+ rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
241
+ return rel_logits
242
+
243
+ def forward(
244
+ self,
245
+ x: Tensor,
246
+ freqs_cis: Tensor,
247
+ mask: Tensor,
248
+ input_pos: Optional[Tensor] = None,
249
+ ) -> Tensor:
250
+ bsz, seqlen, _ = x.shape
251
+
252
+ kv_size = self.n_local_heads * self.head_dim
253
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
254
+ context_seqlen = seqlen
255
+
256
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
257
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
258
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
259
+
260
+ if self.pos_embed_type == "rope":
261
+ q = apply_rotary_emb(q, freqs_cis)
262
+ k = apply_rotary_emb(k, freqs_cis)
263
+
264
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
265
+
266
+ if self.kv_cache is not None:
267
+ k, v = self.kv_cache.update(input_pos, k, v)
268
+
269
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
270
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
271
+
272
+ if self.pos_embed_type == "conformer":
273
+ # Compute attention scores
274
+ scale = 1.0 / math.sqrt(self.head_dim)
275
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
276
+
277
+ # Add relative position embeddings for conformer-style
278
+ rel_scores = self._compute_conformer_pos_scores(q, seqlen)
279
+ scores = scores + rel_scores
280
+
281
+ # Apply attention
282
+ if mask is not None:
283
+ scores = scores.masked_fill(~mask, float("-inf"))
284
+
285
+ attn = F.softmax(scores, dim=-1)
286
+ if self.attn_dropout_rate > 0 and self.training:
287
+ attn = F.dropout(attn, p=self.attn_dropout_rate)
288
+
289
+ y = torch.matmul(attn, v)
290
+ else:
291
+ y = F.scaled_dot_product_attention(
292
+ q,
293
+ k,
294
+ v,
295
+ dropout_p=self.attn_dropout_rate if self.training else 0.0,
296
+ attn_mask=mask,
297
+ )
298
+ # is_causal=True)
299
+ y = (
300
+ y.transpose(1, 2)
301
+ .contiguous()
302
+ .view(bsz, seqlen, self.head_dim * self.n_head)
303
+ )
304
+ y = self.wo(y)
305
+ return y
306
+
307
+
308
+ class FeedForward(nn.Module):
309
+ def __init__(self, config: ModelArgs) -> None:
310
+ super().__init__()
311
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
312
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
313
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
314
+ self.dropout = nn.Dropout(config.dropout_rate)
315
+
316
+ def forward(self, x: Tensor) -> Tensor:
317
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
318
+
319
+
320
+ class RMSNorm(nn.Module):
321
+ def __init__(self, dim: int, eps: float = 1e-5):
322
+ super().__init__()
323
+ self.eps = eps
324
+ self.weight = nn.Parameter(torch.ones(dim))
325
+
326
+ def _norm(self, x):
327
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
328
+
329
+ def forward(self, x: Tensor) -> Tensor:
330
+ output = self._norm(x.float()).type_as(x)
331
+ return output * self.weight
332
+
333
+
334
+ class LayerScale(nn.Module):
335
+ def __init__(
336
+ self,
337
+ dim: int,
338
+ init_values: Union[float, Tensor] = 1e-2,
339
+ inplace: bool = False,
340
+ ) -> None:
341
+ super().__init__()
342
+ self.inplace = inplace
343
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
344
+
345
+ def forward(self, x: Tensor) -> Tensor:
346
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
347
+
348
+
349
+ class WindowLimitedTransformer(Transformer):
350
+ """
351
+ Transformer with window limited attention, causal.
352
+ """
353
+
354
+ def __init__(
355
+ self,
356
+ config: ModelArgs,
357
+ input_dim: int = 512,
358
+ window_size: Optional[int] = None,
359
+ causal: bool = True,
360
+ look_ahead_conv: nn.Module = None,
361
+ ):
362
+ super().__init__(config)
363
+ self.window_size = window_size
364
+ self.causal = causal
365
+ self.channels_first = config.channels_first
366
+ self.look_ahead_conv = (
367
+ look_ahead_conv if look_ahead_conv is not None else nn.Identity()
368
+ )
369
+ self.input_proj = (
370
+ nn.Linear(input_dim, config.dim)
371
+ if input_dim != config.dim
372
+ else nn.Identity()
373
+ )
374
+ self.output_proj = (
375
+ nn.Linear(config.dim, input_dim)
376
+ if input_dim != config.dim
377
+ else nn.Identity()
378
+ )
379
+
380
+ def make_window_limited_mask(
381
+ self,
382
+ max_length: int,
383
+ x_lens: Optional[Tensor] = None,
384
+ ) -> Tensor:
385
+ """
386
+ Make mask to form window limited attention.
387
+ """
388
+ if self.causal:
389
+ mask = torch.tril(torch.ones(max_length, max_length))
390
+ row_indices = torch.arange(max_length).view(-1, 1)
391
+ window_size = self.window_size or max_length
392
+ valid_range = (row_indices - window_size + 1).clamp(min=0)
393
+ column_indices = torch.arange(max_length)
394
+ mask = (column_indices >= valid_range) & mask.bool()
395
+ else:
396
+ raise NotImplementedError
397
+ mask = mask.bool()[None, None]
398
+ return mask
399
+
400
+ def make_mask(
401
+ self,
402
+ max_length: int,
403
+ x_lens: Optional[Tensor] = None,
404
+ ) -> Tensor:
405
+ """
406
+ Make ordinary mask if window size is not specified.
407
+ """
408
+ if self.causal:
409
+ mask = torch.tril(torch.ones(max_length, max_length))
410
+ else:
411
+ mask = torch.ones(max_length, max_length)
412
+ mask = mask.bool()[None, None]
413
+ for i, x_len in enumerate(x_lens):
414
+ mask[:x_len, i] = 0
415
+ mask = mask.bool()[None, None]
416
+ return mask
417
+
418
+ def forward(
419
+ self,
420
+ x: Tensor,
421
+ x_lens: Optional[Tensor] = None,
422
+ ) -> Tensor:
423
+ if self.channels_first:
424
+ x = x.transpose(1, 2)
425
+ x = self.input_proj(x) # (B, T, D)
426
+ x = self.look_ahead_conv(x)
427
+ input_pos = torch.arange(x.shape[1], device=x.device)
428
+ # construct mask to form window limited attention
429
+ max_length = x.shape[1]
430
+ if self.window_size is not None:
431
+ mask = self.make_window_limited_mask(max_length, x_lens)
432
+ else:
433
+ mask = self.make_mask(max_length, x_lens)
434
+ mask = mask.to(x.device)
435
+ x = super().forward(x, input_pos, mask)
436
+ x = self.output_proj(x) # (B, T, D)
437
+ if self.channels_first:
438
+ x = x.transpose(1, 2)
439
+ return x
440
+
441
+
442
+ def precompute_freqs_cis(
443
+ seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
444
+ ) -> Tensor:
445
+ freqs = 1.0 / (
446
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
447
+ )
448
+ t = torch.arange(seq_len, device=freqs.device)
449
+ freqs = torch.outer(t, freqs)
450
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
451
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
452
+ return cache.to(dtype=dtype)
453
+
454
+
455
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
456
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
457
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
458
+ x_out2 = torch.stack(
459
+ [
460
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
461
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
462
+ ],
463
+ -1,
464
+ )
465
+
466
+ x_out2 = x_out2.flatten(3)
467
+ return x_out2.type_as(x)
468
+
469
+
470
+ def init_weights(m):
471
+ if isinstance(m, nn.Conv1d):
472
+ nn.init.trunc_normal_(m.weight, std=0.02)
473
+ nn.init.constant_(m.bias, 0)
474
+
475
+
476
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
477
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
478
+ padding_left, padding_right = paddings
479
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
480
+ assert (padding_left + padding_right) <= x.shape[-1]
481
+ end = x.shape[-1] - padding_right
482
+ return x[..., padding_left:end]
483
+
484
+
485
+ def get_extra_padding_for_conv1d(
486
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
487
+ ) -> int:
488
+ """See `pad_for_conv1d`."""
489
+ length = x.shape[-1]
490
+ n_frames = (length - kernel_size + padding_total) / stride + 1
491
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
492
+ return ideal_length - length
493
+
494
+
495
+ def pad1d(
496
+ x: torch.Tensor,
497
+ paddings: tp.Tuple[int, int],
498
+ mode: str = "zeros",
499
+ value: float = 0.0,
500
+ ):
501
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
502
+ If this is the case, we insert extra 0 padding to the right
503
+ before the reflection happen.
504
+ """
505
+ length = x.shape[-1]
506
+ padding_left, padding_right = paddings
507
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
508
+ if mode == "reflect":
509
+ max_pad = max(padding_left, padding_right)
510
+ extra_pad = 0
511
+ if length <= max_pad:
512
+ extra_pad = max_pad - length + 1
513
+ x = F.pad(x, (0, extra_pad))
514
+ padded = F.pad(x, paddings, mode, value)
515
+ end = padded.shape[-1] - extra_pad
516
+ return padded[..., :end]
517
+ else:
518
+ return F.pad(x, paddings, mode, value)
519
+
520
+
521
+ class CausalConvNet(nn.Module):
522
+ def __init__(
523
+ self,
524
+ in_channels,
525
+ out_channels,
526
+ kernel_size,
527
+ dilation=1,
528
+ stride=1,
529
+ groups=1,
530
+ padding=None,
531
+ ):
532
+ super(CausalConvNet, self).__init__()
533
+ self.conv = nn.Conv1d(
534
+ in_channels,
535
+ out_channels,
536
+ kernel_size,
537
+ stride=stride,
538
+ dilation=dilation,
539
+ groups=groups,
540
+ )
541
+ self.stride = stride
542
+ self.kernel_size = (kernel_size - 1) * dilation + 1
543
+ self.dilation = dilation
544
+ self.padding = self.kernel_size - self.stride
545
+
546
+ def forward(self, x):
547
+ pad = self.padding
548
+ extra_padding = get_extra_padding_for_conv1d(
549
+ x, self.kernel_size, self.stride, pad
550
+ )
551
+ x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
552
+ return self.conv(x).contiguous()
553
+
554
+ def weight_norm(self, name="weight", dim=0):
555
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
556
+ return self
557
+
558
+ def remove_weight_norm(self):
559
+ self.conv = remove_parametrizations(self.conv)
560
+ return self
561
+
562
+
563
+ class CausalTransConvNet(nn.Module):
564
+ def __init__(
565
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
566
+ ):
567
+ super(CausalTransConvNet, self).__init__()
568
+ self.conv = nn.ConvTranspose1d(
569
+ in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
570
+ )
571
+ self.stride = stride
572
+ self.kernel_size = kernel_size
573
+
574
+ def forward(self, x):
575
+ x = self.conv(x)
576
+ pad = self.kernel_size - self.stride
577
+ padding_right = math.ceil(pad)
578
+ padding_left = pad - padding_right
579
+ x = unpad1d(x, (padding_left, padding_right))
580
+ return x.contiguous()
581
+
582
+ def weight_norm(self, name="weight", dim=0):
583
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
584
+ return self
585
+
586
+ def remove_weight_norm(self):
587
+ self.conv = remove_parametrizations(self.conv)
588
+ return self
589
+
590
+
591
+ def CausalWNConv1d(*args, **kwargs):
592
+ return CausalConvNet(*args, **kwargs).weight_norm()
593
+
594
+
595
+ def CausalWNConvTranspose1d(*args, **kwargs):
596
+ return CausalTransConvNet(*args, **kwargs).weight_norm()
597
+
598
+
599
+ class ResidualUnit(nn.Module):
600
+ def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
601
+ super().__init__()
602
+ conv_class = CausalWNConv1d if causal else WNConv1d
603
+ pad = ((7 - 1) * dilation) // 2
604
+ self.block = nn.Sequential(
605
+ Snake1d(dim),
606
+ conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
607
+ Snake1d(dim),
608
+ conv_class(dim, dim, kernel_size=1),
609
+ )
610
+ self.causal = causal
611
+
612
+ def forward(self, x):
613
+ y = self.block(x)
614
+ pad = x.shape[-1] - y.shape[-1]
615
+ if pad > 0:
616
+ if self.causal:
617
+ x = x[..., :-pad]
618
+ else:
619
+ x = x[..., pad // 2 : -pad // 2]
620
+ return x + y
621
+
622
+
623
+ class EncoderBlock(nn.Module):
624
+ def __init__(
625
+ self,
626
+ dim: int = 16,
627
+ stride: int = 1,
628
+ causal: bool = False,
629
+ n_t_layer: int = 0,
630
+ transformer_general_config=None,
631
+ ):
632
+ super().__init__()
633
+ conv_class = CausalWNConv1d if causal else WNConv1d
634
+ transformer_module = (
635
+ nn.Identity()
636
+ if n_t_layer == 0
637
+ else (
638
+ WindowLimitedTransformer(
639
+ causal=causal,
640
+ input_dim=dim,
641
+ window_size=512,
642
+ config=transformer_general_config(
643
+ n_layer=n_t_layer,
644
+ n_head=dim // 64,
645
+ dim=dim,
646
+ intermediate_size=dim * 3,
647
+ ),
648
+ )
649
+ )
650
+ )
651
+ self.block = nn.Sequential(
652
+ ResidualUnit(dim // 2, dilation=1, causal=causal),
653
+ ResidualUnit(dim // 2, dilation=3, causal=causal),
654
+ ResidualUnit(dim // 2, dilation=9, causal=causal),
655
+ Snake1d(dim // 2),
656
+ conv_class(
657
+ dim // 2,
658
+ dim,
659
+ kernel_size=2 * stride,
660
+ stride=stride,
661
+ padding=math.ceil(stride / 2),
662
+ ),
663
+ transformer_module,
664
+ )
665
+
666
+ def forward(self, x):
667
+ return self.block(x)
668
+
669
+
670
+ class Encoder(nn.Module):
671
+ def __init__(
672
+ self,
673
+ d_model: int = 64,
674
+ strides: list = [2, 4, 8, 8],
675
+ d_latent: int = 64,
676
+ n_transformer_layers: list = [0, 0, 4, 4],
677
+ transformer_general_config: ModelArgs = None,
678
+ causal: bool = False,
679
+ ):
680
+ super().__init__()
681
+ conv_class = CausalWNConv1d if causal else WNConv1d
682
+ # Create first convolution
683
+ self.block = [conv_class(1, d_model, kernel_size=7, padding=3)]
684
+
685
+ # Create EncoderBlocks that double channels as they downsample by `stride`
686
+ for stride, n_t_layer in zip(strides, n_transformer_layers):
687
+ d_model *= 2
688
+ self.block += [
689
+ EncoderBlock(
690
+ d_model,
691
+ stride=stride,
692
+ causal=causal,
693
+ n_t_layer=n_t_layer,
694
+ transformer_general_config=transformer_general_config,
695
+ )
696
+ ]
697
+
698
+ # Create last convolution
699
+ self.block += [
700
+ Snake1d(d_model),
701
+ conv_class(d_model, d_latent, kernel_size=3, padding=1),
702
+ ]
703
+
704
+ # Wrap black into nn.Sequential
705
+ self.block = nn.Sequential(*self.block)
706
+ self.enc_dim = d_model
707
+
708
+ def forward(self, x):
709
+ return self.block(x)
710
+
711
+
712
+ class DecoderBlock(nn.Module):
713
+ def __init__(
714
+ self,
715
+ input_dim: int = 16,
716
+ output_dim: int = 8,
717
+ stride: int = 1,
718
+ causal: bool = False,
719
+ n_t_layer: int = 0,
720
+ transformer_general_config=None,
721
+ ):
722
+ super().__init__()
723
+ conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
724
+ transformer_module = (
725
+ nn.Identity()
726
+ if n_t_layer == 0
727
+ else (
728
+ WindowLimitedTransformer(
729
+ causal=causal,
730
+ input_dim=input_dim,
731
+ window_size=None,
732
+ config=transformer_general_config(
733
+ n_layer=n_t_layer,
734
+ n_head=input_dim // 64,
735
+ dim=input_dim,
736
+ intermediate_size=input_dim * 3,
737
+ ),
738
+ )
739
+ )
740
+ )
741
+ self.block = nn.Sequential(
742
+ # transformer_module,
743
+ Snake1d(input_dim),
744
+ conv_trans_class(
745
+ input_dim,
746
+ output_dim,
747
+ kernel_size=2 * stride,
748
+ stride=stride,
749
+ padding=math.ceil(stride / 2),
750
+ ),
751
+ ResidualUnit(output_dim, dilation=1, causal=causal),
752
+ ResidualUnit(output_dim, dilation=3, causal=causal),
753
+ ResidualUnit(output_dim, dilation=9, causal=causal),
754
+ )
755
+
756
+ def forward(self, x):
757
+ return self.block(x)
758
+
759
+
760
+ class Decoder(nn.Module):
761
+ def __init__(
762
+ self,
763
+ input_channel,
764
+ channels,
765
+ rates,
766
+ d_out: int = 1,
767
+ causal: bool = False,
768
+ n_transformer_layers: list = [0, 0, 0, 0],
769
+ transformer_general_config=None,
770
+ ):
771
+ super().__init__()
772
+ conv_class = CausalWNConv1d if causal else WNConv1d
773
+ # Add first conv layer
774
+ layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
775
+
776
+ # Add upsampling + MRF blocks
777
+ for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
778
+ input_dim = channels // 2**i
779
+ output_dim = channels // 2 ** (i + 1)
780
+ layers += [
781
+ DecoderBlock(
782
+ input_dim,
783
+ output_dim,
784
+ stride,
785
+ causal=causal,
786
+ n_t_layer=n_t_layer,
787
+ transformer_general_config=transformer_general_config,
788
+ )
789
+ ]
790
+
791
+ # Add final conv layer
792
+ layers += [
793
+ Snake1d(output_dim),
794
+ conv_class(output_dim, d_out, kernel_size=7, padding=3),
795
+ nn.Tanh(),
796
+ ]
797
+
798
+ self.model = nn.Sequential(*layers)
799
+
800
+ def forward(self, x):
801
+ return self.model(x)
802
+
803
+
804
+ class DAC(BaseModel, CodecMixin):
805
+ def __init__(
806
+ self,
807
+ encoder_dim: int = 64,
808
+ encoder_rates: List[int] = [2, 4, 8, 8],
809
+ latent_dim: int = None,
810
+ decoder_dim: int = 1536,
811
+ decoder_rates: List[int] = [8, 8, 4, 2],
812
+ quantizer: torch.nn.Module = None,
813
+ sample_rate: int = 44100,
814
+ causal: bool = True,
815
+ encoder_transformer_layers: List[int] = [0, 0, 0, 0],
816
+ decoder_transformer_layers: List[int] = [0, 0, 0, 0],
817
+ transformer_general_config=None,
818
+ ):
819
+ super().__init__()
820
+
821
+ self.encoder_dim = encoder_dim
822
+ self.encoder_rates = encoder_rates
823
+ self.decoder_dim = decoder_dim
824
+ self.decoder_rates = decoder_rates
825
+ self.sample_rate = sample_rate
826
+
827
+ if latent_dim is None:
828
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
829
+
830
+ self.latent_dim = latent_dim
831
+
832
+ self.hop_length = np.prod(encoder_rates)
833
+ self.encoder = Encoder(
834
+ encoder_dim,
835
+ encoder_rates,
836
+ latent_dim,
837
+ causal=causal,
838
+ n_transformer_layers=encoder_transformer_layers,
839
+ transformer_general_config=transformer_general_config,
840
+ )
841
+
842
+ self.quantizer = quantizer
843
+
844
+ self.decoder = Decoder(
845
+ latent_dim,
846
+ decoder_dim,
847
+ decoder_rates,
848
+ causal=causal,
849
+ n_transformer_layers=decoder_transformer_layers,
850
+ transformer_general_config=transformer_general_config,
851
+ )
852
+ self.sample_rate = sample_rate
853
+ self.apply(init_weights)
854
+
855
+ self.delay = self.get_delay()
856
+
857
+ self.frame_length = self.hop_length * 4
858
+
859
+ def preprocess(self, audio_data, sample_rate):
860
+ if sample_rate is None:
861
+ sample_rate = self.sample_rate
862
+ assert sample_rate == self.sample_rate
863
+
864
+ length = audio_data.shape[-1]
865
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
866
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
867
+
868
+ return audio_data
869
+
870
+ def encode(
871
+ self,
872
+ audio_data: torch.Tensor,
873
+ audio_lengths: torch.Tensor = None,
874
+ n_quantizers: int = None,
875
+ **kwargs,
876
+ ):
877
+ """Encode given audio data and return quantized latent codes
878
+
879
+ Parameters
880
+ ----------
881
+ audio_data : Tensor[B x T]
882
+ Audio data to encode
883
+ n_quantizers : int, optional
884
+ Number of quantizers to use, by default None
885
+ If None, all quantizers are used.
886
+
887
+ Returns
888
+ -------
889
+ dict
890
+ A dictionary with the following keys:
891
+ "z" : Tensor[B x D x T]
892
+ Quantized continuous representation of input
893
+ "codes" : Tensor[B x N x T]
894
+ Codebook indices for each codebook
895
+ (quantized discrete representation of input)
896
+ "latents" : Tensor[B x N*D x T]
897
+ Projected latents (continuous representation of input before quantization)
898
+ "vq/commitment_loss" : Tensor[1]
899
+ Commitment loss to train encoder to predict vectors closer to codebook
900
+ entries
901
+ "vq/codebook_loss" : Tensor[1]
902
+ Codebook loss to update the codebook
903
+ "length" : int
904
+ Number of samples in input audio
905
+ """
906
+ # pad to multiple of self.frame_length
907
+ if audio_data.ndim == 2:
908
+ audio_data = audio_data.unsqueeze(1)
909
+ # print(audio_data.shape)
910
+ length = audio_data.shape[-1]
911
+ right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
912
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
913
+ if audio_lengths is None:
914
+ audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
915
+
916
+ z = self.encoder(audio_data)
917
+ vq_results = self.quantizer(z, n_quantizers, **kwargs)
918
+ indices = vq_results.codes
919
+ indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
920
+ return indices, indices_lens
921
+
922
+ def decode(self, indices: torch.Tensor, feature_lengths):
923
+ if indices.ndim == 2:
924
+ indices = indices[None]
925
+
926
+ z = self.quantizer.decode(indices)
927
+ audio_lengths = feature_lengths * self.frame_length
928
+ return self.decoder(z), audio_lengths
929
+
930
+ def forward(
931
+ self,
932
+ audio_data: torch.Tensor,
933
+ template: torch.Tensor = None,
934
+ mask: torch.Tensor = None,
935
+ sample_rate: int = None,
936
+ n_quantizers: int = None,
937
+ **kwargs,
938
+ ):
939
+ """Model forward pass
940
+
941
+ Parameters
942
+ ----------
943
+ audio_data : Tensor[B x 1 x T]
944
+ Audio data to encode
945
+ sample_rate : int, optional
946
+ Sample rate of audio data in Hz, by default None
947
+ If None, defaults to `self.sample_rate`
948
+ n_quantizers : int, optional
949
+ Number of quantizers to use, by default None.
950
+ If None, all quantizers are used.
951
+
952
+ Returns
953
+ -------
954
+ dict
955
+ A dictionary with the following keys:
956
+ "z" : Tensor[B x D x T]
957
+ Quantized continuous representation of input
958
+ "codes" : Tensor[B x N x T]
959
+ Codebook indices for each codebook
960
+ (quantized discrete representation of input)
961
+ "latents" : Tensor[B x N*D x T]
962
+ Projected latents (continuous representation of input before quantization)
963
+ "vq/commitment_loss" : Tensor[1]
964
+ Commitment loss to train encoder to predict vectors closer to codebook
965
+ entries
966
+ "vq/codebook_loss" : Tensor[1]
967
+ Codebook loss to update the codebook
968
+ "length" : int
969
+ Number of samples in input audio
970
+ "audio" : Tensor[B x 1 x length]
971
+ Decoded audio data.
972
+ """
973
+ length = audio_data.shape[-1]
974
+ audio_data = self.preprocess(audio_data, sample_rate)
975
+ vq_results = self.encode(audio_data, n_quantizers, **kwargs)
976
+ z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
977
+ x = self.decode(z)
978
+ return x[..., :length], vq_results
979
+
980
+
981
+ if __name__ == "__main__":
982
+
983
+ def filter_state_dict_shapes(params, model):
984
+ model_state_dict = model.state_dict()
985
+ filtered_state_dict = {
986
+ k: v
987
+ for k, v in params.items()
988
+ if k in model_state_dict and v.shape == model_state_dict[k].shape
989
+ }
990
+ skipped_keys = set(params.keys()) - set(filtered_state_dict.keys())
991
+ if skipped_keys:
992
+ print(
993
+ f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
994
+ )
995
+ return filtered_state_dict, skipped_keys
996
+
997
+ model = hydra.utils.instantiate(
998
+ OmegaConf.load("fish_speech/configs/modded_dac_vq.yaml")
999
+ )
1000
+ sd = torch.load("checkpoints/openaudio-s1-mini/firefly-gan-large.pth")
1001
+ filtered_sd, skipped_keys = filter_state_dict_shapes(sd, model)
1002
+ print(f"Skipped keys: {skipped_keys}")
1003
+ model.load_state_dict(filtered_sd, strict=False)
1004
+ model.eval()
1005
+
1006
+ src_audio_path = "./test.wav"
1007
+ wave_np, _ = librosa.load(src_audio_path, sr=44100, mono=False)
1008
+ if len(wave_np.shape) == 1:
1009
+ wave_np = wave_np[None, :]
1010
+ wave_tensor = torch.from_numpy(wave_np).unsqueeze(1)
1011
+
1012
+ with torch.no_grad():
1013
+ # encode 返回 (indices, indices_lens)
1014
+ indices, indices_lens = model.encode(wave_tensor)
1015
+ print(f"Indices shape: {indices.shape}")
1016
+ print(f"Indices lengths: {indices_lens}")
1017
+
1018
+ # decode 需要 indices 和 feature_lengths 两个参数
1019
+ fake_audio, audio_lengths = model.decode(indices, indices_lens)
1020
+ print(f"Decoded audio shape: {fake_audio.shape}")
1021
+ print(f"Audio lengths: {audio_lengths}")
1022
+
1023
+ # 保存重建的音频
1024
+ sf.write("fake.wav", fake_audio.squeeze(1).cpu().numpy().T, 44100)
fish_speech/models/dac/rvq.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing as tp
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from dac.nn.quantize import ResidualVectorQuantize
9
+ from torch.nn.utils.parametrizations import weight_norm
10
+ from torch.nn.utils.parametrize import remove_parametrizations
11
+
12
+
13
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
14
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
15
+ padding_left, padding_right = paddings
16
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
17
+ assert (padding_left + padding_right) <= x.shape[-1]
18
+ end = x.shape[-1] - padding_right
19
+ return x[..., padding_left:end]
20
+
21
+
22
+ def get_extra_padding_for_conv1d(
23
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
24
+ ) -> int:
25
+ """See `pad_for_conv1d`."""
26
+ length = x.shape[-1]
27
+ n_frames = (length - kernel_size + padding_total) / stride + 1
28
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
29
+ return ideal_length - length
30
+
31
+
32
+ def pad1d(
33
+ x: torch.Tensor,
34
+ paddings: tp.Tuple[int, int],
35
+ mode: str = "zeros",
36
+ value: float = 0.0,
37
+ ):
38
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
39
+ If this is the case, we insert extra 0 padding to the right
40
+ before the reflection happen.
41
+ """
42
+ length = x.shape[-1]
43
+ padding_left, padding_right = paddings
44
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
45
+ if mode == "reflect":
46
+ max_pad = max(padding_left, padding_right)
47
+ extra_pad = 0
48
+ if length <= max_pad:
49
+ extra_pad = max_pad - length + 1
50
+ x = F.pad(x, (0, extra_pad))
51
+ padded = F.pad(x, paddings, mode, value)
52
+ end = padded.shape[-1] - extra_pad
53
+ return padded[..., :end]
54
+ else:
55
+ return F.pad(x, paddings, mode, value)
56
+
57
+
58
+ class CausalConvNet(nn.Module):
59
+ def __init__(
60
+ self,
61
+ in_channels,
62
+ out_channels,
63
+ kernel_size,
64
+ dilation=1,
65
+ stride=1,
66
+ groups=1,
67
+ padding=None,
68
+ ):
69
+ super(CausalConvNet, self).__init__()
70
+ self.conv = nn.Conv1d(
71
+ in_channels,
72
+ out_channels,
73
+ kernel_size,
74
+ stride=stride,
75
+ dilation=dilation,
76
+ groups=groups,
77
+ )
78
+ self.stride = stride
79
+ self.kernel_size = (kernel_size - 1) * dilation + 1
80
+ self.dilation = dilation
81
+ self.padding = self.kernel_size - self.stride
82
+
83
+ def forward(self, x):
84
+ pad = self.padding
85
+ extra_padding = get_extra_padding_for_conv1d(
86
+ x, self.kernel_size, self.stride, pad
87
+ )
88
+ x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
89
+ return self.conv(x).contiguous()
90
+
91
+ def weight_norm(self, name="weight", dim=0):
92
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
93
+ return self
94
+
95
+ def remove_weight_norm(self):
96
+ self.conv = remove_parametrizations(self.conv)
97
+ return self
98
+
99
+
100
+ class CausalTransConvNet(nn.Module):
101
+ def __init__(
102
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
103
+ ):
104
+ super(CausalTransConvNet, self).__init__()
105
+ self.conv = nn.ConvTranspose1d(
106
+ in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
107
+ )
108
+ self.stride = stride
109
+ self.kernel_size = kernel_size
110
+
111
+ def forward(self, x):
112
+ x = self.conv(x)
113
+ pad = self.kernel_size - self.stride
114
+ padding_right = math.ceil(pad)
115
+ padding_left = pad - padding_right
116
+ x = unpad1d(x, (padding_left, padding_right))
117
+ return x.contiguous()
118
+
119
+ def weight_norm(self, name="weight", dim=0):
120
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
121
+ return self
122
+
123
+ def remove_weight_norm(self):
124
+ self.conv = remove_parametrizations(self.conv)
125
+ return self
126
+
127
+
128
+ # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
129
+ class ConvNeXtBlock(nn.Module):
130
+ r"""ConvNeXt Block. There are two equivalent implementations:
131
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
132
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
133
+ We use (2) as we find it slightly faster in PyTorch
134
+ Args:
135
+ dim (int): Number of input channels.
136
+ drop_path (float): Stochastic depth rate. Default: 0.0
137
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
138
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
139
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
140
+ dilation (int): Dilation for depthwise conv. Default: 1.
141
+ """ # noqa: E501
142
+
143
+ def __init__(
144
+ self,
145
+ dim: int,
146
+ layer_scale_init_value: float = 1e-6,
147
+ mlp_ratio: float = 4.0,
148
+ kernel_size: int = 7,
149
+ dilation: int = 1,
150
+ ):
151
+ super().__init__()
152
+ convnet_type = CausalConvNet
153
+ self.dwconv = convnet_type(
154
+ dim,
155
+ dim,
156
+ kernel_size=kernel_size,
157
+ # padding=int(dilation * (kernel_size - 1) / 2),
158
+ groups=dim,
159
+ dilation=dilation,
160
+ ) # depthwise conv
161
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
162
+ self.pwconv1 = nn.Linear(
163
+ dim, int(mlp_ratio * dim)
164
+ ) # pointwise/1x1 convs, implemented with linear layers
165
+ self.act = nn.GELU()
166
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
167
+ self.gamma = (
168
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
169
+ if layer_scale_init_value > 0
170
+ else None
171
+ )
172
+
173
+ def forward(self, x, apply_residual: bool = True):
174
+ input = x
175
+
176
+ x = self.dwconv(x)
177
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
178
+ x = self.norm(x)
179
+ x = self.pwconv1(x)
180
+ x = self.act(x)
181
+ x = self.pwconv2(x)
182
+
183
+ if self.gamma is not None:
184
+ x = self.gamma * x
185
+
186
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
187
+
188
+ if apply_residual:
189
+ x = input + x
190
+
191
+ return x
192
+
193
+
194
+ @dataclass
195
+ class VQResult:
196
+ z: torch.Tensor
197
+ codes: torch.Tensor
198
+ latents: torch.Tensor
199
+ codebook_loss: torch.Tensor
200
+ commitment_loss: torch.Tensor
201
+ semantic_distill_z: torch.Tensor | None = None
202
+
203
+
204
+ class DownsampleResidualVectorQuantize(nn.Module):
205
+ def __init__(
206
+ self,
207
+ input_dim: int = 1024,
208
+ n_codebooks: int = 9,
209
+ codebook_dim: int = 8,
210
+ quantizer_dropout: float = 0.5,
211
+ codebook_size: int = 1024,
212
+ semantic_codebook_size: int = 4096,
213
+ downsample_factor: tuple[int] = (2, 2),
214
+ downsample_dims: tuple[int] | None = None,
215
+ pre_module: nn.Module | None = None,
216
+ post_module: nn.Module | None = None,
217
+ semantic_predictor_module: nn.Module | None = None,
218
+ ):
219
+ super().__init__()
220
+
221
+ if downsample_dims is None:
222
+ downsample_dims = [input_dim for _ in range(len(downsample_factor))]
223
+
224
+ all_dims = (input_dim,) + tuple(downsample_dims)
225
+
226
+ self.semantic_quantizer = ResidualVectorQuantize(
227
+ input_dim=input_dim,
228
+ n_codebooks=1,
229
+ codebook_size=semantic_codebook_size,
230
+ codebook_dim=codebook_dim,
231
+ quantizer_dropout=0.0,
232
+ )
233
+
234
+ self.quantizer = ResidualVectorQuantize(
235
+ input_dim=input_dim,
236
+ n_codebooks=n_codebooks,
237
+ codebook_size=codebook_size,
238
+ codebook_dim=codebook_dim,
239
+ quantizer_dropout=quantizer_dropout,
240
+ )
241
+
242
+ self.downsample_factor = downsample_factor
243
+ self.downsample_dims = downsample_dims
244
+
245
+ convnet_type = CausalConvNet
246
+ transconvnet_type = CausalTransConvNet
247
+
248
+ self.downsample = nn.Sequential(
249
+ *[
250
+ nn.Sequential(
251
+ convnet_type(
252
+ all_dims[idx],
253
+ all_dims[idx + 1],
254
+ kernel_size=factor,
255
+ stride=factor,
256
+ ),
257
+ ConvNeXtBlock(dim=all_dims[idx + 1]),
258
+ )
259
+ for idx, factor in enumerate(downsample_factor)
260
+ ]
261
+ )
262
+
263
+ self.upsample = nn.Sequential(
264
+ *[
265
+ nn.Sequential(
266
+ transconvnet_type(
267
+ all_dims[idx + 1],
268
+ all_dims[idx],
269
+ kernel_size=factor,
270
+ stride=factor,
271
+ ),
272
+ ConvNeXtBlock(dim=all_dims[idx]),
273
+ )
274
+ for idx, factor in reversed(list(enumerate(downsample_factor)))
275
+ ]
276
+ )
277
+ self.apply(self._init_weights)
278
+ self.pre_module = (
279
+ pre_module if pre_module is not None else nn.Identity()
280
+ ) # leave for transformer, LSTM or Mamba or something else
281
+ self.post_module = post_module if post_module is not None else nn.Identity()
282
+ self.semantic_predictor_module = (
283
+ semantic_predictor_module
284
+ if semantic_predictor_module is not None
285
+ else nn.Identity()
286
+ )
287
+
288
+ def _init_weights(self, m):
289
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
290
+ nn.init.trunc_normal_(m.weight, std=0.02)
291
+ nn.init.constant_(m.bias, 0)
292
+
293
+ def forward(
294
+ self, z, n_quantizers: int = None, semantic_len: torch.Tensor = None, **kwargs
295
+ ):
296
+ # z: (B, D, T)
297
+ original_shape = z.shape
298
+ if semantic_len is None:
299
+ semantic_len = torch.LongTensor([z.shape[-1]])
300
+ z = self.downsample(z)
301
+ z = self.pre_module(z) # B, T, D
302
+ (
303
+ semantic_z,
304
+ semantic_codes,
305
+ semantic_latents,
306
+ semantic_commitment_loss,
307
+ semantic_codebook_loss,
308
+ ) = self.semantic_quantizer(z)
309
+ residual_z = z - semantic_z
310
+ residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
311
+ residual_z, n_quantizers=n_quantizers
312
+ )
313
+ z = semantic_z + residual_z
314
+ commitment_loss = commitment_loss + semantic_commitment_loss
315
+ codebook_loss = codebook_loss + semantic_codebook_loss
316
+ codes = torch.cat([semantic_codes, codes], dim=1)
317
+ latents = torch.cat([semantic_latents, latents], dim=1)
318
+ z = self.post_module(z)
319
+ z = self.upsample(z)
320
+ # z: (B, D, T)
321
+
322
+ # semantic distillation (disabled here since only used in training)
323
+ # semantic_distill_z = self.semantic_predictor_module(semantic_z, semantic_len).mT # wav2vec target is B, T, D
324
+
325
+ # Pad or crop z to match original shape
326
+ diff = original_shape[-1] - z.shape[-1]
327
+ right = 0
328
+ left = abs(diff) - right
329
+
330
+ if diff > 0:
331
+ z = F.pad(z, (left, right))
332
+ elif diff < 0:
333
+ z = z[..., left:]
334
+
335
+ results = VQResult(
336
+ z=z,
337
+ codes=codes,
338
+ latents=latents,
339
+ commitment_loss=commitment_loss,
340
+ codebook_loss=codebook_loss,
341
+ )
342
+
343
+ return results
344
+
345
+ # def encode(self, z):
346
+ # z = self.downsample(z)
347
+ # z = self.pre_module(z)
348
+ # _, indices, _, _, _ = self.quantizer(z.mT)
349
+ # indices = rearrange(indices, "g b l r -> b (g r) l")
350
+ # return indices
351
+ #
352
+ def decode(self, indices: torch.Tensor):
353
+ # indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
354
+
355
+ # print(f"indices: {indices.shape}, semantic_quantizer.codebook_size: {self.semantic_quantizer.codebook_size}, quantizer.codebook_size: {self.quantizer.codebook_size}, semantic min: {indices[:, 0].min()}, max: {indices[:, 0].max()}, quantizer min: {indices[:, 1:].min()}, max: {indices[:, 1:].max()}")
356
+
357
+ new_indices = torch.zeros_like(indices)
358
+ new_indices[:, 0] = torch.clamp(
359
+ indices[:, 0], max=self.semantic_quantizer.codebook_size - 1
360
+ )
361
+ new_indices[:, 1:] = torch.clamp(
362
+ indices[:, 1:], max=self.quantizer.codebook_size - 1
363
+ )
364
+
365
+ z_q_semantic = self.semantic_quantizer.from_codes(new_indices[:, :1])[0]
366
+ z_q_residual = self.quantizer.from_codes(new_indices[:, 1:])[0]
367
+ z_q = z_q_semantic + z_q_residual
368
+ z_q = self.post_module(z_q)
369
+ z_q = self.upsample(z_q)
370
+ return z_q
371
+
372
+ # def from_latents(self, latents: torch.Tensor):
373
+ # z_q, z_p, codes = super().from_latents(latents)
374
+ # z_q = self.upsample(z_q)
375
+ # return z_q, z_p, codes
376
+
377
+
378
+ if __name__ == "__main__":
379
+ rvq = DownsampleResidualVectorQuantize(
380
+ input_dim=512,
381
+ n_codebooks=8,
382
+ codebook_dim=8,
383
+ codebook_size=1024,
384
+ quantizer_dropout=0.5,
385
+ downsample_factor=[2, 2],
386
+ )
387
+ rvq.eval()
388
+ x = torch.randn(2, 512, 442)
389
+
390
+ result = rvq(x)
391
+ print(rvq)
392
+ print(result.latents.shape, result.codes.shape, result.z.shape)
393
+
394
+ # y = rvq.from_codes(result.codes)
395
+ # print(y[0].shape)
396
+
397
+ # y = rvq.from_latents(
398
+
399
+ result1 = rvq(x[:, :, :40])
400
+ print(result1.latents.shape, result1.codes.shape, result1.z.shape)
401
+
402
+ assert torch.allclose(result.z[:, :, :40], result1.z, atol=1e-8)
403
+ print("Success")
fish_speech/models/text2semantic/inference.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import queue
3
+ import threading
4
+ import time
5
+ from contextlib import nullcontext
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Literal, Optional, Tuple, Union
9
+
10
+ import click
11
+ import numpy as np
12
+ import torch
13
+ import torch._dynamo.config
14
+ import torch._inductor.config
15
+ from loguru import logger
16
+ from tqdm import tqdm
17
+ from transformers import AutoTokenizer
18
+
19
+ from fish_speech.content_sequence import (
20
+ ContentSequence,
21
+ TextPart,
22
+ VQPart,
23
+ )
24
+ from fish_speech.models.text2semantic.llama import BaseModelArgs
25
+ from fish_speech.text import clean_text, split_text
26
+ from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
27
+
28
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
29
+ torch._inductor.config.coordinate_descent_tuning = True
30
+ torch._inductor.config.triton.unique_kernel_names = True
31
+
32
+ if hasattr(torch._inductor.config, "fx_graph_cache"):
33
+ # Experimental feature to reduce compilation times, will be on by default in future
34
+ torch._inductor.config.fx_graph_cache = True
35
+
36
+
37
+ from torch.nn.attention import SDPBackend, sdpa_kernel
38
+
39
+ from fish_speech.models.text2semantic.llama import (
40
+ BaseTransformer,
41
+ DualARTransformer,
42
+ NaiveTransformer,
43
+ )
44
+
45
+
46
+ def multinomial_sample_one_no_sync(
47
+ probs_sort,
48
+ ): # Does multinomial sampling without a cuda synchronization
49
+ q = torch.empty_like(probs_sort).exponential_(1)
50
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
51
+
52
+
53
+ def logits_to_probs(
54
+ logits,
55
+ previous_tokens: Optional[torch.Tensor] = None,
56
+ temperature: torch.Tensor = 1.0,
57
+ top_p: torch.Tensor = 1.0,
58
+ repetition_penalty: torch.Tensor = 1.0,
59
+ ) -> torch.Tensor:
60
+ # Apply repetition penalty
61
+ if previous_tokens is not None:
62
+ previous_tokens = previous_tokens.long()
63
+ score = torch.gather(logits, dim=0, index=previous_tokens)
64
+ score = torch.where(
65
+ score < 0, score * repetition_penalty, score / repetition_penalty
66
+ )
67
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
68
+
69
+ # Apply top-p sampling
70
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
71
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
72
+ sorted_indices_to_remove = cum_probs > top_p
73
+ sorted_indices_to_remove[0] = False # keep at least one option
74
+ indices_to_remove = sorted_indices_to_remove.scatter(
75
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
76
+ )
77
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
78
+
79
+ logits = logits / max(temperature, 1e-5)
80
+
81
+ probs = torch.nn.functional.softmax(logits, dim=-1)
82
+ return probs
83
+
84
+
85
+ def sample(
86
+ logits,
87
+ previous_tokens: Optional[torch.Tensor] = None,
88
+ **sampling_kwargs,
89
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
90
+ probs = logits_to_probs(
91
+ logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
92
+ )
93
+ idx_next = multinomial_sample_one_no_sync(probs)
94
+ return idx_next, probs
95
+
96
+
97
+ def decode_one_token_ar(
98
+ model: DualARTransformer,
99
+ x: torch.Tensor,
100
+ input_pos: torch.Tensor,
101
+ semantic_ids: list,
102
+ previous_tokens: torch.Tensor = None,
103
+ **sampling_kwargs,
104
+ ) -> torch.Tensor:
105
+ x = model.forward_generate(x, input_pos)
106
+
107
+ sampling_kwargs_main = sampling_kwargs.copy()
108
+ # sampling_kwargs_main["temperature"] = 0.1
109
+ # sampling_kwargs_main["top_p"] = 0.1
110
+ # sampling_kwargs_main["repetition_penalty"] = 1.0
111
+
112
+ codebooks = [
113
+ sample(
114
+ x.logits,
115
+ previous_tokens=(
116
+ previous_tokens[0] if previous_tokens is not None else None
117
+ ), # Disable repetition penalty for the token codebook
118
+ **sampling_kwargs_main,
119
+ )[0]
120
+ ]
121
+
122
+ hidden_states = x.hidden_states
123
+
124
+ # Cleanup the cache
125
+ for layer in model.fast_layers:
126
+ layer.attention.kv_cache.k_cache.fill_(0)
127
+ layer.attention.kv_cache.v_cache.fill_(0)
128
+
129
+ input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
130
+ model.forward_generate_fast(hidden_states, input_pos)
131
+ a = codebooks[0] - model.tokenizer.semantic_begin_id
132
+ a[a < 0] = 0
133
+ hidden_states = model.fast_embeddings(a)
134
+ codebooks.append(a)
135
+
136
+ for codebook_idx in range(1, model.config.num_codebooks):
137
+ input_pos = torch.tensor(
138
+ [codebook_idx], device=hidden_states.device, dtype=torch.long
139
+ )
140
+ logits = model.forward_generate_fast(hidden_states, input_pos)
141
+ chunked_logits = logits[..., :1024]
142
+ a = sample(
143
+ chunked_logits,
144
+ previous_tokens=(
145
+ previous_tokens[codebook_idx + 1]
146
+ if previous_tokens is not None
147
+ else None
148
+ ),
149
+ **sampling_kwargs,
150
+ )[0]
151
+ hidden_states = model.fast_embeddings(a)
152
+ codebooks.append(a)
153
+
154
+ codebooks = torch.stack(codebooks, dim=0)
155
+ # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
156
+ # codebooks[1:, :] = torch.masked_fill(
157
+ # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
158
+ # )
159
+
160
+ # print(codebooks)
161
+ return codebooks
162
+
163
+
164
+ def decode_n_tokens(
165
+ model: NaiveTransformer,
166
+ cur_token: torch.Tensor,
167
+ input_pos: torch.Tensor,
168
+ num_new_tokens: int,
169
+ semantic_ids: list,
170
+ decode_one_token=decode_one_token_ar,
171
+ **sampling_kwargs,
172
+ ):
173
+ previous_tokens = torch.zeros(
174
+ (model.config.num_codebooks + 1, model.config.max_seq_len),
175
+ dtype=torch.int,
176
+ device=cur_token.device,
177
+ )
178
+
179
+ for i in tqdm(range(num_new_tokens)):
180
+ # We need to get windowed repeat penalty
181
+ win_size = 16
182
+ if i < win_size:
183
+ window = previous_tokens[:, :win_size]
184
+ else:
185
+ window = previous_tokens[:, i - win_size : i]
186
+
187
+ with (
188
+ torch.backends.cuda.sdp_kernel(
189
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
190
+ )
191
+ if torch.cuda.is_available()
192
+ else nullcontext()
193
+ ): # Actually better for Inductor to codegen attention here
194
+ next_token = decode_one_token(
195
+ model=model,
196
+ x=cur_token,
197
+ input_pos=input_pos,
198
+ previous_tokens=window,
199
+ semantic_ids=semantic_ids,
200
+ **sampling_kwargs,
201
+ )
202
+
203
+ input_pos += 1
204
+ cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
205
+ previous_tokens[:, i : i + 1] = next_token.view(
206
+ model.config.num_codebooks + 1, -1
207
+ )
208
+
209
+ if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
210
+ break
211
+
212
+ return previous_tokens[:, : i + 1]
213
+
214
+
215
+ @torch.no_grad()
216
+ @torch.inference_mode()
217
+ def generate(
218
+ *,
219
+ model: NaiveTransformer,
220
+ prompt: torch.Tensor,
221
+ max_new_tokens: int,
222
+ decode_one_token=decode_one_token_ar,
223
+ **sampling_kwargs,
224
+ ) -> torch.Tensor:
225
+ """
226
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
227
+ """
228
+
229
+ # create an empty tensor of the expected final shape and fill in the current tokens
230
+ T = prompt.size(1)
231
+ # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
232
+ semantic_ids = [
233
+ model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
234
+ ]
235
+
236
+ if max_new_tokens:
237
+ if T + max_new_tokens > model.config.max_seq_len:
238
+ max_new_tokens = model.config.max_seq_len - T
239
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
240
+
241
+ T_new = T + max_new_tokens
242
+ else:
243
+ T_new = model.config.max_seq_len
244
+ max_new_tokens = T_new - T
245
+
246
+ device, dtype = prompt.device, prompt.dtype
247
+
248
+ codebook_dim = 1 + model.config.num_codebooks
249
+ # create an empty tensor of the expected final shape and fill in the current tokens
250
+ empty = torch.empty(
251
+ (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
252
+ )
253
+ empty[:, :T] = prompt
254
+ seq = empty
255
+ input_pos = torch.arange(0, T, device=device)
256
+
257
+ # Use non-accelerated version for now, to avoid compilation overhead
258
+ prefill_decode = decode_one_token_ar
259
+
260
+ next_token = prefill_decode(
261
+ model,
262
+ prompt.view(1, codebook_dim, -1),
263
+ input_pos,
264
+ semantic_ids=semantic_ids,
265
+ **sampling_kwargs,
266
+ )
267
+ seq[:, T : T + 1] = next_token
268
+
269
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
270
+ x = decode_n_tokens(
271
+ model,
272
+ next_token.view(1, codebook_dim, -1),
273
+ input_pos,
274
+ max_new_tokens - 1,
275
+ decode_one_token=decode_one_token,
276
+ semantic_ids=semantic_ids,
277
+ **sampling_kwargs,
278
+ )
279
+ # x = torch.cat(generated_tokens, dim=1)
280
+ seq = seq[:, : T + 1 + x.size(1)]
281
+ seq[:, T + 1 :] = x
282
+
283
+ return seq
284
+
285
+
286
+ def load_model(checkpoint_path, device, precision, compile=False):
287
+ model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
288
+
289
+ model = model.to(device=device, dtype=precision)
290
+ logger.info(f"Restored model from checkpoint")
291
+
292
+ if isinstance(model, DualARTransformer):
293
+ decode_one_token = decode_one_token_ar
294
+ logger.info("Using DualARTransformer")
295
+ else:
296
+ raise ValueError("Model is not a DualARTransformer")
297
+
298
+ if compile:
299
+ logger.info("Compiling function...")
300
+ decode_one_token = torch.compile(
301
+ decode_one_token,
302
+ fullgraph=True,
303
+ backend="inductor" if torch.cuda.is_available() else "aot_eager",
304
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
305
+ )
306
+
307
+ return model.eval(), decode_one_token
308
+
309
+
310
+ @dataclass
311
+ class GenerateResponse:
312
+ action: Literal["sample", "next"]
313
+ codes: Optional[torch.Tensor] = None
314
+ text: Optional[str] = None
315
+
316
+
317
+ def generate_long(
318
+ *,
319
+ model,
320
+ device: str | torch.device,
321
+ decode_one_token: callable,
322
+ text: str,
323
+ num_samples: int = 1,
324
+ max_new_tokens: int = 0,
325
+ top_p: int = 0.8,
326
+ repetition_penalty: float = 1.1,
327
+ temperature: float = 0.8,
328
+ compile: bool = False,
329
+ iterative_prompt: bool = True,
330
+ chunk_length: int = 150,
331
+ prompt_text: Optional[str | list[str]] = None,
332
+ prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
333
+ ):
334
+ assert 0 < top_p <= 1, "top_p must be in (0, 1]"
335
+ assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
336
+ assert 0 < temperature < 2, "temperature must be in (0, 2)"
337
+
338
+ use_prompt = prompt_text is not None and prompt_tokens is not None
339
+ if use_prompt and isinstance(prompt_text, str):
340
+ prompt_text = [prompt_text]
341
+ prompt_tokens = [prompt_tokens]
342
+
343
+ assert use_prompt is False or len(prompt_text) == len(
344
+ prompt_tokens
345
+ ), "Prompt text and tokens must have the same length"
346
+
347
+ prompt_tokens = [i.cpu() for i in prompt_tokens]
348
+
349
+ model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
350
+ tokenizer = model.tokenizer
351
+ base_content_sequence = ContentSequence(modality="interleave")
352
+
353
+ texts = split_text(text, chunk_length) if iterative_prompt else [text]
354
+ max_length = model.config.max_seq_len
355
+
356
+ if use_prompt:
357
+ for t, c in zip(prompt_text, prompt_tokens):
358
+ base_content_sequence.append(
359
+ [
360
+ TextPart(text=t),
361
+ VQPart(codes=c),
362
+ ],
363
+ add_end=True,
364
+ )
365
+
366
+ encoded_prompts = base_content_sequence.encode_for_inference(
367
+ tokenizer, num_codebooks=model.config.num_codebooks
368
+ )
369
+ if encoded_prompts.size(1) > max_length - 2048:
370
+ raise ValueError(
371
+ f"Prompt is too long: {encoded_prompts.size(1)} > {max_length - 2048}"
372
+ )
373
+
374
+ encoded = []
375
+ for text in texts:
376
+ content_sequence = ContentSequence(modality=None)
377
+ content_sequence.append(TextPart(text=text))
378
+ encoded.append(
379
+ content_sequence.encode_for_inference(
380
+ tokenizer, num_codebooks=model.config.num_codebooks
381
+ )
382
+ )
383
+ logger.info(f"Encoded text: {text}")
384
+
385
+ # Move temperature, top_p, repetition_penalty to device
386
+ # This is important so that changing params doesn't trigger recompile
387
+ temperature = torch.tensor(temperature, device=device, dtype=torch.float)
388
+ top_p = torch.tensor(top_p, device=device, dtype=torch.float)
389
+ repetition_penalty = torch.tensor(
390
+ repetition_penalty, device=device, dtype=torch.float
391
+ )
392
+
393
+ for sample_idx in range(num_samples):
394
+ if torch.cuda.is_available():
395
+ torch.cuda.synchronize()
396
+
397
+ global_encoded = []
398
+ seg_idx = 0
399
+
400
+ while seg_idx < len(encoded):
401
+ logger.info(
402
+ f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
403
+ )
404
+
405
+ seg = encoded[seg_idx]
406
+ global_encoded.append(seg)
407
+
408
+ # Do not use previous segments to generate current segment for now
409
+ # lengths = reversed([seg.size(1) for seg in global_encoded])
410
+
411
+ # # Pick last 2000 tokens
412
+ # count = 0
413
+ # for i, length in enumerate(lengths):
414
+ # count += length
415
+ # if count + length > max_length - 2048 - encoded_prompts.size(1):
416
+ # break
417
+
418
+ # if i != 0 and i % 2 == 0:
419
+ # i -= 1
420
+
421
+ # # Rotate the list, always make sure first segment is included to avoid drift
422
+ # if i < len(global_encoded) - 2:
423
+ # partial_encoded = global_encoded[:2] + global_encoded[-i:]
424
+ # else:
425
+ # partial_encoded = global_encoded
426
+
427
+ # cat_encoded = torch.cat([encoded_prompts, *partial_encoded], dim=1)
428
+ if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
429
+ cat_encoded = torch.cat(
430
+ [encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
431
+ )
432
+ else:
433
+ cat_encoded = torch.cat([encoded_prompts, seg], dim=1)
434
+
435
+ cat_encoded = cat_encoded.to(device=device)
436
+ prompt_length = cat_encoded.size(1)
437
+
438
+ t0 = time.perf_counter()
439
+ y = generate(
440
+ model=model,
441
+ prompt=cat_encoded,
442
+ max_new_tokens=max_new_tokens,
443
+ decode_one_token=decode_one_token,
444
+ temperature=temperature,
445
+ top_p=top_p,
446
+ repetition_penalty=repetition_penalty,
447
+ )
448
+
449
+ if sample_idx == 0 and seg_idx == 0 and compile:
450
+ logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
451
+
452
+ if torch.cuda.is_available():
453
+ torch.cuda.synchronize()
454
+
455
+ t = time.perf_counter() - t0
456
+
457
+ tokens_generated = y.size(1) - prompt_length
458
+ tokens_sec = tokens_generated / t
459
+ logger.info(
460
+ f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
461
+ )
462
+ logger.info(
463
+ f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
464
+ )
465
+
466
+ if torch.cuda.is_available():
467
+ logger.info(
468
+ f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
469
+ )
470
+
471
+ # Put the generated tokens
472
+ # since there is <im_end>, we remove last token
473
+ codes = y[1:, prompt_length:-1].clone()
474
+ assert (codes >= 0).all(), f"Negative code found"
475
+
476
+ decoded = y[:, prompt_length:].clone()
477
+ # But for global encoding, we should keep the <im_end> token
478
+
479
+ global_encoded.append(decoded.cpu())
480
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
481
+ yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
482
+ seg_idx += 1
483
+
484
+ # This indicates the end of the current sample
485
+ yield GenerateResponse(action="next")
486
+
487
+
488
+ @dataclass
489
+ class WrappedGenerateResponse:
490
+ status: Literal["success", "error"]
491
+ response: Optional[GenerateResponse | Exception] = None
492
+
493
+
494
+ @dataclass
495
+ class GenerateRequest:
496
+ request: dict
497
+ response_queue: queue.Queue
498
+
499
+
500
+ def launch_thread_safe_queue(
501
+ checkpoint_path,
502
+ device,
503
+ precision,
504
+ compile: bool = False,
505
+ ):
506
+ input_queue = queue.Queue()
507
+ init_event = threading.Event()
508
+
509
+ def worker():
510
+ model, decode_one_token = load_model(
511
+ checkpoint_path, device, precision, compile=compile
512
+ )
513
+ with torch.device(device):
514
+ model.setup_caches(
515
+ max_batch_size=1,
516
+ max_seq_len=model.config.max_seq_len,
517
+ dtype=next(model.parameters()).dtype,
518
+ )
519
+ init_event.set()
520
+
521
+ while True:
522
+ item: GenerateRequest | None = input_queue.get()
523
+ if item is None:
524
+ break
525
+
526
+ kwargs = item.request
527
+ response_queue = item.response_queue
528
+
529
+ try:
530
+ for chunk in generate_long(
531
+ model=model, decode_one_token=decode_one_token, **kwargs
532
+ ):
533
+ response_queue.put(
534
+ WrappedGenerateResponse(status="success", response=chunk)
535
+ )
536
+ except Exception as e:
537
+ response_queue.put(WrappedGenerateResponse(status="error", response=e))
538
+
539
+ threading.Thread(target=worker, daemon=True).start()
540
+ init_event.wait()
541
+
542
+ return input_queue
543
+
544
+
545
+ def launch_thread_safe_queue_agent(
546
+ checkpoint_path,
547
+ device,
548
+ precision,
549
+ compile: bool = False,
550
+ ):
551
+ input_queue = queue.Queue()
552
+ init_event = threading.Event()
553
+
554
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
555
+ config = BaseModelArgs.from_pretrained(checkpoint_path)
556
+
557
+ def worker():
558
+ model, decode_one_token = load_model(
559
+ checkpoint_path, device, precision, compile=compile, is_agent=True
560
+ )
561
+
562
+ with torch.device(device):
563
+ model.setup_caches(
564
+ max_batch_size=1,
565
+ max_seq_len=model.config.max_seq_len,
566
+ dtype=next(model.parameters()).dtype,
567
+ )
568
+ init_event.set()
569
+
570
+ while True:
571
+ item: GenerateRequest | None = input_queue.get()
572
+ if item is None:
573
+ break
574
+
575
+ kwargs = item.request
576
+ response_queue = item.response_queue
577
+
578
+ try:
579
+ for token in generate_agent(
580
+ model=model,
581
+ decode_one_token=decode_one_token,
582
+ **kwargs,
583
+ ):
584
+ response_queue.put(token)
585
+
586
+ response_queue.put("stop")
587
+ except Exception as e:
588
+ import traceback
589
+
590
+ logger.exception(f"Error in worker: {traceback.format_exc()}")
591
+ response_queue.put("error")
592
+
593
+ threading.Thread(target=worker, daemon=True).start()
594
+ init_event.wait()
595
+
596
+ return input_queue, tokenizer, config
597
+
598
+
599
+ @click.command()
600
+ @click.option(
601
+ "--text",
602
+ type=str,
603
+ default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
604
+ )
605
+ @click.option("--prompt-text", type=str, default=None, multiple=True)
606
+ @click.option(
607
+ "--prompt-tokens",
608
+ type=click.Path(path_type=Path, exists=True),
609
+ default=None,
610
+ multiple=True,
611
+ )
612
+ @click.option("--num-samples", type=int, default=1)
613
+ @click.option("--max-new-tokens", type=int, default=0)
614
+ @click.option("--top-p", type=float, default=0.8)
615
+ @click.option("--repetition-penalty", type=float, default=1.1)
616
+ @click.option("--temperature", type=float, default=0.8)
617
+ @click.option(
618
+ "--checkpoint-path",
619
+ type=click.Path(path_type=Path, exists=True),
620
+ default="checkpoints/openaudio-s1-mini",
621
+ )
622
+ @click.option("--device", type=str, default="cuda")
623
+ @click.option("--compile/--no-compile", default=False)
624
+ @click.option("--seed", type=int, default=42)
625
+ @click.option("--half/--no-half", default=False)
626
+ @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
627
+ @click.option("--chunk-length", type=int, default=300)
628
+ @click.option("--output-dir", type=Path, default="temp")
629
+ def main(
630
+ text: str,
631
+ prompt_text: Optional[list[str]],
632
+ prompt_tokens: Optional[list[Path]],
633
+ num_samples: int,
634
+ max_new_tokens: int,
635
+ top_p: int,
636
+ repetition_penalty: float,
637
+ temperature: float,
638
+ checkpoint_path: Path,
639
+ device: str,
640
+ compile: bool,
641
+ seed: int,
642
+ half: bool,
643
+ iterative_prompt: bool,
644
+ chunk_length: int,
645
+ output_dir: Path,
646
+ ) -> None:
647
+ os.makedirs(output_dir, exist_ok=True)
648
+ precision = torch.half if half else torch.bfloat16
649
+
650
+ if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
651
+ raise ValueError(
652
+ f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
653
+ )
654
+
655
+ logger.info("Loading model ...")
656
+ t0 = time.time()
657
+ model, decode_one_token = load_model(
658
+ checkpoint_path, device, precision, compile=compile
659
+ )
660
+ with torch.device(device):
661
+ model.setup_caches(
662
+ max_batch_size=1,
663
+ max_seq_len=model.config.max_seq_len,
664
+ dtype=next(model.parameters()).dtype,
665
+ )
666
+ if torch.cuda.is_available():
667
+ torch.cuda.synchronize()
668
+
669
+ logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
670
+
671
+ if prompt_tokens is not None:
672
+ prompt_tokens = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
673
+
674
+ torch.manual_seed(seed)
675
+
676
+ if torch.cuda.is_available():
677
+ torch.cuda.manual_seed(seed)
678
+
679
+ generator = generate_long(
680
+ model=model,
681
+ device=device,
682
+ decode_one_token=decode_one_token,
683
+ text=text,
684
+ num_samples=num_samples,
685
+ max_new_tokens=max_new_tokens,
686
+ top_p=top_p,
687
+ repetition_penalty=repetition_penalty,
688
+ temperature=temperature,
689
+ compile=compile,
690
+ iterative_prompt=iterative_prompt,
691
+ chunk_length=chunk_length,
692
+ prompt_text=prompt_text,
693
+ prompt_tokens=prompt_tokens,
694
+ )
695
+
696
+ idx = 0
697
+ codes = []
698
+
699
+ for response in generator:
700
+ if response.action == "sample":
701
+ codes.append(response.codes)
702
+ logger.info(f"Sampled text: {response.text}")
703
+ elif response.action == "next":
704
+ if codes:
705
+ codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
706
+ np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
707
+ logger.info(f"Saved codes to {codes_npy_path}")
708
+ logger.info(f"Next sample")
709
+ codes = []
710
+ idx += 1
711
+ else:
712
+ logger.error(f"Error: {response}")
713
+
714
+
715
+ if __name__ == "__main__":
716
+ main()
fish_speech/models/text2semantic/lit_module.py CHANGED
@@ -1,202 +1,202 @@
1
- from typing import Any, Optional
2
-
3
- import lightning as L
4
- import torch
5
- import torch.nn.functional as F
6
- from lightning.pytorch.utilities.types import OptimizerLRScheduler
7
-
8
- import fish_speech.utils as utils
9
- from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
10
- from fish_speech.models.text2semantic.llama import NaiveTransformer
11
-
12
- log = utils.RankedLogger(__name__, rank_zero_only=True)
13
-
14
-
15
- class TextToSemantic(L.LightningModule):
16
- def __init__(
17
- self,
18
- model: NaiveTransformer,
19
- optimizer: Any,
20
- lr_scheduler: Any,
21
- ):
22
- super().__init__()
23
-
24
- self.model = model
25
- self.optimizer_builder = optimizer
26
- self.lr_scheduler_builder = lr_scheduler
27
-
28
- def forward(self, x):
29
- return self.model(x)
30
-
31
- def on_save_checkpoint(self, checkpoint):
32
- # Save only LoRA parameters
33
- state_dict = checkpoint["state_dict"]
34
- use_lora = any("lora" in name for name in state_dict.keys())
35
- if not use_lora:
36
- return
37
-
38
- for name in list(state_dict.keys()):
39
- if "lora" not in name:
40
- state_dict.pop(name)
41
-
42
- def configure_optimizers(self) -> OptimizerLRScheduler:
43
- # Get weight decay parameters
44
- weight_decay_parameters, other_parameters = [], []
45
- for name, param in self.named_parameters():
46
- if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
47
- other_parameters.append(param)
48
- else:
49
- weight_decay_parameters.append(param)
50
-
51
- optimizer = self.optimizer_builder(
52
- [
53
- {"params": weight_decay_parameters},
54
- {"params": other_parameters, "weight_decay": 0.0},
55
- ]
56
- )
57
-
58
- # Print the parameters and their weight decay
59
- for i in optimizer.param_groups:
60
- log.info(
61
- f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
62
- )
63
-
64
- lr_scheduler = self.lr_scheduler_builder(optimizer)
65
-
66
- return {
67
- "optimizer": optimizer,
68
- "lr_scheduler": {
69
- "scheduler": lr_scheduler,
70
- "interval": "step",
71
- },
72
- }
73
-
74
- # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
75
- def get_batch_logps(
76
- self,
77
- logits: torch.FloatTensor,
78
- labels: torch.LongTensor,
79
- average_log_prob: bool = False,
80
- ) -> torch.FloatTensor:
81
- """Compute the log probabilities of the given labels under the given logits.
82
-
83
- Args:
84
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
85
- labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
86
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
87
-
88
- Returns:
89
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
90
- """
91
- assert logits.shape[:-1] == labels.shape
92
-
93
- labels = labels.clone()
94
- loss_mask = labels != -100
95
-
96
- # dummy token; we'll ignore the losses on these tokens later
97
- labels[labels == -100] = 0
98
-
99
- per_token_logps = torch.gather(
100
- logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
101
- ).squeeze(-1)
102
-
103
- if average_log_prob:
104
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
105
- else:
106
- return (per_token_logps * loss_mask).sum(-1)
107
-
108
- def _step(self, batch, batch_idx, stage: str):
109
- is_train = stage == "train"
110
-
111
- if is_train:
112
- # Key part to make lora work
113
- # Otherwise the parameters are merged, which lead to incorrect gradients
114
- self.model.train()
115
-
116
- # Do positive and negative samples in the same batch to speed up training
117
- labels = batch["labels"]
118
- outputs = self.model(
119
- inp=batch["inputs"],
120
- key_padding_mask=batch["attention_masks"],
121
- )
122
- token_logits = outputs.token_logits
123
- codebook_logits = outputs.codebook_logits
124
-
125
- # Generate labels
126
- base_loss = F.cross_entropy(
127
- token_logits.view(-1, token_logits.size(-1)),
128
- labels[:, 0].reshape(-1),
129
- ignore_index=-100,
130
- )
131
-
132
- codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
133
- semantic_loss = F.cross_entropy(
134
- codebook_logits.view(-1, codebook_logits.size(-1)),
135
- codebook_labels.reshape(-1),
136
- ignore_index=-100,
137
- )
138
-
139
- loss = base_loss + semantic_loss
140
-
141
- self.log(
142
- f"{stage}/loss",
143
- loss,
144
- on_step=is_train,
145
- on_epoch=not is_train,
146
- prog_bar=True,
147
- logger=True,
148
- sync_dist=not is_train,
149
- )
150
-
151
- self.log(
152
- f"{stage}/base_loss",
153
- base_loss,
154
- on_step=is_train,
155
- on_epoch=not is_train,
156
- prog_bar=False,
157
- logger=True,
158
- sync_dist=not is_train,
159
- )
160
-
161
- self.log(
162
- f"{stage}/semantic_loss",
163
- semantic_loss,
164
- on_step=is_train,
165
- on_epoch=not is_train,
166
- prog_bar=False,
167
- logger=True,
168
- sync_dist=not is_train,
169
- )
170
-
171
- # Top-5 accuracy
172
- accuracy = self.get_accuracy(codebook_logits, codebook_labels)
173
- self.log(
174
- f"{stage}/top_5_accuracy",
175
- accuracy,
176
- on_step=is_train,
177
- on_epoch=not is_train,
178
- prog_bar=True,
179
- logger=True,
180
- sync_dist=not is_train,
181
- )
182
-
183
- return loss
184
-
185
- def get_accuracy(self, logits, labels):
186
- mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
187
- if mask.sum() == 0:
188
- return torch.tensor(0.0, device=logits.device)
189
-
190
- _, indices = logits.topk(5, dim=-1)
191
- correct = indices.eq(labels.unsqueeze(-1))
192
- correct[~mask] = 0
193
- correct = correct.sum()
194
- accuracy = correct / mask.sum()
195
-
196
- return accuracy
197
-
198
- def training_step(self, batch, batch_idx):
199
- return self._step(batch, batch_idx, "train")
200
-
201
- def validation_step(self, batch, batch_idx):
202
- return self._step(batch, batch_idx, "val")
 
1
+ from typing import Any, Optional
2
+
3
+ import lightning as L
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from lightning.pytorch.utilities.types import OptimizerLRScheduler
7
+
8
+ import fish_speech.utils as utils
9
+ from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
10
+ from fish_speech.models.text2semantic.llama import NaiveTransformer
11
+
12
+ log = utils.RankedLogger(__name__, rank_zero_only=True)
13
+
14
+
15
+ class TextToSemantic(L.LightningModule):
16
+ def __init__(
17
+ self,
18
+ model: NaiveTransformer,
19
+ optimizer: Any,
20
+ lr_scheduler: Any,
21
+ ):
22
+ super().__init__()
23
+
24
+ self.model = model
25
+ self.optimizer_builder = optimizer
26
+ self.lr_scheduler_builder = lr_scheduler
27
+
28
+ def forward(self, x):
29
+ return self.model(x)
30
+
31
+ def on_save_checkpoint(self, checkpoint):
32
+ # Save only LoRA parameters
33
+ state_dict = checkpoint["state_dict"]
34
+ use_lora = any("lora" in name for name in state_dict.keys())
35
+ if not use_lora:
36
+ return
37
+
38
+ for name in list(state_dict.keys()):
39
+ if "lora" not in name:
40
+ state_dict.pop(name)
41
+
42
+ def configure_optimizers(self) -> OptimizerLRScheduler:
43
+ # Get weight decay parameters
44
+ weight_decay_parameters, other_parameters = [], []
45
+ for name, param in self.named_parameters():
46
+ if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
47
+ other_parameters.append(param)
48
+ else:
49
+ weight_decay_parameters.append(param)
50
+
51
+ optimizer = self.optimizer_builder(
52
+ [
53
+ {"params": weight_decay_parameters},
54
+ {"params": other_parameters, "weight_decay": 0.0},
55
+ ]
56
+ )
57
+
58
+ # Print the parameters and their weight decay
59
+ for i in optimizer.param_groups:
60
+ log.info(
61
+ f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
62
+ )
63
+
64
+ lr_scheduler = self.lr_scheduler_builder(optimizer)
65
+
66
+ return {
67
+ "optimizer": optimizer,
68
+ "lr_scheduler": {
69
+ "scheduler": lr_scheduler,
70
+ "interval": "step",
71
+ },
72
+ }
73
+
74
+ # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
75
+ def get_batch_logps(
76
+ self,
77
+ logits: torch.FloatTensor,
78
+ labels: torch.LongTensor,
79
+ average_log_prob: bool = False,
80
+ ) -> torch.FloatTensor:
81
+ """Compute the log probabilities of the given labels under the given logits.
82
+
83
+ Args:
84
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
85
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
86
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
87
+
88
+ Returns:
89
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
90
+ """
91
+ assert logits.shape[:-1] == labels.shape
92
+
93
+ labels = labels.clone()
94
+ loss_mask = labels != -100
95
+
96
+ # dummy token; we'll ignore the losses on these tokens later
97
+ labels[labels == -100] = 0
98
+
99
+ per_token_logps = torch.gather(
100
+ logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
101
+ ).squeeze(-1)
102
+
103
+ if average_log_prob:
104
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
105
+ else:
106
+ return (per_token_logps * loss_mask).sum(-1)
107
+
108
+ def _step(self, batch, batch_idx, stage: str):
109
+ is_train = stage == "train"
110
+
111
+ if is_train:
112
+ # Key part to make lora work
113
+ # Otherwise the parameters are merged, which lead to incorrect gradients
114
+ self.model.train()
115
+
116
+ # Do positive and negative samples in the same batch to speed up training
117
+ labels = batch["labels"]
118
+ outputs = self.model(
119
+ inp=batch["inputs"],
120
+ key_padding_mask=batch["attention_masks"],
121
+ )
122
+ token_logits = outputs.token_logits
123
+ codebook_logits = outputs.codebook_logits
124
+
125
+ # Generate labels
126
+ base_loss = F.cross_entropy(
127
+ token_logits.view(-1, token_logits.size(-1)),
128
+ labels[:, 0].reshape(-1),
129
+ ignore_index=-100,
130
+ )
131
+
132
+ codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
133
+ semantic_loss = F.cross_entropy(
134
+ codebook_logits.view(-1, codebook_logits.size(-1)),
135
+ codebook_labels.reshape(-1),
136
+ ignore_index=-100,
137
+ )
138
+
139
+ loss = base_loss + semantic_loss
140
+
141
+ self.log(
142
+ f"{stage}/loss",
143
+ loss,
144
+ on_step=is_train,
145
+ on_epoch=not is_train,
146
+ prog_bar=True,
147
+ logger=True,
148
+ sync_dist=not is_train,
149
+ )
150
+
151
+ self.log(
152
+ f"{stage}/base_loss",
153
+ base_loss,
154
+ on_step=is_train,
155
+ on_epoch=not is_train,
156
+ prog_bar=False,
157
+ logger=True,
158
+ sync_dist=not is_train,
159
+ )
160
+
161
+ self.log(
162
+ f"{stage}/semantic_loss",
163
+ semantic_loss,
164
+ on_step=is_train,
165
+ on_epoch=not is_train,
166
+ prog_bar=False,
167
+ logger=True,
168
+ sync_dist=not is_train,
169
+ )
170
+
171
+ # Top-5 accuracy
172
+ accuracy = self.get_accuracy(codebook_logits, codebook_labels)
173
+ self.log(
174
+ f"{stage}/top_5_accuracy",
175
+ accuracy,
176
+ on_step=is_train,
177
+ on_epoch=not is_train,
178
+ prog_bar=True,
179
+ logger=True,
180
+ sync_dist=not is_train,
181
+ )
182
+
183
+ return loss
184
+
185
+ def get_accuracy(self, logits, labels):
186
+ mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
187
+ if mask.sum() == 0:
188
+ return torch.tensor(0.0, device=logits.device)
189
+
190
+ _, indices = logits.topk(5, dim=-1)
191
+ correct = indices.eq(labels.unsqueeze(-1))
192
+ correct[~mask] = 0
193
+ correct = correct.sum()
194
+ accuracy = correct / mask.sum()
195
+
196
+ return accuracy
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ return self._step(batch, batch_idx, "train")
200
+
201
+ def validation_step(self, batch, batch_idx):
202
+ return self._step(batch, batch_idx, "val")
fish_speech/models/text2semantic/llama.py CHANGED
@@ -1,887 +1,903 @@
1
- import dataclasses
2
- import json
3
- import math
4
- from collections import OrderedDict
5
- from dataclasses import dataclass
6
- from pathlib import Path
7
- from typing import Optional
8
-
9
- import torch
10
- import torch.nn as nn
11
- from einops import rearrange
12
- from loguru import logger
13
- from torch import Tensor
14
- from torch.nn import functional as F
15
- from torch.nn.attention import SDPBackend, sdpa_kernel
16
- from torch.utils.checkpoint import checkpoint
17
- from transformers import AutoTokenizer
18
-
19
- from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
20
- from fish_speech.utils import RankedLogger
21
-
22
- from .lora import LoraConfig, setup_lora
23
-
24
- log = RankedLogger(__name__, rank_zero_only=True)
25
-
26
-
27
- def find_multiple(n: int, k: int) -> int:
28
- if n % k == 0:
29
- return n
30
- return n + k - (n % k)
31
-
32
-
33
- @dataclass
34
- class BaseModelArgs:
35
- model_type: str = "base"
36
-
37
- vocab_size: int = 32000
38
- n_layer: int = 32
39
- n_head: int = 32
40
- dim: int = 4096
41
- intermediate_size: int = None
42
- n_local_heads: int = -1
43
- head_dim: int = 64
44
- rope_base: float = 10000
45
- norm_eps: float = 1e-5
46
- max_seq_len: int = 2048
47
- dropout: float = 0.0
48
- tie_word_embeddings: bool = True
49
- attention_qkv_bias: bool = False
50
-
51
- # Codebook configs
52
- codebook_size: int = 160
53
- num_codebooks: int = 4
54
-
55
- # Gradient checkpointing
56
- use_gradient_checkpointing: bool = True
57
-
58
- # Initialize the model
59
- initializer_range: float = 0.02
60
-
61
- # Dummy vars
62
- is_reward_model: bool = False
63
- share_codebook_embeddings: bool = True
64
- scale_codebook_embeddings: bool = False
65
-
66
- def __post_init__(self):
67
- if self.n_local_heads == -1:
68
- self.n_local_heads = self.n_head
69
- if self.intermediate_size is None:
70
- hidden_dim = 4 * self.dim
71
- n_hidden = int(2 * hidden_dim / 3)
72
- self.intermediate_size = find_multiple(n_hidden, 256)
73
- self.head_dim = self.dim // self.n_head
74
-
75
- @staticmethod
76
- def from_pretrained(path: str):
77
- path = Path(path)
78
-
79
- if path.is_dir():
80
- path = path / "config.json"
81
-
82
- with open(path, "r", encoding="utf-8") as f:
83
- data = json.load(f)
84
-
85
- match data["model_type"]:
86
- case "naive":
87
- cls = NaiveModelArgs
88
- case "dual_ar":
89
- cls = DualARModelArgs
90
- case _:
91
- raise ValueError(f"Unknown model type: {data['model_type']}")
92
-
93
- return cls(**data)
94
-
95
- def save(self, path: str):
96
- with open(path, "w") as f:
97
- json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
98
-
99
-
100
- @dataclass
101
- class NaiveModelArgs(BaseModelArgs):
102
- model_type: str = "naive"
103
-
104
-
105
- @dataclass
106
- class DualARModelArgs(BaseModelArgs):
107
- model_type: str = "dual_ar"
108
- n_fast_layer: int = 4
109
- fast_dim: int | None = None
110
- fast_n_head: int | None = None
111
- fast_n_local_heads: int | None = None
112
- fast_head_dim: int | None = None
113
- fast_intermediate_size: int | None = None
114
- fast_attention_qkv_bias: bool | None = None
115
-
116
- def __post_init__(self):
117
- super().__post_init__()
118
-
119
- self.fast_dim = self.fast_dim or self.dim
120
- self.fast_n_head = self.fast_n_head or self.n_head
121
- self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
122
- self.fast_head_dim = self.fast_head_dim or self.head_dim
123
- self.fast_intermediate_size = (
124
- self.fast_intermediate_size or self.intermediate_size
125
- )
126
- self.fast_attention_qkv_bias = (
127
- self.fast_attention_qkv_bias
128
- if self.fast_attention_qkv_bias is not None
129
- else self.attention_qkv_bias
130
- )
131
-
132
-
133
- class KVCache(nn.Module):
134
- def __init__(
135
- self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
136
- ):
137
- super().__init__()
138
- cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
139
- self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
140
- self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
141
-
142
- def update(self, input_pos, k_val, v_val):
143
- # input_pos: [S], k_val: [B, H, S, D]
144
- assert input_pos.shape[0] == k_val.shape[2]
145
-
146
- k_out = self.k_cache
147
- v_out = self.v_cache
148
- k_out[:, :, input_pos] = k_val
149
- v_out[:, :, input_pos] = v_val
150
-
151
- return k_out, v_out
152
-
153
-
154
- @dataclass
155
- class TransformerForwardResult:
156
- token_logits: Tensor
157
- codebook_logits: Tensor
158
-
159
-
160
- @dataclass
161
- class BaseTransformerForwardResult:
162
- logits: Tensor
163
- hidden_states: Tensor
164
-
165
-
166
- class BaseTransformer(nn.Module):
167
- def __init__(
168
- self,
169
- config: BaseModelArgs,
170
- tokenizer: FishTokenizer | AutoTokenizer,
171
- init_weights: bool = True,
172
- ) -> None:
173
- super().__init__()
174
- self.config = config
175
- self.tokenizer = tokenizer
176
- self.semantic_token_ids = [
177
- tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS
178
- ]
179
-
180
- # Slow transformer
181
- self.embeddings = nn.Embedding(
182
- config.vocab_size,
183
- config.dim,
184
- )
185
- self.codebook_embeddings = nn.Embedding(
186
- config.codebook_size * config.num_codebooks,
187
- config.dim,
188
- )
189
- self.layers = nn.ModuleList(
190
- TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
191
- )
192
- self.norm = RMSNorm(config.dim, eps=config.norm_eps)
193
-
194
- if self.config.tie_word_embeddings is False:
195
- self.output = nn.Linear(
196
- config.dim,
197
- config.vocab_size,
198
- bias=False,
199
- )
200
-
201
- self.register_buffer(
202
- "freqs_cis",
203
- precompute_freqs_cis(
204
- config.max_seq_len,
205
- config.dim // config.n_head,
206
- config.rope_base,
207
- ),
208
- persistent=False,
209
- )
210
- self.register_buffer(
211
- "causal_mask",
212
- torch.tril(
213
- torch.ones(
214
- config.max_seq_len,
215
- config.max_seq_len,
216
- dtype=torch.bool,
217
- )
218
- ),
219
- persistent=False,
220
- )
221
-
222
- # For kv cache
223
- self.max_batch_size = -1
224
- self.max_seq_len = -1
225
-
226
- if init_weights:
227
- self.apply(self._init_weights)
228
-
229
- def setup_caches(
230
- self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
231
- ):
232
- if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
233
- return
234
-
235
- head_dim = self.config.dim // self.config.n_head
236
- max_seq_len = find_multiple(max_seq_len, 8)
237
- self.max_seq_len = max_seq_len
238
- self.max_batch_size = max_batch_size
239
-
240
- for b in self.layers:
241
- b.attention.kv_cache = KVCache(
242
- max_batch_size,
243
- max_seq_len,
244
- self.config.n_local_heads,
245
- head_dim,
246
- dtype=dtype,
247
- )
248
-
249
- def embed(self, x: Tensor) -> Tensor:
250
- vocab_embeds = [self.embeddings(x[:, 0])]
251
- for i in range(self.config.num_codebooks):
252
- emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
253
- semantic_token_ids_tensor = torch.tensor(
254
- self.semantic_token_ids, device=x.device
255
- )
256
- emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0
257
-
258
- x = torch.stack(vocab_embeds, dim=3)
259
- x = x.sum(dim=3)
260
-
261
- return x
262
-
263
- def forward(
264
- self,
265
- inp: Tensor,
266
- key_padding_mask: Optional[Tensor] = None,
267
- ) -> BaseTransformerForwardResult:
268
- seq_len = inp.size(2)
269
-
270
- # Here we want to merge the embeddings of the codebooks
271
- x = self.embed(inp)
272
-
273
- freqs_cis = self.freqs_cis[:seq_len]
274
-
275
- # Not that the causal mask here follows the definition of scaled_dot_product_attention
276
- # That is, FALSE means masked out
277
- # To maintain consistency, key_padding_mask use TRUE to mask out
278
- mask = None
279
- if key_padding_mask is not None:
280
- mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
281
- mask = mask & key_padding_mask[:, None, None, :].logical_not()
282
-
283
- for layer in self.layers:
284
- if self.config.use_gradient_checkpointing and self.training:
285
- x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
286
- else:
287
- x = layer(x, freqs_cis, mask)
288
-
289
- # We got slow_out here
290
- slow_out = self.norm(x)
291
-
292
- if self.config.tie_word_embeddings:
293
- token_logits = F.linear(slow_out, self.embeddings.weight)
294
- else:
295
- token_logits = self.output(slow_out)
296
-
297
- return BaseTransformerForwardResult(
298
- logits=token_logits,
299
- hidden_states=x,
300
- )
301
-
302
- def forward_generate(
303
- self,
304
- inp: Tensor,
305
- input_pos: Optional[Tensor] = None,
306
- vq_masks: Optional[Tensor] = None, # this is not used in fact
307
- return_all: bool = False,
308
- ) -> BaseTransformerForwardResult:
309
- # This is used for generation, optimized for torch compile
310
- # assert (
311
- # self.max_seq_len != -1 and self.max_batch_size != -1
312
- # ), "Please call setup_caches before forward_generate"
313
-
314
- embeds = []
315
- for i in range(self.config.num_codebooks):
316
- if self.config.share_codebook_embeddings:
317
- _tokens = inp[:, i + 1] + i * self.config.codebook_size
318
- else:
319
- _tokens = inp[:, i + 1]
320
-
321
- emb = self.codebook_embeddings(_tokens)
322
- embeds.append(emb)
323
-
324
- vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
325
- # if self.config.use_codebook_mlp:
326
- # vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks
327
- # vq_embeds_sum = self.codebook_mlp(vq_embeds_sum)
328
-
329
- vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
330
- inp[:, 0] <= self.tokenizer.semantic_end_id
331
- )
332
-
333
- vq_embeds_sum[~vq_masks] = 0
334
- x = self.embeddings(inp[:, 0]) + vq_embeds_sum
335
-
336
- if input_pos is None:
337
- input_pos = torch.arange(inp.shape[-1], device=x.device)
338
- max_seq_len = inp.shape[-1]
339
- else:
340
- max_seq_len = self.max_seq_len
341
-
342
- mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
343
- freqs_cis = self.freqs_cis[input_pos]
344
-
345
- for layer in self.layers:
346
- x = layer(x, freqs_cis, mask, input_pos=input_pos)
347
-
348
- # If prefill, we only calculate the logits of last token
349
- if x.size(1) > 1 and not return_all:
350
- x = x[:, -1:]
351
-
352
- # We got slow_out here
353
- slow_out = self.norm(x)
354
-
355
- if self.config.is_reward_model:
356
- token_logits = self.score_output(slow_out)
357
- elif self.config.tie_word_embeddings:
358
- token_logits = F.linear(slow_out, self.embeddings.weight)
359
- else:
360
- token_logits = self.output(slow_out)
361
-
362
- return BaseTransformerForwardResult(
363
- logits=token_logits,
364
- hidden_states=x,
365
- )
366
-
367
- def _init_weights(self, module):
368
- std = self.config.initializer_range
369
- if isinstance(module, nn.Linear):
370
- module.weight.data.normal_(mean=0.0, std=std)
371
- if module.bias is not None:
372
- module.bias.data.zero_()
373
- elif isinstance(module, nn.Embedding):
374
- module.weight.data.normal_(mean=0.0, std=std)
375
- if module.padding_idx is not None:
376
- module.weight.data[module.padding_idx].zero_()
377
-
378
- @staticmethod
379
- def from_pretrained(
380
- path: str,
381
- load_weights: bool = False,
382
- max_length: int | None = None,
383
- lora_config: LoraConfig | None = None,
384
- rope_base: int | None = None,
385
- is_agent: bool = False,
386
- ) -> "BaseTransformer":
387
- config = BaseModelArgs.from_pretrained(str(path))
388
- if max_length is not None:
389
- config.max_seq_len = max_length
390
- log.info(f"Override max_seq_len to {max_length}")
391
-
392
- if rope_base is not None:
393
- config.rope_base = rope_base
394
- log.info(f"Override rope_base to {rope_base}")
395
-
396
- match config.model_type:
397
- case "naive":
398
- model_cls = NaiveTransformer
399
- case "dual_ar":
400
- model_cls = DualARTransformer
401
- case _:
402
- raise ValueError(f"Unknown model type: {config.model_type}")
403
-
404
- if is_agent:
405
- tokenizer = AutoTokenizer.from_pretrained(str(path))
406
- else:
407
- tokenizer_path = str(path) + "/tokenizer.tiktoken"
408
- tokenizer = FishTokenizer(tokenizer_path)
409
-
410
- log.info(f"Loading model from {path}, config: {config}")
411
- model = model_cls(config, tokenizer=tokenizer)
412
-
413
- if lora_config is not None:
414
- setup_lora(model, lora_config)
415
- log.info(f"LoRA setup: {lora_config}")
416
-
417
- if load_weights is False:
418
- log.info("Randomly initialized model")
419
- else:
420
-
421
- if "int8" in str(Path(path)):
422
- logger.info("Using int8 weight-only quantization!")
423
- from tools.llama.quantize import WeightOnlyInt8QuantHandler
424
-
425
- simple_quantizer = WeightOnlyInt8QuantHandler(model)
426
- model = simple_quantizer.convert_for_runtime()
427
-
428
- if "int4" in str(Path(path)):
429
- logger.info("Using int4 quantization!")
430
- path_comps = path.name.split("-")
431
- assert path_comps[-2].startswith("g")
432
- groupsize = int(path_comps[-2][1:])
433
- from tools.llama.quantize import WeightOnlyInt4QuantHandler
434
-
435
- simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
436
- model = simple_quantizer.convert_for_runtime()
437
-
438
- weights = torch.load(
439
- Path(path) / "model.pth",
440
- map_location="cpu",
441
- mmap=True,
442
- weights_only=True,
443
- )
444
-
445
- if "state_dict" in weights:
446
- logger.warning(
447
- "Using a TextToSemantic LightningModule checkpoint, "
448
- "please make sure it is a full model, not a LoRA model."
449
- )
450
- weights = weights["state_dict"]
451
-
452
- if next(iter(weights.keys())).startswith("model."):
453
- logger.info(
454
- f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
455
- )
456
- new_weights = OrderedDict()
457
- for k, v in weights.items():
458
- new_weights[k.replace("model.", "")] = v
459
- weights = new_weights
460
-
461
- # Verify the name and shape of parameters since strict=False in load_state_dict.
462
- for k, v in model.named_parameters():
463
- if k not in weights:
464
- logger.warning(f"No weight for {k}")
465
- elif v.shape != weights[k].shape:
466
- logger.warning(
467
- f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
468
- )
469
-
470
- err = model.load_state_dict(weights, strict=False, assign=True)
471
- log.info(f"Loaded weights with error: {err}")
472
-
473
- return model
474
-
475
- def save_pretrained(self, path: str, drop_lora: bool = False):
476
- path = Path(path)
477
- path.mkdir(parents=True, exist_ok=True)
478
-
479
- self.config.save(path / "config.json")
480
- state_dict = self.state_dict()
481
-
482
- if drop_lora:
483
- for key in list(state_dict.keys()):
484
- if "lora" not in key:
485
- continue
486
-
487
- state_dict.pop(key)
488
- log.info(f"Drop LoRA parameter: {key}")
489
-
490
- torch.save(state_dict, path / "model.pth")
491
- self.tokenizer.save_pretrained(path)
492
-
493
-
494
- class NaiveTransformer(BaseTransformer):
495
- def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
496
- super().__init__(config, init_weights=False, tokenizer=tokenizer)
497
-
498
- self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
499
- self.codebook_output = nn.Linear(
500
- config.dim,
501
- config.codebook_size * config.num_codebooks,
502
- bias=False,
503
- )
504
-
505
- self.apply(self._init_weights)
506
-
507
- def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
508
- token_logits = result.logits
509
- x = result.hidden_states
510
-
511
- # Codebook
512
- codebook_logits = self.codebook_output(self.codebook_norm(x))
513
- codebook_logits = rearrange(
514
- codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
515
- )
516
-
517
- return TransformerForwardResult(
518
- token_logits=token_logits,
519
- codebook_logits=codebook_logits,
520
- )
521
-
522
- def forward(
523
- self,
524
- inp: Tensor,
525
- key_padding_mask: Optional[Tensor] = None,
526
- ) -> TransformerForwardResult:
527
- result = super().forward(
528
- inp=inp,
529
- key_padding_mask=key_padding_mask,
530
- )
531
- return self.decode(result)
532
-
533
- def forward_generate(
534
- self, x: Tensor, input_pos: Optional[Tensor] = None
535
- ) -> TransformerForwardResult:
536
- result = super().forward_generate(x, input_pos)
537
- return self.decode(result)
538
-
539
-
540
- class DualARTransformer(BaseTransformer):
541
- def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
542
- super().__init__(config, init_weights=False, tokenizer=tokenizer)
543
-
544
- # Project to fast dim if needed
545
- if config.fast_dim is not None and config.fast_dim != config.dim:
546
- self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
547
- else:
548
- self.fast_project_in = nn.Identity()
549
-
550
- # Fast transformer
551
- self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
552
-
553
- # The equivalent bs is so large that sdpa doesn't work
554
- override_config = dataclasses.replace(
555
- config,
556
- dim=config.fast_dim,
557
- n_head=config.fast_n_head,
558
- n_local_heads=config.fast_n_local_heads,
559
- head_dim=config.fast_head_dim,
560
- intermediate_size=config.fast_intermediate_size,
561
- attention_qkv_bias=config.fast_attention_qkv_bias,
562
- )
563
-
564
- self.fast_layers = nn.ModuleList(
565
- TransformerBlock(override_config, use_sdpa=False)
566
- for _ in range(config.n_fast_layer)
567
- )
568
- self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
569
- self.fast_output = nn.Linear(
570
- config.fast_dim,
571
- config.codebook_size,
572
- bias=False,
573
- )
574
-
575
- self.register_buffer(
576
- "fast_freqs_cis",
577
- precompute_freqs_cis(
578
- config.num_codebooks,
579
- config.fast_dim // config.fast_n_head,
580
- config.rope_base,
581
- ),
582
- persistent=False,
583
- )
584
- self.apply(self._init_weights)
585
-
586
- def setup_caches(
587
- self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
588
- ):
589
- super().setup_caches(max_batch_size, max_seq_len, dtype)
590
-
591
- head_dim = self.config.fast_dim // self.config.fast_n_head
592
-
593
- # Fast transformer
594
- # The max seq len here is the number of codebooks
595
- for b in self.fast_layers:
596
- b.attention.kv_cache = KVCache(
597
- max_batch_size,
598
- self.config.num_codebooks,
599
- self.config.fast_n_local_heads,
600
- head_dim,
601
- dtype=dtype,
602
- )
603
-
604
- def forward(
605
- self,
606
- inp: Tensor,
607
- key_padding_mask: Optional[Tensor] = None,
608
- ) -> TransformerForwardResult:
609
- parent_result = super().forward(inp, key_padding_mask)
610
- token_logits = parent_result.logits
611
- x = parent_result.hidden_states
612
- x = self.fast_project_in(x)
613
-
614
- # Fast transformer
615
- fast_seq_len = self.config.num_codebooks
616
- fast_mask = self.causal_mask[
617
- None, None, :fast_seq_len, :fast_seq_len
618
- ] # (B, N, Q, K)
619
-
620
- # Drop the last token and rotate left
621
- codebooks = inp[:, 1:-1, 1:]
622
- codebooks = F.pad(codebooks, (0, 1), value=0)
623
- codebook_embeddings = self.fast_embeddings(codebooks)
624
- x = torch.cat([x[:, None], codebook_embeddings], dim=1)
625
- b, s = x.size(0), x.size(2)
626
- x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
627
-
628
- # Remove padded part
629
- codebooks = rearrange(codebooks, "b n s -> (b s) n")
630
- codebook_mask = (codebooks == 0).all(dim=-1)
631
-
632
- if torch.all(codebook_mask):
633
- # If all codebooks are padded, we keep first 8 to make sure the model runs
634
- codebook_mask[:8] = False
635
-
636
- x_bs, x_len = x.size(0), x.size(1)
637
- x = x[~codebook_mask]
638
-
639
- for layer in self.fast_layers:
640
- if self.config.use_gradient_checkpointing and self.training:
641
- x = checkpoint(
642
- layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
643
- )
644
- else:
645
- x = layer(x, self.fast_freqs_cis, fast_mask)
646
-
647
- # unflatten the batch and num_codebooks
648
- fast_out = self.fast_norm(x)
649
- codebook_logits = self.fast_output(fast_out)
650
-
651
- # Re-pad the codebook_logits
652
- buffer = torch.zeros(
653
- x_bs,
654
- x_len,
655
- codebook_logits.size(-1),
656
- device=codebook_logits.device,
657
- dtype=codebook_logits.dtype,
658
- )
659
- buffer[~codebook_mask] = codebook_logits
660
- codebook_logits = buffer
661
-
662
- assert codebook_logits.shape[1] == self.config.num_codebooks
663
- codebook_logits = rearrange(
664
- codebook_logits,
665
- "(b s) n d -> b s n d",
666
- b=b,
667
- s=s,
668
- n=self.config.num_codebooks,
669
- )
670
-
671
- return TransformerForwardResult(
672
- token_logits=token_logits,
673
- codebook_logits=codebook_logits,
674
- )
675
-
676
- def forward_generate_fast(
677
- self, x: Tensor, input_pos: Optional[Tensor] = None
678
- ) -> Tensor:
679
- # Fast transformer
680
- x = x.view(1, 1, -1)
681
-
682
- fast_mask = self.causal_mask[
683
- None, None, input_pos, : self.config.num_codebooks
684
- ] # (B, N, Q, K)
685
- fast_freqs_cis = self.fast_freqs_cis[input_pos]
686
-
687
- for layer in self.fast_layers:
688
- x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
689
-
690
- # unflatten the batch and num_codebooks
691
- fast_out = self.fast_norm(x) # only take the last token
692
- codebook_logits = self.fast_output(fast_out)
693
-
694
- return codebook_logits
695
-
696
- def forward_generate(
697
- self,
698
- x: Tensor,
699
- input_pos: Optional[Tensor] = None,
700
- vq_masks: Optional[Tensor] = None,
701
- ) -> TransformerForwardResult:
702
- x = super().forward_generate(x, input_pos, vq_masks)
703
- x.hidden_states = self.fast_project_in(x.hidden_states)
704
- return x
705
-
706
-
707
- class TransformerBlock(nn.Module):
708
- def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
709
- super().__init__()
710
- self.attention = Attention(config, use_sdpa=use_sdpa)
711
- self.feed_forward = FeedForward(config)
712
- self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
713
- self.attention_norm = RMSNorm(config.dim, config.norm_eps)
714
-
715
- def forward(
716
- self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
717
- ) -> Tensor:
718
- h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
719
- out = h + self.feed_forward(self.ffn_norm(h))
720
- return out
721
-
722
-
723
- class Attention(nn.Module):
724
- def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
725
- super().__init__()
726
- assert config.dim % config.n_head == 0
727
-
728
- total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
729
- # key, query, value projections for all heads, but in a batch
730
- self.wqkv = nn.Linear(
731
- config.dim, total_head_dim, bias=config.attention_qkv_bias
732
- )
733
- self.wo = nn.Linear(config.dim, config.dim, bias=False)
734
- self.kv_cache = None
735
-
736
- self.dropout = config.dropout
737
- self.n_head = config.n_head
738
- self.head_dim = config.head_dim
739
- self.n_local_heads = config.n_local_heads
740
- self.dim = config.dim
741
- self.use_sdpa = use_sdpa
742
- self._register_load_state_dict_pre_hook(self.load_hook)
743
-
744
- def load_hook(self, state_dict, prefix, *args):
745
- if prefix + "wq.weight" in state_dict:
746
- wq = state_dict.pop(prefix + "wq.weight")
747
- wk = state_dict.pop(prefix + "wk.weight")
748
- wv = state_dict.pop(prefix + "wv.weight")
749
- state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
750
-
751
- def forward(
752
- self,
753
- x: Tensor,
754
- freqs_cis: Tensor,
755
- mask: Tensor,
756
- input_pos: Optional[Tensor] = None,
757
- ) -> Tensor:
758
- bsz, seqlen, _ = x.shape
759
-
760
- kv_size = self.n_local_heads * self.head_dim
761
- q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
762
-
763
- q = q.view(bsz, seqlen, self.n_head, self.head_dim)
764
- k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
765
- v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
766
-
767
- q = apply_rotary_emb(q, freqs_cis)
768
- k = apply_rotary_emb(k, freqs_cis)
769
-
770
- q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
771
-
772
- if self.kv_cache is not None:
773
- k, v = self.kv_cache.update(input_pos, k, v)
774
-
775
- k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
776
- v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
777
-
778
- if self.use_sdpa:
779
- if mask is None:
780
- with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
781
- y = F.scaled_dot_product_attention(
782
- q,
783
- k,
784
- v,
785
- dropout_p=self.dropout if self.training else 0.0,
786
- is_causal=True,
787
- # No third party attn_mask here to use flash_attention
788
- )
789
- else:
790
- y = F.scaled_dot_product_attention(
791
- q,
792
- k,
793
- v,
794
- attn_mask=mask,
795
- dropout_p=self.dropout if self.training else 0.0,
796
- )
797
- else:
798
- y = self.eq_scaled_dot_product_attention(
799
- q,
800
- k,
801
- v,
802
- attn_mask=mask,
803
- dropout_p=self.dropout if self.training else 0.0,
804
- )
805
-
806
- y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
807
-
808
- return self.wo(y)
809
-
810
- def eq_scaled_dot_product_attention(
811
- self,
812
- query,
813
- key,
814
- value,
815
- attn_mask=None,
816
- dropout_p=0.0,
817
- ) -> torch.Tensor:
818
- # This is a standard scaled dot product attention
819
- # It's low efficient, but it doesn't raise cuda error
820
-
821
- L, S = query.size(-2), key.size(-2)
822
- scale_factor = 1 / math.sqrt(query.size(-1))
823
- attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
824
-
825
- if attn_mask is not None:
826
- if attn_mask.dtype == torch.bool:
827
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
828
- else:
829
- attn_bias += attn_mask
830
-
831
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
832
- attn_weight += attn_bias
833
- attn_weight = torch.softmax(attn_weight, dim=-1)
834
- attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
835
-
836
- return attn_weight @ value
837
-
838
-
839
- class FeedForward(nn.Module):
840
- def __init__(self, config: BaseModelArgs) -> None:
841
- super().__init__()
842
- self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
843
- self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
844
- self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
845
-
846
- def forward(self, x: Tensor) -> Tensor:
847
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
848
-
849
-
850
- class RMSNorm(nn.Module):
851
- def __init__(self, dim: int, eps: float = 1e-5):
852
- super().__init__()
853
- self.eps = eps
854
- self.weight = nn.Parameter(torch.ones(dim))
855
-
856
- def _norm(self, x):
857
- return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
858
-
859
- def forward(self, x: Tensor) -> Tensor:
860
- output = self._norm(x.float()).type_as(x)
861
- return output * self.weight
862
-
863
-
864
- def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
865
- freqs = 1.0 / (
866
- base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
867
- )
868
- t = torch.arange(seq_len, device=freqs.device)
869
- freqs = torch.outer(t, freqs)
870
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
871
- cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
872
- return cache.to(dtype=torch.bfloat16)
873
-
874
-
875
- def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
876
- xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
877
- freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
878
- x_out2 = torch.stack(
879
- [
880
- xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
881
- xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
882
- ],
883
- -1,
884
- )
885
-
886
- x_out2 = x_out2.flatten(3)
887
- return x_out2.type_as(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import json
3
+ import math
4
+ from collections import OrderedDict
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from loguru import logger
13
+ from torch import Tensor
14
+ from torch.nn import functional as F
15
+ from torch.nn.attention import SDPBackend, sdpa_kernel
16
+ from torch.utils.checkpoint import checkpoint
17
+ from transformers import AutoTokenizer
18
+
19
+ from fish_speech.models.text2semantic.lora import LoraConfig, setup_lora
20
+ from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
21
+
22
+
23
+ def find_multiple(n: int, k: int) -> int:
24
+ if n % k == 0:
25
+ return n
26
+ return n + k - (n % k)
27
+
28
+
29
+ @dataclass
30
+ class BaseModelArgs:
31
+ model_type: str = "base"
32
+
33
+ vocab_size: int = 32000
34
+ n_layer: int = 32
35
+ n_head: int = 32
36
+ dim: int = 4096
37
+ intermediate_size: int = None
38
+ n_local_heads: int = -1
39
+ head_dim: int = 64
40
+ rope_base: float = 10000
41
+ norm_eps: float = 1e-5
42
+ max_seq_len: int = 2048
43
+ dropout: float = 0.0
44
+ tie_word_embeddings: bool = True
45
+ attention_qkv_bias: bool = False
46
+ attention_o_bias: bool = False
47
+ attention_qk_norm: bool = False
48
+
49
+ # Codebook configs
50
+ codebook_size: int = 160
51
+ num_codebooks: int = 4
52
+
53
+ # Gradient checkpointing
54
+ use_gradient_checkpointing: bool = True
55
+
56
+ # Initialize the model
57
+ initializer_range: float = 0.02
58
+
59
+ # Dummy vars
60
+ is_reward_model: bool = False
61
+ scale_codebook_embeddings: bool = False
62
+
63
+ def __post_init__(self):
64
+ if self.n_local_heads == -1:
65
+ self.n_local_heads = self.n_head
66
+ if self.intermediate_size is None:
67
+ hidden_dim = 4 * self.dim
68
+ n_hidden = int(2 * hidden_dim / 3)
69
+ self.intermediate_size = find_multiple(n_hidden, 256)
70
+ if self.head_dim is None:
71
+ self.head_dim = self.dim // self.n_head
72
+
73
+ @staticmethod
74
+ def from_pretrained(path: str):
75
+ path = Path(path)
76
+
77
+ if path.is_dir():
78
+ path = path / "config.json"
79
+
80
+ with open(path, "r", encoding="utf-8") as f:
81
+ data = json.load(f)
82
+
83
+ match data["model_type"]:
84
+ case "naive":
85
+ cls = NaiveModelArgs
86
+ case "dual_ar":
87
+ cls = DualARModelArgs
88
+ case _:
89
+ raise ValueError(f"Unknown model type: {data['model_type']}")
90
+
91
+ return cls(**data)
92
+
93
+ def save(self, path: str):
94
+ with open(path, "w") as f:
95
+ json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
96
+
97
+
98
+ @dataclass
99
+ class NaiveModelArgs(BaseModelArgs):
100
+ model_type: str = "naive"
101
+
102
+
103
+ @dataclass
104
+ class DualARModelArgs(BaseModelArgs):
105
+ model_type: str = "dual_ar"
106
+ n_fast_layer: int = 4
107
+ fast_dim: int | None = None
108
+ fast_n_head: int | None = None
109
+ fast_n_local_heads: int | None = None
110
+ fast_head_dim: int | None = None
111
+ fast_intermediate_size: int | None = None
112
+ fast_attention_qkv_bias: bool | None = None
113
+ fast_attention_qk_norm: bool | None = None
114
+ fast_attention_o_bias: bool | None = None
115
+
116
+ def __post_init__(self):
117
+ super().__post_init__()
118
+
119
+ self.fast_dim = self.fast_dim or self.dim
120
+ self.fast_n_head = self.fast_n_head or self.n_head
121
+ self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
122
+ self.fast_head_dim = self.fast_head_dim or self.head_dim
123
+ self.fast_intermediate_size = (
124
+ self.fast_intermediate_size or self.intermediate_size
125
+ )
126
+ self.fast_attention_qkv_bias = (
127
+ self.fast_attention_qkv_bias
128
+ if self.fast_attention_qkv_bias is not None
129
+ else self.attention_qkv_bias
130
+ )
131
+ self.fast_attention_qk_norm = (
132
+ self.fast_attention_qk_norm
133
+ if self.fast_attention_qk_norm is not None
134
+ else self.attention_qk_norm
135
+ )
136
+ self.fast_attention_o_bias = (
137
+ self.fast_attention_o_bias
138
+ if self.fast_attention_o_bias is not None
139
+ else self.attention_o_bias
140
+ )
141
+
142
+
143
+ class KVCache(nn.Module):
144
+ def __init__(
145
+ self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
146
+ ):
147
+ super().__init__()
148
+ cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
149
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
150
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
151
+
152
+ def update(self, input_pos, k_val, v_val):
153
+ # input_pos: [S], k_val: [B, H, S, D]
154
+ assert input_pos.shape[0] == k_val.shape[2]
155
+
156
+ k_out = self.k_cache
157
+ v_out = self.v_cache
158
+ k_out[:, :, input_pos] = k_val
159
+ v_out[:, :, input_pos] = v_val
160
+
161
+ return k_out, v_out
162
+
163
+
164
+ @dataclass
165
+ class TransformerForwardResult:
166
+ token_logits: Tensor
167
+ codebook_logits: Tensor
168
+
169
+
170
+ @dataclass
171
+ class BaseTransformerForwardResult:
172
+ logits: Tensor
173
+ hidden_states: Tensor
174
+
175
+
176
+ class BaseTransformer(nn.Module):
177
+ def __init__(
178
+ self,
179
+ config: BaseModelArgs,
180
+ tokenizer: FishTokenizer,
181
+ init_weights: bool = True,
182
+ ) -> None:
183
+ super().__init__()
184
+ self.config = config
185
+ self.tokenizer = tokenizer
186
+ self.semantic_token_ids = list(tokenizer.semantic_id_to_token_id.values())
187
+
188
+ # Slow transformer
189
+ self.embeddings = nn.Embedding(
190
+ config.vocab_size,
191
+ config.dim,
192
+ )
193
+ self.codebook_embeddings = nn.Embedding(
194
+ config.codebook_size * config.num_codebooks,
195
+ config.dim,
196
+ )
197
+ self.layers = nn.ModuleList(
198
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
199
+ )
200
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
201
+
202
+ if self.config.tie_word_embeddings is False:
203
+ self.output = nn.Linear(
204
+ config.dim,
205
+ config.vocab_size,
206
+ bias=False,
207
+ )
208
+
209
+ self.register_buffer(
210
+ "freqs_cis",
211
+ precompute_freqs_cis(
212
+ config.max_seq_len,
213
+ config.head_dim,
214
+ config.rope_base,
215
+ ),
216
+ persistent=False,
217
+ )
218
+ self.register_buffer(
219
+ "causal_mask",
220
+ torch.tril(
221
+ torch.ones(
222
+ config.max_seq_len,
223
+ config.max_seq_len,
224
+ dtype=torch.bool,
225
+ )
226
+ ),
227
+ persistent=False,
228
+ )
229
+
230
+ # For kv cache
231
+ self.max_batch_size = -1
232
+ self.max_seq_len = -1
233
+
234
+ if init_weights:
235
+ self.apply(self._init_weights)
236
+
237
+ def setup_caches(
238
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
239
+ ):
240
+ if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
241
+ return
242
+
243
+ max_seq_len = find_multiple(max_seq_len, 8)
244
+ self.max_seq_len = max_seq_len
245
+ self.max_batch_size = max_batch_size
246
+
247
+ for b in self.layers:
248
+ b.attention.kv_cache = KVCache(
249
+ max_batch_size,
250
+ max_seq_len,
251
+ self.config.n_local_heads,
252
+ self.config.head_dim,
253
+ dtype=dtype,
254
+ )
255
+
256
+ def embed(self, inp: Tensor) -> Tensor:
257
+ embeds = []
258
+ semantic_token_ids_tensor = torch.tensor(
259
+ self.semantic_token_ids, device=inp.device, dtype=inp.dtype
260
+ )
261
+
262
+ for i in range(self.config.num_codebooks):
263
+ emb = self.codebook_embeddings(
264
+ inp[:, i + 1] + i * self.config.codebook_size
265
+ )
266
+ embeds.append(emb)
267
+
268
+ vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
269
+ vq_embeds_sum[~torch.isin(inp[:, 0], semantic_token_ids_tensor)] = 0
270
+ x = self.embeddings(inp[:, 0]) + vq_embeds_sum
271
+
272
+ return x
273
+
274
+ def forward(
275
+ self,
276
+ inp: Tensor,
277
+ key_padding_mask: Optional[Tensor] = None,
278
+ ) -> BaseTransformerForwardResult:
279
+ seq_len = inp.size(2)
280
+
281
+ # Here we want to merge the embeddings of the codebooks
282
+ x = self.embed(inp)
283
+
284
+ freqs_cis = self.freqs_cis[:seq_len]
285
+
286
+ # Not that the causal mask here follows the definition of scaled_dot_product_attention
287
+ # That is, FALSE means masked out
288
+ # To maintain consistency, key_padding_mask use TRUE to mask out
289
+ mask = None
290
+ if key_padding_mask is not None:
291
+ causal = self.causal_mask[:seq_len, :seq_len]
292
+ causal = rearrange(causal, "q k -> 1 1 q k")
293
+
294
+ atten_mask = rearrange(key_padding_mask, "b s -> b 1 1 s")
295
+ atten_mask = atten_mask.logical_not()
296
+ mask = causal & atten_mask
297
+
298
+ # return freqs_cis, mask
299
+
300
+ for layer in self.layers:
301
+ if self.config.use_gradient_checkpointing and self.training:
302
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
303
+ else:
304
+ x = layer(x, freqs_cis, mask)
305
+
306
+ # We got slow_out here
307
+ slow_out = self.norm(x)
308
+
309
+ if self.config.tie_word_embeddings:
310
+ token_logits = F.linear(slow_out, self.embeddings.weight)
311
+ else:
312
+ token_logits = self.output(slow_out)
313
+
314
+ return BaseTransformerForwardResult(
315
+ logits=token_logits,
316
+ hidden_states=x,
317
+ )
318
+
319
+ def forward_generate(
320
+ self,
321
+ inp: Tensor,
322
+ input_pos: Optional[Tensor] = None,
323
+ return_all: bool = False,
324
+ ) -> BaseTransformerForwardResult:
325
+ x = self.embed(inp)
326
+
327
+ if input_pos is None:
328
+ input_pos = torch.arange(inp.shape[-1], device=x.device)
329
+ max_seq_len = inp.shape[-1]
330
+ else:
331
+ max_seq_len = self.max_seq_len
332
+
333
+ mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
334
+ freqs_cis = self.freqs_cis[input_pos]
335
+
336
+ for layer in self.layers:
337
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
338
+
339
+ # If prefill, we only calculate the logits of last token
340
+ if x.size(1) > 1 and not return_all:
341
+ x = x[:, -1:]
342
+
343
+ # We got slow_out here
344
+ slow_out = self.norm(x)
345
+
346
+ if self.config.is_reward_model:
347
+ token_logits = self.score_output(slow_out)
348
+ elif self.config.tie_word_embeddings:
349
+ token_logits = F.linear(slow_out, self.embeddings.weight)
350
+ else:
351
+ token_logits = self.output(slow_out)
352
+
353
+ return BaseTransformerForwardResult(
354
+ logits=token_logits,
355
+ hidden_states=x,
356
+ )
357
+
358
+ def _init_weights(self, module):
359
+ std = self.config.initializer_range
360
+ if isinstance(module, nn.Linear):
361
+ module.weight.data.normal_(mean=0.0, std=std)
362
+ if module.bias is not None:
363
+ module.bias.data.zero_()
364
+ elif isinstance(module, nn.Embedding):
365
+ module.weight.data.normal_(mean=0.0, std=std)
366
+ if module.padding_idx is not None:
367
+ module.weight.data[module.padding_idx].zero_()
368
+
369
+ @staticmethod
370
+ def from_pretrained(
371
+ path: str,
372
+ load_weights: bool = False,
373
+ max_length: int | None = None,
374
+ lora_config: LoraConfig | None = None,
375
+ rope_base: int | None = None,
376
+ ) -> "BaseTransformer":
377
+ config = BaseModelArgs.from_pretrained(str(path))
378
+ if max_length is not None:
379
+ config.max_seq_len = max_length
380
+ logger.info(f"Override max_seq_len to {max_length}")
381
+
382
+ if rope_base is not None:
383
+ config.rope_base = rope_base
384
+ logger.info(f"Override rope_base to {rope_base}")
385
+
386
+ match config.model_type:
387
+ case "naive":
388
+ model_cls = NaiveTransformer
389
+ case "dual_ar":
390
+ model_cls = DualARTransformer
391
+ case _:
392
+ raise ValueError(f"Unknown model type: {config.model_type}")
393
+
394
+ tokenizer = FishTokenizer.from_pretrained(path)
395
+
396
+ logger.info(f"Loading model from {path}, config: {config}")
397
+ model = model_cls(config, tokenizer=tokenizer)
398
+
399
+ if lora_config is not None:
400
+ setup_lora(model, lora_config)
401
+ logger.info(f"LoRA setup: {lora_config}")
402
+
403
+ if load_weights is False:
404
+ logger.info("Randomly initialized model")
405
+ else:
406
+
407
+ if "int8" in str(Path(path)):
408
+ logger.info("Using int8 weight-only quantization!")
409
+ from tools.llama.quantize import WeightOnlyInt8QuantHandler
410
+
411
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
412
+ model = simple_quantizer.convert_for_runtime()
413
+
414
+ if "int4" in str(Path(path)):
415
+ logger.info("Using int4 quantization!")
416
+ path_comps = path.name.split("-")
417
+ assert path_comps[-2].startswith("g")
418
+ groupsize = int(path_comps[-2][1:])
419
+ from tools.llama.quantize import WeightOnlyInt4QuantHandler
420
+
421
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
422
+ model = simple_quantizer.convert_for_runtime()
423
+
424
+ weights = torch.load(
425
+ Path(path) / "model.pth",
426
+ map_location="cpu",
427
+ mmap=True,
428
+ weights_only=True,
429
+ )
430
+
431
+ if "state_dict" in weights:
432
+ logger.warning(
433
+ "Using a TextToSemantic LightningModule checkpoint, "
434
+ "please make sure it is a full model, not a LoRA model."
435
+ )
436
+ weights = weights["state_dict"]
437
+
438
+ if next(iter(weights.keys())).startswith("model."):
439
+ logger.info(
440
+ f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
441
+ )
442
+ new_weights = OrderedDict()
443
+ for k, v in weights.items():
444
+ new_weights[k.replace("model.", "")] = v
445
+ weights = new_weights
446
+
447
+ # Remove audio related weights
448
+ for k in list(weights.keys()):
449
+ if "audio_" in k:
450
+ weights.pop(k)
451
+
452
+ # Verify the name and shape of parameters since strict=False in load_state_dict.
453
+ for k, v in model.named_parameters():
454
+ if k not in weights:
455
+ logger.warning(f"No weight for {k}")
456
+ elif v.shape != weights[k].shape:
457
+ logger.warning(
458
+ f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
459
+ )
460
+
461
+ err = model.load_state_dict(weights, strict=False, assign=True)
462
+ logger.info(f"Loaded weights with error: {err}")
463
+
464
+ return model
465
+
466
+ def save_pretrained(self, path: str, drop_lora: bool = False):
467
+ path = Path(path)
468
+ path.mkdir(parents=True, exist_ok=True)
469
+
470
+ self.config.save(path / "config.json")
471
+ state_dict = self.state_dict()
472
+
473
+ if drop_lora:
474
+ for key in list(state_dict.keys()):
475
+ if "lora" not in key:
476
+ continue
477
+
478
+ state_dict.pop(key)
479
+ logger.info(f"Drop LoRA parameter: {key}")
480
+
481
+ torch.save(state_dict, path / "model.pth")
482
+ self.tokenizer.save_pretrained(path)
483
+
484
+
485
+ class NaiveTransformer(BaseTransformer):
486
+ def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
487
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
488
+
489
+ self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
490
+ self.codebook_output = nn.Linear(
491
+ config.dim,
492
+ config.codebook_size * config.num_codebooks,
493
+ bias=False,
494
+ )
495
+
496
+ self.apply(self._init_weights)
497
+
498
+ def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
499
+ token_logits = result.logits
500
+ x = result.hidden_states
501
+
502
+ # Codebook
503
+ codebook_logits = self.codebook_output(self.codebook_norm(x))
504
+ codebook_logits = rearrange(
505
+ codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
506
+ )
507
+
508
+ return TransformerForwardResult(
509
+ token_logits=token_logits,
510
+ codebook_logits=codebook_logits,
511
+ )
512
+
513
+ def forward(
514
+ self,
515
+ inp: Tensor,
516
+ key_padding_mask: Optional[Tensor] = None,
517
+ ) -> TransformerForwardResult:
518
+ result = super().forward(
519
+ inp=inp,
520
+ key_padding_mask=key_padding_mask,
521
+ )
522
+ return self.decode(result)
523
+
524
+ def forward_generate(
525
+ self, x: Tensor, input_pos: Optional[Tensor] = None
526
+ ) -> TransformerForwardResult:
527
+ result = super().forward_generate(x, input_pos)
528
+ return self.decode(result)
529
+
530
+
531
+ class DualARTransformer(BaseTransformer):
532
+ def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
533
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
534
+
535
+ # Project to fast dim if needed
536
+ if config.fast_dim is not None and config.fast_dim != config.dim:
537
+ self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
538
+ else:
539
+ self.fast_project_in = nn.Identity()
540
+
541
+ # Fast transformer
542
+ self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
543
+
544
+ # The equivalent bs is so large that sdpa doesn't work
545
+ override_config = dataclasses.replace(
546
+ config,
547
+ dim=config.fast_dim,
548
+ n_head=config.fast_n_head,
549
+ n_local_heads=config.fast_n_local_heads,
550
+ head_dim=config.fast_head_dim,
551
+ intermediate_size=config.fast_intermediate_size,
552
+ attention_qkv_bias=config.fast_attention_qkv_bias,
553
+ attention_qk_norm=config.fast_attention_qk_norm,
554
+ attention_o_bias=config.fast_attention_o_bias,
555
+ )
556
+
557
+ self.fast_layers = nn.ModuleList(
558
+ TransformerBlock(override_config, use_sdpa=False)
559
+ for _ in range(config.n_fast_layer)
560
+ )
561
+ self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
562
+ self.fast_output = nn.Linear(
563
+ config.fast_dim,
564
+ config.codebook_size,
565
+ bias=False,
566
+ )
567
+
568
+ self.register_buffer(
569
+ "fast_freqs_cis",
570
+ precompute_freqs_cis(
571
+ config.num_codebooks,
572
+ config.fast_head_dim,
573
+ config.rope_base,
574
+ ),
575
+ persistent=False,
576
+ )
577
+ self.apply(self._init_weights)
578
+
579
+ def setup_caches(
580
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
581
+ ):
582
+ super().setup_caches(max_batch_size, max_seq_len, dtype)
583
+
584
+ # Fast transformer
585
+ # The max seq len here is the number of codebooks
586
+ for b in self.fast_layers:
587
+ b.attention.kv_cache = KVCache(
588
+ max_batch_size,
589
+ self.config.num_codebooks,
590
+ self.config.fast_n_local_heads,
591
+ self.config.fast_head_dim,
592
+ dtype=dtype,
593
+ )
594
+
595
+ def forward(
596
+ self,
597
+ inp: Tensor,
598
+ key_padding_mask: Optional[Tensor] = None,
599
+ ) -> TransformerForwardResult:
600
+ parent_result = super().forward(inp, key_padding_mask)
601
+ token_logits = parent_result.logits
602
+ x = parent_result.hidden_states
603
+ x = self.fast_project_in(x)
604
+
605
+ # Fast transformer
606
+ fast_seq_len = self.config.num_codebooks
607
+ fast_mask = self.causal_mask[
608
+ None, None, :fast_seq_len, :fast_seq_len
609
+ ] # (B, N, Q, K)
610
+
611
+ # Drop the last token and rotate left
612
+ codebooks = inp[:, 1:-1, 1:]
613
+ codebooks = F.pad(codebooks, (0, 1), value=0)
614
+ codebook_embeddings = self.fast_embeddings(codebooks)
615
+ x = torch.cat([x[:, None], codebook_embeddings], dim=1)
616
+ b, s = x.size(0), x.size(2)
617
+ x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
618
+
619
+ # Remove padded part
620
+ codebooks = rearrange(codebooks, "b n s -> (b s) n")
621
+ codebook_mask = (codebooks == 0).all(dim=-1)
622
+
623
+ if torch.all(codebook_mask):
624
+ # If all codebooks are padded, we keep first 8 to make sure the model runs
625
+ codebook_mask[:8] = False
626
+
627
+ x_bs, x_len = x.size(0), x.size(1)
628
+ x = x[~codebook_mask]
629
+
630
+ for layer in self.fast_layers:
631
+ if self.config.use_gradient_checkpointing and self.training:
632
+ x = checkpoint(
633
+ layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
634
+ )
635
+ else:
636
+ x = layer(x, self.fast_freqs_cis, fast_mask)
637
+
638
+ # unflatten the batch and num_codebooks
639
+ fast_out = self.fast_norm(x)
640
+ codebook_logits = self.fast_output(fast_out)
641
+
642
+ # Re-pad the codebook_logits
643
+ buffer = torch.zeros(
644
+ x_bs,
645
+ x_len,
646
+ codebook_logits.size(-1),
647
+ device=codebook_logits.device,
648
+ dtype=codebook_logits.dtype,
649
+ )
650
+ buffer[~codebook_mask] = codebook_logits
651
+ codebook_logits = buffer
652
+
653
+ assert codebook_logits.shape[1] == self.config.num_codebooks
654
+ codebook_logits = rearrange(
655
+ codebook_logits,
656
+ "(b s) n d -> b s n d",
657
+ b=b,
658
+ s=s,
659
+ n=self.config.num_codebooks,
660
+ )
661
+
662
+ return TransformerForwardResult(
663
+ token_logits=token_logits,
664
+ codebook_logits=codebook_logits,
665
+ )
666
+
667
+ def forward_generate_fast(
668
+ self, x: Tensor, input_pos: Optional[Tensor] = None
669
+ ) -> Tensor:
670
+ # Fast transformer
671
+ x = x.view(1, 1, -1)
672
+
673
+ fast_mask = self.causal_mask[
674
+ None, None, input_pos, : self.config.num_codebooks
675
+ ] # (B, N, Q, K)
676
+ fast_freqs_cis = self.fast_freqs_cis[input_pos]
677
+
678
+ for layer in self.fast_layers:
679
+ x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
680
+
681
+ # unflatten the batch and num_codebooks
682
+ fast_out = self.fast_norm(x) # only take the last token
683
+ codebook_logits = self.fast_output(fast_out)
684
+
685
+ return codebook_logits
686
+
687
+ def forward_generate(
688
+ self,
689
+ x: Tensor,
690
+ input_pos: Optional[Tensor] = None,
691
+ vq_masks: Optional[Tensor] = None,
692
+ ) -> TransformerForwardResult:
693
+ x = super().forward_generate(x, input_pos, vq_masks)
694
+ x.hidden_states = self.fast_project_in(x.hidden_states)
695
+ return x
696
+
697
+
698
+ class TransformerBlock(nn.Module):
699
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
700
+ super().__init__()
701
+ self.attention = Attention(config, use_sdpa=use_sdpa)
702
+ self.feed_forward = FeedForward(config)
703
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
704
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
705
+
706
+ def forward(
707
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
708
+ ) -> Tensor:
709
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
710
+ out = h + self.feed_forward(self.ffn_norm(h))
711
+ return out
712
+
713
+
714
+ class Attention(nn.Module):
715
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
716
+ super().__init__()
717
+ assert config.dim % config.n_head == 0
718
+
719
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
720
+ # key, query, value projections for all heads, but in a batch
721
+ self.wqkv = nn.Linear(
722
+ config.dim, total_head_dim, bias=config.attention_qkv_bias
723
+ )
724
+ self.wo = nn.Linear(
725
+ config.n_head * config.head_dim, config.dim, bias=config.attention_o_bias
726
+ )
727
+ self.kv_cache = None
728
+
729
+ if config.attention_qk_norm:
730
+ self.q_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
731
+ self.k_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
732
+
733
+ self.dropout = config.dropout
734
+ self.n_head = config.n_head
735
+ self.head_dim = config.head_dim
736
+ self.n_local_heads = config.n_local_heads
737
+ self.dim = config.dim
738
+ self.use_sdpa = use_sdpa
739
+ self.attention_qk_norm = config.attention_qk_norm
740
+ self.config = config
741
+
742
+ self._register_load_state_dict_pre_hook(self.load_hook)
743
+
744
+ def load_hook(self, state_dict, prefix, *args):
745
+ if prefix + "wq.weight" in state_dict:
746
+ wq = state_dict.pop(prefix + "wq.weight")
747
+ wk = state_dict.pop(prefix + "wk.weight")
748
+ wv = state_dict.pop(prefix + "wv.weight")
749
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
750
+
751
+ def forward(
752
+ self,
753
+ x: Tensor,
754
+ freqs_cis: Tensor,
755
+ mask: Tensor,
756
+ input_pos: Optional[Tensor] = None,
757
+ ) -> Tensor:
758
+ bsz, seqlen, _ = x.shape
759
+
760
+ q_size = self.n_head * self.head_dim
761
+ kv_size = self.n_local_heads * self.head_dim
762
+ q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
763
+
764
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
765
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
766
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
767
+
768
+ if self.attention_qk_norm:
769
+ q = self.q_norm(q)
770
+ k = self.k_norm(k)
771
+
772
+ q = apply_rotary_emb(q, freqs_cis)
773
+ k = apply_rotary_emb(k, freqs_cis)
774
+
775
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
776
+
777
+ if self.kv_cache is not None:
778
+ k, v = self.kv_cache.update(input_pos, k, v)
779
+
780
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
781
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
782
+
783
+ if self.use_sdpa:
784
+ if mask is None:
785
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
786
+ y = F.scaled_dot_product_attention(
787
+ q,
788
+ k,
789
+ v,
790
+ dropout_p=self.dropout if self.training else 0.0,
791
+ is_causal=True,
792
+ # No third party attn_mask here to use flash_attention
793
+ )
794
+ else:
795
+ y = F.scaled_dot_product_attention(
796
+ q,
797
+ k,
798
+ v,
799
+ attn_mask=mask,
800
+ dropout_p=self.dropout if self.training else 0.0,
801
+ )
802
+ else:
803
+ y = self.eq_scaled_dot_product_attention(
804
+ q,
805
+ k,
806
+ v,
807
+ attn_mask=mask,
808
+ dropout_p=self.dropout if self.training else 0.0,
809
+ )
810
+
811
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
812
+
813
+ return self.wo(y)
814
+
815
+ def eq_scaled_dot_product_attention(
816
+ self,
817
+ query,
818
+ key,
819
+ value,
820
+ attn_mask=None,
821
+ dropout_p=0.0,
822
+ ) -> torch.Tensor:
823
+ # This is a standard scaled dot product attention
824
+ # It's low efficient, but it doesn't raise cuda error
825
+
826
+ L, S = query.size(-2), key.size(-2)
827
+ scale_factor = 1 / math.sqrt(query.size(-1))
828
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
829
+
830
+ if attn_mask is not None:
831
+ if attn_mask.dtype == torch.bool:
832
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
833
+ else:
834
+ attn_bias += attn_mask
835
+
836
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
837
+ attn_weight += attn_bias
838
+ attn_weight = torch.softmax(attn_weight, dim=-1)
839
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
840
+
841
+ return attn_weight @ value
842
+
843
+
844
+ class FeedForward(nn.Module):
845
+ def __init__(self, config: BaseModelArgs) -> None:
846
+ super().__init__()
847
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
848
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
849
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
850
+
851
+ def forward(self, x: Tensor) -> Tensor:
852
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
853
+
854
+
855
+ class RMSNorm(nn.Module):
856
+ def __init__(self, dim: int, eps: float = 1e-5):
857
+ super().__init__()
858
+ self.eps = eps
859
+ self.weight = nn.Parameter(torch.ones(dim))
860
+
861
+ def _norm(self, x):
862
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
863
+
864
+ def forward(self, x: Tensor) -> Tensor:
865
+ output = self._norm(x.float()).type_as(x)
866
+ return output * self.weight
867
+
868
+
869
+ def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
870
+ """
871
+ Precomputes frequency tensors for complex exponentials (cis)
872
+
873
+ Args:
874
+ seq_len: Length of the sequence for which positional embeddings are needed.
875
+ n_elem: Number of elements in the frequency tensor.
876
+ base: Base value for the frequency scaling (default: 10000).
877
+
878
+ Returns:
879
+ A tensor containing the precomputed frequencies in real and imaginary parts (bfloat16).
880
+ """
881
+ freqs = 1.0 / (
882
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
883
+ )
884
+ t = torch.arange(seq_len, device=freqs.device)
885
+ freqs = torch.outer(t, freqs)
886
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
887
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
888
+ return cache.to(dtype=torch.bfloat16)
889
+
890
+
891
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
892
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
893
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
894
+ x_out2 = torch.stack(
895
+ [
896
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
897
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
898
+ ],
899
+ -1,
900
+ )
901
+
902
+ x_out2 = x_out2.flatten(3)
903
+ return x_out2.type_as(x)
fish_speech/models/text2semantic/lora.py CHANGED
@@ -1,92 +1,92 @@
1
- from dataclasses import dataclass
2
-
3
- import loralib as lora
4
-
5
-
6
- @dataclass
7
- class LoraConfig:
8
- r: int
9
- lora_alpha: float
10
- lora_dropout: float = 0.0
11
-
12
-
13
- def setup_lora(model, lora_config):
14
- # Replace the embedding layer with a LoRA layer
15
- model.embeddings = lora.Embedding(
16
- num_embeddings=model.embeddings.num_embeddings,
17
- embedding_dim=model.embeddings.embedding_dim,
18
- padding_idx=model.embeddings.padding_idx,
19
- r=lora_config.r,
20
- lora_alpha=lora_config.lora_alpha,
21
- )
22
-
23
- model.codebook_embeddings = lora.Embedding(
24
- num_embeddings=model.codebook_embeddings.num_embeddings,
25
- embedding_dim=model.codebook_embeddings.embedding_dim,
26
- padding_idx=model.codebook_embeddings.padding_idx,
27
- r=lora_config.r,
28
- lora_alpha=lora_config.lora_alpha,
29
- )
30
-
31
- # Replace output layer with a LoRA layer
32
- linears = [(model, "output")]
33
-
34
- # Replace all linear layers with LoRA layers
35
- for layer in model.layers:
36
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
37
- linears.extend(
38
- [
39
- (layer.feed_forward, "w1"),
40
- (layer.feed_forward, "w2"),
41
- (layer.feed_forward, "w3"),
42
- ]
43
- )
44
-
45
- if hasattr(model, "fast_layers"):
46
- model.fast_embeddings = lora.Embedding(
47
- num_embeddings=model.fast_embeddings.num_embeddings,
48
- embedding_dim=model.fast_embeddings.embedding_dim,
49
- padding_idx=model.fast_embeddings.padding_idx,
50
- r=lora_config.r,
51
- lora_alpha=lora_config.lora_alpha,
52
- )
53
-
54
- # Dual-AR model
55
- linears.append((model, "fast_output"))
56
-
57
- for layer in model.fast_layers:
58
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
59
- linears.extend(
60
- [
61
- (layer.feed_forward, "w1"),
62
- (layer.feed_forward, "w2"),
63
- (layer.feed_forward, "w3"),
64
- ]
65
- )
66
-
67
- for module, layer in linears:
68
- updated_linear = lora.Linear(
69
- in_features=getattr(module, layer).in_features,
70
- out_features=getattr(module, layer).out_features,
71
- bias=getattr(module, layer).bias,
72
- r=lora_config.r,
73
- lora_alpha=lora_config.lora_alpha,
74
- lora_dropout=lora_config.lora_dropout,
75
- )
76
- setattr(module, layer, updated_linear)
77
-
78
- # Mark only the LoRA layers as trainable
79
- lora.mark_only_lora_as_trainable(model, bias="none")
80
-
81
-
82
- def get_merged_state_dict(model):
83
- # This line will merge the state dict of the model and the LoRA parameters
84
- model.eval()
85
-
86
- # Then we need to remove the LoRA parameters from the state dict
87
- state_dict = model.state_dict()
88
- for name in list(state_dict.keys()):
89
- if "lora" in name:
90
- state_dict.pop(name)
91
-
92
- return state_dict
 
1
+ from dataclasses import dataclass
2
+
3
+ import loralib as lora
4
+
5
+
6
+ @dataclass
7
+ class LoraConfig:
8
+ r: int
9
+ lora_alpha: float
10
+ lora_dropout: float = 0.0
11
+
12
+
13
+ def setup_lora(model, lora_config):
14
+ # Replace the embedding layer with a LoRA layer
15
+ model.embeddings = lora.Embedding(
16
+ num_embeddings=model.embeddings.num_embeddings,
17
+ embedding_dim=model.embeddings.embedding_dim,
18
+ padding_idx=model.embeddings.padding_idx,
19
+ r=lora_config.r,
20
+ lora_alpha=lora_config.lora_alpha,
21
+ )
22
+
23
+ model.codebook_embeddings = lora.Embedding(
24
+ num_embeddings=model.codebook_embeddings.num_embeddings,
25
+ embedding_dim=model.codebook_embeddings.embedding_dim,
26
+ padding_idx=model.codebook_embeddings.padding_idx,
27
+ r=lora_config.r,
28
+ lora_alpha=lora_config.lora_alpha,
29
+ )
30
+
31
+ # Replace output layer with a LoRA layer
32
+ linears = [(model, "output")]
33
+
34
+ # Replace all linear layers with LoRA layers
35
+ for layer in model.layers:
36
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
37
+ linears.extend(
38
+ [
39
+ (layer.feed_forward, "w1"),
40
+ (layer.feed_forward, "w2"),
41
+ (layer.feed_forward, "w3"),
42
+ ]
43
+ )
44
+
45
+ if hasattr(model, "fast_layers"):
46
+ model.fast_embeddings = lora.Embedding(
47
+ num_embeddings=model.fast_embeddings.num_embeddings,
48
+ embedding_dim=model.fast_embeddings.embedding_dim,
49
+ padding_idx=model.fast_embeddings.padding_idx,
50
+ r=lora_config.r,
51
+ lora_alpha=lora_config.lora_alpha,
52
+ )
53
+
54
+ # Dual-AR model
55
+ linears.append((model, "fast_output"))
56
+
57
+ for layer in model.fast_layers:
58
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
59
+ linears.extend(
60
+ [
61
+ (layer.feed_forward, "w1"),
62
+ (layer.feed_forward, "w2"),
63
+ (layer.feed_forward, "w3"),
64
+ ]
65
+ )
66
+
67
+ for module, layer in linears:
68
+ updated_linear = lora.Linear(
69
+ in_features=getattr(module, layer).in_features,
70
+ out_features=getattr(module, layer).out_features,
71
+ bias=getattr(module, layer).bias,
72
+ r=lora_config.r,
73
+ lora_alpha=lora_config.lora_alpha,
74
+ lora_dropout=lora_config.lora_dropout,
75
+ )
76
+ setattr(module, layer, updated_linear)
77
+
78
+ # Mark only the LoRA layers as trainable
79
+ lora.mark_only_lora_as_trainable(model, bias="none")
80
+
81
+
82
+ def get_merged_state_dict(model):
83
+ # This line will merge the state dict of the model and the LoRA parameters
84
+ model.eval()
85
+
86
+ # Then we need to remove the LoRA parameters from the state dict
87
+ state_dict = model.state_dict()
88
+ for name in list(state_dict.keys()):
89
+ if "lora" in name:
90
+ state_dict.pop(name)
91
+
92
+ return state_dict
fish_speech/text/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .clean import clean_text
2
- from .spliter import split_text
3
-
4
- __all__ = ["clean_text", "split_text"]
 
1
+ from .clean import clean_text
2
+ from .spliter import split_text
3
+
4
+ __all__ = ["clean_text", "split_text"]
fish_speech/text/clean.py CHANGED
@@ -1,37 +1,37 @@
1
- import re
2
-
3
- SYMBOLS_MAPPING = {
4
- "‘": "'",
5
- "’": "'",
6
- }
7
-
8
- REPLACE_SYMBOL_REGEX = re.compile(
9
- "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
10
- )
11
-
12
-
13
- EMOJI_REGEX = re.compile(
14
- "["
15
- "\U0001F600-\U0001F64F" # emoticons
16
- "\U0001F300-\U0001F5FF" # symbols & pictographs
17
- "\U0001F680-\U0001F6FF" # transport & map symbols
18
- "\U0001F1E0-\U0001F1FF" # flags (iOS)
19
- "]+",
20
- flags=re.UNICODE,
21
- )
22
-
23
-
24
- def clean_text(text):
25
- # Clean the text
26
- text = text.strip()
27
-
28
- # Replace all chinese symbols with their english counterparts
29
- text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
30
-
31
- # Remove emojis
32
- text = EMOJI_REGEX.sub(r"", text)
33
-
34
- # Remove continuous periods (...) and commas (,,,)
35
- text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
36
-
37
- return text
 
1
+ import re
2
+
3
+ SYMBOLS_MAPPING = {
4
+ "‘": "'",
5
+ "’": "'",
6
+ }
7
+
8
+ REPLACE_SYMBOL_REGEX = re.compile(
9
+ "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
10
+ )
11
+
12
+
13
+ EMOJI_REGEX = re.compile(
14
+ "["
15
+ "\U0001f600-\U0001f64f" # emoticons
16
+ "\U0001f300-\U0001f5ff" # symbols & pictographs
17
+ "\U0001f680-\U0001f6ff" # transport & map symbols
18
+ "\U0001f1e0-\U0001f1ff" # flags (iOS)
19
+ "]+",
20
+ flags=re.UNICODE,
21
+ )
22
+
23
+
24
+ def clean_text(text):
25
+ # Clean the text
26
+ text = text.strip()
27
+
28
+ # Replace all chinese symbols with their english counterparts
29
+ text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
30
+
31
+ # Remove emojis
32
+ text = EMOJI_REGEX.sub(r"", text)
33
+
34
+ # Remove continuous periods (...) and commas (,,,)
35
+ text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
36
+
37
+ return text
fish_speech/text/spliter.py CHANGED
@@ -1,130 +1,130 @@
1
- import re
2
- import string
3
-
4
- from fish_speech.text.clean import clean_text
5
-
6
-
7
- def utf_8_len(text: str):
8
- return len(text.encode("utf-8"))
9
-
10
-
11
- def break_text(texts, length, splits: set):
12
- for text in texts:
13
- if utf_8_len(text) <= length:
14
- yield text
15
- continue
16
-
17
- curr = ""
18
- for char in text:
19
- curr += char
20
-
21
- if char in splits:
22
- yield curr
23
- curr = ""
24
-
25
- if curr:
26
- yield curr
27
-
28
-
29
- def break_text_by_length(texts, length):
30
- for text in texts:
31
- if utf_8_len(text) <= length:
32
- yield text
33
- continue
34
-
35
- curr = ""
36
- for char in text:
37
- curr += char
38
-
39
- if utf_8_len(curr) >= length:
40
- yield curr
41
- curr = ""
42
-
43
- if curr:
44
- yield curr
45
-
46
-
47
- def add_cleaned(curr, segments):
48
- curr = curr.strip()
49
- if curr and not all(c.isspace() or c in string.punctuation for c in curr):
50
- segments.append(curr)
51
-
52
-
53
- def protect_float(text):
54
- # Turns 3.14 into <3_f_14> to prevent splitting
55
- return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
56
-
57
-
58
- def unprotect_float(text):
59
- # Turns <3_f_14> into 3.14
60
- return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
61
-
62
-
63
- def split_text(text, length):
64
- text = clean_text(text)
65
-
66
- # Break the text into pieces with following rules:
67
- # 1. Split the text at ".", "!", "?" if text is NOT a float
68
- # 2. If the text is longer than length, split at ","
69
- # 3. If the text is still longer than length, split at " "
70
- # 4. If the text is still longer than length, split at any character to length
71
-
72
- texts = [text]
73
- texts = map(protect_float, texts)
74
- texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
75
- texts = map(unprotect_float, texts)
76
- texts = break_text(texts, length, {",", ","})
77
- texts = break_text(texts, length, {" "})
78
- texts = list(break_text_by_length(texts, length))
79
-
80
- # Then, merge the texts into segments with length <= length
81
- segments = []
82
- curr = ""
83
-
84
- for text in texts:
85
- if utf_8_len(curr) + utf_8_len(text) <= length:
86
- curr += text
87
- else:
88
- add_cleaned(curr, segments)
89
- curr = text
90
-
91
- if curr:
92
- add_cleaned(curr, segments)
93
-
94
- return segments
95
-
96
-
97
- if __name__ == "__main__":
98
- # Test the split_text function
99
-
100
- text = "This is a test sentence. This is another test sentence. And a third one."
101
-
102
- assert split_text(text, 50) == [
103
- "This is a test sentence.",
104
- "This is another test sentence. And a third one.",
105
- ]
106
- assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
107
- assert split_text(" ", 10) == []
108
- assert split_text("a", 10) == ["a"]
109
-
110
- text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
111
- assert split_text(text, 50) == [
112
- "This is a test sentence with only commas,",
113
- "and no dots, and no exclamation marks,",
114
- "and no question marks, and no newlines.",
115
- ]
116
-
117
- text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
118
- # First half split at " ", second half split at ","
119
- assert split_text(text, 50) == [
120
- "This is a test sentence This is a test sentence",
121
- "This is a test sentence. This is a test sentence,",
122
- "This is a test sentence, This is a test sentence.",
123
- ]
124
-
125
- text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
126
- assert split_text(text, 50) == [
127
- "这是一段很长的中文文本,",
128
- "而且没有句号,也没有感叹号,",
129
- "也没有问号,也没有换行符.",
130
- ]
 
1
+ import re
2
+ import string
3
+
4
+ from fish_speech.text.clean import clean_text
5
+
6
+
7
+ def utf_8_len(text: str):
8
+ return len(text.encode("utf-8"))
9
+
10
+
11
+ def break_text(texts, length, splits: set):
12
+ for text in texts:
13
+ if utf_8_len(text) <= length:
14
+ yield text
15
+ continue
16
+
17
+ curr = ""
18
+ for char in text:
19
+ curr += char
20
+
21
+ if char in splits:
22
+ yield curr
23
+ curr = ""
24
+
25
+ if curr:
26
+ yield curr
27
+
28
+
29
+ def break_text_by_length(texts, length):
30
+ for text in texts:
31
+ if utf_8_len(text) <= length:
32
+ yield text
33
+ continue
34
+
35
+ curr = ""
36
+ for char in text:
37
+ curr += char
38
+
39
+ if utf_8_len(curr) >= length:
40
+ yield curr
41
+ curr = ""
42
+
43
+ if curr:
44
+ yield curr
45
+
46
+
47
+ def add_cleaned(curr, segments):
48
+ curr = curr.strip()
49
+ if curr and not all(c.isspace() or c in string.punctuation for c in curr):
50
+ segments.append(curr)
51
+
52
+
53
+ def protect_float(text):
54
+ # Turns 3.14 into <3_f_14> to prevent splitting
55
+ return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
56
+
57
+
58
+ def unprotect_float(text):
59
+ # Turns <3_f_14> into 3.14
60
+ return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
61
+
62
+
63
+ def split_text(text, length):
64
+ text = clean_text(text)
65
+
66
+ # Break the text into pieces with following rules:
67
+ # 1. Split the text at ".", "!", "?" if text is NOT a float
68
+ # 2. If the text is longer than length, split at ","
69
+ # 3. If the text is still longer than length, split at " "
70
+ # 4. If the text is still longer than length, split at any character to length
71
+
72
+ texts = [text]
73
+ texts = map(protect_float, texts)
74
+ texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
75
+ texts = map(unprotect_float, texts)
76
+ texts = break_text(texts, length, {",", ","})
77
+ texts = break_text(texts, length, {" "})
78
+ texts = list(break_text_by_length(texts, length))
79
+
80
+ # Then, merge the texts into segments with length <= length
81
+ segments = []
82
+ curr = ""
83
+
84
+ for text in texts:
85
+ if utf_8_len(curr) + utf_8_len(text) <= length:
86
+ curr += text
87
+ else:
88
+ add_cleaned(curr, segments)
89
+ curr = text
90
+
91
+ if curr:
92
+ add_cleaned(curr, segments)
93
+
94
+ return segments
95
+
96
+
97
+ if __name__ == "__main__":
98
+ # Test the split_text function
99
+
100
+ text = "This is a test sentence. This is another test sentence. And a third one."
101
+
102
+ assert split_text(text, 50) == [
103
+ "This is a test sentence.",
104
+ "This is another test sentence. And a third one.",
105
+ ]
106
+ assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
107
+ assert split_text(" ", 10) == []
108
+ assert split_text("a", 10) == ["a"]
109
+
110
+ text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
111
+ assert split_text(text, 50) == [
112
+ "This is a test sentence with only commas,",
113
+ "and no dots, and no exclamation marks,",
114
+ "and no question marks, and no newlines.",
115
+ ]
116
+
117
+ text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
118
+ # First half split at " ", second half split at ","
119
+ assert split_text(text, 50) == [
120
+ "This is a test sentence This is a test sentence",
121
+ "This is a test sentence. This is a test sentence,",
122
+ "This is a test sentence, This is a test sentence.",
123
+ ]
124
+
125
+ text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
126
+ assert split_text(text, 50) == [
127
+ "这是一段很长的中文文本,",
128
+ "而且没有句号,也没有感叹号,",
129
+ "也没有问号,也没有换行符.",
130
+ ]
fish_speech/tokenizer.py CHANGED
@@ -1,152 +1,179 @@
1
- import base64
2
- import json
3
- import logging
4
- from pathlib import Path
5
-
6
- import tiktoken
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
- # This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
11
- FISH_TIKTOKEN_PATTERN = "|".join(
12
- [
13
- r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
14
- r"\p{P}",
15
- r"[^\r\n\p{L}\p{N}]?\p{L}+",
16
- r"\p{N}",
17
- r" ?[^\s\p{L}\p{N}]+[\r\n]*",
18
- r"\s*[\r\n]+",
19
- r"\s+(\?!\S)",
20
- r"\s+",
21
- ]
22
- )
23
- TIKTOKEN_MAX_ENCODE_CHARS = 400_000
24
-
25
- BOS_TOKEN = "<|begin_of_text|>"
26
- EOS_TOKEN = "<|end_of_text|>"
27
- PAD_TOKEN = "<|pad|>"
28
- IM_START_TOKEN = "<|im_start|>"
29
- IM_END_TOKEN = "<|im_end|>"
30
-
31
- MODALITY_TEXT_TOKEN = "<|text|>"
32
- MODALITY_VOICE_TOKEN = "<|voice|>"
33
- MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
34
- MODALITY_TOKENS = {
35
- "text": MODALITY_TEXT_TOKEN,
36
- "voice": MODALITY_VOICE_TOKEN,
37
- "interleave": MODALITY_INTERLEAVE_TOKEN,
38
- }
39
-
40
- PLACEHOLDER_TOKEN = [""] * 4
41
- for i in range(4):
42
- PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>"
43
-
44
- SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
45
- SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
46
-
47
- # Warning: when you add a new special token, you should only add it to the end of the list.
48
- ALL_SPECIAL_TOKENS = [
49
- BOS_TOKEN,
50
- EOS_TOKEN,
51
- PAD_TOKEN,
52
- IM_START_TOKEN,
53
- IM_END_TOKEN,
54
- PLACEHOLDER_TOKEN[0],
55
- PLACEHOLDER_TOKEN[1],
56
- PLACEHOLDER_TOKEN[2],
57
- PLACEHOLDER_TOKEN[3],
58
- MODALITY_TEXT_TOKEN,
59
- MODALITY_VOICE_TOKEN,
60
- MODALITY_INTERLEAVE_TOKEN,
61
- *SEMANTIC_TOKENS,
62
- ]
63
-
64
-
65
- class FishTokenizer:
66
- def __init__(self, model_path: str) -> None:
67
- mergeable_ranks = self.load_tiktoken_bpe(model_path)
68
- special_token_begin = len(mergeable_ranks)
69
- self.all_special_tokens_with_ids = {
70
- token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS)
71
- }
72
- self.semantic_id_to_token_id = {
73
- i: self.all_special_tokens_with_ids[token]
74
- for i, token in enumerate(SEMANTIC_TOKENS)
75
- }
76
- self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]]
77
- self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]]
78
-
79
- self.tkt_model = tiktoken.core.Encoding(
80
- name=Path(model_path).stem,
81
- pat_str=FISH_TIKTOKEN_PATTERN,
82
- mergeable_ranks=mergeable_ranks,
83
- special_tokens=self.all_special_tokens_with_ids,
84
- )
85
-
86
- @staticmethod
87
- def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
88
- data = {}
89
- for line in open(tiktoken_bpe_file).read().splitlines():
90
- if not line:
91
- continue
92
- token, rank = line.split()
93
- data[base64.b64decode(token)] = int(rank)
94
- return data
95
-
96
- def get_token_id(self, token: str) -> int:
97
- return self.all_special_tokens_with_ids[token]
98
-
99
- def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
100
- assert isinstance(s, str)
101
-
102
- subs = []
103
- for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
104
- subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
105
-
106
- if allowed_special is True:
107
- allowed_special = self.tkt_model.special_tokens_set
108
- elif allowed_special is False:
109
- allowed_special = set()
110
-
111
- return sum(
112
- self.tkt_model.encode_batch(
113
- subs, allowed_special=allowed_special, disallowed_special=set()
114
- ),
115
- start=[],
116
- )
117
-
118
- def decode(self, tokens: list[int]) -> str:
119
- return self.tkt_model.decode(tokens)
120
-
121
- def save_pretrained(self, path: str):
122
- path = Path(path)
123
- path.mkdir(parents=True, exist_ok=True)
124
-
125
- with open(path / "tokenizer.tiktoken", "w") as f:
126
- for token, rank in self.tkt_model._mergeable_ranks.items():
127
- f.write(f"{base64.b64encode(token).decode()} {rank}\n")
128
-
129
- with open(path / "special_tokens.json", "w") as f:
130
- json.dump(
131
- self.all_special_tokens_with_ids,
132
- f,
133
- indent=2,
134
- ensure_ascii=False,
135
- )
136
-
137
- @staticmethod
138
- def from_pretrained(path: str):
139
- return FishTokenizer(Path(path) / "tokenizer.tiktoken")
140
-
141
-
142
- if __name__ == "__main__":
143
- tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken")
144
- tokenizer.save_pretrained("checkpoints/fish-speech-0.5B")
145
- tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B")
146
-
147
- print(
148
- [
149
- tokenizer.decode([i])
150
- for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}")
151
- ]
152
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import logging
4
+ import re
5
+ from pathlib import Path
6
+
7
+ import tiktoken
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
12
+ FISH_TIKTOKEN_PATTERN = "|".join(
13
+ [
14
+ r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
15
+ r"\p{P}",
16
+ r"[^\r\n\p{L}\p{N}]?\p{L}+",
17
+ r"\p{N}",
18
+ r" ?[^\s\p{L}\p{N}]+[\r\n]*",
19
+ r"\s*[\r\n]+",
20
+ r"\s+(\?!\S)",
21
+ r"\s+",
22
+ ]
23
+ )
24
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
25
+
26
+ BOS_TOKEN = "<|begin_of_text|>"
27
+ EOS_TOKEN = "<|end_of_text|>"
28
+ PAD_TOKEN = "<|pad|>"
29
+ IM_START_TOKEN = "<|im_start|>"
30
+ IM_END_TOKEN = "<|im_end|>"
31
+ PHONEME_START_TOKEN = "<|phoneme_start|>"
32
+ PHONEME_END_TOKEN = "<|phoneme_end|>"
33
+ TOOL_CALL_START_TOKEN = "<|tool_call_start|>"
34
+ TOOL_CALL_END_TOKEN = "<|tool_call_end|>"
35
+
36
+ MODALITY_TEXT_TOKEN = "<|text|>"
37
+ MODALITY_VOICE_TOKEN = "<|voice|>"
38
+ MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
39
+ AUDIO_START_TOKEN = "<|audio_start|>"
40
+ AUDIO_END_TOKEN = "<|audio_end|>"
41
+ AUDIO_EMBED_TOKEN = "<|audio|>"
42
+ MODALITY_TOKENS = {
43
+ "text": MODALITY_TEXT_TOKEN,
44
+ "voice": MODALITY_VOICE_TOKEN,
45
+ "interleave": MODALITY_INTERLEAVE_TOKEN,
46
+ }
47
+
48
+ SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
49
+ SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
50
+
51
+ # Warning: when you add a new special token, you should only add it to the end of the list.
52
+ ALL_SPECIAL_TOKENS = [
53
+ BOS_TOKEN,
54
+ EOS_TOKEN,
55
+ PAD_TOKEN,
56
+ IM_START_TOKEN,
57
+ IM_END_TOKEN,
58
+ PHONEME_START_TOKEN,
59
+ PHONEME_END_TOKEN,
60
+ TOOL_CALL_START_TOKEN,
61
+ TOOL_CALL_END_TOKEN,
62
+ MODALITY_TEXT_TOKEN,
63
+ MODALITY_VOICE_TOKEN,
64
+ MODALITY_INTERLEAVE_TOKEN,
65
+ AUDIO_START_TOKEN,
66
+ AUDIO_END_TOKEN,
67
+ AUDIO_EMBED_TOKEN,
68
+ *SEMANTIC_TOKENS,
69
+ ]
70
+
71
+
72
+ class FishTokenizer:
73
+ def __init__(
74
+ self, model_path: str, special_tokens: list[str] = ALL_SPECIAL_TOKENS
75
+ ) -> None:
76
+ mergeable_ranks = self.load_tiktoken_bpe(model_path)
77
+ special_token_begin = len(mergeable_ranks)
78
+ self.all_special_tokens_with_ids = {
79
+ token: special_token_begin + i for i, token in enumerate(special_tokens)
80
+ }
81
+
82
+ self.semantic_id_to_token_id = {}
83
+ end_idx = 0
84
+ for token in special_tokens:
85
+ if token.startswith("<|semantic:"):
86
+ idx = int(re.match(r"<\|semantic:(\d+)\|>", token).group(1))
87
+ self.semantic_id_to_token_id[idx] = self.all_special_tokens_with_ids[
88
+ token
89
+ ]
90
+
91
+ if idx > end_idx:
92
+ end_idx = idx
93
+
94
+ self.semantic_begin_id = self.semantic_id_to_token_id[0]
95
+ self.semantic_end_id = self.semantic_id_to_token_id[end_idx]
96
+
97
+ self.tkt_model = tiktoken.core.Encoding(
98
+ name=Path(model_path).stem,
99
+ pat_str=FISH_TIKTOKEN_PATTERN,
100
+ mergeable_ranks=mergeable_ranks,
101
+ special_tokens=self.all_special_tokens_with_ids,
102
+ )
103
+
104
+ @property
105
+ def vocab_size(self):
106
+ return len(self.tkt_model._mergeable_ranks)
107
+
108
+ @property
109
+ def num_special_tokens(self):
110
+ return len(self.all_special_tokens_with_ids)
111
+
112
+ @staticmethod
113
+ def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
114
+ data = {}
115
+ for line in open(tiktoken_bpe_file).read().splitlines():
116
+ if not line:
117
+ continue
118
+ token, rank = line.split()
119
+ if token == "=":
120
+ continue
121
+ data[base64.b64decode(token)] = int(rank)
122
+ return data
123
+
124
+ def get_token_id(self, token: str) -> int:
125
+ return self.all_special_tokens_with_ids[token]
126
+
127
+ def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
128
+ assert isinstance(s, str)
129
+
130
+ subs = []
131
+ for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
132
+ subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
133
+
134
+ if allowed_special is True:
135
+ allowed_special = self.tkt_model.special_tokens_set
136
+ elif allowed_special is False:
137
+ allowed_special = set()
138
+
139
+ return sum(
140
+ self.tkt_model.encode_batch(
141
+ subs, allowed_special=allowed_special, disallowed_special=set()
142
+ ),
143
+ start=[],
144
+ )
145
+
146
+ def decode(self, tokens: list[int]) -> str:
147
+ return self.tkt_model.decode(tokens)
148
+
149
+ def save_pretrained(self, path: str):
150
+ path = Path(path)
151
+ path.mkdir(parents=True, exist_ok=True)
152
+
153
+ with open(path / "tokenizer.tiktoken", "w") as f:
154
+ for token, rank in self.tkt_model._mergeable_ranks.items():
155
+ a = base64.b64encode(token).decode()
156
+ if a == "":
157
+ a = "="
158
+ f.write(f"{a} {rank}\n")
159
+
160
+ with open(path / "special_tokens.json", "w") as f:
161
+ json.dump(
162
+ self.all_special_tokens_with_ids,
163
+ f,
164
+ indent=2,
165
+ ensure_ascii=False,
166
+ )
167
+
168
+ @staticmethod
169
+ def from_pretrained(path: str):
170
+ special_tokens_path = Path(path) / "special_tokens.json"
171
+ if special_tokens_path.exists():
172
+ with open(special_tokens_path) as f:
173
+ all_special_tokens_with_ids = json.load(f)
174
+ else:
175
+ all_special_tokens_with_ids = ALL_SPECIAL_TOKENS
176
+
177
+ return FishTokenizer(
178
+ Path(path) / "tokenizer.tiktoken", all_special_tokens_with_ids
179
+ )
fish_speech/utils/__init__.py CHANGED
@@ -1,24 +1,24 @@
1
- from .braceexpand import braceexpand
2
- from .context import autocast_exclude_mps
3
- from .file import get_latest_checkpoint
4
- from .instantiators import instantiate_callbacks, instantiate_loggers
5
- from .logger import RankedLogger
6
- from .logging_utils import log_hyperparameters
7
- from .rich_utils import enforce_tags, print_config_tree
8
- from .utils import extras, get_metric_value, set_seed, task_wrapper
9
-
10
- __all__ = [
11
- "enforce_tags",
12
- "extras",
13
- "get_metric_value",
14
- "RankedLogger",
15
- "instantiate_callbacks",
16
- "instantiate_loggers",
17
- "log_hyperparameters",
18
- "print_config_tree",
19
- "task_wrapper",
20
- "braceexpand",
21
- "get_latest_checkpoint",
22
- "autocast_exclude_mps",
23
- "set_seed",
24
- ]
 
1
+ from .braceexpand import braceexpand
2
+ from .context import autocast_exclude_mps
3
+ from .file import get_latest_checkpoint
4
+ from .instantiators import instantiate_callbacks, instantiate_loggers
5
+ from .logger import RankedLogger
6
+ from .logging_utils import log_hyperparameters
7
+ from .rich_utils import enforce_tags, print_config_tree
8
+ from .utils import extras, get_metric_value, set_seed, task_wrapper
9
+
10
+ __all__ = [
11
+ "enforce_tags",
12
+ "extras",
13
+ "get_metric_value",
14
+ "RankedLogger",
15
+ "instantiate_callbacks",
16
+ "instantiate_loggers",
17
+ "log_hyperparameters",
18
+ "print_config_tree",
19
+ "task_wrapper",
20
+ "braceexpand",
21
+ "get_latest_checkpoint",
22
+ "autocast_exclude_mps",
23
+ "set_seed",
24
+ ]
fish_speech/utils/braceexpand.py CHANGED
@@ -1,217 +1,217 @@
1
- """
2
- Bash-style brace expansion
3
- Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
4
- License: MIT
5
- """
6
-
7
- import re
8
- import string
9
- from itertools import chain, product
10
- from typing import Iterable, Iterator, Optional
11
-
12
- __all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
13
-
14
-
15
- class UnbalancedBracesError(ValueError):
16
- pass
17
-
18
-
19
- alphabet = string.ascii_uppercase + string.ascii_lowercase
20
-
21
- int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
22
- char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
23
- escape_re = re.compile(r"\\(.)")
24
-
25
-
26
- def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
27
- """braceexpand(pattern) -> iterator over generated strings
28
-
29
- Returns an iterator over the strings resulting from brace expansion
30
- of pattern. This function implements Brace Expansion as described in
31
- bash(1), with the following limitations:
32
-
33
- * A pattern containing unbalanced braces will raise an
34
- UnbalancedBracesError exception. In bash, unbalanced braces will either
35
- be partly expanded or ignored.
36
-
37
- * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
38
- include the characters '[]^_`' between 'Z' and 'a'.
39
-
40
- When escape is True (the default), characters in pattern can be
41
- prefixed with a backslash to cause them not to be interpreted as
42
- special characters for brace expansion (such as '{', '}', ',').
43
- To pass through a a literal backslash, double it ('\\\\').
44
-
45
- When escape is False, backslashes in pattern have no special
46
- meaning and will be preserved in the output.
47
-
48
- Examples:
49
-
50
- >>> from braceexpand import braceexpand
51
-
52
- # Integer range
53
- >>> list(braceexpand('item{1..3}'))
54
- ['item1', 'item2', 'item3']
55
-
56
- # Character range
57
- >>> list(braceexpand('{a..c}'))
58
- ['a', 'b', 'c']
59
-
60
- # Sequence
61
- >>> list(braceexpand('index.html{,.backup}'))
62
- ['index.html', 'index.html.backup']
63
-
64
- # Nested patterns
65
- >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
66
- ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
67
-
68
- # Prefixing an integer with zero causes all numbers to be padded to
69
- # the same width.
70
- >>> list(braceexpand('{07..10}'))
71
- ['07', '08', '09', '10']
72
-
73
- # An optional increment can be specified for ranges.
74
- >>> list(braceexpand('{a..g..2}'))
75
- ['a', 'c', 'e', 'g']
76
-
77
- # Ranges can go in both directions.
78
- >>> list(braceexpand('{4..1}'))
79
- ['4', '3', '2', '1']
80
-
81
- # Numbers can be negative
82
- >>> list(braceexpand('{2..-1}'))
83
- ['2', '1', '0', '-1']
84
-
85
- # Unbalanced braces raise an exception.
86
- >>> list(braceexpand('{1{2,3}'))
87
- Traceback (most recent call last):
88
- ...
89
- UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
90
-
91
- # By default, the backslash is the escape character.
92
- >>> list(braceexpand(r'{1\\{2,3}'))
93
- ['1{2', '3']
94
-
95
- # Setting 'escape' to False disables backslash escaping.
96
- >>> list(braceexpand(r'\\{1,2}', escape=False))
97
- ['\\\\1', '\\\\2']
98
-
99
- """
100
- return (
101
- escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
102
- )
103
-
104
-
105
- def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
106
- start = 0
107
- pos = 0
108
- bracketdepth = 0
109
- items: list[Iterable[str]] = []
110
-
111
- # print 'pattern:', pattern
112
- while pos < len(pattern):
113
- if escape and pattern[pos] == "\\":
114
- pos += 2
115
- continue
116
- elif pattern[pos] == "{":
117
- if bracketdepth == 0 and pos > start:
118
- # print 'literal:', pattern[start:pos]
119
- items.append([pattern[start:pos]])
120
- start = pos
121
- bracketdepth += 1
122
- elif pattern[pos] == "}":
123
- bracketdepth -= 1
124
- if bracketdepth == 0:
125
- # print 'expression:', pattern[start+1:pos]
126
- expr = pattern[start + 1 : pos]
127
- item = parse_expression(expr, escape)
128
- if item is None: # not a range or sequence
129
- items.extend([["{"], parse_pattern(expr, escape), ["}"]])
130
- else:
131
- items.append(item)
132
- start = pos + 1 # skip the closing brace
133
- pos += 1
134
-
135
- if bracketdepth != 0: # unbalanced braces
136
- raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
137
-
138
- if start < pos:
139
- items.append([pattern[start:]])
140
-
141
- return ("".join(item) for item in product(*items))
142
-
143
-
144
- def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
145
- int_range_match = int_range_re.match(expr)
146
- if int_range_match:
147
- return make_int_range(*int_range_match.groups())
148
-
149
- char_range_match = char_range_re.match(expr)
150
- if char_range_match:
151
- return make_char_range(*char_range_match.groups())
152
-
153
- return parse_sequence(expr, escape)
154
-
155
-
156
- def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
157
- # sequence -> chain(*sequence_items)
158
- start = 0
159
- pos = 0
160
- bracketdepth = 0
161
- items: list[Iterable[str]] = []
162
-
163
- # print 'sequence:', seq
164
- while pos < len(seq):
165
- if escape and seq[pos] == "\\":
166
- pos += 2
167
- continue
168
- elif seq[pos] == "{":
169
- bracketdepth += 1
170
- elif seq[pos] == "}":
171
- bracketdepth -= 1
172
- elif seq[pos] == "," and bracketdepth == 0:
173
- items.append(parse_pattern(seq[start:pos], escape))
174
- start = pos + 1 # skip the comma
175
- pos += 1
176
-
177
- if bracketdepth != 0:
178
- raise UnbalancedBracesError
179
- if not items:
180
- return None
181
-
182
- # part after the last comma (may be the empty string)
183
- items.append(parse_pattern(seq[start:], escape))
184
- return chain(*items)
185
-
186
-
187
- def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
188
- if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
189
- padding = max(len(left), len(right))
190
- else:
191
- padding = 0
192
- step = (int(incr) or 1) if incr else 1
193
- start = int(left)
194
- end = int(right)
195
- r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
196
- fmt = "%0{}d".format(padding)
197
- return (fmt % i for i in r)
198
-
199
-
200
- def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
201
- step = (int(incr) or 1) if incr else 1
202
- start = alphabet.index(left)
203
- end = alphabet.index(right)
204
- if start < end:
205
- return alphabet[start : end + 1 : step]
206
- else:
207
- end = end or -len(alphabet)
208
- return alphabet[start : end - 1 : -step]
209
-
210
-
211
- if __name__ == "__main__":
212
- import doctest
213
- import sys
214
-
215
- failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
216
- if failed:
217
- sys.exit(1)
 
1
+ """
2
+ Bash-style brace expansion
3
+ Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
4
+ License: MIT
5
+ """
6
+
7
+ import re
8
+ import string
9
+ from itertools import chain, product
10
+ from typing import Iterable, Iterator, Optional
11
+
12
+ __all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
13
+
14
+
15
+ class UnbalancedBracesError(ValueError):
16
+ pass
17
+
18
+
19
+ alphabet = string.ascii_uppercase + string.ascii_lowercase
20
+
21
+ int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
22
+ char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
23
+ escape_re = re.compile(r"\\(.)")
24
+
25
+
26
+ def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
27
+ """braceexpand(pattern) -> iterator over generated strings
28
+
29
+ Returns an iterator over the strings resulting from brace expansion
30
+ of pattern. This function implements Brace Expansion as described in
31
+ bash(1), with the following limitations:
32
+
33
+ * A pattern containing unbalanced braces will raise an
34
+ UnbalancedBracesError exception. In bash, unbalanced braces will either
35
+ be partly expanded or ignored.
36
+
37
+ * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
38
+ include the characters '[]^_`' between 'Z' and 'a'.
39
+
40
+ When escape is True (the default), characters in pattern can be
41
+ prefixed with a backslash to cause them not to be interpreted as
42
+ special characters for brace expansion (such as '{', '}', ',').
43
+ To pass through a a literal backslash, double it ('\\\\').
44
+
45
+ When escape is False, backslashes in pattern have no special
46
+ meaning and will be preserved in the output.
47
+
48
+ Examples:
49
+
50
+ >>> from braceexpand import braceexpand
51
+
52
+ # Integer range
53
+ >>> list(braceexpand('item{1..3}'))
54
+ ['item1', 'item2', 'item3']
55
+
56
+ # Character range
57
+ >>> list(braceexpand('{a..c}'))
58
+ ['a', 'b', 'c']
59
+
60
+ # Sequence
61
+ >>> list(braceexpand('index.html{,.backup}'))
62
+ ['index.html', 'index.html.backup']
63
+
64
+ # Nested patterns
65
+ >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
66
+ ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
67
+
68
+ # Prefixing an integer with zero causes all numbers to be padded to
69
+ # the same width.
70
+ >>> list(braceexpand('{07..10}'))
71
+ ['07', '08', '09', '10']
72
+
73
+ # An optional increment can be specified for ranges.
74
+ >>> list(braceexpand('{a..g..2}'))
75
+ ['a', 'c', 'e', 'g']
76
+
77
+ # Ranges can go in both directions.
78
+ >>> list(braceexpand('{4..1}'))
79
+ ['4', '3', '2', '1']
80
+
81
+ # Numbers can be negative
82
+ >>> list(braceexpand('{2..-1}'))
83
+ ['2', '1', '0', '-1']
84
+
85
+ # Unbalanced braces raise an exception.
86
+ >>> list(braceexpand('{1{2,3}'))
87
+ Traceback (most recent call last):
88
+ ...
89
+ UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
90
+
91
+ # By default, the backslash is the escape character.
92
+ >>> list(braceexpand(r'{1\\{2,3}'))
93
+ ['1{2', '3']
94
+
95
+ # Setting 'escape' to False disables backslash escaping.
96
+ >>> list(braceexpand(r'\\{1,2}', escape=False))
97
+ ['\\\\1', '\\\\2']
98
+
99
+ """
100
+ return (
101
+ escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
102
+ )
103
+
104
+
105
+ def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
106
+ start = 0
107
+ pos = 0
108
+ bracketdepth = 0
109
+ items: list[Iterable[str]] = []
110
+
111
+ # print 'pattern:', pattern
112
+ while pos < len(pattern):
113
+ if escape and pattern[pos] == "\\":
114
+ pos += 2
115
+ continue
116
+ elif pattern[pos] == "{":
117
+ if bracketdepth == 0 and pos > start:
118
+ # print 'literal:', pattern[start:pos]
119
+ items.append([pattern[start:pos]])
120
+ start = pos
121
+ bracketdepth += 1
122
+ elif pattern[pos] == "}":
123
+ bracketdepth -= 1
124
+ if bracketdepth == 0:
125
+ # print 'expression:', pattern[start+1:pos]
126
+ expr = pattern[start + 1 : pos]
127
+ item = parse_expression(expr, escape)
128
+ if item is None: # not a range or sequence
129
+ items.extend([["{"], parse_pattern(expr, escape), ["}"]])
130
+ else:
131
+ items.append(item)
132
+ start = pos + 1 # skip the closing brace
133
+ pos += 1
134
+
135
+ if bracketdepth != 0: # unbalanced braces
136
+ raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
137
+
138
+ if start < pos:
139
+ items.append([pattern[start:]])
140
+
141
+ return ("".join(item) for item in product(*items))
142
+
143
+
144
+ def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
145
+ int_range_match = int_range_re.match(expr)
146
+ if int_range_match:
147
+ return make_int_range(*int_range_match.groups())
148
+
149
+ char_range_match = char_range_re.match(expr)
150
+ if char_range_match:
151
+ return make_char_range(*char_range_match.groups())
152
+
153
+ return parse_sequence(expr, escape)
154
+
155
+
156
+ def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
157
+ # sequence -> chain(*sequence_items)
158
+ start = 0
159
+ pos = 0
160
+ bracketdepth = 0
161
+ items: list[Iterable[str]] = []
162
+
163
+ # print 'sequence:', seq
164
+ while pos < len(seq):
165
+ if escape and seq[pos] == "\\":
166
+ pos += 2
167
+ continue
168
+ elif seq[pos] == "{":
169
+ bracketdepth += 1
170
+ elif seq[pos] == "}":
171
+ bracketdepth -= 1
172
+ elif seq[pos] == "," and bracketdepth == 0:
173
+ items.append(parse_pattern(seq[start:pos], escape))
174
+ start = pos + 1 # skip the comma
175
+ pos += 1
176
+
177
+ if bracketdepth != 0:
178
+ raise UnbalancedBracesError
179
+ if not items:
180
+ return None
181
+
182
+ # part after the last comma (may be the empty string)
183
+ items.append(parse_pattern(seq[start:], escape))
184
+ return chain(*items)
185
+
186
+
187
+ def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
188
+ if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
189
+ padding = max(len(left), len(right))
190
+ else:
191
+ padding = 0
192
+ step = (int(incr) or 1) if incr else 1
193
+ start = int(left)
194
+ end = int(right)
195
+ r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
196
+ fmt = "%0{}d".format(padding)
197
+ return (fmt % i for i in r)
198
+
199
+
200
+ def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
201
+ step = (int(incr) or 1) if incr else 1
202
+ start = alphabet.index(left)
203
+ end = alphabet.index(right)
204
+ if start < end:
205
+ return alphabet[start : end + 1 : step]
206
+ else:
207
+ end = end or -len(alphabet)
208
+ return alphabet[start : end - 1 : -step]
209
+
210
+
211
+ if __name__ == "__main__":
212
+ import doctest
213
+ import sys
214
+
215
+ failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
216
+ if failed:
217
+ sys.exit(1)
fish_speech/utils/context.py CHANGED
@@ -1,13 +1,13 @@
1
- from contextlib import nullcontext
2
-
3
- import torch
4
-
5
-
6
- def autocast_exclude_mps(
7
- device_type: str, dtype: torch.dtype
8
- ) -> nullcontext | torch.autocast:
9
- return (
10
- nullcontext()
11
- if torch.backends.mps.is_available()
12
- else torch.autocast(device_type, dtype)
13
- )
 
1
+ from contextlib import nullcontext
2
+
3
+ import torch
4
+
5
+
6
+ def autocast_exclude_mps(
7
+ device_type: str, dtype: torch.dtype
8
+ ) -> nullcontext | torch.autocast:
9
+ return (
10
+ nullcontext()
11
+ if torch.backends.mps.is_available()
12
+ else torch.autocast(device_type, dtype)
13
+ )
fish_speech/utils/file.py CHANGED
@@ -1,16 +1,139 @@
1
- import os
2
- from pathlib import Path
3
-
4
-
5
- def get_latest_checkpoint(path: Path | str) -> Path | None:
6
- # Find the latest checkpoint
7
- ckpt_dir = Path(path)
8
-
9
- if ckpt_dir.exists() is False:
10
- return None
11
-
12
- ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
13
- if len(ckpts) == 0:
14
- return None
15
-
16
- return ckpts[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ from loguru import logger
6
+ from natsort import natsorted
7
+
8
+ AUDIO_EXTENSIONS = {
9
+ ".mp3",
10
+ ".wav",
11
+ ".flac",
12
+ ".ogg",
13
+ ".m4a",
14
+ ".wma",
15
+ ".aac",
16
+ ".aiff",
17
+ ".aif",
18
+ ".aifc",
19
+ }
20
+
21
+ VIDEO_EXTENSIONS = {
22
+ ".mp4",
23
+ ".avi",
24
+ }
25
+
26
+
27
+ def get_latest_checkpoint(path: Path | str) -> Path | None:
28
+ # Find the latest checkpoint
29
+ ckpt_dir = Path(path)
30
+
31
+ if ckpt_dir.exists() is False:
32
+ return None
33
+
34
+ ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
35
+ if len(ckpts) == 0:
36
+ return None
37
+
38
+ return ckpts[-1]
39
+
40
+
41
+ def audio_to_bytes(file_path):
42
+ if not file_path or not Path(file_path).exists():
43
+ return None
44
+ with open(file_path, "rb") as wav_file:
45
+ wav = wav_file.read()
46
+ return wav
47
+
48
+
49
+ def read_ref_text(ref_text):
50
+ path = Path(ref_text)
51
+ if path.exists() and path.is_file():
52
+ with path.open("r", encoding="utf-8") as file:
53
+ return file.read()
54
+ return ref_text
55
+
56
+
57
+ def list_files(
58
+ path: Union[Path, str],
59
+ extensions: set[str] = set(),
60
+ recursive: bool = False,
61
+ sort: bool = True,
62
+ ) -> list[Path]:
63
+ """List files in a directory.
64
+
65
+ Args:
66
+ path (Path): Path to the directory.
67
+ extensions (set, optional): Extensions to filter. Defaults to None.
68
+ recursive (bool, optional): Whether to search recursively. Defaults to False.
69
+ sort (bool, optional): Whether to sort the files. Defaults to True.
70
+
71
+ Returns:
72
+ list: List of files.
73
+ """
74
+
75
+ if isinstance(path, str):
76
+ path = Path(path)
77
+
78
+ if not path.exists():
79
+ raise FileNotFoundError(f"Directory {path} does not exist.")
80
+
81
+ files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
82
+
83
+ if sort:
84
+ files = natsorted(files)
85
+
86
+ return files
87
+
88
+
89
+ def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
90
+ """
91
+ Load a Bert-VITS2 style filelist.
92
+ """
93
+
94
+ files = set()
95
+ results = []
96
+ count_duplicated, count_not_found = 0, 0
97
+
98
+ LANGUAGE_TO_LANGUAGES = {
99
+ "zh": ["zh", "en"],
100
+ "jp": ["jp", "en"],
101
+ "en": ["en"],
102
+ }
103
+
104
+ with open(path, "r", encoding="utf-8") as f:
105
+ for line in f.readlines():
106
+ splits = line.strip().split("|", maxsplit=3)
107
+ if len(splits) != 4:
108
+ logger.warning(f"Invalid line: {line}")
109
+ continue
110
+
111
+ filename, speaker, language, text = splits
112
+ file = Path(filename)
113
+ language = language.strip().lower()
114
+
115
+ if language == "ja":
116
+ language = "jp"
117
+
118
+ assert language in ["zh", "jp", "en"], f"Invalid language {language}"
119
+ languages = LANGUAGE_TO_LANGUAGES[language]
120
+
121
+ if file in files:
122
+ logger.warning(f"Duplicated file: {file}")
123
+ count_duplicated += 1
124
+ continue
125
+
126
+ if not file.exists():
127
+ logger.warning(f"File not found: {file}")
128
+ count_not_found += 1
129
+ continue
130
+
131
+ results.append((file, speaker, languages, text))
132
+
133
+ if count_duplicated > 0:
134
+ logger.warning(f"Total duplicated files: {count_duplicated}")
135
+
136
+ if count_not_found > 0:
137
+ logger.warning(f"Total files not found: {count_not_found}")
138
+
139
+ return results
fish_speech/utils/instantiators.py CHANGED
@@ -1,50 +1,50 @@
1
- from typing import List
2
-
3
- import hydra
4
- from omegaconf import DictConfig
5
- from pytorch_lightning import Callback
6
- from pytorch_lightning.loggers import Logger
7
-
8
- from .logger import RankedLogger
9
-
10
- log = RankedLogger(__name__, rank_zero_only=True)
11
-
12
-
13
- def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
14
- """Instantiates callbacks from config."""
15
-
16
- callbacks: List[Callback] = []
17
-
18
- if not callbacks_cfg:
19
- log.warning("No callback configs found! Skipping..")
20
- return callbacks
21
-
22
- if not isinstance(callbacks_cfg, DictConfig):
23
- raise TypeError("Callbacks config must be a DictConfig!")
24
-
25
- for _, cb_conf in callbacks_cfg.items():
26
- if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
27
- log.info(f"Instantiating callback <{cb_conf._target_}>")
28
- callbacks.append(hydra.utils.instantiate(cb_conf))
29
-
30
- return callbacks
31
-
32
-
33
- def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
34
- """Instantiates loggers from config."""
35
-
36
- logger: List[Logger] = []
37
-
38
- if not logger_cfg:
39
- log.warning("No logger configs found! Skipping...")
40
- return logger
41
-
42
- if not isinstance(logger_cfg, DictConfig):
43
- raise TypeError("Logger config must be a DictConfig!")
44
-
45
- for _, lg_conf in logger_cfg.items():
46
- if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
47
- log.info(f"Instantiating logger <{lg_conf._target_}>")
48
- logger.append(hydra.utils.instantiate(lg_conf))
49
-
50
- return logger
 
1
+ from typing import List
2
+
3
+ import hydra
4
+ from omegaconf import DictConfig
5
+ from pytorch_lightning import Callback
6
+ from pytorch_lightning.loggers import Logger
7
+
8
+ from .logger import RankedLogger
9
+
10
+ log = RankedLogger(__name__, rank_zero_only=True)
11
+
12
+
13
+ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
14
+ """Instantiates callbacks from config."""
15
+
16
+ callbacks: List[Callback] = []
17
+
18
+ if not callbacks_cfg:
19
+ log.warning("No callback configs found! Skipping..")
20
+ return callbacks
21
+
22
+ if not isinstance(callbacks_cfg, DictConfig):
23
+ raise TypeError("Callbacks config must be a DictConfig!")
24
+
25
+ for _, cb_conf in callbacks_cfg.items():
26
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
27
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
28
+ callbacks.append(hydra.utils.instantiate(cb_conf))
29
+
30
+ return callbacks
31
+
32
+
33
+ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
34
+ """Instantiates loggers from config."""
35
+
36
+ logger: List[Logger] = []
37
+
38
+ if not logger_cfg:
39
+ log.warning("No logger configs found! Skipping...")
40
+ return logger
41
+
42
+ if not isinstance(logger_cfg, DictConfig):
43
+ raise TypeError("Logger config must be a DictConfig!")
44
+
45
+ for _, lg_conf in logger_cfg.items():
46
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
47
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
48
+ logger.append(hydra.utils.instantiate(lg_conf))
49
+
50
+ return logger
fish_speech/utils/logger.py CHANGED
@@ -1,55 +1,55 @@
1
- import logging
2
- from typing import Mapping, Optional
3
-
4
- from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
5
-
6
-
7
- class RankedLogger(logging.LoggerAdapter):
8
- """A multi-GPU-friendly python command line logger."""
9
-
10
- def __init__(
11
- self,
12
- name: str = __name__,
13
- rank_zero_only: bool = True,
14
- extra: Optional[Mapping[str, object]] = None,
15
- ) -> None:
16
- """Initializes a multi-GPU-friendly python command line logger that logs on all processes
17
- with their rank prefixed in the log message.
18
-
19
- :param name: The name of the logger. Default is ``__name__``.
20
- :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
21
- :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
22
- """
23
- logger = logging.getLogger(name)
24
- super().__init__(logger=logger, extra=extra)
25
- self.rank_zero_only = rank_zero_only
26
-
27
- def log(
28
- self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
29
- ) -> None:
30
- """Delegate a log call to the underlying logger, after prefixing its message with the rank
31
- of the process it's being logged from. If `'rank'` is provided, then the log will only
32
- occur on that rank/process.
33
-
34
- :param level: The level to log at. Look at `logging.__init__.py` for more information.
35
- :param msg: The message to log.
36
- :param rank: The rank to log at.
37
- :param args: Additional args to pass to the underlying logging function.
38
- :param kwargs: Any additional keyword args to pass to the underlying logging function.
39
- """
40
- if self.isEnabledFor(level):
41
- msg, kwargs = self.process(msg, kwargs)
42
- current_rank = getattr(rank_zero_only, "rank", None)
43
- if current_rank is None:
44
- raise RuntimeError(
45
- "The `rank_zero_only.rank` needs to be set before use"
46
- )
47
- msg = rank_prefixed_message(msg, current_rank)
48
- if self.rank_zero_only:
49
- if current_rank == 0:
50
- self.logger.log(level, msg, *args, **kwargs)
51
- else:
52
- if rank is None:
53
- self.logger.log(level, msg, *args, **kwargs)
54
- elif current_rank == rank:
55
- self.logger.log(level, msg, *args, **kwargs)
 
1
+ import logging
2
+ from typing import Mapping, Optional
3
+
4
+ from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
5
+
6
+
7
+ class RankedLogger(logging.LoggerAdapter):
8
+ """A multi-GPU-friendly python command line logger."""
9
+
10
+ def __init__(
11
+ self,
12
+ name: str = __name__,
13
+ rank_zero_only: bool = True,
14
+ extra: Optional[Mapping[str, object]] = None,
15
+ ) -> None:
16
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
17
+ with their rank prefixed in the log message.
18
+
19
+ :param name: The name of the logger. Default is ``__name__``.
20
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
21
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
22
+ """
23
+ logger = logging.getLogger(name)
24
+ super().__init__(logger=logger, extra=extra)
25
+ self.rank_zero_only = rank_zero_only
26
+
27
+ def log(
28
+ self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
29
+ ) -> None:
30
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
31
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
32
+ occur on that rank/process.
33
+
34
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
35
+ :param msg: The message to log.
36
+ :param rank: The rank to log at.
37
+ :param args: Additional args to pass to the underlying logging function.
38
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
39
+ """
40
+ if self.isEnabledFor(level):
41
+ msg, kwargs = self.process(msg, kwargs)
42
+ current_rank = getattr(rank_zero_only, "rank", None)
43
+ if current_rank is None:
44
+ raise RuntimeError(
45
+ "The `rank_zero_only.rank` needs to be set before use"
46
+ )
47
+ msg = rank_prefixed_message(msg, current_rank)
48
+ if self.rank_zero_only:
49
+ if current_rank == 0:
50
+ self.logger.log(level, msg, *args, **kwargs)
51
+ else:
52
+ if rank is None:
53
+ self.logger.log(level, msg, *args, **kwargs)
54
+ elif current_rank == rank:
55
+ self.logger.log(level, msg, *args, **kwargs)
fish_speech/utils/logging_utils.py CHANGED
@@ -1,48 +1,48 @@
1
- from lightning.pytorch.utilities import rank_zero_only
2
-
3
- from fish_speech.utils import logger as log
4
-
5
-
6
- @rank_zero_only
7
- def log_hyperparameters(object_dict: dict) -> None:
8
- """Controls which config parts are saved by lightning loggers.
9
-
10
- Additionally saves:
11
- - Number of model parameters
12
- """
13
-
14
- hparams = {}
15
-
16
- cfg = object_dict["cfg"]
17
- model = object_dict["model"]
18
- trainer = object_dict["trainer"]
19
-
20
- if not trainer.logger:
21
- log.warning("Logger not found! Skipping hyperparameter logging...")
22
- return
23
-
24
- hparams["model"] = cfg["model"]
25
-
26
- # save number of model parameters
27
- hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
28
- hparams["model/params/trainable"] = sum(
29
- p.numel() for p in model.parameters() if p.requires_grad
30
- )
31
- hparams["model/params/non_trainable"] = sum(
32
- p.numel() for p in model.parameters() if not p.requires_grad
33
- )
34
-
35
- hparams["data"] = cfg["data"]
36
- hparams["trainer"] = cfg["trainer"]
37
-
38
- hparams["callbacks"] = cfg.get("callbacks")
39
- hparams["extras"] = cfg.get("extras")
40
-
41
- hparams["task_name"] = cfg.get("task_name")
42
- hparams["tags"] = cfg.get("tags")
43
- hparams["ckpt_path"] = cfg.get("ckpt_path")
44
- hparams["seed"] = cfg.get("seed")
45
-
46
- # send hparams to all loggers
47
- for logger in trainer.loggers:
48
- logger.log_hyperparams(hparams)
 
1
+ from lightning.pytorch.utilities import rank_zero_only
2
+
3
+ from fish_speech.utils import logger as log
4
+
5
+
6
+ @rank_zero_only
7
+ def log_hyperparameters(object_dict: dict) -> None:
8
+ """Controls which config parts are saved by lightning loggers.
9
+
10
+ Additionally saves:
11
+ - Number of model parameters
12
+ """
13
+
14
+ hparams = {}
15
+
16
+ cfg = object_dict["cfg"]
17
+ model = object_dict["model"]
18
+ trainer = object_dict["trainer"]
19
+
20
+ if not trainer.logger:
21
+ log.warning("Logger not found! Skipping hyperparameter logging...")
22
+ return
23
+
24
+ hparams["model"] = cfg["model"]
25
+
26
+ # save number of model parameters
27
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
28
+ hparams["model/params/trainable"] = sum(
29
+ p.numel() for p in model.parameters() if p.requires_grad
30
+ )
31
+ hparams["model/params/non_trainable"] = sum(
32
+ p.numel() for p in model.parameters() if not p.requires_grad
33
+ )
34
+
35
+ hparams["data"] = cfg["data"]
36
+ hparams["trainer"] = cfg["trainer"]
37
+
38
+ hparams["callbacks"] = cfg.get("callbacks")
39
+ hparams["extras"] = cfg.get("extras")
40
+
41
+ hparams["task_name"] = cfg.get("task_name")
42
+ hparams["tags"] = cfg.get("tags")
43
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
44
+ hparams["seed"] = cfg.get("seed")
45
+
46
+ # send hparams to all loggers
47
+ for logger in trainer.loggers:
48
+ logger.log_hyperparams(hparams)
fish_speech/utils/rich_utils.py CHANGED
@@ -1,100 +1,100 @@
1
- from pathlib import Path
2
- from typing import Sequence
3
-
4
- import rich
5
- import rich.syntax
6
- import rich.tree
7
- from hydra.core.hydra_config import HydraConfig
8
- from lightning.pytorch.utilities import rank_zero_only
9
- from omegaconf import DictConfig, OmegaConf, open_dict
10
- from rich.prompt import Prompt
11
-
12
- from fish_speech.utils import logger as log
13
-
14
-
15
- @rank_zero_only
16
- def print_config_tree(
17
- cfg: DictConfig,
18
- print_order: Sequence[str] = (
19
- "data",
20
- "model",
21
- "callbacks",
22
- "logger",
23
- "trainer",
24
- "paths",
25
- "extras",
26
- ),
27
- resolve: bool = False,
28
- save_to_file: bool = False,
29
- ) -> None:
30
- """Prints content of DictConfig using Rich library and its tree structure.
31
-
32
- Args:
33
- cfg (DictConfig): Configuration composed by Hydra.
34
- print_order (Sequence[str], optional): Determines in what order config components are printed.
35
- resolve (bool, optional): Whether to resolve reference fields of DictConfig.
36
- save_to_file (bool, optional): Whether to export config to the hydra output folder.
37
- """ # noqa: E501
38
-
39
- style = "dim"
40
- tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
41
-
42
- queue = []
43
-
44
- # add fields from `print_order` to queue
45
- for field in print_order:
46
- (
47
- queue.append(field)
48
- if field in cfg
49
- else log.warning(
50
- f"Field '{field}' not found in config. "
51
- + f"Skipping '{field}' config printing..."
52
- )
53
- )
54
-
55
- # add all the other fields to queue (not specified in `print_order`)
56
- for field in cfg:
57
- if field not in queue:
58
- queue.append(field)
59
-
60
- # generate config tree from queue
61
- for field in queue:
62
- branch = tree.add(field, style=style, guide_style=style)
63
-
64
- config_group = cfg[field]
65
- if isinstance(config_group, DictConfig):
66
- branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
67
- else:
68
- branch_content = str(config_group)
69
-
70
- branch.add(rich.syntax.Syntax(branch_content, "yaml"))
71
-
72
- # print config tree
73
- rich.print(tree)
74
-
75
- # save config tree to file
76
- if save_to_file:
77
- with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
78
- rich.print(tree, file=file)
79
-
80
-
81
- @rank_zero_only
82
- def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
83
- """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
84
-
85
- if not cfg.get("tags"):
86
- if "id" in HydraConfig().cfg.hydra.job:
87
- raise ValueError("Specify tags before launching a multirun!")
88
-
89
- log.warning("No tags provided in config. Prompting user to input tags...")
90
- tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
91
- tags = [t.strip() for t in tags.split(",") if t != ""]
92
-
93
- with open_dict(cfg):
94
- cfg.tags = tags
95
-
96
- log.info(f"Tags: {cfg.tags}")
97
-
98
- if save_to_file:
99
- with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
100
- rich.print(cfg.tags, file=file)
 
1
+ from pathlib import Path
2
+ from typing import Sequence
3
+
4
+ import rich
5
+ import rich.syntax
6
+ import rich.tree
7
+ from hydra.core.hydra_config import HydraConfig
8
+ from lightning.pytorch.utilities import rank_zero_only
9
+ from omegaconf import DictConfig, OmegaConf, open_dict
10
+ from rich.prompt import Prompt
11
+
12
+ from fish_speech.utils import logger as log
13
+
14
+
15
+ @rank_zero_only
16
+ def print_config_tree(
17
+ cfg: DictConfig,
18
+ print_order: Sequence[str] = (
19
+ "data",
20
+ "model",
21
+ "callbacks",
22
+ "logger",
23
+ "trainer",
24
+ "paths",
25
+ "extras",
26
+ ),
27
+ resolve: bool = False,
28
+ save_to_file: bool = False,
29
+ ) -> None:
30
+ """Prints content of DictConfig using Rich library and its tree structure.
31
+
32
+ Args:
33
+ cfg (DictConfig): Configuration composed by Hydra.
34
+ print_order (Sequence[str], optional): Determines in what order config components are printed.
35
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
36
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
37
+ """ # noqa: E501
38
+
39
+ style = "dim"
40
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
41
+
42
+ queue = []
43
+
44
+ # add fields from `print_order` to queue
45
+ for field in print_order:
46
+ (
47
+ queue.append(field)
48
+ if field in cfg
49
+ else log.warning(
50
+ f"Field '{field}' not found in config. "
51
+ + f"Skipping '{field}' config printing..."
52
+ )
53
+ )
54
+
55
+ # add all the other fields to queue (not specified in `print_order`)
56
+ for field in cfg:
57
+ if field not in queue:
58
+ queue.append(field)
59
+
60
+ # generate config tree from queue
61
+ for field in queue:
62
+ branch = tree.add(field, style=style, guide_style=style)
63
+
64
+ config_group = cfg[field]
65
+ if isinstance(config_group, DictConfig):
66
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
67
+ else:
68
+ branch_content = str(config_group)
69
+
70
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
71
+
72
+ # print config tree
73
+ rich.print(tree)
74
+
75
+ # save config tree to file
76
+ if save_to_file:
77
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
78
+ rich.print(tree, file=file)
79
+
80
+
81
+ @rank_zero_only
82
+ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
83
+ """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
84
+
85
+ if not cfg.get("tags"):
86
+ if "id" in HydraConfig().cfg.hydra.job:
87
+ raise ValueError("Specify tags before launching a multirun!")
88
+
89
+ log.warning("No tags provided in config. Prompting user to input tags...")
90
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
91
+ tags = [t.strip() for t in tags.split(",") if t != ""]
92
+
93
+ with open_dict(cfg):
94
+ cfg.tags = tags
95
+
96
+ log.info(f"Tags: {cfg.tags}")
97
+
98
+ if save_to_file:
99
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
100
+ rich.print(cfg.tags, file=file)