Spaces:
Running
on
L4
Running
on
L4
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- README.md +1 -1
- app.py +73 -314
- examples/Arabic.wav +0 -0
- examples/English.wav +0 -0
- examples/French.wav +0 -0
- examples/German.wav +0 -0
- examples/Japanese.wav +2 -2
- examples/Korean.wav +2 -2
- examples/Nice English Ref.wav +2 -2
- examples/Spanish.wav +0 -0
- fish_speech/configs/base.yaml +87 -87
- fish_speech/configs/lora/r_8_alpha_16.yaml +4 -4
- fish_speech/configs/modded_dac_vq.yaml +50 -0
- fish_speech/configs/text2semantic_finetune.yaml +86 -83
- fish_speech/content_sequence.py +367 -0
- fish_speech/i18n/README.md +27 -27
- fish_speech/i18n/__init__.py +3 -3
- fish_speech/i18n/core.py +40 -40
- fish_speech/i18n/locale/en_US.json +123 -123
- fish_speech/i18n/locale/es_ES.json +123 -123
- fish_speech/i18n/locale/ja_JP.json +123 -123
- fish_speech/i18n/locale/ko_KR.json +123 -123
- fish_speech/i18n/locale/pt_BR.json +133 -133
- fish_speech/i18n/locale/zh_CN.json +123 -123
- fish_speech/i18n/scan.py +122 -122
- fish_speech/inference_engine/__init__.py +192 -0
- fish_speech/inference_engine/reference_loader.py +130 -0
- fish_speech/inference_engine/utils.py +29 -0
- fish_speech/inference_engine/vq_manager.py +59 -0
- fish_speech/models/dac/__init__.py +0 -0
- fish_speech/models/dac/inference.py +123 -0
- fish_speech/models/dac/modded_dac.py +1024 -0
- fish_speech/models/dac/rvq.py +403 -0
- fish_speech/models/text2semantic/inference.py +716 -0
- fish_speech/models/text2semantic/lit_module.py +202 -202
- fish_speech/models/text2semantic/llama.py +903 -887
- fish_speech/models/text2semantic/lora.py +92 -92
- fish_speech/text/__init__.py +4 -4
- fish_speech/text/clean.py +37 -37
- fish_speech/text/spliter.py +130 -130
- fish_speech/tokenizer.py +179 -152
- fish_speech/utils/__init__.py +24 -24
- fish_speech/utils/braceexpand.py +217 -217
- fish_speech/utils/context.py +13 -13
- fish_speech/utils/file.py +139 -16
- fish_speech/utils/instantiators.py +50 -50
- fish_speech/utils/logger.py +55 -55
- fish_speech/utils/logging_utils.py +48 -48
- fish_speech/utils/rich_utils.py +100 -100
.gitattributes
CHANGED
@@ -36,3 +36,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
36 |
examples/Japanese.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
examples/Korean.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
examples/Nice[[:space:]]English[[:space:]]Ref.wav filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
36 |
examples/Japanese.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
examples/Korean.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
examples/Nice[[:space:]]English[[:space:]]Ref.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
examples/Arabic.wav filter=lfs diff=lfs merge=lfs -text
|
40 |
+
examples/English.wav filter=lfs diff=lfs merge=lfs -text
|
41 |
+
examples/French.wav filter=lfs diff=lfs merge=lfs -text
|
42 |
+
examples/German.wav filter=lfs diff=lfs merge=lfs -text
|
43 |
+
examples/Spanish.wav filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🏆
|
4 |
colorFrom: purple
|
5 |
colorTo: gray
|
|
|
1 |
---
|
2 |
+
title: OpenAudio S1
|
3 |
emoji: 🏆
|
4 |
colorFrom: purple
|
5 |
colorTo: gray
|
app.py
CHANGED
@@ -1,86 +1,51 @@
|
|
1 |
import os
|
2 |
import queue
|
3 |
from huggingface_hub import snapshot_download
|
4 |
-
import hydra
|
5 |
import numpy as np
|
6 |
import wave
|
7 |
import io
|
8 |
-
import pyrootutils
|
9 |
import gc
|
|
|
10 |
|
11 |
# Download if not exists
|
12 |
os.makedirs("checkpoints", exist_ok=True)
|
13 |
-
snapshot_download(repo_id="fishaudio/
|
14 |
|
15 |
print("All checkpoints downloaded")
|
16 |
|
17 |
import html
|
18 |
import os
|
19 |
-
import threading
|
20 |
from argparse import ArgumentParser
|
21 |
from pathlib import Path
|
22 |
-
from functools import partial
|
23 |
|
24 |
import gradio as gr
|
25 |
-
import librosa
|
26 |
import torch
|
27 |
import torchaudio
|
28 |
|
29 |
torchaudio.set_audio_backend("soundfile")
|
30 |
|
31 |
from loguru import logger
|
32 |
-
from transformers import AutoTokenizer
|
33 |
-
|
34 |
from fish_speech.i18n import i18n
|
35 |
-
from fish_speech.
|
36 |
-
from fish_speech.
|
37 |
-
from
|
38 |
-
from tools.
|
39 |
-
from
|
40 |
-
|
41 |
-
GenerateResponse,
|
42 |
-
WrappedGenerateResponse,
|
43 |
-
launch_thread_safe_queue,
|
44 |
-
)
|
45 |
-
from tools.vqgan.inference import load_model as load_decoder_model
|
46 |
-
|
47 |
-
from tools.schema import (
|
48 |
-
GLOBAL_NUM_SAMPLES,
|
49 |
-
ASRPackRequest,
|
50 |
-
ServeASRRequest,
|
51 |
-
ServeASRResponse,
|
52 |
-
ServeASRSegment,
|
53 |
-
ServeAudioPart,
|
54 |
-
ServeForwardMessage,
|
55 |
-
ServeMessage,
|
56 |
-
ServeRequest,
|
57 |
-
ServeResponse,
|
58 |
-
ServeStreamDelta,
|
59 |
-
ServeStreamResponse,
|
60 |
-
ServeTextPart,
|
61 |
-
ServeTimedASRResponse,
|
62 |
-
ServeTTSRequest,
|
63 |
-
ServeVQGANDecodeRequest,
|
64 |
-
ServeVQGANDecodeResponse,
|
65 |
-
ServeVQGANEncodeRequest,
|
66 |
-
ServeVQGANEncodeResponse,
|
67 |
-
ServeVQPart,
|
68 |
-
ServeReferenceAudio
|
69 |
-
)
|
70 |
# Make einx happy
|
71 |
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
72 |
|
73 |
|
74 |
-
HEADER_MD = """#
|
75 |
|
76 |
-
## The demo in this space is
|
77 |
-
## 该 Demo 为
|
78 |
|
79 |
-
A text-to-speech model based on
|
80 |
-
由 [Fish Audio](https://fish.audio) 研发的基于
|
81 |
|
82 |
-
You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/
|
83 |
-
你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/
|
84 |
|
85 |
Related code and weights are released under CC BY-NC-SA 4.0 License.
|
86 |
相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
|
@@ -88,8 +53,8 @@ Related code and weights are released under CC BY-NC-SA 4.0 License.
|
|
88 |
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
|
89 |
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
|
90 |
|
91 |
-
The model running in this WebUI is
|
92 |
-
在此 WebUI 中运行的模型是
|
93 |
"""
|
94 |
|
95 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
@@ -106,7 +71,6 @@ except ImportError:
|
|
106 |
|
107 |
return wrapper
|
108 |
|
109 |
-
|
110 |
def build_html_error_message(error):
|
111 |
return f"""
|
112 |
<div style="color: red;
|
@@ -115,109 +79,6 @@ def build_html_error_message(error):
|
|
115 |
</div>
|
116 |
"""
|
117 |
|
118 |
-
|
119 |
-
@GPU_DECORATOR
|
120 |
-
@torch.inference_mode()
|
121 |
-
def inference(req: ServeTTSRequest):
|
122 |
-
try:
|
123 |
-
# Parse reference audio aka prompt
|
124 |
-
refs = req.references
|
125 |
-
|
126 |
-
prompt_tokens = [
|
127 |
-
encode_reference(
|
128 |
-
decoder_model=decoder_model,
|
129 |
-
reference_audio=ref.audio,
|
130 |
-
enable_reference_audio=True,
|
131 |
-
)
|
132 |
-
for ref in refs
|
133 |
-
]
|
134 |
-
prompt_texts = [ref.text for ref in refs]
|
135 |
-
|
136 |
-
if req.seed is not None:
|
137 |
-
set_seed(req.seed)
|
138 |
-
logger.warning(f"set seed: {req.seed}")
|
139 |
-
|
140 |
-
# LLAMA Inference
|
141 |
-
request = dict(
|
142 |
-
device=decoder_model.device,
|
143 |
-
max_new_tokens=req.max_new_tokens,
|
144 |
-
text=(
|
145 |
-
req.text
|
146 |
-
if not req.normalize
|
147 |
-
else ChnNormedText(raw_text=req.text).normalize()
|
148 |
-
),
|
149 |
-
top_p=req.top_p,
|
150 |
-
repetition_penalty=req.repetition_penalty,
|
151 |
-
temperature=req.temperature,
|
152 |
-
compile=args.compile,
|
153 |
-
iterative_prompt=req.chunk_length > 0,
|
154 |
-
chunk_length=req.chunk_length,
|
155 |
-
max_length=4096,
|
156 |
-
prompt_tokens=prompt_tokens,
|
157 |
-
prompt_text=prompt_texts,
|
158 |
-
)
|
159 |
-
|
160 |
-
response_queue = queue.Queue()
|
161 |
-
llama_queue.put(
|
162 |
-
GenerateRequest(
|
163 |
-
request=request,
|
164 |
-
response_queue=response_queue,
|
165 |
-
)
|
166 |
-
)
|
167 |
-
|
168 |
-
segments = []
|
169 |
-
|
170 |
-
while True:
|
171 |
-
result: WrappedGenerateResponse = response_queue.get()
|
172 |
-
if result.status == "error":
|
173 |
-
yield None, None, build_html_error_message(result.response)
|
174 |
-
break
|
175 |
-
|
176 |
-
result: GenerateResponse = result.response
|
177 |
-
if result.action == "next":
|
178 |
-
break
|
179 |
-
|
180 |
-
with autocast_exclude_mps(
|
181 |
-
device_type=decoder_model.device.type, dtype=args.precision
|
182 |
-
):
|
183 |
-
fake_audios = decode_vq_tokens(
|
184 |
-
decoder_model=decoder_model,
|
185 |
-
codes=result.codes,
|
186 |
-
)
|
187 |
-
|
188 |
-
fake_audios = fake_audios.float().cpu().numpy()
|
189 |
-
segments.append(fake_audios)
|
190 |
-
|
191 |
-
if len(segments) == 0:
|
192 |
-
return (
|
193 |
-
None,
|
194 |
-
None,
|
195 |
-
build_html_error_message(
|
196 |
-
i18n("No audio generated, please check the input text.")
|
197 |
-
),
|
198 |
-
)
|
199 |
-
|
200 |
-
# No matter streaming or not, we need to return the final audio
|
201 |
-
audio = np.concatenate(segments, axis=0)
|
202 |
-
yield None, (decoder_model.spec_transform.sample_rate, audio), None
|
203 |
-
|
204 |
-
if torch.cuda.is_available():
|
205 |
-
torch.cuda.empty_cache()
|
206 |
-
gc.collect()
|
207 |
-
|
208 |
-
except Exception as e:
|
209 |
-
er = "CUDA error: device-side assert triggered"
|
210 |
-
if er in str(e):
|
211 |
-
app.close()
|
212 |
-
else:
|
213 |
-
raise Exception(e)
|
214 |
-
|
215 |
-
n_audios = 4
|
216 |
-
|
217 |
-
global_audio_list = []
|
218 |
-
global_error_list = []
|
219 |
-
|
220 |
-
|
221 |
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
222 |
buffer = io.BytesIO()
|
223 |
|
@@ -230,13 +91,8 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
230 |
buffer.close()
|
231 |
return wav_header_bytes
|
232 |
|
233 |
-
def normalize_text(user_input, use_normalization):
|
234 |
-
if use_normalization:
|
235 |
-
return ChnNormedText(raw_text=user_input).normalize()
|
236 |
-
else:
|
237 |
-
return user_input
|
238 |
|
239 |
-
def build_app():
|
240 |
with gr.Blocks(theme=gr.themes.Base()) as app:
|
241 |
gr.Markdown(HEADER_MD)
|
242 |
|
@@ -245,7 +101,7 @@ def build_app():
|
|
245 |
None,
|
246 |
None,
|
247 |
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
|
248 |
-
%
|
249 |
)
|
250 |
|
251 |
# Inference
|
@@ -254,20 +110,6 @@ def build_app():
|
|
254 |
text = gr.Textbox(
|
255 |
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
|
256 |
)
|
257 |
-
refined_text = gr.Textbox(
|
258 |
-
label=i18n("Realtime Transform Text"),
|
259 |
-
placeholder=i18n(
|
260 |
-
"Normalization Result Preview (Currently Only Chinese)"
|
261 |
-
),
|
262 |
-
lines=5,
|
263 |
-
interactive=False,
|
264 |
-
)
|
265 |
-
|
266 |
-
with gr.Row():
|
267 |
-
normalize = gr.Checkbox(
|
268 |
-
label=i18n("Text Normalization"),
|
269 |
-
value=False,
|
270 |
-
)
|
271 |
|
272 |
with gr.Row():
|
273 |
with gr.Column():
|
@@ -275,45 +117,45 @@ def build_app():
|
|
275 |
with gr.Row():
|
276 |
chunk_length = gr.Slider(
|
277 |
label=i18n("Iterative Prompt Length, 0 means off"),
|
278 |
-
minimum=
|
279 |
-
maximum=
|
280 |
-
value=
|
281 |
step=8,
|
282 |
)
|
283 |
|
284 |
max_new_tokens = gr.Slider(
|
285 |
label=i18n(
|
286 |
-
"Maximum tokens per batch"
|
287 |
),
|
288 |
-
minimum=
|
289 |
maximum=2048,
|
290 |
-
value=
|
291 |
-
step=
|
292 |
)
|
293 |
|
294 |
with gr.Row():
|
295 |
top_p = gr.Slider(
|
296 |
label="Top-P",
|
297 |
-
minimum=0.
|
298 |
-
maximum=0.
|
299 |
-
value=0.
|
300 |
step=0.01,
|
301 |
)
|
302 |
|
303 |
repetition_penalty = gr.Slider(
|
304 |
label=i18n("Repetition Penalty"),
|
305 |
minimum=1,
|
306 |
-
maximum=1.
|
307 |
-
value=1.
|
308 |
step=0.01,
|
309 |
)
|
310 |
|
311 |
with gr.Row():
|
312 |
temperature = gr.Slider(
|
313 |
label="Temperature",
|
314 |
-
minimum=0.
|
315 |
-
maximum=0
|
316 |
-
value=0.
|
317 |
step=0.01,
|
318 |
)
|
319 |
seed = gr.Number(
|
@@ -326,24 +168,20 @@ def build_app():
|
|
326 |
with gr.Row():
|
327 |
gr.Markdown(
|
328 |
i18n(
|
329 |
-
"
|
330 |
)
|
331 |
)
|
332 |
-
|
333 |
with gr.Row():
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
label="Select Example Audio",
|
338 |
-
choices=[""] + example_audio_files,
|
339 |
-
value=""
|
340 |
)
|
341 |
|
342 |
with gr.Row():
|
343 |
use_memory_cache = gr.Radio(
|
344 |
label=i18n("Use Memory Cache"),
|
345 |
-
choices=["
|
346 |
-
value="
|
347 |
)
|
348 |
|
349 |
with gr.Row():
|
@@ -351,7 +189,6 @@ def build_app():
|
|
351 |
label=i18n("Reference Audio"),
|
352 |
type="filepath",
|
353 |
)
|
354 |
-
|
355 |
with gr.Row():
|
356 |
reference_text = gr.Textbox(
|
357 |
label=i18n("Reference Text"),
|
@@ -377,101 +214,16 @@ def build_app():
|
|
377 |
with gr.Row():
|
378 |
with gr.Column(scale=3):
|
379 |
generate = gr.Button(
|
380 |
-
value="\
|
|
|
381 |
)
|
382 |
|
383 |
-
text.input(
|
384 |
-
fn=normalize_text, inputs=[text, normalize], outputs=[refined_text]
|
385 |
-
)
|
386 |
-
|
387 |
-
def inference_wrapper(
|
388 |
-
text,
|
389 |
-
normalize,
|
390 |
-
reference_audio,
|
391 |
-
reference_text,
|
392 |
-
max_new_tokens,
|
393 |
-
chunk_length,
|
394 |
-
top_p,
|
395 |
-
repetition_penalty,
|
396 |
-
temperature,
|
397 |
-
seed,
|
398 |
-
use_memory_cache,
|
399 |
-
):
|
400 |
-
print(
|
401 |
-
"call inference wrapper",
|
402 |
-
text,
|
403 |
-
normalize,
|
404 |
-
reference_audio,
|
405 |
-
reference_text,
|
406 |
-
max_new_tokens,
|
407 |
-
chunk_length,
|
408 |
-
top_p,
|
409 |
-
repetition_penalty,
|
410 |
-
temperature,
|
411 |
-
seed,
|
412 |
-
use_memory_cache
|
413 |
-
)
|
414 |
-
|
415 |
-
references = []
|
416 |
-
if reference_audio:
|
417 |
-
# 将文件路径转换为字节
|
418 |
-
with open(reference_audio, 'rb') as audio_file:
|
419 |
-
audio_bytes = audio_file.read()
|
420 |
-
|
421 |
-
references = [
|
422 |
-
ServeReferenceAudio(audio=audio_bytes, text=reference_text)
|
423 |
-
]
|
424 |
-
|
425 |
-
req = ServeTTSRequest(
|
426 |
-
text=text,
|
427 |
-
normalize=normalize,
|
428 |
-
reference_id=None,
|
429 |
-
references=references,
|
430 |
-
max_new_tokens=max_new_tokens,
|
431 |
-
chunk_length=chunk_length,
|
432 |
-
top_p=top_p,
|
433 |
-
repetition_penalty=repetition_penalty,
|
434 |
-
temperature=temperature,
|
435 |
-
seed=int(seed) if seed else None,
|
436 |
-
use_memory_cache=use_memory_cache,
|
437 |
-
)
|
438 |
-
|
439 |
-
for result in inference(req):
|
440 |
-
if result[2]: # Error message
|
441 |
-
return None, result[2]
|
442 |
-
elif result[1]: # Audio data
|
443 |
-
return result[1], None
|
444 |
-
|
445 |
-
return None, i18n("No audio generated")
|
446 |
-
|
447 |
-
def select_example_audio(audio_file):
|
448 |
-
if audio_file:
|
449 |
-
audio_path = os.path.join("examples", audio_file)
|
450 |
-
lab_file = os.path.splitext(audio_file)[0] + ".lab"
|
451 |
-
lab_path = os.path.join("examples", lab_file)
|
452 |
-
|
453 |
-
if os.path.exists(lab_path):
|
454 |
-
with open(lab_path, "r", encoding="utf-8") as f:
|
455 |
-
lab_content = f.read().strip()
|
456 |
-
else:
|
457 |
-
lab_content = ""
|
458 |
-
|
459 |
-
return audio_path, lab_content
|
460 |
-
return None, ""
|
461 |
-
|
462 |
-
# Connect the dropdown to update reference audio and text
|
463 |
-
example_audio_dropdown.change(
|
464 |
-
fn=select_example_audio,
|
465 |
-
inputs=[example_audio_dropdown],
|
466 |
-
outputs=[reference_audio, reference_text]
|
467 |
-
)
|
468 |
-
|
469 |
# Submit
|
470 |
generate.click(
|
471 |
-
|
472 |
[
|
473 |
-
|
474 |
-
|
475 |
reference_audio,
|
476 |
reference_text,
|
477 |
max_new_tokens,
|
@@ -488,26 +240,24 @@ def build_app():
|
|
488 |
|
489 |
return app
|
490 |
|
491 |
-
|
492 |
-
|
493 |
def parse_args():
|
494 |
parser = ArgumentParser()
|
495 |
parser.add_argument(
|
496 |
"--llama-checkpoint-path",
|
497 |
type=Path,
|
498 |
-
default="checkpoints/
|
499 |
)
|
500 |
parser.add_argument(
|
501 |
"--decoder-checkpoint-path",
|
502 |
type=Path,
|
503 |
-
default="checkpoints/
|
504 |
)
|
505 |
-
parser.add_argument("--decoder-config-name", type=str, default="
|
506 |
parser.add_argument("--device", type=str, default="cuda")
|
507 |
parser.add_argument("--half", action="store_true")
|
508 |
parser.add_argument("--compile", action="store_true",default=True)
|
509 |
parser.add_argument("--max-gradio-length", type=int, default=0)
|
510 |
-
parser.add_argument("--theme", type=str, default="
|
511 |
|
512 |
return parser.parse_args()
|
513 |
|
@@ -533,25 +283,34 @@ if __name__ == "__main__":
|
|
533 |
|
534 |
logger.info("Decoder model loaded, warming up...")
|
535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
537 |
list(
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
format="wav",
|
550 |
-
)
|
551 |
)
|
|
|
552 |
)
|
553 |
|
554 |
logger.info("Warming up done, launching the web UI...")
|
555 |
|
556 |
-
|
557 |
-
|
|
|
|
|
|
1 |
import os
|
2 |
import queue
|
3 |
from huggingface_hub import snapshot_download
|
|
|
4 |
import numpy as np
|
5 |
import wave
|
6 |
import io
|
|
|
7 |
import gc
|
8 |
+
from typing import Callable
|
9 |
|
10 |
# Download if not exists
|
11 |
os.makedirs("checkpoints", exist_ok=True)
|
12 |
+
snapshot_download(repo_id="fishaudio/openaudio-s1-mini", local_dir="./checkpoints/openaudio-s1-mini")
|
13 |
|
14 |
print("All checkpoints downloaded")
|
15 |
|
16 |
import html
|
17 |
import os
|
|
|
18 |
from argparse import ArgumentParser
|
19 |
from pathlib import Path
|
|
|
20 |
|
21 |
import gradio as gr
|
|
|
22 |
import torch
|
23 |
import torchaudio
|
24 |
|
25 |
torchaudio.set_audio_backend("soundfile")
|
26 |
|
27 |
from loguru import logger
|
|
|
|
|
28 |
from fish_speech.i18n import i18n
|
29 |
+
from fish_speech.inference_engine import TTSInferenceEngine
|
30 |
+
from fish_speech.models.dac.inference import load_model as load_decoder_model
|
31 |
+
from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
|
32 |
+
from tools.webui.inference import get_inference_wrapper
|
33 |
+
from fish_speech.utils.schema import ServeTTSRequest
|
34 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
# Make einx happy
|
36 |
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
37 |
|
38 |
|
39 |
+
HEADER_MD = """# OpenAudio S1
|
40 |
|
41 |
+
## The demo in this space is OpenAudio S1, Please check [Fish Audio](https://fish.audio) for the best model.
|
42 |
+
## 该 Demo 为 OpenAudio S1 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
|
43 |
|
44 |
+
A text-to-speech model based on DAC and Qwen3 developed by [Fish Audio](https://fish.audio).
|
45 |
+
由 [Fish Audio](https://fish.audio) 研发的基于 DAC 和 Qwen3 的多语种语音合成.
|
46 |
|
47 |
+
You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/openaudio-s1-mini).
|
48 |
+
你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/openaudio-s1-mini) 找到模型.
|
49 |
|
50 |
Related code and weights are released under CC BY-NC-SA 4.0 License.
|
51 |
相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
|
|
|
53 |
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
|
54 |
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
|
55 |
|
56 |
+
The model running in this WebUI is OpenAudio S1 Mini.
|
57 |
+
在此 WebUI 中运行的模型是 OpenAudio S1 Mini.
|
58 |
"""
|
59 |
|
60 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
|
|
71 |
|
72 |
return wrapper
|
73 |
|
|
|
74 |
def build_html_error_message(error):
|
75 |
return f"""
|
76 |
<div style="color: red;
|
|
|
79 |
</div>
|
80 |
"""
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
83 |
buffer = io.BytesIO()
|
84 |
|
|
|
91 |
buffer.close()
|
92 |
return wav_header_bytes
|
93 |
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
+
def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
|
96 |
with gr.Blocks(theme=gr.themes.Base()) as app:
|
97 |
gr.Markdown(HEADER_MD)
|
98 |
|
|
|
101 |
None,
|
102 |
None,
|
103 |
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
|
104 |
+
% theme,
|
105 |
)
|
106 |
|
107 |
# Inference
|
|
|
110 |
text = gr.Textbox(
|
111 |
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
|
112 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
with gr.Row():
|
115 |
with gr.Column():
|
|
|
117 |
with gr.Row():
|
118 |
chunk_length = gr.Slider(
|
119 |
label=i18n("Iterative Prompt Length, 0 means off"),
|
120 |
+
minimum=100,
|
121 |
+
maximum=400,
|
122 |
+
value=300,
|
123 |
step=8,
|
124 |
)
|
125 |
|
126 |
max_new_tokens = gr.Slider(
|
127 |
label=i18n(
|
128 |
+
"Maximum tokens per batch, 0 means no limit"
|
129 |
),
|
130 |
+
minimum=0,
|
131 |
maximum=2048,
|
132 |
+
value=0,
|
133 |
+
step=8,
|
134 |
)
|
135 |
|
136 |
with gr.Row():
|
137 |
top_p = gr.Slider(
|
138 |
label="Top-P",
|
139 |
+
minimum=0.7,
|
140 |
+
maximum=0.95,
|
141 |
+
value=0.8,
|
142 |
step=0.01,
|
143 |
)
|
144 |
|
145 |
repetition_penalty = gr.Slider(
|
146 |
label=i18n("Repetition Penalty"),
|
147 |
minimum=1,
|
148 |
+
maximum=1.2,
|
149 |
+
value=1.1,
|
150 |
step=0.01,
|
151 |
)
|
152 |
|
153 |
with gr.Row():
|
154 |
temperature = gr.Slider(
|
155 |
label="Temperature",
|
156 |
+
minimum=0.7,
|
157 |
+
maximum=1.0,
|
158 |
+
value=0.8,
|
159 |
step=0.01,
|
160 |
)
|
161 |
seed = gr.Number(
|
|
|
168 |
with gr.Row():
|
169 |
gr.Markdown(
|
170 |
i18n(
|
171 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker."
|
172 |
)
|
173 |
)
|
|
|
174 |
with gr.Row():
|
175 |
+
reference_id = gr.Textbox(
|
176 |
+
label=i18n("Reference ID"),
|
177 |
+
placeholder="Leave empty to use uploaded references",
|
|
|
|
|
|
|
178 |
)
|
179 |
|
180 |
with gr.Row():
|
181 |
use_memory_cache = gr.Radio(
|
182 |
label=i18n("Use Memory Cache"),
|
183 |
+
choices=["on", "off"],
|
184 |
+
value="on",
|
185 |
)
|
186 |
|
187 |
with gr.Row():
|
|
|
189 |
label=i18n("Reference Audio"),
|
190 |
type="filepath",
|
191 |
)
|
|
|
192 |
with gr.Row():
|
193 |
reference_text = gr.Textbox(
|
194 |
label=i18n("Reference Text"),
|
|
|
214 |
with gr.Row():
|
215 |
with gr.Column(scale=3):
|
216 |
generate = gr.Button(
|
217 |
+
value="\U0001f3a7 " + i18n("Generate"),
|
218 |
+
variant="primary",
|
219 |
)
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
# Submit
|
222 |
generate.click(
|
223 |
+
inference_fct,
|
224 |
[
|
225 |
+
text,
|
226 |
+
reference_id,
|
227 |
reference_audio,
|
228 |
reference_text,
|
229 |
max_new_tokens,
|
|
|
240 |
|
241 |
return app
|
242 |
|
|
|
|
|
243 |
def parse_args():
|
244 |
parser = ArgumentParser()
|
245 |
parser.add_argument(
|
246 |
"--llama-checkpoint-path",
|
247 |
type=Path,
|
248 |
+
default="checkpoints/openaudio-s1-mini",
|
249 |
)
|
250 |
parser.add_argument(
|
251 |
"--decoder-checkpoint-path",
|
252 |
type=Path,
|
253 |
+
default="checkpoints/openaudio-s1-mini/codec.pth",
|
254 |
)
|
255 |
+
parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
|
256 |
parser.add_argument("--device", type=str, default="cuda")
|
257 |
parser.add_argument("--half", action="store_true")
|
258 |
parser.add_argument("--compile", action="store_true",default=True)
|
259 |
parser.add_argument("--max-gradio-length", type=int, default=0)
|
260 |
+
parser.add_argument("--theme", type=str, default="dark")
|
261 |
|
262 |
return parser.parse_args()
|
263 |
|
|
|
283 |
|
284 |
logger.info("Decoder model loaded, warming up...")
|
285 |
|
286 |
+
# Create the inference engine
|
287 |
+
inference_engine = TTSInferenceEngine(
|
288 |
+
llama_queue=llama_queue,
|
289 |
+
decoder_model=decoder_model,
|
290 |
+
compile=args.compile,
|
291 |
+
precision=args.precision,
|
292 |
+
)
|
293 |
+
|
294 |
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
295 |
list(
|
296 |
+
inference_engine.inference(
|
297 |
+
ServeTTSRequest(
|
298 |
+
text="Hello world.",
|
299 |
+
references=[],
|
300 |
+
reference_id=None,
|
301 |
+
max_new_tokens=1024,
|
302 |
+
chunk_length=200,
|
303 |
+
top_p=0.7,
|
304 |
+
repetition_penalty=1.5,
|
305 |
+
temperature=0.7,
|
306 |
+
format="wav",
|
|
|
|
|
307 |
)
|
308 |
+
)
|
309 |
)
|
310 |
|
311 |
logger.info("Warming up done, launching the web UI...")
|
312 |
|
313 |
+
inference_fct = get_inference_wrapper(inference_engine)
|
314 |
+
|
315 |
+
app = build_app(inference_fct, args.theme)
|
316 |
+
app.queue(api_open=True).launch(show_error=True, show_api=True, server_name="0.0.0.0", server_port=18888)
|
examples/Arabic.wav
CHANGED
Binary files a/examples/Arabic.wav and b/examples/Arabic.wav differ
|
|
examples/English.wav
CHANGED
Binary files a/examples/English.wav and b/examples/English.wav differ
|
|
examples/French.wav
CHANGED
Binary files a/examples/French.wav and b/examples/French.wav differ
|
|
examples/German.wav
CHANGED
Binary files a/examples/German.wav and b/examples/German.wav differ
|
|
examples/Japanese.wav
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3034a38260884be854cb4a3f6cb648db85ebdeeb8cab74cfae2a578dc7aaedc2
|
3 |
+
size 132
|
examples/Korean.wav
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5767663f0c26f4dc94f45227f385c2be568aac065272466915d65eaa64fdda0f
|
3 |
+
size 132
|
examples/Nice English Ref.wav
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b707de0cfc5d2eee59dcc3fea495603fe28d95ca64d8202bcdb31537d588782
|
3 |
+
size 132
|
examples/Spanish.wav
CHANGED
Binary files a/examples/Spanish.wav and b/examples/Spanish.wav differ
|
|
fish_speech/configs/base.yaml
CHANGED
@@ -1,87 +1,87 @@
|
|
1 |
-
# Base configuration for training a model
|
2 |
-
paths:
|
3 |
-
run_dir: results/${project}
|
4 |
-
ckpt_dir: ${paths.run_dir}/checkpoints
|
5 |
-
|
6 |
-
hydra:
|
7 |
-
run:
|
8 |
-
dir: ${paths.run_dir}
|
9 |
-
|
10 |
-
# Lightning Trainer
|
11 |
-
trainer:
|
12 |
-
_target_: lightning.pytorch.trainer.Trainer
|
13 |
-
|
14 |
-
default_root_dir: ${paths.run_dir}
|
15 |
-
accelerator: gpu
|
16 |
-
num_nodes: 1
|
17 |
-
devices: auto
|
18 |
-
strategy:
|
19 |
-
_target_: lightning.pytorch.strategies.DDPStrategy
|
20 |
-
process_group_backend: nccl # This should be override when training on windows
|
21 |
-
|
22 |
-
precision: bf16-mixed
|
23 |
-
|
24 |
-
# disable validation by epoch end
|
25 |
-
check_val_every_n_epoch: null
|
26 |
-
val_check_interval: 5000
|
27 |
-
max_steps: 100_000
|
28 |
-
|
29 |
-
# Use torch.backends.cudnn.benchmark to speed up training
|
30 |
-
benchmark: true
|
31 |
-
|
32 |
-
# Callbacks
|
33 |
-
callbacks:
|
34 |
-
model_checkpoint:
|
35 |
-
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
36 |
-
dirpath: ${paths.ckpt_dir}
|
37 |
-
filename: "step_{step:09d}"
|
38 |
-
save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
39 |
-
save_top_k: 5 # save 5 latest checkpoints
|
40 |
-
monitor: step # use step to monitor checkpoints
|
41 |
-
mode: max # save the latest checkpoint with the highest global_step
|
42 |
-
every_n_epochs: null # don't save checkpoints by epoch end
|
43 |
-
every_n_train_steps: 5000 # save checkpoints every 5000 steps
|
44 |
-
auto_insert_metric_name: false
|
45 |
-
|
46 |
-
model_summary:
|
47 |
-
_target_: lightning.pytorch.callbacks.ModelSummary
|
48 |
-
max_depth: 2 # the maximum depth of layer nesting that the summary will include
|
49 |
-
|
50 |
-
learning_rate_monitor:
|
51 |
-
_target_: lightning.pytorch.callbacks.LearningRateMonitor
|
52 |
-
logging_interval: step
|
53 |
-
log_momentum: false
|
54 |
-
|
55 |
-
grad_norm_monitor:
|
56 |
-
_target_: fish_speech.callbacks.GradNormMonitor
|
57 |
-
norm_type: 2
|
58 |
-
logging_interval: step
|
59 |
-
|
60 |
-
# Logger
|
61 |
-
logger:
|
62 |
-
tensorboard:
|
63 |
-
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
64 |
-
save_dir: "${paths.run_dir}/tensorboard/"
|
65 |
-
name: null
|
66 |
-
log_graph: false
|
67 |
-
default_hp_metric: true
|
68 |
-
prefix: ""
|
69 |
-
|
70 |
-
# wandb:
|
71 |
-
# _target_: lightning.pytorch.loggers.wandb.WandbLogger
|
72 |
-
# # name: "" # name of the run (normally generated by wandb)
|
73 |
-
# save_dir: "${paths.run_dir}"
|
74 |
-
# offline: False
|
75 |
-
# id: null # pass correct id to resume experiment!
|
76 |
-
# anonymous: null # enable anonymous logging
|
77 |
-
# project: "fish-speech"
|
78 |
-
# log_model: False # upload lightning ckpts
|
79 |
-
# prefix: "" # a string to put at the beginning of metric keys
|
80 |
-
# # entity: "" # set to name of your wandb team
|
81 |
-
# group: ""
|
82 |
-
# tags: ["vq", "hq", "finetune"]
|
83 |
-
# job_type: ""
|
84 |
-
|
85 |
-
# Loop
|
86 |
-
train: true
|
87 |
-
test: false
|
|
|
1 |
+
# Base configuration for training a model
|
2 |
+
paths:
|
3 |
+
run_dir: results/${project}
|
4 |
+
ckpt_dir: ${paths.run_dir}/checkpoints
|
5 |
+
|
6 |
+
hydra:
|
7 |
+
run:
|
8 |
+
dir: ${paths.run_dir}
|
9 |
+
|
10 |
+
# Lightning Trainer
|
11 |
+
trainer:
|
12 |
+
_target_: lightning.pytorch.trainer.Trainer
|
13 |
+
|
14 |
+
default_root_dir: ${paths.run_dir}
|
15 |
+
accelerator: gpu
|
16 |
+
num_nodes: 1
|
17 |
+
devices: auto
|
18 |
+
strategy:
|
19 |
+
_target_: lightning.pytorch.strategies.DDPStrategy
|
20 |
+
process_group_backend: nccl # This should be override when training on windows
|
21 |
+
|
22 |
+
precision: bf16-mixed
|
23 |
+
|
24 |
+
# disable validation by epoch end
|
25 |
+
check_val_every_n_epoch: null
|
26 |
+
val_check_interval: 5000
|
27 |
+
max_steps: 100_000
|
28 |
+
|
29 |
+
# Use torch.backends.cudnn.benchmark to speed up training
|
30 |
+
benchmark: true
|
31 |
+
|
32 |
+
# Callbacks
|
33 |
+
callbacks:
|
34 |
+
model_checkpoint:
|
35 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
36 |
+
dirpath: ${paths.ckpt_dir}
|
37 |
+
filename: "step_{step:09d}"
|
38 |
+
save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
39 |
+
save_top_k: 5 # save 5 latest checkpoints
|
40 |
+
monitor: step # use step to monitor checkpoints
|
41 |
+
mode: max # save the latest checkpoint with the highest global_step
|
42 |
+
every_n_epochs: null # don't save checkpoints by epoch end
|
43 |
+
every_n_train_steps: 5000 # save checkpoints every 5000 steps
|
44 |
+
auto_insert_metric_name: false
|
45 |
+
|
46 |
+
model_summary:
|
47 |
+
_target_: lightning.pytorch.callbacks.ModelSummary
|
48 |
+
max_depth: 2 # the maximum depth of layer nesting that the summary will include
|
49 |
+
|
50 |
+
learning_rate_monitor:
|
51 |
+
_target_: lightning.pytorch.callbacks.LearningRateMonitor
|
52 |
+
logging_interval: step
|
53 |
+
log_momentum: false
|
54 |
+
|
55 |
+
grad_norm_monitor:
|
56 |
+
_target_: fish_speech.callbacks.GradNormMonitor
|
57 |
+
norm_type: 2
|
58 |
+
logging_interval: step
|
59 |
+
|
60 |
+
# Logger
|
61 |
+
logger:
|
62 |
+
tensorboard:
|
63 |
+
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
64 |
+
save_dir: "${paths.run_dir}/tensorboard/"
|
65 |
+
name: null
|
66 |
+
log_graph: false
|
67 |
+
default_hp_metric: true
|
68 |
+
prefix: ""
|
69 |
+
|
70 |
+
# wandb:
|
71 |
+
# _target_: lightning.pytorch.loggers.wandb.WandbLogger
|
72 |
+
# # name: "" # name of the run (normally generated by wandb)
|
73 |
+
# save_dir: "${paths.run_dir}"
|
74 |
+
# offline: False
|
75 |
+
# id: null # pass correct id to resume experiment!
|
76 |
+
# anonymous: null # enable anonymous logging
|
77 |
+
# project: "fish-speech"
|
78 |
+
# log_model: False # upload lightning ckpts
|
79 |
+
# prefix: "" # a string to put at the beginning of metric keys
|
80 |
+
# # entity: "" # set to name of your wandb team
|
81 |
+
# group: ""
|
82 |
+
# tags: ["vq", "hq", "finetune"]
|
83 |
+
# job_type: ""
|
84 |
+
|
85 |
+
# Loop
|
86 |
+
train: true
|
87 |
+
test: false
|
fish_speech/configs/lora/r_8_alpha_16.yaml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
_target_: fish_speech.models.text2semantic.lora.LoraConfig
|
2 |
-
r: 8
|
3 |
-
lora_alpha: 16
|
4 |
-
lora_dropout: 0.01
|
|
|
1 |
+
_target_: fish_speech.models.text2semantic.lora.LoraConfig
|
2 |
+
r: 8
|
3 |
+
lora_alpha: 16
|
4 |
+
lora_dropout: 0.01
|
fish_speech/configs/modded_dac_vq.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: fish_speech.models.dac.modded_dac.DAC
|
2 |
+
# Model setup
|
3 |
+
sample_rate: 44100
|
4 |
+
encoder_dim: 64
|
5 |
+
encoder_rates: [2, 4, 8, 8]
|
6 |
+
decoder_dim: 1536
|
7 |
+
decoder_rates: [8, 8, 4, 2]
|
8 |
+
encoder_transformer_layers: [0, 0, 0, 4]
|
9 |
+
decoder_transformer_layers: [4, 0, 0, 0]
|
10 |
+
transformer_general_config:
|
11 |
+
_target_: fish_speech.models.dac.modded_dac.ModelArgs
|
12 |
+
_partial_: true
|
13 |
+
block_size: 16384
|
14 |
+
n_local_heads: -1
|
15 |
+
head_dim: 64
|
16 |
+
rope_base: 10000
|
17 |
+
norm_eps: 1e-5
|
18 |
+
dropout_rate: 0.1
|
19 |
+
attn_dropout_rate: 0.1
|
20 |
+
channels_first: true
|
21 |
+
# Quantization
|
22 |
+
quantizer:
|
23 |
+
_target_: fish_speech.models.dac.rvq.DownsampleResidualVectorQuantize
|
24 |
+
input_dim: 1024
|
25 |
+
n_codebooks: 9
|
26 |
+
codebook_size: 1024
|
27 |
+
codebook_dim: 8
|
28 |
+
quantizer_dropout: 0.5
|
29 |
+
downsample_factor: [2, 2]
|
30 |
+
post_module: &transformer_module
|
31 |
+
_target_: fish_speech.models.dac.modded_dac.WindowLimitedTransformer
|
32 |
+
causal: true
|
33 |
+
window_size: 128 # empirically this does not seem to matter
|
34 |
+
input_dim: 1024
|
35 |
+
config: &transformer_config
|
36 |
+
_target_: fish_speech.models.dac.modded_dac.ModelArgs
|
37 |
+
block_size: 4096
|
38 |
+
n_layer: 8
|
39 |
+
n_head: 16
|
40 |
+
dim: 1024
|
41 |
+
intermediate_size: 3072
|
42 |
+
n_local_heads: -1
|
43 |
+
head_dim: 64
|
44 |
+
rope_base: 10000
|
45 |
+
norm_eps: 1e-5
|
46 |
+
dropout_rate: 0.1
|
47 |
+
attn_dropout_rate: 0.1
|
48 |
+
channels_first: true
|
49 |
+
pre_module: *transformer_module
|
50 |
+
semantic_codebook_size: 4096
|
fish_speech/configs/text2semantic_finetune.yaml
CHANGED
@@ -1,83 +1,86 @@
|
|
1 |
-
defaults:
|
2 |
-
- base
|
3 |
-
- _self_
|
4 |
-
|
5 |
-
project: text2semantic_finetune_dual_ar
|
6 |
-
max_length: 4096
|
7 |
-
pretrained_ckpt_path: checkpoints/
|
8 |
-
|
9 |
-
# Lightning Trainer
|
10 |
-
trainer:
|
11 |
-
accumulate_grad_batches: 1
|
12 |
-
gradient_clip_val: 1.0
|
13 |
-
gradient_clip_algorithm: "norm"
|
14 |
-
max_steps:
|
15 |
-
precision: bf16-true
|
16 |
-
limit_val_batches: 10
|
17 |
-
val_check_interval: 100
|
18 |
-
|
19 |
-
#
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- base
|
3 |
+
- _self_
|
4 |
+
|
5 |
+
project: text2semantic_finetune_dual_ar
|
6 |
+
max_length: 4096
|
7 |
+
pretrained_ckpt_path: checkpoints/openaudio-s1-mini
|
8 |
+
|
9 |
+
# Lightning Trainer
|
10 |
+
trainer:
|
11 |
+
accumulate_grad_batches: 1
|
12 |
+
gradient_clip_val: 1.0
|
13 |
+
gradient_clip_algorithm: "norm"
|
14 |
+
max_steps: 10000
|
15 |
+
precision: bf16-true
|
16 |
+
limit_val_batches: 10
|
17 |
+
val_check_interval: 100
|
18 |
+
# strategy:
|
19 |
+
# find_unused_parameters: true
|
20 |
+
# static_graph: true
|
21 |
+
|
22 |
+
# Dataset Configuration
|
23 |
+
tokenizer:
|
24 |
+
_target_: fish_speech.tokenizer.FishTokenizer
|
25 |
+
model_path: ${pretrained_ckpt_path}/tokenizer.tiktoken
|
26 |
+
|
27 |
+
# Dataset Configuration
|
28 |
+
train_dataset:
|
29 |
+
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
|
30 |
+
proto_files:
|
31 |
+
- data/protos
|
32 |
+
tokenizer: ${tokenizer}
|
33 |
+
causal: true
|
34 |
+
max_length: ${max_length}
|
35 |
+
use_speaker: false
|
36 |
+
interactive_prob: 0.7
|
37 |
+
|
38 |
+
val_dataset:
|
39 |
+
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
|
40 |
+
proto_files:
|
41 |
+
- data/protos
|
42 |
+
tokenizer: ${tokenizer}
|
43 |
+
causal: true
|
44 |
+
max_length: ${max_length}
|
45 |
+
use_speaker: false
|
46 |
+
interactive_prob: 0.7
|
47 |
+
|
48 |
+
data:
|
49 |
+
_target_: fish_speech.datasets.semantic.SemanticDataModule
|
50 |
+
train_dataset: ${train_dataset}
|
51 |
+
val_dataset: ${val_dataset}
|
52 |
+
num_workers: 4
|
53 |
+
batch_size: 4
|
54 |
+
tokenizer: ${tokenizer}
|
55 |
+
max_length: ${max_length}
|
56 |
+
|
57 |
+
# Model Configuration
|
58 |
+
model:
|
59 |
+
_target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
|
60 |
+
model:
|
61 |
+
_target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
|
62 |
+
path: ${pretrained_ckpt_path}
|
63 |
+
load_weights: true
|
64 |
+
max_length: ${max_length}
|
65 |
+
lora_config: null
|
66 |
+
|
67 |
+
optimizer:
|
68 |
+
_target_: torch.optim.AdamW
|
69 |
+
_partial_: true
|
70 |
+
lr: 1e-4
|
71 |
+
weight_decay: 0
|
72 |
+
betas: [0.9, 0.95]
|
73 |
+
eps: 1e-5
|
74 |
+
|
75 |
+
lr_scheduler:
|
76 |
+
_target_: torch.optim.lr_scheduler.LambdaLR
|
77 |
+
_partial_: true
|
78 |
+
lr_lambda:
|
79 |
+
_target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
|
80 |
+
_partial_: true
|
81 |
+
num_warmup_steps: 10
|
82 |
+
|
83 |
+
# Callbacks
|
84 |
+
callbacks:
|
85 |
+
model_checkpoint:
|
86 |
+
every_n_train_steps: ${trainer.val_check_interval}
|
fish_speech/content_sequence.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List, Literal, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from fish_speech.tokenizer import (
|
8 |
+
IM_END_TOKEN,
|
9 |
+
MODALITY_TOKENS,
|
10 |
+
FishTokenizer,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def restore_ndarray(obj, to_tensor: bool = False):
|
15 |
+
if isinstance(obj, dict) and "__ndarray__" in obj:
|
16 |
+
obj = np.frombuffer(obj["data"], dtype=obj["dtype"]).reshape(obj["shape"])
|
17 |
+
|
18 |
+
if to_tensor and isinstance(obj, np.ndarray):
|
19 |
+
obj = torch.from_numpy(obj.copy())
|
20 |
+
|
21 |
+
return obj
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class BasePart:
|
26 |
+
type: Literal["text", "vq", "audio"] | None = None
|
27 |
+
cal_loss: bool = False
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass(kw_only=True)
|
31 |
+
class VQPart(BasePart):
|
32 |
+
type = "vq"
|
33 |
+
codes: torch.Tensor
|
34 |
+
|
35 |
+
def __post_init__(self: "VQPart"):
|
36 |
+
self.type = "vq"
|
37 |
+
self.codes = restore_ndarray(self.codes, to_tensor=True)
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass(kw_only=True)
|
41 |
+
class TextPart(BasePart):
|
42 |
+
type = "text"
|
43 |
+
text: str | None = None
|
44 |
+
tokens: list[int] | None = None
|
45 |
+
|
46 |
+
def __post_init__(self: "TextPart"):
|
47 |
+
self.type = "text"
|
48 |
+
if self.text is None and self.tokens is None:
|
49 |
+
raise ValueError("Either text or tokens must be provided")
|
50 |
+
|
51 |
+
|
52 |
+
@dataclass(kw_only=True)
|
53 |
+
class AudioPart(BasePart):
|
54 |
+
type = "audio"
|
55 |
+
features: torch.Tensor
|
56 |
+
|
57 |
+
def __post_init__(self: "AudioPart"):
|
58 |
+
self.type = "audio"
|
59 |
+
self.features = restore_ndarray(self.features, to_tensor=True)
|
60 |
+
|
61 |
+
|
62 |
+
@dataclass(kw_only=True)
|
63 |
+
class EncodedMessage:
|
64 |
+
tokens: torch.Tensor
|
65 |
+
labels: torch.Tensor
|
66 |
+
vq_mask_tokens: torch.Tensor | None = None
|
67 |
+
vq_mask_labels: torch.Tensor | None = None
|
68 |
+
vq_parts: list[torch.Tensor]
|
69 |
+
vq_require_losses: torch.Tensor | None = None
|
70 |
+
audio_parts: list[torch.Tensor]
|
71 |
+
audio_masks: torch.Tensor | None = None
|
72 |
+
metadata: dict | None = None
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class ContentSequence:
|
77 |
+
"""
|
78 |
+
Flexible sequence of content parts that supports interleaved multimodal format.
|
79 |
+
Example format: <|interleave|><|speaker:1|> TEXT AUDIO <|im_end|><|speaker:2|> TEXT AUDIO <|im_end|>
|
80 |
+
"""
|
81 |
+
|
82 |
+
parts: list[BasePart] = field(default_factory=list)
|
83 |
+
modality: Literal["text", "voice", "interleave"] | None = None
|
84 |
+
metadata: dict | None = None
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self: "ContentSequence",
|
88 |
+
parts: list[BasePart | dict] | None = None,
|
89 |
+
modality: Literal["text", "voice", "interleave"] | None = None,
|
90 |
+
metadata: dict | None = None,
|
91 |
+
):
|
92 |
+
self.modality = modality
|
93 |
+
self.metadata = metadata or {}
|
94 |
+
|
95 |
+
fixed_parts = []
|
96 |
+
for part in parts or []:
|
97 |
+
if isinstance(part, dict):
|
98 |
+
if part["type"] == "vq":
|
99 |
+
part = VQPart(**part)
|
100 |
+
elif part["type"] == "audio":
|
101 |
+
part = AudioPart(**part)
|
102 |
+
elif part["type"] == "text":
|
103 |
+
part = TextPart(**part)
|
104 |
+
else:
|
105 |
+
raise ValueError(f"Unsupported part type: {part['type']}")
|
106 |
+
fixed_parts.append(part)
|
107 |
+
|
108 |
+
self.parts = fixed_parts
|
109 |
+
|
110 |
+
# If modality is specified, add it at the beginning if it's not already there
|
111 |
+
if self.modality and not (
|
112 |
+
len(self.parts) > 0
|
113 |
+
and isinstance(self.parts[0], dict) is False
|
114 |
+
and isinstance(self.parts[0], TextPart)
|
115 |
+
and self.parts[0].text is not None
|
116 |
+
and self.parts[0].text.startswith(MODALITY_TOKENS[self.modality])
|
117 |
+
):
|
118 |
+
modality_token = MODALITY_TOKENS[self.modality]
|
119 |
+
self.parts.insert(0, TextPart(text=modality_token))
|
120 |
+
|
121 |
+
def append(
|
122 |
+
self: "ContentSequence",
|
123 |
+
part_or_parts: Union[BasePart, List[BasePart]],
|
124 |
+
add_end: bool = False,
|
125 |
+
speaker: Union[str, int] | None = None,
|
126 |
+
):
|
127 |
+
"""
|
128 |
+
Append a part or list of parts to the sequence.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
part_or_parts: A single part or list of parts to add
|
132 |
+
add_end: Whether to add the IM_END_TOKEN after these parts
|
133 |
+
speaker: Optional speaker identifier (name or ID) to add before the parts
|
134 |
+
"""
|
135 |
+
# Convert single part to list
|
136 |
+
parts_to_add = (
|
137 |
+
[part_or_parts] if not isinstance(part_or_parts, list) else part_or_parts
|
138 |
+
)
|
139 |
+
|
140 |
+
# Add speaker token if specified
|
141 |
+
if speaker is not None:
|
142 |
+
speaker_token = f"<|speaker:{speaker}|>"
|
143 |
+
self.parts.append(TextPart(text=speaker_token))
|
144 |
+
|
145 |
+
# Add all the parts
|
146 |
+
self.parts.extend(parts_to_add)
|
147 |
+
|
148 |
+
# Add end token if requested
|
149 |
+
if add_end:
|
150 |
+
self.parts.append(
|
151 |
+
TextPart(text=IM_END_TOKEN, cal_loss=self.parts[-1].cal_loss)
|
152 |
+
)
|
153 |
+
|
154 |
+
def encode(
|
155 |
+
self: "ContentSequence",
|
156 |
+
tokenizer: FishTokenizer,
|
157 |
+
add_shift: bool = True,
|
158 |
+
ignore_loss_tokens: list[str] = [],
|
159 |
+
) -> EncodedMessage:
|
160 |
+
"""
|
161 |
+
Encode the sequence parts into tokens for the model.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
tokenizer: The tokenizer to use
|
165 |
+
add_shift: Whether to shift tokens for next-token prediction
|
166 |
+
ignore_loss_tokens: List of token strings to ignore when calculating loss
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
EncodedMessage with tensors ready for the model
|
170 |
+
"""
|
171 |
+
all_tokens = []
|
172 |
+
all_labels = []
|
173 |
+
|
174 |
+
# Multi-modal elements
|
175 |
+
vq_parts = []
|
176 |
+
vq_masks = []
|
177 |
+
vq_require_losses = []
|
178 |
+
|
179 |
+
audio_parts = []
|
180 |
+
audio_masks = []
|
181 |
+
|
182 |
+
ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
|
183 |
+
|
184 |
+
for part in self.parts:
|
185 |
+
if isinstance(part, TextPart):
|
186 |
+
if part.tokens is None:
|
187 |
+
assert part.text is not None
|
188 |
+
tokens = tokenizer.encode(part.text)
|
189 |
+
else:
|
190 |
+
tokens = part.tokens
|
191 |
+
|
192 |
+
tokens = torch.tensor(tokens, dtype=torch.int)
|
193 |
+
elif isinstance(part, VQPart):
|
194 |
+
curr_codes = part.codes.clone().to(torch.int)
|
195 |
+
tokens = torch.tensor(
|
196 |
+
[
|
197 |
+
tokenizer.semantic_id_to_token_id[int(i.item())]
|
198 |
+
for i in curr_codes[0].int()
|
199 |
+
],
|
200 |
+
dtype=torch.int,
|
201 |
+
)
|
202 |
+
vq_parts.append(curr_codes)
|
203 |
+
vq_require_losses.append(part.cal_loss)
|
204 |
+
else:
|
205 |
+
raise ValueError(f"Unsupported part type: {type(part)}")
|
206 |
+
|
207 |
+
all_tokens.append(tokens)
|
208 |
+
|
209 |
+
# Set masks for different part types
|
210 |
+
if isinstance(part, VQPart):
|
211 |
+
vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
|
212 |
+
audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
|
213 |
+
elif isinstance(part, AudioPart):
|
214 |
+
vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
|
215 |
+
audio_mask = torch.ones_like(tokens, dtype=torch.bool)
|
216 |
+
audio_mask[0] = False # Skip start token
|
217 |
+
audio_mask[-1] = False # Skip end token
|
218 |
+
audio_masks.append(audio_mask)
|
219 |
+
else:
|
220 |
+
vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
|
221 |
+
audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
|
222 |
+
|
223 |
+
# Set labels based on whether we want to calculate loss for this part
|
224 |
+
if part.cal_loss and not isinstance(part, AudioPart):
|
225 |
+
all_labels.append(tokens.clone())
|
226 |
+
else:
|
227 |
+
all_labels.append(torch.full_like(tokens, -100))
|
228 |
+
|
229 |
+
# Concatenate all tensors
|
230 |
+
tokens = torch.cat(all_tokens, dim=0)
|
231 |
+
labels = torch.cat(all_labels, dim=0)
|
232 |
+
vq_masks = torch.cat(vq_masks, dim=0)
|
233 |
+
audio_masks = torch.cat(audio_masks, dim=0)
|
234 |
+
vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
|
235 |
+
|
236 |
+
# Apply shift if needed for next-token prediction
|
237 |
+
vq_mask_tokens = vq_masks
|
238 |
+
vq_mask_labels = vq_masks
|
239 |
+
|
240 |
+
if add_shift:
|
241 |
+
tokens = tokens[:-1]
|
242 |
+
labels = labels[1:]
|
243 |
+
vq_masks = vq_masks[:-1]
|
244 |
+
vq_mask_tokens = vq_mask_tokens[:-1]
|
245 |
+
vq_mask_labels = vq_mask_labels[1:]
|
246 |
+
audio_masks = audio_masks[:-1]
|
247 |
+
|
248 |
+
# Ignore specified tokens
|
249 |
+
for i in ignore_loss_token_ids:
|
250 |
+
assert i != -100 and i is not None
|
251 |
+
labels[labels == i] = -100
|
252 |
+
|
253 |
+
assert tokens.dtype in [
|
254 |
+
torch.int,
|
255 |
+
torch.long,
|
256 |
+
], f"Invalid dtype: {tokens.dtype}"
|
257 |
+
|
258 |
+
return EncodedMessage(
|
259 |
+
tokens=tokens,
|
260 |
+
labels=labels,
|
261 |
+
vq_parts=vq_parts,
|
262 |
+
vq_mask_tokens=vq_mask_tokens,
|
263 |
+
vq_mask_labels=vq_mask_labels,
|
264 |
+
vq_require_losses=vq_require_losses,
|
265 |
+
audio_parts=audio_parts,
|
266 |
+
audio_masks=audio_masks,
|
267 |
+
metadata=self.metadata,
|
268 |
+
)
|
269 |
+
|
270 |
+
def encode_for_inference(
|
271 |
+
self: "ContentSequence",
|
272 |
+
tokenizer: FishTokenizer,
|
273 |
+
num_codebooks: int,
|
274 |
+
) -> torch.Tensor:
|
275 |
+
encoded = self.encode(tokenizer, add_shift=False)
|
276 |
+
tokens = encoded.tokens
|
277 |
+
values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
|
278 |
+
values[0] = tokens
|
279 |
+
|
280 |
+
if (encoded.vq_parts is None or len(encoded.vq_parts) == 0) and (
|
281 |
+
encoded.audio_parts is None or len(encoded.audio_parts) == 0
|
282 |
+
):
|
283 |
+
return values
|
284 |
+
|
285 |
+
if encoded.vq_parts is not None and len(encoded.vq_parts) > 0:
|
286 |
+
vq_parts = encoded.vq_parts
|
287 |
+
vq_parts = torch.cat(vq_parts, dim=1)
|
288 |
+
values[0, encoded.vq_mask_tokens] = (
|
289 |
+
vq_parts[0] + tokenizer.semantic_begin_id
|
290 |
+
)
|
291 |
+
values[1:, encoded.vq_mask_tokens] = vq_parts
|
292 |
+
|
293 |
+
return values
|
294 |
+
|
295 |
+
def visualize(
|
296 |
+
self: "ContentSequence",
|
297 |
+
tokenizer: FishTokenizer,
|
298 |
+
ignore_loss_tokens: list[str] = [],
|
299 |
+
merge_semantic_tokens: bool = False,
|
300 |
+
):
|
301 |
+
"""
|
302 |
+
Visualize the encoded sequence with color-coded tokens.
|
303 |
+
Blue/cyan tokens contribute to loss, green tokens do not.
|
304 |
+
"""
|
305 |
+
encoded = self.encode(
|
306 |
+
tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
|
307 |
+
)
|
308 |
+
|
309 |
+
# Colors for alternating tokens
|
310 |
+
colors = {
|
311 |
+
"blue": "\033[94m", # Light blue
|
312 |
+
"cyan": "\033[96m", # Cyan
|
313 |
+
"green": "\033[92m", # Light green
|
314 |
+
"dark_green": "\033[32m", # Dark green
|
315 |
+
}
|
316 |
+
blue_idx = 0
|
317 |
+
green_idx = 0
|
318 |
+
|
319 |
+
def print_in_blue(x):
|
320 |
+
nonlocal blue_idx
|
321 |
+
color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
|
322 |
+
print(f"{color}{x}\033[0m", end="")
|
323 |
+
blue_idx += 1
|
324 |
+
|
325 |
+
def print_in_green(x):
|
326 |
+
nonlocal green_idx
|
327 |
+
color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
|
328 |
+
print(f"{color}{x}\033[0m", end="")
|
329 |
+
green_idx += 1
|
330 |
+
|
331 |
+
def print_semantic_token(x, count):
|
332 |
+
val = f"[<|semantic|>x{count}]"
|
333 |
+
if x == -100:
|
334 |
+
print_in_green(val)
|
335 |
+
else:
|
336 |
+
print_in_blue(val)
|
337 |
+
|
338 |
+
count_semantic_tokens = 0
|
339 |
+
semantic_label = None
|
340 |
+
|
341 |
+
for tok, lab in zip(encoded.tokens, encoded.labels):
|
342 |
+
token_id = int(tok.item())
|
343 |
+
|
344 |
+
if merge_semantic_tokens:
|
345 |
+
if (
|
346 |
+
tokenizer.semantic_begin_id <= token_id <= tokenizer.semantic_end_id
|
347 |
+
and (semantic_label is None or semantic_label == lab)
|
348 |
+
):
|
349 |
+
count_semantic_tokens += 1
|
350 |
+
semantic_label = lab
|
351 |
+
continue
|
352 |
+
elif count_semantic_tokens > 0:
|
353 |
+
print_semantic_token(semantic_label, count_semantic_tokens)
|
354 |
+
count_semantic_tokens = 0
|
355 |
+
semantic_label = None
|
356 |
+
|
357 |
+
val = tokenizer.decode([int(tok.item())])
|
358 |
+
|
359 |
+
if lab == -100:
|
360 |
+
print_in_green(val)
|
361 |
+
else:
|
362 |
+
print_in_blue(val)
|
363 |
+
|
364 |
+
if merge_semantic_tokens and count_semantic_tokens > 0:
|
365 |
+
print_semantic_token(semantic_label, count_semantic_tokens)
|
366 |
+
|
367 |
+
print()
|
fish_speech/i18n/README.md
CHANGED
@@ -1,27 +1,27 @@
|
|
1 |
-
## i18n Folder Attribution
|
2 |
-
|
3 |
-
The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
|
4 |
-
|
5 |
-
### fish_speech/i18n/core.py
|
6 |
-
|
7 |
-
**Related code from RVC:**
|
8 |
-
[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
|
9 |
-
|
10 |
-
**Initial commit:**
|
11 |
-
add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
|
12 |
-
|
13 |
-
**Initial author:**
|
14 |
-
[@L4Ph](https://github.com/L4Ph)
|
15 |
-
|
16 |
-
### fish_speech/i18n/scan.py
|
17 |
-
|
18 |
-
**Related code from RVC:**
|
19 |
-
[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
|
20 |
-
|
21 |
-
**Initial commit:**
|
22 |
-
File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
|
23 |
-
|
24 |
-
**Initial author:**
|
25 |
-
[@towzeur](https://github.com/towzeur)
|
26 |
-
|
27 |
-
We appreciate the contributions of the RVC project and its authors.
|
|
|
1 |
+
## i18n Folder Attribution
|
2 |
+
|
3 |
+
The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
|
4 |
+
|
5 |
+
### fish_speech/i18n/core.py
|
6 |
+
|
7 |
+
**Related code from RVC:**
|
8 |
+
[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
|
9 |
+
|
10 |
+
**Initial commit:**
|
11 |
+
add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
|
12 |
+
|
13 |
+
**Initial author:**
|
14 |
+
[@L4Ph](https://github.com/L4Ph)
|
15 |
+
|
16 |
+
### fish_speech/i18n/scan.py
|
17 |
+
|
18 |
+
**Related code from RVC:**
|
19 |
+
[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
|
20 |
+
|
21 |
+
**Initial commit:**
|
22 |
+
File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
|
23 |
+
|
24 |
+
**Initial author:**
|
25 |
+
[@towzeur](https://github.com/towzeur)
|
26 |
+
|
27 |
+
We appreciate the contributions of the RVC project and its authors.
|
fish_speech/i18n/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
from .core import i18n
|
2 |
-
|
3 |
-
__all__ = ["i18n"]
|
|
|
1 |
+
from .core import i18n
|
2 |
+
|
3 |
+
__all__ = ["i18n"]
|
fish_speech/i18n/core.py
CHANGED
@@ -1,40 +1,40 @@
|
|
1 |
-
import json
|
2 |
-
import locale
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
I18N_FILE_PATH = Path(__file__).parent / "locale"
|
6 |
-
DEFAULT_LANGUAGE = "en_US"
|
7 |
-
|
8 |
-
|
9 |
-
def load_language_list(language):
|
10 |
-
with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
|
11 |
-
language_list = json.load(f)
|
12 |
-
|
13 |
-
return language_list
|
14 |
-
|
15 |
-
|
16 |
-
class I18nAuto:
|
17 |
-
def __init__(self):
|
18 |
-
i18n_file = Path(".locale")
|
19 |
-
|
20 |
-
if i18n_file.exists():
|
21 |
-
with open(i18n_file, "r", encoding="utf-8") as f:
|
22 |
-
language = f.read().strip()
|
23 |
-
else:
|
24 |
-
# getlocale can't identify the system's language ((None, None))
|
25 |
-
language = locale.getdefaultlocale()[0]
|
26 |
-
|
27 |
-
if (I18N_FILE_PATH / f"{language}.json").exists() is False:
|
28 |
-
language = DEFAULT_LANGUAGE
|
29 |
-
|
30 |
-
self.language = language
|
31 |
-
self.language_map = load_language_list(language)
|
32 |
-
|
33 |
-
def __call__(self, key):
|
34 |
-
return self.language_map.get(key, key)
|
35 |
-
|
36 |
-
def __repr__(self):
|
37 |
-
return "Use Language: " + self.language
|
38 |
-
|
39 |
-
|
40 |
-
i18n = I18nAuto()
|
|
|
1 |
+
import json
|
2 |
+
import locale
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
I18N_FILE_PATH = Path(__file__).parent / "locale"
|
6 |
+
DEFAULT_LANGUAGE = "en_US"
|
7 |
+
|
8 |
+
|
9 |
+
def load_language_list(language):
|
10 |
+
with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
|
11 |
+
language_list = json.load(f)
|
12 |
+
|
13 |
+
return language_list
|
14 |
+
|
15 |
+
|
16 |
+
class I18nAuto:
|
17 |
+
def __init__(self):
|
18 |
+
i18n_file = Path(".locale")
|
19 |
+
|
20 |
+
if i18n_file.exists():
|
21 |
+
with open(i18n_file, "r", encoding="utf-8") as f:
|
22 |
+
language = f.read().strip()
|
23 |
+
else:
|
24 |
+
# getlocale can't identify the system's language ((None, None))
|
25 |
+
language = locale.getdefaultlocale()[0]
|
26 |
+
|
27 |
+
if (I18N_FILE_PATH / f"{language}.json").exists() is False:
|
28 |
+
language = DEFAULT_LANGUAGE
|
29 |
+
|
30 |
+
self.language = language
|
31 |
+
self.language_map = load_language_list(language)
|
32 |
+
|
33 |
+
def __call__(self, key):
|
34 |
+
return self.language_map.get(key, key)
|
35 |
+
|
36 |
+
def __repr__(self):
|
37 |
+
return "Use Language: " + self.language
|
38 |
+
|
39 |
+
|
40 |
+
i18n = I18nAuto()
|
fish_speech/i18n/locale/en_US.json
CHANGED
@@ -1,123 +1,123 @@
|
|
1 |
-
{
|
2 |
-
"16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
|
3 |
-
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
|
4 |
-
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
|
5 |
-
"Accumulate Gradient Batches": "Accumulate Gradient Batches",
|
6 |
-
"Add to Processing Area": "Add to Processing Area",
|
7 |
-
"Added path successfully!": "Added path successfully!",
|
8 |
-
"Advanced Config": "Advanced Config",
|
9 |
-
"Base LLAMA Model": "Base LLAMA Model",
|
10 |
-
"Batch Inference": "Batch Inference",
|
11 |
-
"Batch Size": "Batch Size",
|
12 |
-
"Changing with the Model Path": "Changing with the Model Path",
|
13 |
-
"Chinese": "Chinese",
|
14 |
-
"Compile Model": "Compile Model",
|
15 |
-
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
|
16 |
-
"Copy": "Copy",
|
17 |
-
"Data Preprocessing": "Data Preprocessing",
|
18 |
-
"Data Preprocessing Path": "Data Preprocessing Path",
|
19 |
-
"Data Source": "Data Source",
|
20 |
-
"Decoder Model Config": "Decoder Model Config",
|
21 |
-
"Decoder Model Path": "Decoder Model Path",
|
22 |
-
"Disabled": "Disabled",
|
23 |
-
"Enable Reference Audio": "Enable Reference Audio",
|
24 |
-
"English": "English",
|
25 |
-
"Error Message": "Error Message",
|
26 |
-
"File Preprocessing": "File Preprocessing",
|
27 |
-
"Generate": "Generate",
|
28 |
-
"Generated Audio": "Generated Audio",
|
29 |
-
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
|
30 |
-
"Infer interface is closed": "Infer interface is closed",
|
31 |
-
"Inference Configuration": "Inference Configuration",
|
32 |
-
"Inference Server Configuration": "Inference Server Configuration",
|
33 |
-
"Inference Server Error": "Inference Server Error",
|
34 |
-
"Inferring interface is launched at {}": "Inferring interface is launched at {}",
|
35 |
-
"Initial Learning Rate": "Initial Learning Rate",
|
36 |
-
"Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
|
37 |
-
"Input Text": "Input Text",
|
38 |
-
"Invalid path: {}": "Invalid path: {}",
|
39 |
-
"It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
|
40 |
-
"Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
|
41 |
-
"Japanese": "Japanese",
|
42 |
-
"LLAMA Configuration": "LLAMA Configuration",
|
43 |
-
"LLAMA Model Config": "LLAMA Model Config",
|
44 |
-
"LLAMA Model Path": "LLAMA Model Path",
|
45 |
-
"Labeling Device": "Labeling Device",
|
46 |
-
"LoRA Model to be merged": "LoRA Model to be merged",
|
47 |
-
"Maximum Audio Duration": "Maximum Audio Duration",
|
48 |
-
"Maximum Length per Sample": "Maximum Length per Sample",
|
49 |
-
"Maximum Training Steps": "Maximum Training Steps",
|
50 |
-
"Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
|
51 |
-
"Merge": "Merge",
|
52 |
-
"Merge LoRA": "Merge LoRA",
|
53 |
-
"Merge successfully": "Merge successfully",
|
54 |
-
"Minimum Audio Duration": "Minimum Audio Duration",
|
55 |
-
"Model Output Path": "Model Output Path",
|
56 |
-
"Model Size": "Model Size",
|
57 |
-
"Move": "Move",
|
58 |
-
"Move files successfully": "Move files successfully",
|
59 |
-
"No audio generated, please check the input text.": "No audio generated, please check the input text.",
|
60 |
-
"No selected options": "No selected options",
|
61 |
-
"Number of Workers": "Number of Workers",
|
62 |
-
"Open Inference Server": "Open Inference Server",
|
63 |
-
"Open Labeler WebUI": "Open Labeler WebUI",
|
64 |
-
"Open Tensorboard": "Open Tensorboard",
|
65 |
-
"Opened labeler in browser": "Opened labeler in browser",
|
66 |
-
"Optional Label Language": "Optional Label Language",
|
67 |
-
"Optional online ver": "Optional online ver",
|
68 |
-
"Output Path": "Output Path",
|
69 |
-
"Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
|
70 |
-
"Precision": "Precision",
|
71 |
-
"Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
|
72 |
-
"Put your text here.": "Put your text here.",
|
73 |
-
"Reference Audio": "Reference Audio",
|
74 |
-
"Reference Text": "Reference Text",
|
75 |
-
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
|
76 |
-
"Remove Selected Data": "Remove Selected Data",
|
77 |
-
"Removed path successfully!": "Removed path successfully!",
|
78 |
-
"Repetition Penalty": "Repetition Penalty",
|
79 |
-
"Save model every n steps": "Save model every n steps",
|
80 |
-
"Select LLAMA ckpt": "Select LLAMA ckpt",
|
81 |
-
"Select VITS ckpt": "Select VITS ckpt",
|
82 |
-
"Select VQGAN ckpt": "Select VQGAN ckpt",
|
83 |
-
"Select source file processing method": "Select source file processing method",
|
84 |
-
"Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
|
85 |
-
"Selected: {}": "Selected: {}",
|
86 |
-
"Speaker": "Speaker",
|
87 |
-
"Speaker is identified by the folder name": "Speaker is identified by the folder name",
|
88 |
-
"Start Training": "Start Training",
|
89 |
-
"Streaming Audio": "Streaming Audio",
|
90 |
-
"Streaming Generate": "Streaming Generate",
|
91 |
-
"Tensorboard Host": "Tensorboard Host",
|
92 |
-
"Tensorboard Log Path": "Tensorboard Log Path",
|
93 |
-
"Tensorboard Port": "Tensorboard Port",
|
94 |
-
"Tensorboard interface is closed": "Tensorboard interface is closed",
|
95 |
-
"Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
|
96 |
-
"Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
|
97 |
-
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
|
98 |
-
"Training Configuration": "Training Configuration",
|
99 |
-
"Training Error": "Training Error",
|
100 |
-
"Training stopped": "Training stopped",
|
101 |
-
"Type name of the speaker": "Type name of the speaker",
|
102 |
-
"Type the path or select from the dropdown": "Type the path or select from the dropdown",
|
103 |
-
"Use LoRA": "Use LoRA",
|
104 |
-
"Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
|
105 |
-
"Use filelist": "Use filelist",
|
106 |
-
"Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
|
107 |
-
"VITS Configuration": "VITS Configuration",
|
108 |
-
"VQGAN Configuration": "VQGAN Configuration",
|
109 |
-
"Validation Batch Size": "Validation Batch Size",
|
110 |
-
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
|
111 |
-
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
|
112 |
-
"WebUI Host": "WebUI Host",
|
113 |
-
"WebUI Port": "WebUI Port",
|
114 |
-
"Whisper Model": "Whisper Model",
|
115 |
-
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
|
116 |
-
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
|
117 |
-
"latest": "latest",
|
118 |
-
"new": "new",
|
119 |
-
"Realtime Transform Text": "Realtime Transform Text",
|
120 |
-
"Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
|
121 |
-
"Text Normalization": "Text Normalization",
|
122 |
-
"Select Example Audio": "Select Example Audio"
|
123 |
-
}
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
|
5 |
+
"Accumulate Gradient Batches": "Accumulate Gradient Batches",
|
6 |
+
"Add to Processing Area": "Add to Processing Area",
|
7 |
+
"Added path successfully!": "Added path successfully!",
|
8 |
+
"Advanced Config": "Advanced Config",
|
9 |
+
"Base LLAMA Model": "Base LLAMA Model",
|
10 |
+
"Batch Inference": "Batch Inference",
|
11 |
+
"Batch Size": "Batch Size",
|
12 |
+
"Changing with the Model Path": "Changing with the Model Path",
|
13 |
+
"Chinese": "Chinese",
|
14 |
+
"Compile Model": "Compile Model",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
|
16 |
+
"Copy": "Copy",
|
17 |
+
"Data Preprocessing": "Data Preprocessing",
|
18 |
+
"Data Preprocessing Path": "Data Preprocessing Path",
|
19 |
+
"Data Source": "Data Source",
|
20 |
+
"Decoder Model Config": "Decoder Model Config",
|
21 |
+
"Decoder Model Path": "Decoder Model Path",
|
22 |
+
"Disabled": "Disabled",
|
23 |
+
"Enable Reference Audio": "Enable Reference Audio",
|
24 |
+
"English": "English",
|
25 |
+
"Error Message": "Error Message",
|
26 |
+
"File Preprocessing": "File Preprocessing",
|
27 |
+
"Generate": "Generate",
|
28 |
+
"Generated Audio": "Generated Audio",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
|
30 |
+
"Infer interface is closed": "Infer interface is closed",
|
31 |
+
"Inference Configuration": "Inference Configuration",
|
32 |
+
"Inference Server Configuration": "Inference Server Configuration",
|
33 |
+
"Inference Server Error": "Inference Server Error",
|
34 |
+
"Inferring interface is launched at {}": "Inferring interface is launched at {}",
|
35 |
+
"Initial Learning Rate": "Initial Learning Rate",
|
36 |
+
"Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
|
37 |
+
"Input Text": "Input Text",
|
38 |
+
"Invalid path: {}": "Invalid path: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
|
40 |
+
"Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
|
41 |
+
"Japanese": "Japanese",
|
42 |
+
"LLAMA Configuration": "LLAMA Configuration",
|
43 |
+
"LLAMA Model Config": "LLAMA Model Config",
|
44 |
+
"LLAMA Model Path": "LLAMA Model Path",
|
45 |
+
"Labeling Device": "Labeling Device",
|
46 |
+
"LoRA Model to be merged": "LoRA Model to be merged",
|
47 |
+
"Maximum Audio Duration": "Maximum Audio Duration",
|
48 |
+
"Maximum Length per Sample": "Maximum Length per Sample",
|
49 |
+
"Maximum Training Steps": "Maximum Training Steps",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
|
51 |
+
"Merge": "Merge",
|
52 |
+
"Merge LoRA": "Merge LoRA",
|
53 |
+
"Merge successfully": "Merge successfully",
|
54 |
+
"Minimum Audio Duration": "Minimum Audio Duration",
|
55 |
+
"Model Output Path": "Model Output Path",
|
56 |
+
"Model Size": "Model Size",
|
57 |
+
"Move": "Move",
|
58 |
+
"Move files successfully": "Move files successfully",
|
59 |
+
"No audio generated, please check the input text.": "No audio generated, please check the input text.",
|
60 |
+
"No selected options": "No selected options",
|
61 |
+
"Number of Workers": "Number of Workers",
|
62 |
+
"Open Inference Server": "Open Inference Server",
|
63 |
+
"Open Labeler WebUI": "Open Labeler WebUI",
|
64 |
+
"Open Tensorboard": "Open Tensorboard",
|
65 |
+
"Opened labeler in browser": "Opened labeler in browser",
|
66 |
+
"Optional Label Language": "Optional Label Language",
|
67 |
+
"Optional online ver": "Optional online ver",
|
68 |
+
"Output Path": "Output Path",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
|
70 |
+
"Precision": "Precision",
|
71 |
+
"Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
|
72 |
+
"Put your text here.": "Put your text here.",
|
73 |
+
"Reference Audio": "Reference Audio",
|
74 |
+
"Reference Text": "Reference Text",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
|
76 |
+
"Remove Selected Data": "Remove Selected Data",
|
77 |
+
"Removed path successfully!": "Removed path successfully!",
|
78 |
+
"Repetition Penalty": "Repetition Penalty",
|
79 |
+
"Save model every n steps": "Save model every n steps",
|
80 |
+
"Select LLAMA ckpt": "Select LLAMA ckpt",
|
81 |
+
"Select VITS ckpt": "Select VITS ckpt",
|
82 |
+
"Select VQGAN ckpt": "Select VQGAN ckpt",
|
83 |
+
"Select source file processing method": "Select source file processing method",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
|
85 |
+
"Selected: {}": "Selected: {}",
|
86 |
+
"Speaker": "Speaker",
|
87 |
+
"Speaker is identified by the folder name": "Speaker is identified by the folder name",
|
88 |
+
"Start Training": "Start Training",
|
89 |
+
"Streaming Audio": "Streaming Audio",
|
90 |
+
"Streaming Generate": "Streaming Generate",
|
91 |
+
"Tensorboard Host": "Tensorboard Host",
|
92 |
+
"Tensorboard Log Path": "Tensorboard Log Path",
|
93 |
+
"Tensorboard Port": "Tensorboard Port",
|
94 |
+
"Tensorboard interface is closed": "Tensorboard interface is closed",
|
95 |
+
"Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
|
96 |
+
"Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
|
98 |
+
"Training Configuration": "Training Configuration",
|
99 |
+
"Training Error": "Training Error",
|
100 |
+
"Training stopped": "Training stopped",
|
101 |
+
"Type name of the speaker": "Type name of the speaker",
|
102 |
+
"Type the path or select from the dropdown": "Type the path or select from the dropdown",
|
103 |
+
"Use LoRA": "Use LoRA",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
|
105 |
+
"Use filelist": "Use filelist",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
|
107 |
+
"VITS Configuration": "VITS Configuration",
|
108 |
+
"VQGAN Configuration": "VQGAN Configuration",
|
109 |
+
"Validation Batch Size": "Validation Batch Size",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
|
112 |
+
"WebUI Host": "WebUI Host",
|
113 |
+
"WebUI Port": "WebUI Port",
|
114 |
+
"Whisper Model": "Whisper Model",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
|
117 |
+
"latest": "latest",
|
118 |
+
"new": "new",
|
119 |
+
"Realtime Transform Text": "Realtime Transform Text",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
|
121 |
+
"Text Normalization": "Text Normalization",
|
122 |
+
"Select Example Audio": "Select Example Audio"
|
123 |
+
}
|
fish_speech/i18n/locale/es_ES.json
CHANGED
@@ -1,123 +1,123 @@
|
|
1 |
-
{
|
2 |
-
"16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
|
3 |
-
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
|
4 |
-
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
|
5 |
-
"Accumulate Gradient Batches": "Acumular lotes de gradientes",
|
6 |
-
"Add to Processing Area": "Agregar al Área de Procesamiento",
|
7 |
-
"Added path successfully!": "¡Ruta agregada exitosamente!",
|
8 |
-
"Advanced Config": "Configuración Avanzada",
|
9 |
-
"Base LLAMA Model": "Modelo Base LLAMA",
|
10 |
-
"Batch Inference": "Inferencia por Lote",
|
11 |
-
"Batch Size": "Tamaño del Lote",
|
12 |
-
"Changing with the Model Path": "Cambiando con la Ruta del Modelo",
|
13 |
-
"Chinese": "Chino",
|
14 |
-
"Compile Model": "Compilar Modelo",
|
15 |
-
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
|
16 |
-
"Copy": "Copiar",
|
17 |
-
"Data Preprocessing": "Preprocesamiento de Datos",
|
18 |
-
"Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
|
19 |
-
"Data Source": "Fuente de Datos",
|
20 |
-
"Decoder Model Config": "Configuración del modelo decodificador",
|
21 |
-
"Decoder Model Path": "Ruta del modelo decodificador",
|
22 |
-
"Disabled": "Desactivado",
|
23 |
-
"Enable Reference Audio": "Habilitar Audio de Referencia",
|
24 |
-
"English": "Inglés",
|
25 |
-
"Error Message": "Mensaje de Error",
|
26 |
-
"File Preprocessing": "Preprocesamiento de Archivos",
|
27 |
-
"Generate": "Generar",
|
28 |
-
"Generated Audio": "Audio Generado",
|
29 |
-
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
|
30 |
-
"Infer interface is closed": "La interfaz de inferencia está cerrada",
|
31 |
-
"Inference Configuration": "Configuración de Inferencia",
|
32 |
-
"Inference Server Configuration": "Configuración del Servidor de Inferencia",
|
33 |
-
"Inference Server Error": "Error del Servidor de Inferencia",
|
34 |
-
"Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
|
35 |
-
"Initial Learning Rate": "Tasa de Aprendizaje Inicial",
|
36 |
-
"Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
|
37 |
-
"Input Text": "Texto de Entrada",
|
38 |
-
"Invalid path: {}": "Ruta inválida: {}",
|
39 |
-
"It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
|
40 |
-
"Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
|
41 |
-
"Japanese": "Japonés",
|
42 |
-
"LLAMA Configuration": "Configuración de LLAMA",
|
43 |
-
"LLAMA Model Config": "Configuración del Modelo LLAMA",
|
44 |
-
"LLAMA Model Path": "Ruta del Modelo LLAMA",
|
45 |
-
"Labeling Device": "Dispositivo de Etiquetado",
|
46 |
-
"LoRA Model to be merged": "Modelo LoRA a fusionar",
|
47 |
-
"Maximum Audio Duration": "Duración máxima de audio",
|
48 |
-
"Maximum Length per Sample": "Longitud Máxima por Muestra",
|
49 |
-
"Maximum Training Steps": "Pasos Máximos de Entrenamiento",
|
50 |
-
"Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
|
51 |
-
"Merge": "Fusionar",
|
52 |
-
"Merge LoRA": "Fusionar LoRA",
|
53 |
-
"Merge successfully": "Fusionado exitosamente",
|
54 |
-
"Minimum Audio Duration": "Duración mínima de audio",
|
55 |
-
"Model Output Path": "Ruta de Salida del Modelo",
|
56 |
-
"Model Size": "Tamaño del Modelo",
|
57 |
-
"Move": "Mover",
|
58 |
-
"Move files successfully": "Archivos movidos exitosamente",
|
59 |
-
"No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
|
60 |
-
"No selected options": "No hay opciones seleccionadas",
|
61 |
-
"Number of Workers": "Número de Trabajadores",
|
62 |
-
"Open Inference Server": "Abrir Servidor de Inferencia",
|
63 |
-
"Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
|
64 |
-
"Open Tensorboard": "Abrir Tensorboard",
|
65 |
-
"Opened labeler in browser": "Se abrió el etiquetador en el navegador",
|
66 |
-
"Optional Label Language": "Idioma de Etiquetado Opcional",
|
67 |
-
"Optional online ver": "Ver en línea opcional",
|
68 |
-
"Output Path": "Ruta de Salida",
|
69 |
-
"Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
|
70 |
-
"Precision": "Precisión",
|
71 |
-
"Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
|
72 |
-
"Put your text here.": "Ponga su texto aquí.",
|
73 |
-
"Reference Audio": "Audio de Referencia",
|
74 |
-
"Reference Text": "Texto de Referencia",
|
75 |
-
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
|
76 |
-
"Remove Selected Data": "Eliminar Datos Seleccionados",
|
77 |
-
"Removed path successfully!": "¡Ruta eliminada exitosamente!",
|
78 |
-
"Repetition Penalty": "Penalización por Repetición",
|
79 |
-
"Save model every n steps": "Guardar modelo cada n pasos",
|
80 |
-
"Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
|
81 |
-
"Select VITS ckpt": "Seleccionar punto de control VITS",
|
82 |
-
"Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
|
83 |
-
"Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
|
84 |
-
"Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
|
85 |
-
"Selected: {}": "Seleccionado: {}",
|
86 |
-
"Speaker": "Hablante",
|
87 |
-
"Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
|
88 |
-
"Start Training": "Iniciar Entrenamiento",
|
89 |
-
"Streaming Audio": "transmisión de audio",
|
90 |
-
"Streaming Generate": "síntesis en flujo",
|
91 |
-
"Tensorboard Host": "Host de Tensorboard",
|
92 |
-
"Tensorboard Log Path": "Ruta de Registro de Tensorboard",
|
93 |
-
"Tensorboard Port": "Puerto de Tensorboard",
|
94 |
-
"Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
|
95 |
-
"Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
|
96 |
-
"Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
|
97 |
-
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
|
98 |
-
"Training Configuration": "Configuración de Entrenamiento",
|
99 |
-
"Training Error": "Error de Entrenamiento",
|
100 |
-
"Training stopped": "Entrenamiento detenido",
|
101 |
-
"Type name of the speaker": "Escriba el nombre del hablante",
|
102 |
-
"Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
|
103 |
-
"Use LoRA": "Usar LoRA",
|
104 |
-
"Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
|
105 |
-
"Use filelist": "Usar lista de archivos",
|
106 |
-
"Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
|
107 |
-
"VITS Configuration": "Configuración de VITS",
|
108 |
-
"VQGAN Configuration": "Configuración de VQGAN",
|
109 |
-
"Validation Batch Size": "Tamaño del Lote de Validación",
|
110 |
-
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
|
111 |
-
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
|
112 |
-
"WebUI Host": "Host de WebUI",
|
113 |
-
"WebUI Port": "Puerto de WebUI",
|
114 |
-
"Whisper Model": "Modelo Whisper",
|
115 |
-
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
|
116 |
-
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
|
117 |
-
"latest": "más reciente",
|
118 |
-
"new": "nuevo",
|
119 |
-
"Realtime Transform Text": "Transformación de Texto en Tiempo Real",
|
120 |
-
"Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
|
121 |
-
"Text Normalization": "Normalización de Texto",
|
122 |
-
"Select Example Audio": "Selecionar áudio de exemplo"
|
123 |
-
}
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
|
5 |
+
"Accumulate Gradient Batches": "Acumular lotes de gradientes",
|
6 |
+
"Add to Processing Area": "Agregar al Área de Procesamiento",
|
7 |
+
"Added path successfully!": "¡Ruta agregada exitosamente!",
|
8 |
+
"Advanced Config": "Configuración Avanzada",
|
9 |
+
"Base LLAMA Model": "Modelo Base LLAMA",
|
10 |
+
"Batch Inference": "Inferencia por Lote",
|
11 |
+
"Batch Size": "Tamaño del Lote",
|
12 |
+
"Changing with the Model Path": "Cambiando con la Ruta del Modelo",
|
13 |
+
"Chinese": "Chino",
|
14 |
+
"Compile Model": "Compilar Modelo",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
|
16 |
+
"Copy": "Copiar",
|
17 |
+
"Data Preprocessing": "Preprocesamiento de Datos",
|
18 |
+
"Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
|
19 |
+
"Data Source": "Fuente de Datos",
|
20 |
+
"Decoder Model Config": "Configuración del modelo decodificador",
|
21 |
+
"Decoder Model Path": "Ruta del modelo decodificador",
|
22 |
+
"Disabled": "Desactivado",
|
23 |
+
"Enable Reference Audio": "Habilitar Audio de Referencia",
|
24 |
+
"English": "Inglés",
|
25 |
+
"Error Message": "Mensaje de Error",
|
26 |
+
"File Preprocessing": "Preprocesamiento de Archivos",
|
27 |
+
"Generate": "Generar",
|
28 |
+
"Generated Audio": "Audio Generado",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
|
30 |
+
"Infer interface is closed": "La interfaz de inferencia está cerrada",
|
31 |
+
"Inference Configuration": "Configuración de Inferencia",
|
32 |
+
"Inference Server Configuration": "Configuración del Servidor de Inferencia",
|
33 |
+
"Inference Server Error": "Error del Servidor de Inferencia",
|
34 |
+
"Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
|
35 |
+
"Initial Learning Rate": "Tasa de Aprendizaje Inicial",
|
36 |
+
"Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
|
37 |
+
"Input Text": "Texto de Entrada",
|
38 |
+
"Invalid path: {}": "Ruta inválida: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
|
40 |
+
"Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
|
41 |
+
"Japanese": "Japonés",
|
42 |
+
"LLAMA Configuration": "Configuración de LLAMA",
|
43 |
+
"LLAMA Model Config": "Configuración del Modelo LLAMA",
|
44 |
+
"LLAMA Model Path": "Ruta del Modelo LLAMA",
|
45 |
+
"Labeling Device": "Dispositivo de Etiquetado",
|
46 |
+
"LoRA Model to be merged": "Modelo LoRA a fusionar",
|
47 |
+
"Maximum Audio Duration": "Duración máxima de audio",
|
48 |
+
"Maximum Length per Sample": "Longitud Máxima por Muestra",
|
49 |
+
"Maximum Training Steps": "Pasos Máximos de Entrenamiento",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
|
51 |
+
"Merge": "Fusionar",
|
52 |
+
"Merge LoRA": "Fusionar LoRA",
|
53 |
+
"Merge successfully": "Fusionado exitosamente",
|
54 |
+
"Minimum Audio Duration": "Duración mínima de audio",
|
55 |
+
"Model Output Path": "Ruta de Salida del Modelo",
|
56 |
+
"Model Size": "Tamaño del Modelo",
|
57 |
+
"Move": "Mover",
|
58 |
+
"Move files successfully": "Archivos movidos exitosamente",
|
59 |
+
"No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
|
60 |
+
"No selected options": "No hay opciones seleccionadas",
|
61 |
+
"Number of Workers": "Número de Trabajadores",
|
62 |
+
"Open Inference Server": "Abrir Servidor de Inferencia",
|
63 |
+
"Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
|
64 |
+
"Open Tensorboard": "Abrir Tensorboard",
|
65 |
+
"Opened labeler in browser": "Se abrió el etiquetador en el navegador",
|
66 |
+
"Optional Label Language": "Idioma de Etiquetado Opcional",
|
67 |
+
"Optional online ver": "Ver en línea opcional",
|
68 |
+
"Output Path": "Ruta de Salida",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
|
70 |
+
"Precision": "Precisión",
|
71 |
+
"Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
|
72 |
+
"Put your text here.": "Ponga su texto aquí.",
|
73 |
+
"Reference Audio": "Audio de Referencia",
|
74 |
+
"Reference Text": "Texto de Referencia",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
|
76 |
+
"Remove Selected Data": "Eliminar Datos Seleccionados",
|
77 |
+
"Removed path successfully!": "¡Ruta eliminada exitosamente!",
|
78 |
+
"Repetition Penalty": "Penalización por Repetición",
|
79 |
+
"Save model every n steps": "Guardar modelo cada n pasos",
|
80 |
+
"Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
|
81 |
+
"Select VITS ckpt": "Seleccionar punto de control VITS",
|
82 |
+
"Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
|
83 |
+
"Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
|
85 |
+
"Selected: {}": "Seleccionado: {}",
|
86 |
+
"Speaker": "Hablante",
|
87 |
+
"Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
|
88 |
+
"Start Training": "Iniciar Entrenamiento",
|
89 |
+
"Streaming Audio": "transmisión de audio",
|
90 |
+
"Streaming Generate": "síntesis en flujo",
|
91 |
+
"Tensorboard Host": "Host de Tensorboard",
|
92 |
+
"Tensorboard Log Path": "Ruta de Registro de Tensorboard",
|
93 |
+
"Tensorboard Port": "Puerto de Tensorboard",
|
94 |
+
"Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
|
95 |
+
"Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
|
96 |
+
"Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
|
98 |
+
"Training Configuration": "Configuración de Entrenamiento",
|
99 |
+
"Training Error": "Error de Entrenamiento",
|
100 |
+
"Training stopped": "Entrenamiento detenido",
|
101 |
+
"Type name of the speaker": "Escriba el nombre del hablante",
|
102 |
+
"Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
|
103 |
+
"Use LoRA": "Usar LoRA",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
|
105 |
+
"Use filelist": "Usar lista de archivos",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
|
107 |
+
"VITS Configuration": "Configuración de VITS",
|
108 |
+
"VQGAN Configuration": "Configuración de VQGAN",
|
109 |
+
"Validation Batch Size": "Tamaño del Lote de Validación",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
|
112 |
+
"WebUI Host": "Host de WebUI",
|
113 |
+
"WebUI Port": "Puerto de WebUI",
|
114 |
+
"Whisper Model": "Modelo Whisper",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
|
117 |
+
"latest": "más reciente",
|
118 |
+
"new": "nuevo",
|
119 |
+
"Realtime Transform Text": "Transformación de Texto en Tiempo Real",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
|
121 |
+
"Text Normalization": "Normalización de Texto",
|
122 |
+
"Select Example Audio": "Selecionar áudio de exemplo"
|
123 |
+
}
|
fish_speech/i18n/locale/ja_JP.json
CHANGED
@@ -1,123 +1,123 @@
|
|
1 |
-
{
|
2 |
-
"16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
|
3 |
-
"5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
|
4 |
-
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlama
|
5 |
-
"Accumulate Gradient Batches": "勾配バッチの累積",
|
6 |
-
"Add to Processing Area": "処理エリアに追加",
|
7 |
-
"Added path successfully!": "パスの追加に成功しました!",
|
8 |
-
"Advanced Config": "詳細設定",
|
9 |
-
"Base LLAMA Model": "基本LLAMAモデル",
|
10 |
-
"Batch Inference": "バッチ推論",
|
11 |
-
"Batch Size": "バッチサイズ",
|
12 |
-
"Changing with the Model Path": "モデルのパスに伴って変化する",
|
13 |
-
"Chinese": "中国語",
|
14 |
-
"Compile Model": "モデルのコンパイル",
|
15 |
-
"Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
|
16 |
-
"Copy": "コピー",
|
17 |
-
"Data Preprocessing": "データ前処理",
|
18 |
-
"Data Preprocessing Path": "データ前処理パス",
|
19 |
-
"Data Source": "データソース",
|
20 |
-
"Decoder Model Config": "デコーダーモデルの構成",
|
21 |
-
"Decoder Model Path": "デコーダーモデルのパス",
|
22 |
-
"Disabled": "無効",
|
23 |
-
"Enable Reference Audio": "リファレンスオーディオを有効にする",
|
24 |
-
"English": "英語",
|
25 |
-
"Error Message": "エラーメッセージ",
|
26 |
-
"File Preprocessing": "文書前处理",
|
27 |
-
"Generate": "生成",
|
28 |
-
"Generated Audio": "生成されたオーディオ",
|
29 |
-
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
|
30 |
-
"Infer interface is closed": "推論インターフェースが閉じられています",
|
31 |
-
"Inference Configuration": "推論設定",
|
32 |
-
"Inference Server Configuration": "推論サーバー設定",
|
33 |
-
"Inference Server Error": "推論サーバーエラー",
|
34 |
-
"Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
|
35 |
-
"Initial Learning Rate": "初期学習率",
|
36 |
-
"Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
|
37 |
-
"Input Text": "入力テキスト",
|
38 |
-
"Invalid path: {}": "無効なパス: {}",
|
39 |
-
"It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
|
40 |
-
"Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
|
41 |
-
"Japanese": "日本語",
|
42 |
-
"LLAMA Configuration": "LLAMA設定",
|
43 |
-
"LLAMA Model Config": "LLAMAモデル設定",
|
44 |
-
"LLAMA Model Path": "LLAMAモデルパス",
|
45 |
-
"Labeling Device": "ラベリングデバイス",
|
46 |
-
"LoRA Model to be merged": "マージするLoRAモデル",
|
47 |
-
"Maximum Audio Duration": "最大オーディオの長さ",
|
48 |
-
"Maximum Length per Sample": "サンプルあたりの最大長",
|
49 |
-
"Maximum Training Steps": "最大トレーニングステップ数",
|
50 |
-
"Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
|
51 |
-
"Merge": "マージ",
|
52 |
-
"Merge LoRA": "LoRAのマージ",
|
53 |
-
"Merge successfully": "マージに成功しました",
|
54 |
-
"Minimum Audio Duration": "最小オーディオの長さ",
|
55 |
-
"Model Output Path": "モデル出力パス",
|
56 |
-
"Model Size": "モデルサイズ",
|
57 |
-
"Move": "移動",
|
58 |
-
"Move files successfully": "ファイルの移動に成功しました",
|
59 |
-
"No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
|
60 |
-
"No selected options": "選択されたオプションはありません",
|
61 |
-
"Number of Workers": "ワーカー数",
|
62 |
-
"Open Inference Server": "推論サーバーを開く",
|
63 |
-
"Open Labeler WebUI": "ラベラーWebUIを開く",
|
64 |
-
"Open Tensorboard": "Tensorboardを開く",
|
65 |
-
"Opened labeler in browser": "ブラウザでラベラーを開きました",
|
66 |
-
"Optional Label Language": "オプションのラベル言語",
|
67 |
-
"Optional online ver": "オプションのオンラインバージョン",
|
68 |
-
"Output Path": "出力パス",
|
69 |
-
"Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
|
70 |
-
"Precision": "精度",
|
71 |
-
"Probability of applying Speaker Condition": "話者条件を適用する確率",
|
72 |
-
"Put your text here.": "
|
73 |
-
"Reference Audio": "リファレンスオーディオ",
|
74 |
-
"Reference Text": "リファレンステキスト",
|
75 |
-
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
|
76 |
-
"Remove Selected Data": "選択したデータを削除",
|
77 |
-
"Removed path successfully!": "パスの削除に成功しました!",
|
78 |
-
"Repetition Penalty": "反復ペナルティ",
|
79 |
-
"Save model every n steps": "nステップごとにモデルを保存",
|
80 |
-
"Select LLAMA ckpt": " LLAMA チェックポイントを選択",
|
81 |
-
"Select VITS ckpt": "VITS チェックポイントを選択",
|
82 |
-
"Select VQGAN ckpt": "VQGAN チェックポイントを選択",
|
83 |
-
"Select source file processing method": "ソースファイルの処理方法を選択",
|
84 |
-
"Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
|
85 |
-
"Selected: {}": "選択済み: {}",
|
86 |
-
"Speaker": "話者",
|
87 |
-
"Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
|
88 |
-
"Start Training": "トレーニング開始",
|
89 |
-
"Streaming Audio": "ストリーミングオーディオ",
|
90 |
-
"Streaming Generate": "ストリーミング合成",
|
91 |
-
"Tensorboard Host": "Tensorboardホスト",
|
92 |
-
"Tensorboard Log Path": "Tensorboardログパス",
|
93 |
-
"Tensorboard Port": "Tensorboardポート",
|
94 |
-
"Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
|
95 |
-
"Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
|
96 |
-
"Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
|
97 |
-
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
|
98 |
-
"Training Configuration": "トレーニング設定",
|
99 |
-
"Training Error": "トレーニングエラー",
|
100 |
-
"Training stopped": "トレーニングが停止しました",
|
101 |
-
"Type name of the speaker": "話者の名前を入力",
|
102 |
-
"Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
|
103 |
-
"Use LoRA": "LoRAを使用",
|
104 |
-
"Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
|
105 |
-
"Use filelist": "ファイルリストを使用",
|
106 |
-
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
|
107 |
-
"VITS Configuration": "VITS の構成",
|
108 |
-
"VQGAN Configuration": "VQGAN の構成",
|
109 |
-
"Validation Batch Size": "検証バッチサイズ",
|
110 |
-
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
|
111 |
-
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
|
112 |
-
"WebUI Host": "WebUIホスト",
|
113 |
-
"WebUI Port": "WebUIポート",
|
114 |
-
"Whisper Model": "Whisperモデル",
|
115 |
-
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
|
116 |
-
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
|
117 |
-
"latest": "最新",
|
118 |
-
"new": "新規",
|
119 |
-
"Realtime Transform Text": "リアルタイム変換テキスト",
|
120 |
-
"Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
|
121 |
-
"Text Normalization": "テキスト正規化",
|
122 |
-
"Select Example Audio": "サンプル音声を選択"
|
123 |
-
}
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成��デル。",
|
5 |
+
"Accumulate Gradient Batches": "勾配バッチの累積",
|
6 |
+
"Add to Processing Area": "処理エリアに追加",
|
7 |
+
"Added path successfully!": "パスの追加に成功しました!",
|
8 |
+
"Advanced Config": "詳細設定",
|
9 |
+
"Base LLAMA Model": "基本LLAMAモデル",
|
10 |
+
"Batch Inference": "バッチ推論",
|
11 |
+
"Batch Size": "バッチサイズ",
|
12 |
+
"Changing with the Model Path": "モデルのパスに伴って変化する",
|
13 |
+
"Chinese": "中国語",
|
14 |
+
"Compile Model": "モデルのコンパイル",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
|
16 |
+
"Copy": "コピー",
|
17 |
+
"Data Preprocessing": "データ前処理",
|
18 |
+
"Data Preprocessing Path": "データ前処理パス",
|
19 |
+
"Data Source": "データソース",
|
20 |
+
"Decoder Model Config": "デコーダーモデルの構成",
|
21 |
+
"Decoder Model Path": "デコーダーモデルのパス",
|
22 |
+
"Disabled": "無効",
|
23 |
+
"Enable Reference Audio": "リファレンスオーディオを有効にする",
|
24 |
+
"English": "英語",
|
25 |
+
"Error Message": "エラーメッセージ",
|
26 |
+
"File Preprocessing": "文書前处理",
|
27 |
+
"Generate": "生成",
|
28 |
+
"Generated Audio": "生成されたオーディオ",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
|
30 |
+
"Infer interface is closed": "推論インターフェースが閉じられています",
|
31 |
+
"Inference Configuration": "推論設定",
|
32 |
+
"Inference Server Configuration": "推論サーバー設定",
|
33 |
+
"Inference Server Error": "推論サーバーエラー",
|
34 |
+
"Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
|
35 |
+
"Initial Learning Rate": "初期学習率",
|
36 |
+
"Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
|
37 |
+
"Input Text": "入力テキスト",
|
38 |
+
"Invalid path: {}": "無効なパス: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
|
40 |
+
"Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
|
41 |
+
"Japanese": "日本語",
|
42 |
+
"LLAMA Configuration": "LLAMA設定",
|
43 |
+
"LLAMA Model Config": "LLAMAモデル設定",
|
44 |
+
"LLAMA Model Path": "LLAMAモデルパス",
|
45 |
+
"Labeling Device": "ラベリングデバイス",
|
46 |
+
"LoRA Model to be merged": "マージするLoRAモデル",
|
47 |
+
"Maximum Audio Duration": "最大オーディオの長さ",
|
48 |
+
"Maximum Length per Sample": "サンプルあたりの最大長",
|
49 |
+
"Maximum Training Steps": "最大トレーニングステップ数",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
|
51 |
+
"Merge": "マージ",
|
52 |
+
"Merge LoRA": "LoRAのマージ",
|
53 |
+
"Merge successfully": "マージに成功しました",
|
54 |
+
"Minimum Audio Duration": "最小オーディオの長さ",
|
55 |
+
"Model Output Path": "モデル出力パス",
|
56 |
+
"Model Size": "モデルサイズ",
|
57 |
+
"Move": "移動",
|
58 |
+
"Move files successfully": "ファイルの移動に成功しました",
|
59 |
+
"No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
|
60 |
+
"No selected options": "選択されたオプションはありません",
|
61 |
+
"Number of Workers": "ワーカー数",
|
62 |
+
"Open Inference Server": "推論サーバーを開く",
|
63 |
+
"Open Labeler WebUI": "ラベラーWebUIを開く",
|
64 |
+
"Open Tensorboard": "Tensorboardを開く",
|
65 |
+
"Opened labeler in browser": "ブラウザでラベラーを開きました",
|
66 |
+
"Optional Label Language": "オプションのラベル言語",
|
67 |
+
"Optional online ver": "オプションのオンラインバージョン",
|
68 |
+
"Output Path": "出力パス",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
|
70 |
+
"Precision": "精度",
|
71 |
+
"Probability of applying Speaker Condition": "話者条件を適用する確率",
|
72 |
+
"Put your text here.": "ここにテキストを入力してください。",
|
73 |
+
"Reference Audio": "リファレンスオーディオ",
|
74 |
+
"Reference Text": "リファレンステキスト",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
|
76 |
+
"Remove Selected Data": "選択したデータを削除",
|
77 |
+
"Removed path successfully!": "パスの削除に成功しました!",
|
78 |
+
"Repetition Penalty": "反復ペナルティ",
|
79 |
+
"Save model every n steps": "nステップごとにモデルを保存",
|
80 |
+
"Select LLAMA ckpt": " LLAMA チェックポイントを選択",
|
81 |
+
"Select VITS ckpt": "VITS チェックポイントを選択",
|
82 |
+
"Select VQGAN ckpt": "VQGAN チェックポイントを選択",
|
83 |
+
"Select source file processing method": "ソースファイルの処理方法を選択",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
|
85 |
+
"Selected: {}": "選択済み: {}",
|
86 |
+
"Speaker": "話者",
|
87 |
+
"Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
|
88 |
+
"Start Training": "トレーニング開始",
|
89 |
+
"Streaming Audio": "ストリーミングオーディオ",
|
90 |
+
"Streaming Generate": "ストリーミング合成",
|
91 |
+
"Tensorboard Host": "Tensorboardホスト",
|
92 |
+
"Tensorboard Log Path": "Tensorboardログパス",
|
93 |
+
"Tensorboard Port": "Tensorboardポート",
|
94 |
+
"Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
|
95 |
+
"Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
|
96 |
+
"Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
|
98 |
+
"Training Configuration": "トレーニング設定",
|
99 |
+
"Training Error": "トレーニングエラー",
|
100 |
+
"Training stopped": "トレーニングが停止しました",
|
101 |
+
"Type name of the speaker": "話者の名前を入力",
|
102 |
+
"Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
|
103 |
+
"Use LoRA": "LoRAを使用",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
|
105 |
+
"Use filelist": "ファイルリストを使用",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
|
107 |
+
"VITS Configuration": "VITS の構成",
|
108 |
+
"VQGAN Configuration": "VQGAN の構成",
|
109 |
+
"Validation Batch Size": "検証バッチサイズ",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
|
112 |
+
"WebUI Host": "WebUIホスト",
|
113 |
+
"WebUI Port": "WebUIポート",
|
114 |
+
"Whisper Model": "Whisperモデル",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
|
117 |
+
"latest": "最新",
|
118 |
+
"new": "新規",
|
119 |
+
"Realtime Transform Text": "リアルタイム変換テキスト",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
|
121 |
+
"Text Normalization": "テキスト正規化",
|
122 |
+
"Select Example Audio": "サンプル音声を選択"
|
123 |
+
}
|
fish_speech/i18n/locale/ko_KR.json
CHANGED
@@ -1,123 +1,123 @@
|
|
1 |
-
{
|
2 |
-
"16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
|
3 |
-
"5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
|
4 |
-
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
|
5 |
-
"Accumulate Gradient Batches": "그라디언트 배치 누적",
|
6 |
-
"Add to Processing Area": "처리 영역에 추가",
|
7 |
-
"Added path successfully!": "경로가 성공적으로 추가되었습니다!",
|
8 |
-
"Advanced Config": "고급 설정",
|
9 |
-
"Base LLAMA Model": "기본 LLAMA 모델",
|
10 |
-
"Batch Inference": "배치 추론",
|
11 |
-
"Batch Size": "배치 크기",
|
12 |
-
"Changing with the Model Path": "모델 경로에 따라 변경 중",
|
13 |
-
"Chinese": "중국어",
|
14 |
-
"Compile Model": "모델 컴파일",
|
15 |
-
"Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
|
16 |
-
"Copy": "복사",
|
17 |
-
"Data Preprocessing": "데이터 전처리",
|
18 |
-
"Data Preprocessing Path": "데이터 전처리 경로",
|
19 |
-
"Data Source": "데이터 소스",
|
20 |
-
"Decoder Model Config": "디코더 모델 설정",
|
21 |
-
"Decoder Model Path": "디코더 모델 경로",
|
22 |
-
"Disabled": "비활성화 됨",
|
23 |
-
"Enable Reference Audio": "참고 음성 활성화",
|
24 |
-
"English": "영어",
|
25 |
-
"Error Message": "오류 메시지",
|
26 |
-
"File Preprocessing": "파일 전처리",
|
27 |
-
"Generate": "생성",
|
28 |
-
"Generated Audio": "생성된 오디오",
|
29 |
-
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
|
30 |
-
"Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
|
31 |
-
"Inference Configuration": "추론 설정",
|
32 |
-
"Inference Server Configuration": "추론 서버 설정",
|
33 |
-
"Inference Server Error": "추론 서버 오류",
|
34 |
-
"Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
|
35 |
-
"Initial Learning Rate": "초기 학습률",
|
36 |
-
"Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
|
37 |
-
"Input Text": "입력 텍스트",
|
38 |
-
"Invalid path: {}": "유효하지 않은 경로: {}",
|
39 |
-
"It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
|
40 |
-
"Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
|
41 |
-
"Japanese": "일본어",
|
42 |
-
"LLAMA Configuration": "LLAMA 설정",
|
43 |
-
"LLAMA Model Config": "LLAMA 모델 설정",
|
44 |
-
"LLAMA Model Path": "LLAMA 모델 경로",
|
45 |
-
"Labeling Device": "라벨링 장치",
|
46 |
-
"LoRA Model to be merged": "병합할 LoRA 모델",
|
47 |
-
"Maximum Audio Duration": "최대 오디오 길이",
|
48 |
-
"Maximum Length per Sample": "샘플당 최대 길이",
|
49 |
-
"Maximum Training Steps": "최대 학습 단계",
|
50 |
-
"Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
|
51 |
-
"Merge": "병합",
|
52 |
-
"Merge LoRA": "LoRA 병합",
|
53 |
-
"Merge successfully": "성공적으로 병합 되었습니다.",
|
54 |
-
"Minimum Audio Duration": "최소 오디오 길이",
|
55 |
-
"Model Output Path": "모델 출력 경로",
|
56 |
-
"Model Size": "모델 크기",
|
57 |
-
"Move": "이동",
|
58 |
-
"Move files successfully": "파일이 성공적으로 이동되었습니다.",
|
59 |
-
"No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
|
60 |
-
"No selected options": "옵션이 선택되지 않았습니다.",
|
61 |
-
"Number of Workers": "작업자 수",
|
62 |
-
"Open Inference Server": "추론 서버 열기",
|
63 |
-
"Open Labeler WebUI": "라벨러 WebUI 열기",
|
64 |
-
"Open Tensorboard": "Tensorboard 열기",
|
65 |
-
"Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
|
66 |
-
"Optional Label Language": "선택적 라벨 언어",
|
67 |
-
"Optional online ver": "온라인 버전 선택",
|
68 |
-
"Output Path": "출력 경로",
|
69 |
-
"Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
|
70 |
-
"Precision": "정밀도",
|
71 |
-
"Probability of applying Speaker Condition": "화자 조건 적용 확률",
|
72 |
-
"Put your text here.": "여기에 텍스트를 입력하세요.",
|
73 |
-
"Reference Audio": "참고 오디오",
|
74 |
-
"Reference Text": "참고 텍스트",
|
75 |
-
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
|
76 |
-
"Remove Selected Data": "선택한 데이터 제거",
|
77 |
-
"Removed path successfully!": "경로가 성공적으로 제거되었습니다!",
|
78 |
-
"Repetition Penalty": "반복 패널티",
|
79 |
-
"Save model every n steps": "n 단계마다 모델 저장",
|
80 |
-
"Select LLAMA ckpt": "LLAMA ckpt 선택",
|
81 |
-
"Select VITS ckpt": "VITS ckpt 선택",
|
82 |
-
"Select VQGAN ckpt": "VQGAN ckpt 선택",
|
83 |
-
"Select source file processing method": "소스 파일 처리 방법 선택",
|
84 |
-
"Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
|
85 |
-
"Selected: {}": "선택됨: {}",
|
86 |
-
"Speaker": "화자",
|
87 |
-
"Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
|
88 |
-
"Start Training": "학습 시작",
|
89 |
-
"Streaming Audio": "스트리밍 오디오",
|
90 |
-
"Streaming Generate": "스트리밍 생성",
|
91 |
-
"Tensorboard Host": "Tensorboard 호스트",
|
92 |
-
"Tensorboard Log Path": "Tensorboard 로그 경로",
|
93 |
-
"Tensorboard Port": "Tensorboard 포트",
|
94 |
-
"Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
|
95 |
-
"Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
|
96 |
-
"Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
|
97 |
-
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
|
98 |
-
"Training Configuration": "학습 설정",
|
99 |
-
"Training Error": "학습 오류",
|
100 |
-
"Training stopped": "학습이 중지되었습니다.",
|
101 |
-
"Type name of the speaker": "화자의 이름을 입력하세요.",
|
102 |
-
"Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
|
103 |
-
"Use LoRA": "LoRA 사용",
|
104 |
-
"Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
|
105 |
-
"Use filelist": "파일 목록 사용",
|
106 |
-
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
|
107 |
-
"VITS Configuration": "VITS 설정",
|
108 |
-
"VQGAN Configuration": "VQGAN 설정",
|
109 |
-
"Validation Batch Size": "검증 배치 크기",
|
110 |
-
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
|
111 |
-
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
|
112 |
-
"WebUI Host": "WebUI 호스트",
|
113 |
-
"WebUI Port": "WebUI 포트",
|
114 |
-
"Whisper Model": "Whisper 모델",
|
115 |
-
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
|
116 |
-
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
|
117 |
-
"latest": "최신",
|
118 |
-
"new": "새로운",
|
119 |
-
"Realtime Transform Text": "실시간 텍스트 변환",
|
120 |
-
"Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
|
121 |
-
"Text Normalization": "텍스트 정규화",
|
122 |
-
"Select Example Audio": "예시 오디오 선택"
|
123 |
-
}
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
|
5 |
+
"Accumulate Gradient Batches": "그라디언트 배치 누적",
|
6 |
+
"Add to Processing Area": "처리 영역에 추가",
|
7 |
+
"Added path successfully!": "경로가 성공적으로 추가되었습니다!",
|
8 |
+
"Advanced Config": "고급 설정",
|
9 |
+
"Base LLAMA Model": "기본 LLAMA 모델",
|
10 |
+
"Batch Inference": "배치 추론",
|
11 |
+
"Batch Size": "배치 크기",
|
12 |
+
"Changing with the Model Path": "모델 경로에 따라 변경 중",
|
13 |
+
"Chinese": "중국어",
|
14 |
+
"Compile Model": "모델 컴파일",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
|
16 |
+
"Copy": "복사",
|
17 |
+
"Data Preprocessing": "데이터 전처리",
|
18 |
+
"Data Preprocessing Path": "데이터 전처리 경로",
|
19 |
+
"Data Source": "데이터 소스",
|
20 |
+
"Decoder Model Config": "디코더 모델 설정",
|
21 |
+
"Decoder Model Path": "디코더 모델 경로",
|
22 |
+
"Disabled": "비활성화 됨",
|
23 |
+
"Enable Reference Audio": "참고 음성 활성화",
|
24 |
+
"English": "영어",
|
25 |
+
"Error Message": "오류 메시지",
|
26 |
+
"File Preprocessing": "파일 전처리",
|
27 |
+
"Generate": "생성",
|
28 |
+
"Generated Audio": "생성된 오디오",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
|
30 |
+
"Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
|
31 |
+
"Inference Configuration": "추론 설정",
|
32 |
+
"Inference Server Configuration": "추론 서버 설정",
|
33 |
+
"Inference Server Error": "추론 서버 오류",
|
34 |
+
"Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
|
35 |
+
"Initial Learning Rate": "초기 학습률",
|
36 |
+
"Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
|
37 |
+
"Input Text": "입력 텍스트",
|
38 |
+
"Invalid path: {}": "유효하지 않은 경로: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
|
40 |
+
"Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
|
41 |
+
"Japanese": "일본어",
|
42 |
+
"LLAMA Configuration": "LLAMA 설정",
|
43 |
+
"LLAMA Model Config": "LLAMA 모델 설정",
|
44 |
+
"LLAMA Model Path": "LLAMA 모델 경로",
|
45 |
+
"Labeling Device": "라벨링 장치",
|
46 |
+
"LoRA Model to be merged": "병합할 LoRA 모델",
|
47 |
+
"Maximum Audio Duration": "최대 오디오 길이",
|
48 |
+
"Maximum Length per Sample": "샘플당 최대 길이",
|
49 |
+
"Maximum Training Steps": "최대 학습 단계",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
|
51 |
+
"Merge": "병합",
|
52 |
+
"Merge LoRA": "LoRA 병합",
|
53 |
+
"Merge successfully": "성공적으로 병합 되었습니다.",
|
54 |
+
"Minimum Audio Duration": "최소 오디오 길이",
|
55 |
+
"Model Output Path": "모델 출력 경로",
|
56 |
+
"Model Size": "모델 크기",
|
57 |
+
"Move": "이동",
|
58 |
+
"Move files successfully": "파일이 성공적으로 이동되었습니다.",
|
59 |
+
"No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
|
60 |
+
"No selected options": "옵션이 선택되지 않았습니다.",
|
61 |
+
"Number of Workers": "작업자 수",
|
62 |
+
"Open Inference Server": "추론 서버 열기",
|
63 |
+
"Open Labeler WebUI": "라벨러 WebUI 열기",
|
64 |
+
"Open Tensorboard": "Tensorboard 열기",
|
65 |
+
"Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
|
66 |
+
"Optional Label Language": "선택적 라벨 언어",
|
67 |
+
"Optional online ver": "온라인 버전 선택",
|
68 |
+
"Output Path": "출력 경로",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
|
70 |
+
"Precision": "정밀도",
|
71 |
+
"Probability of applying Speaker Condition": "화자 조건 적용 확률",
|
72 |
+
"Put your text here.": "여기에 텍스트를 입력하세요.",
|
73 |
+
"Reference Audio": "참고 오디오",
|
74 |
+
"Reference Text": "참고 텍스트",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
|
76 |
+
"Remove Selected Data": "선택한 데이터 제거",
|
77 |
+
"Removed path successfully!": "경로가 성공적으로 제거되었습니다!",
|
78 |
+
"Repetition Penalty": "반복 패널티",
|
79 |
+
"Save model every n steps": "n 단계마다 모델 저장",
|
80 |
+
"Select LLAMA ckpt": "LLAMA ckpt 선택",
|
81 |
+
"Select VITS ckpt": "VITS ckpt 선택",
|
82 |
+
"Select VQGAN ckpt": "VQGAN ckpt 선택",
|
83 |
+
"Select source file processing method": "소스 파일 처리 방법 선택",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
|
85 |
+
"Selected: {}": "선택됨: {}",
|
86 |
+
"Speaker": "화자",
|
87 |
+
"Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
|
88 |
+
"Start Training": "학습 시작",
|
89 |
+
"Streaming Audio": "스트리밍 오디오",
|
90 |
+
"Streaming Generate": "스트리밍 생성",
|
91 |
+
"Tensorboard Host": "Tensorboard 호스트",
|
92 |
+
"Tensorboard Log Path": "Tensorboard 로그 경로",
|
93 |
+
"Tensorboard Port": "Tensorboard 포트",
|
94 |
+
"Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
|
95 |
+
"Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
|
96 |
+
"Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
|
98 |
+
"Training Configuration": "학습 설정",
|
99 |
+
"Training Error": "학습 오류",
|
100 |
+
"Training stopped": "학습이 중지되었습니다.",
|
101 |
+
"Type name of the speaker": "화자의 이름을 입력하세요.",
|
102 |
+
"Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
|
103 |
+
"Use LoRA": "LoRA 사용",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
|
105 |
+
"Use filelist": "파일 목록 사용",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
|
107 |
+
"VITS Configuration": "VITS 설정",
|
108 |
+
"VQGAN Configuration": "VQGAN 설정",
|
109 |
+
"Validation Batch Size": "검증 배치 크기",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
|
112 |
+
"WebUI Host": "WebUI 호스트",
|
113 |
+
"WebUI Port": "WebUI 포트",
|
114 |
+
"Whisper Model": "Whisper 모델",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
|
117 |
+
"latest": "최신",
|
118 |
+
"new": "새로운",
|
119 |
+
"Realtime Transform Text": "실시간 텍스트 변환",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
|
121 |
+
"Text Normalization": "텍스트 정규화",
|
122 |
+
"Select Example Audio": "예시 오디오 선택"
|
123 |
+
}
|
fish_speech/i18n/locale/pt_BR.json
CHANGED
@@ -1,133 +1,133 @@
|
|
1 |
-
{
|
2 |
-
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
|
3 |
-
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
|
4 |
-
"Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
|
5 |
-
"Add to Processing Area": "Adicionar à Área de Processamento",
|
6 |
-
"Added path successfully!": "Caminho adicionado com sucesso!",
|
7 |
-
"Advanced Config": "Configuração Avançada",
|
8 |
-
"Base LLAMA Model": "Modelo LLAMA Base",
|
9 |
-
"Batch Inference": "Inferência em Lote",
|
10 |
-
"Batch Size": "Tamanho do Lote",
|
11 |
-
"Changing with the Model Path": "Alterando com o Caminho do Modelo",
|
12 |
-
|
13 |
-
"Compile Model": "Compilar Modelo",
|
14 |
-
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
|
15 |
-
"Copy": "Copiar",
|
16 |
-
"Data Preprocessing": "Pré-processamento de Dados",
|
17 |
-
"Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
|
18 |
-
"Data Source": "Fonte de Dados",
|
19 |
-
"Decoder Model Config": "Configuração do Modelo Decodificador",
|
20 |
-
"Decoder Model Path": "Caminho do Modelo Decodificador",
|
21 |
-
"Disabled": "Desativado",
|
22 |
-
"Enable Initial Prompt": "Habilitar Prompt Inicial",
|
23 |
-
"Enable Reference Audio": "Habilitar Áudio de Referência",
|
24 |
-
"English": "Inglês",
|
25 |
-
"Japanese": "Japonês",
|
26 |
-
"Chinese": "Chinês",
|
27 |
-
"Portuguese": "Português",
|
28 |
-
"Spanish": "Espanhol",
|
29 |
-
"Error Message": "Mensagem de Erro",
|
30 |
-
"Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
|
31 |
-
"File Preprocessing": "Pré-processamento de Arquivos",
|
32 |
-
"Generate": "Gerar",
|
33 |
-
"Generated Audio": "Áudio Gerado",
|
34 |
-
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
|
35 |
-
"Infer interface is closed": "A interface de inferência foi fechada",
|
36 |
-
"Inference Configuration": "Configuração de Inferência",
|
37 |
-
"Inference Server Configuration": "Configuração do Servidor de Inferência",
|
38 |
-
"Inference Server Error": "Erro do Servidor de Inferência",
|
39 |
-
"Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
|
40 |
-
"Initial Learning Rate": "Taxa de Aprendizagem Inicial",
|
41 |
-
"Initial Prompt": "Prompt Inicial",
|
42 |
-
"Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
|
43 |
-
"Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
|
44 |
-
"Input Text": "Texto de Entrada",
|
45 |
-
"Invalid path: {}": "Caminho inválido: {}",
|
46 |
-
"It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
|
47 |
-
"Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
|
48 |
-
"LLAMA Configuration": "Configuração do LLAMA",
|
49 |
-
"LLAMA Model Config": "Configuração do Modelo LLAMA",
|
50 |
-
"LLAMA Model Path": "Caminho do Modelo LLAMA",
|
51 |
-
"Labeling Device": "Dispositivo de Rotulagem",
|
52 |
-
"LoRA Model to be merged": "Modelo LoRA para mesclagem",
|
53 |
-
"Maximum Length per Sample": "Comprimento Máximo por Amostra",
|
54 |
-
"Maximum Training Steps": "Etapas Máximas de Treinamento",
|
55 |
-
"Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
|
56 |
-
"Merge": "Mesclar",
|
57 |
-
"Merge LoRA": "Mesclar LoRA",
|
58 |
-
"Merge successfully": "Mesclado com sucesso",
|
59 |
-
"Model Output Path": "Caminho de Saída do Modelo",
|
60 |
-
"Model Quantization": "Quantização do Modelo",
|
61 |
-
"Model Size": "Tamanho do Modelo",
|
62 |
-
"Move": "Mover",
|
63 |
-
"Move files successfully": "Arquivos movidos com sucesso",
|
64 |
-
"No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
|
65 |
-
"No selected options": "Nenhuma opção selecionada",
|
66 |
-
"Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
|
67 |
-
"Number of Workers": "Número de Processos",
|
68 |
-
"Open Inference Server": "Abrir Servidor de Inferência",
|
69 |
-
"Open Labeler WebUI": "Abrir WebUI de Rotulagem",
|
70 |
-
"Open Tensorboard": "Abrir Tensorboard",
|
71 |
-
"Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
|
72 |
-
"Optional Label Language": "Idioma do Rótulo (Opcional)",
|
73 |
-
"Optional online ver": "Versão online (opcional)",
|
74 |
-
"Output Path": "Caminho de Saída",
|
75 |
-
"Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
|
76 |
-
"Post-quantification Precision": "Precisão Pós-quantização",
|
77 |
-
"Precision": "Precisão",
|
78 |
-
"Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
|
79 |
-
"Put your text here.": "Insira seu texto aqui.",
|
80 |
-
"Quantify": "Quantizar",
|
81 |
-
"Quantify successfully": "Quantizado com sucesso",
|
82 |
-
"Realtime Transform Text": "Transformar Texto em Tempo Real",
|
83 |
-
"Reference Audio": "Áudio de Referência",
|
84 |
-
"Reference Text": "Texto de Referência",
|
85 |
-
"warning": "Aviso",
|
86 |
-
"Pre-processing begins...": "O pré-processamento começou!",
|
87 |
-
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
|
88 |
-
"Remove Selected Data": "Remover Dados Selecionados",
|
89 |
-
"Removed path successfully!": "Caminho removido com sucesso!",
|
90 |
-
"Repetition Penalty": "Penalidade de Repetição",
|
91 |
-
"Save model every n steps": "Salvar modelo a cada n etapas",
|
92 |
-
"Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
|
93 |
-
"Select source file processing method": "Escolha como processar o arquivo de origem",
|
94 |
-
"Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
|
95 |
-
"Selected: {}": "Selecionado: {}",
|
96 |
-
"Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
|
97 |
-
"Start Training": "Iniciar Treinamento",
|
98 |
-
"Streaming Audio": "Áudio em Streaming",
|
99 |
-
"Streaming Generate": "Geração em Streaming",
|
100 |
-
"Tensorboard Host": "Host do Tensorboard",
|
101 |
-
"Tensorboard Log Path": "Caminho de Log do Tensorboard",
|
102 |
-
"Tensorboard Port": "Porta do Tensorboard",
|
103 |
-
"Tensorboard interface is closed": "A interface do Tensorboard está fechada",
|
104 |
-
"Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
|
105 |
-
"Text Normalization": "Normalização de Texto",
|
106 |
-
"Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
|
107 |
-
"The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
|
108 |
-
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
|
109 |
-
"Training Configuration": "Configuração de Treinamento",
|
110 |
-
"Training Error": "Erro de Treinamento",
|
111 |
-
"Training stopped": "Treinamento interrompido!",
|
112 |
-
"Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
|
113 |
-
"Use LoRA": "Usar LoRA",
|
114 |
-
"Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
|
115 |
-
"Use filelist": "Usar lista de arquivos",
|
116 |
-
"VQGAN Configuration": "Configuração do VQGAN",
|
117 |
-
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
|
118 |
-
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
|
119 |
-
"WebUI Host": "Host da WebUI",
|
120 |
-
"WebUI Port": "Porta da WebUI",
|
121 |
-
"Whisper Model": "Modelo Whisper",
|
122 |
-
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
|
123 |
-
"auto": "automático",
|
124 |
-
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
|
125 |
-
"latest": "mais recente",
|
126 |
-
"new": "novo",
|
127 |
-
"This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
|
128 |
-
"You don't need to train this model!": "Não é necessário treinar este modelo!",
|
129 |
-
"Yes": "Sim",
|
130 |
-
"No": "Não",
|
131 |
-
"version:": "versão:",
|
132 |
-
"author:": "autor:"
|
133 |
-
}
|
|
|
1 |
+
{
|
2 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
|
3 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
|
4 |
+
"Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
|
5 |
+
"Add to Processing Area": "Adicionar à Área de Processamento",
|
6 |
+
"Added path successfully!": "Caminho adicionado com sucesso!",
|
7 |
+
"Advanced Config": "Configuração Avançada",
|
8 |
+
"Base LLAMA Model": "Modelo LLAMA Base",
|
9 |
+
"Batch Inference": "Inferência em Lote",
|
10 |
+
"Batch Size": "Tamanho do Lote",
|
11 |
+
"Changing with the Model Path": "Alterando com o Caminho do Modelo",
|
12 |
+
|
13 |
+
"Compile Model": "Compilar Modelo",
|
14 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
|
15 |
+
"Copy": "Copiar",
|
16 |
+
"Data Preprocessing": "Pré-processamento de Dados",
|
17 |
+
"Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
|
18 |
+
"Data Source": "Fonte de Dados",
|
19 |
+
"Decoder Model Config": "Configuração do Modelo Decodificador",
|
20 |
+
"Decoder Model Path": "Caminho do Modelo Decodificador",
|
21 |
+
"Disabled": "Desativado",
|
22 |
+
"Enable Initial Prompt": "Habilitar Prompt Inicial",
|
23 |
+
"Enable Reference Audio": "Habilitar Áudio de Referência",
|
24 |
+
"English": "Inglês",
|
25 |
+
"Japanese": "Japonês",
|
26 |
+
"Chinese": "Chinês",
|
27 |
+
"Portuguese": "Português",
|
28 |
+
"Spanish": "Espanhol",
|
29 |
+
"Error Message": "Mensagem de Erro",
|
30 |
+
"Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
|
31 |
+
"File Preprocessing": "Pré-processamento de Arquivos",
|
32 |
+
"Generate": "Gerar",
|
33 |
+
"Generated Audio": "Áudio Gerado",
|
34 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
|
35 |
+
"Infer interface is closed": "A interface de inferência foi fechada",
|
36 |
+
"Inference Configuration": "Configuração de Inferência",
|
37 |
+
"Inference Server Configuration": "Configuração do Servidor de Inferência",
|
38 |
+
"Inference Server Error": "Erro do Servidor de Inferência",
|
39 |
+
"Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
|
40 |
+
"Initial Learning Rate": "Taxa de Aprendizagem Inicial",
|
41 |
+
"Initial Prompt": "Prompt Inicial",
|
42 |
+
"Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
|
43 |
+
"Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
|
44 |
+
"Input Text": "Texto de Entrada",
|
45 |
+
"Invalid path: {}": "Caminho inválido: {}",
|
46 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
|
47 |
+
"Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
|
48 |
+
"LLAMA Configuration": "Configuração do LLAMA",
|
49 |
+
"LLAMA Model Config": "Configuração do Modelo LLAMA",
|
50 |
+
"LLAMA Model Path": "Caminho do Modelo LLAMA",
|
51 |
+
"Labeling Device": "Dispositivo de Rotulagem",
|
52 |
+
"LoRA Model to be merged": "Modelo LoRA para mesclagem",
|
53 |
+
"Maximum Length per Sample": "Comprimento Máximo por Amostra",
|
54 |
+
"Maximum Training Steps": "Etapas Máximas de Treinamento",
|
55 |
+
"Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
|
56 |
+
"Merge": "Mesclar",
|
57 |
+
"Merge LoRA": "Mesclar LoRA",
|
58 |
+
"Merge successfully": "Mesclado com sucesso",
|
59 |
+
"Model Output Path": "Caminho de Saída do Modelo",
|
60 |
+
"Model Quantization": "Quantização do Modelo",
|
61 |
+
"Model Size": "Tamanho do Modelo",
|
62 |
+
"Move": "Mover",
|
63 |
+
"Move files successfully": "Arquivos movidos com sucesso",
|
64 |
+
"No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
|
65 |
+
"No selected options": "Nenhuma opção selecionada",
|
66 |
+
"Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
|
67 |
+
"Number of Workers": "Número de Processos",
|
68 |
+
"Open Inference Server": "Abrir Servidor de Inferência",
|
69 |
+
"Open Labeler WebUI": "Abrir WebUI de Rotulagem",
|
70 |
+
"Open Tensorboard": "Abrir Tensorboard",
|
71 |
+
"Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
|
72 |
+
"Optional Label Language": "Idioma do Rótulo (Opcional)",
|
73 |
+
"Optional online ver": "Versão online (opcional)",
|
74 |
+
"Output Path": "Caminho de Saída",
|
75 |
+
"Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
|
76 |
+
"Post-quantification Precision": "Precisão Pós-quantização",
|
77 |
+
"Precision": "Precisão",
|
78 |
+
"Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
|
79 |
+
"Put your text here.": "Insira seu texto aqui.",
|
80 |
+
"Quantify": "Quantizar",
|
81 |
+
"Quantify successfully": "Quantizado com sucesso",
|
82 |
+
"Realtime Transform Text": "Transformar Texto em Tempo Real",
|
83 |
+
"Reference Audio": "Áudio de Referência",
|
84 |
+
"Reference Text": "Texto de Referência",
|
85 |
+
"warning": "Aviso",
|
86 |
+
"Pre-processing begins...": "O pré-processamento começou!",
|
87 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
|
88 |
+
"Remove Selected Data": "Remover Dados Selecionados",
|
89 |
+
"Removed path successfully!": "Caminho removido com sucesso!",
|
90 |
+
"Repetition Penalty": "Penalidade de Repetição",
|
91 |
+
"Save model every n steps": "Salvar modelo a cada n etapas",
|
92 |
+
"Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
|
93 |
+
"Select source file processing method": "Escolha como processar o arquivo de origem",
|
94 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
|
95 |
+
"Selected: {}": "Selecionado: {}",
|
96 |
+
"Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
|
97 |
+
"Start Training": "Iniciar Treinamento",
|
98 |
+
"Streaming Audio": "Áudio em Streaming",
|
99 |
+
"Streaming Generate": "Geração em Streaming",
|
100 |
+
"Tensorboard Host": "Host do Tensorboard",
|
101 |
+
"Tensorboard Log Path": "Caminho de Log do Tensorboard",
|
102 |
+
"Tensorboard Port": "Porta do Tensorboard",
|
103 |
+
"Tensorboard interface is closed": "A interface do Tensorboard está fechada",
|
104 |
+
"Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
|
105 |
+
"Text Normalization": "Normalização de Texto",
|
106 |
+
"Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
|
107 |
+
"The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
|
108 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
|
109 |
+
"Training Configuration": "Configuração de Treinamento",
|
110 |
+
"Training Error": "Erro de Treinamento",
|
111 |
+
"Training stopped": "Treinamento interrompido!",
|
112 |
+
"Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
|
113 |
+
"Use LoRA": "Usar LoRA",
|
114 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
|
115 |
+
"Use filelist": "Usar lista de arquivos",
|
116 |
+
"VQGAN Configuration": "Configuração do VQGAN",
|
117 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
|
118 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
|
119 |
+
"WebUI Host": "Host da WebUI",
|
120 |
+
"WebUI Port": "Porta da WebUI",
|
121 |
+
"Whisper Model": "Modelo Whisper",
|
122 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
|
123 |
+
"auto": "automático",
|
124 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
|
125 |
+
"latest": "mais recente",
|
126 |
+
"new": "novo",
|
127 |
+
"This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
|
128 |
+
"You don't need to train this model!": "Não é necessário treinar este modelo!",
|
129 |
+
"Yes": "Sim",
|
130 |
+
"No": "Não",
|
131 |
+
"version:": "versão:",
|
132 |
+
"author:": "autor:"
|
133 |
+
}
|
fish_speech/i18n/locale/zh_CN.json
CHANGED
@@ -1,123 +1,123 @@
|
|
1 |
-
{
|
2 |
-
"16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
|
3 |
-
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
|
4 |
-
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
|
5 |
-
"Accumulate Gradient Batches": "梯度累积批次",
|
6 |
-
"Add to Processing Area": "加入处理区",
|
7 |
-
"Added path successfully!": "添加路径成功!",
|
8 |
-
"Advanced Config": "高级参数",
|
9 |
-
"Base LLAMA Model": "基础 LLAMA 模型",
|
10 |
-
"Batch Inference": "批量推理",
|
11 |
-
"Batch Size": "批次大小",
|
12 |
-
"Changing with the Model Path": "随模型路径变化",
|
13 |
-
"Chinese": "中文",
|
14 |
-
"Compile Model": "编译模型",
|
15 |
-
"Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
|
16 |
-
"Copy": "复制",
|
17 |
-
"Data Preprocessing": "数据预处理",
|
18 |
-
"Data Preprocessing Path": "数据预处理路径",
|
19 |
-
"Data Source": "数据源",
|
20 |
-
"Decoder Model Config": "解码器模型配置",
|
21 |
-
"Decoder Model Path": "解码器模型路径",
|
22 |
-
"Disabled": "禁用",
|
23 |
-
"Enable Reference Audio": "启用参考音频",
|
24 |
-
"English": "英文",
|
25 |
-
"Error Message": "错误信息",
|
26 |
-
"File Preprocessing": "文件预处理",
|
27 |
-
"Generate": "生成",
|
28 |
-
"Generated Audio": "音频",
|
29 |
-
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
|
30 |
-
"Infer interface is closed": "推理界面已关闭",
|
31 |
-
"Inference Configuration": "推理配置",
|
32 |
-
"Inference Server Configuration": "推理服务器配置",
|
33 |
-
"Inference Server Error": "推理服务器错误",
|
34 |
-
"Inferring interface is launched at {}": "推理界面已在 {} 上启动",
|
35 |
-
"Initial Learning Rate": "初始学习率",
|
36 |
-
"Input Audio & Source Path for Transcription": "输入音频和转录源路径",
|
37 |
-
"Input Text": "输入文本",
|
38 |
-
"Invalid path: {}": "
|
39 |
-
"It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
|
40 |
-
"Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
|
41 |
-
"Japanese": "日文",
|
42 |
-
"LLAMA Configuration": "LLAMA 配置",
|
43 |
-
"LLAMA Model Config": "LLAMA 模型配置",
|
44 |
-
"LLAMA Model Path": "LLAMA 模型路径",
|
45 |
-
"Labeling Device": "标注加速设备",
|
46 |
-
"LoRA Model to be merged": "要合并的 LoRA 模型",
|
47 |
-
"Maximum Audio Duration": "最大音频时长",
|
48 |
-
"Maximum Length per Sample": "每个样本的最大长度",
|
49 |
-
"Maximum Training Steps": "最大训练步数",
|
50 |
-
"Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
|
51 |
-
"Merge": "合并",
|
52 |
-
"Merge LoRA": "合并 LoRA",
|
53 |
-
"Merge successfully": "合并成功",
|
54 |
-
"Minimum Audio Duration": "最小音频时长",
|
55 |
-
"Model Output Path": "模型输出路径",
|
56 |
-
"Model Size": "模型规模",
|
57 |
-
"Move": "移动",
|
58 |
-
"Move files successfully": "移动文件成功",
|
59 |
-
"No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
|
60 |
-
"No selected options": "没有选择的选项",
|
61 |
-
"Number of Workers": "数据加载进程数",
|
62 |
-
"Open Inference Server": "打开推理服务器",
|
63 |
-
"Open Labeler WebUI": "打开标注工具",
|
64 |
-
"Open Tensorboard": "打开 Tensorboard",
|
65 |
-
"Opened labeler in browser": "在浏览器中打开标注工具",
|
66 |
-
"Optional Label Language": "[可选] 标注语言",
|
67 |
-
"Optional online ver": "[可选] 使用在线版",
|
68 |
-
"Output Path": "输出路径",
|
69 |
-
"Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
|
70 |
-
"Precision": "精度",
|
71 |
-
"Probability of applying Speaker Condition": "应用说话人条件的概率",
|
72 |
-
"Put your text here.": "在此处输入文本.",
|
73 |
-
"Reference Audio": "参考音频",
|
74 |
-
"Reference Text": "参考文本",
|
75 |
-
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
|
76 |
-
"Remove Selected Data": "移除选中数据",
|
77 |
-
"Removed path successfully!": "移除路径成功!",
|
78 |
-
"Repetition Penalty": "重复惩罚",
|
79 |
-
"Save model every n steps": "每 n 步保存模型",
|
80 |
-
"Select LLAMA ckpt": "选择 LLAMA 检查点",
|
81 |
-
"Select VITS ckpt": "选择 VITS 检查点",
|
82 |
-
"Select VQGAN ckpt": "选择 VQGAN 检查点",
|
83 |
-
"Select source file processing method": "选择源文件处理方法",
|
84 |
-
"Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
|
85 |
-
"Selected: {}": "已选择: {}",
|
86 |
-
"Speaker": "说话人",
|
87 |
-
"Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
|
88 |
-
"Start Training": "开始训练",
|
89 |
-
"Streaming Audio": "流式音频",
|
90 |
-
"Streaming Generate": "流式合成",
|
91 |
-
"Tensorboard Host": "Tensorboard 监听地址",
|
92 |
-
"Tensorboard Log Path": "Tensorboard 日志路径",
|
93 |
-
"Tensorboard Port": "Tensorboard 端口",
|
94 |
-
"Tensorboard interface is closed": "Tensorboard 界面已关闭",
|
95 |
-
"Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
|
96 |
-
"Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
|
97 |
-
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
|
98 |
-
"Training Configuration": "训练配置",
|
99 |
-
"Training Error": "训练错误",
|
100 |
-
"Training stopped": "训练已停止",
|
101 |
-
"Type name of the speaker": "输入说话人的名称",
|
102 |
-
"Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
|
103 |
-
"Use LoRA": "使用 LoRA",
|
104 |
-
"Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
|
105 |
-
"Use filelist": "使用文件列表",
|
106 |
-
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
|
107 |
-
"VITS Configuration": "VITS 配置",
|
108 |
-
"VQGAN Configuration": "VQGAN 配置",
|
109 |
-
"Validation Batch Size": "验证批次大小",
|
110 |
-
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
|
111 |
-
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
|
112 |
-
"WebUI Host": "WebUI 监听地址",
|
113 |
-
"WebUI Port": "WebUI 端口",
|
114 |
-
"Whisper Model": "Whisper 模型",
|
115 |
-
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
|
116 |
-
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
|
117 |
-
"latest": "最近的检查点",
|
118 |
-
"new": "创建新的检查点",
|
119 |
-
"Realtime Transform Text": "实时规范化文本",
|
120 |
-
"Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
|
121 |
-
"Text Normalization": "文本规范化",
|
122 |
-
"Select Example Audio": "选择参考音频"
|
123 |
-
}
|
|
|
1 |
+
{
|
2 |
+
"16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
|
3 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
|
4 |
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
|
5 |
+
"Accumulate Gradient Batches": "梯度累积批次",
|
6 |
+
"Add to Processing Area": "加入处理区",
|
7 |
+
"Added path successfully!": "添加路径成功!",
|
8 |
+
"Advanced Config": "高级参数",
|
9 |
+
"Base LLAMA Model": "基础 LLAMA 模型",
|
10 |
+
"Batch Inference": "批量推理",
|
11 |
+
"Batch Size": "批次大小",
|
12 |
+
"Changing with the Model Path": "随模型路径变化",
|
13 |
+
"Chinese": "中文",
|
14 |
+
"Compile Model": "编译模型",
|
15 |
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
|
16 |
+
"Copy": "复制",
|
17 |
+
"Data Preprocessing": "数据预处理",
|
18 |
+
"Data Preprocessing Path": "数据预处理路径",
|
19 |
+
"Data Source": "数据源",
|
20 |
+
"Decoder Model Config": "解码器模型配置",
|
21 |
+
"Decoder Model Path": "解码器模型路径",
|
22 |
+
"Disabled": "禁用",
|
23 |
+
"Enable Reference Audio": "启用参考音频",
|
24 |
+
"English": "英文",
|
25 |
+
"Error Message": "错误信息",
|
26 |
+
"File Preprocessing": "文件预处理",
|
27 |
+
"Generate": "生成",
|
28 |
+
"Generated Audio": "音频",
|
29 |
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
|
30 |
+
"Infer interface is closed": "推理界面已关闭",
|
31 |
+
"Inference Configuration": "推理配置",
|
32 |
+
"Inference Server Configuration": "推理服务器配置",
|
33 |
+
"Inference Server Error": "推理服务器错误",
|
34 |
+
"Inferring interface is launched at {}": "推理界面已在 {} 上启动",
|
35 |
+
"Initial Learning Rate": "初始学习率",
|
36 |
+
"Input Audio & Source Path for Transcription": "输入音频和转录源路径",
|
37 |
+
"Input Text": "输入文本",
|
38 |
+
"Invalid path: {}": "无效���径: {}",
|
39 |
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
|
40 |
+
"Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
|
41 |
+
"Japanese": "日文",
|
42 |
+
"LLAMA Configuration": "LLAMA 配置",
|
43 |
+
"LLAMA Model Config": "LLAMA 模型配置",
|
44 |
+
"LLAMA Model Path": "LLAMA 模型路径",
|
45 |
+
"Labeling Device": "标注加速设备",
|
46 |
+
"LoRA Model to be merged": "要合并的 LoRA 模型",
|
47 |
+
"Maximum Audio Duration": "最大音频时长",
|
48 |
+
"Maximum Length per Sample": "每个样本的最大长度",
|
49 |
+
"Maximum Training Steps": "最大训练步数",
|
50 |
+
"Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
|
51 |
+
"Merge": "合并",
|
52 |
+
"Merge LoRA": "合并 LoRA",
|
53 |
+
"Merge successfully": "合并成功",
|
54 |
+
"Minimum Audio Duration": "最小音频时长",
|
55 |
+
"Model Output Path": "模型输出路径",
|
56 |
+
"Model Size": "模型规模",
|
57 |
+
"Move": "移动",
|
58 |
+
"Move files successfully": "移动文件成功",
|
59 |
+
"No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
|
60 |
+
"No selected options": "没有选择的选项",
|
61 |
+
"Number of Workers": "数据加载进程数",
|
62 |
+
"Open Inference Server": "打开推理服务器",
|
63 |
+
"Open Labeler WebUI": "打开标注工具",
|
64 |
+
"Open Tensorboard": "打开 Tensorboard",
|
65 |
+
"Opened labeler in browser": "在浏览器中打开标注工具",
|
66 |
+
"Optional Label Language": "[可选] 标注语言",
|
67 |
+
"Optional online ver": "[可选] 使用在线版",
|
68 |
+
"Output Path": "输出路径",
|
69 |
+
"Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
|
70 |
+
"Precision": "精度",
|
71 |
+
"Probability of applying Speaker Condition": "应用说话人条件的概率",
|
72 |
+
"Put your text here.": "在此处输入文本.",
|
73 |
+
"Reference Audio": "参考音频",
|
74 |
+
"Reference Text": "参考文本",
|
75 |
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
|
76 |
+
"Remove Selected Data": "移除选中数据",
|
77 |
+
"Removed path successfully!": "移除路径成功!",
|
78 |
+
"Repetition Penalty": "重复惩罚",
|
79 |
+
"Save model every n steps": "每 n 步保存模型",
|
80 |
+
"Select LLAMA ckpt": "选择 LLAMA 检查点",
|
81 |
+
"Select VITS ckpt": "选择 VITS 检查点",
|
82 |
+
"Select VQGAN ckpt": "选择 VQGAN 检查点",
|
83 |
+
"Select source file processing method": "选择源文件处理方法",
|
84 |
+
"Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
|
85 |
+
"Selected: {}": "已选择: {}",
|
86 |
+
"Speaker": "说话人",
|
87 |
+
"Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
|
88 |
+
"Start Training": "开始训练",
|
89 |
+
"Streaming Audio": "流式音频",
|
90 |
+
"Streaming Generate": "流式合成",
|
91 |
+
"Tensorboard Host": "Tensorboard 监听地址",
|
92 |
+
"Tensorboard Log Path": "Tensorboard 日志路径",
|
93 |
+
"Tensorboard Port": "Tensorboard 端口",
|
94 |
+
"Tensorboard interface is closed": "Tensorboard 界面已关闭",
|
95 |
+
"Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
|
96 |
+
"Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
|
97 |
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
|
98 |
+
"Training Configuration": "训练配置",
|
99 |
+
"Training Error": "训练错误",
|
100 |
+
"Training stopped": "训练已停止",
|
101 |
+
"Type name of the speaker": "输入说话人的名称",
|
102 |
+
"Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
|
103 |
+
"Use LoRA": "使用 LoRA",
|
104 |
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
|
105 |
+
"Use filelist": "使用文件列表",
|
106 |
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
|
107 |
+
"VITS Configuration": "VITS 配置",
|
108 |
+
"VQGAN Configuration": "VQGAN 配置",
|
109 |
+
"Validation Batch Size": "验证批次大小",
|
110 |
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
|
111 |
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
|
112 |
+
"WebUI Host": "WebUI 监听地址",
|
113 |
+
"WebUI Port": "WebUI 端口",
|
114 |
+
"Whisper Model": "Whisper 模型",
|
115 |
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
|
116 |
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
|
117 |
+
"latest": "最近的检查点",
|
118 |
+
"new": "创建新的检查点",
|
119 |
+
"Realtime Transform Text": "实时规范化文本",
|
120 |
+
"Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
|
121 |
+
"Text Normalization": "文本规范化",
|
122 |
+
"Select Example Audio": "选择参考音频"
|
123 |
+
}
|
fish_speech/i18n/scan.py
CHANGED
@@ -1,122 +1,122 @@
|
|
1 |
-
import ast
|
2 |
-
import glob
|
3 |
-
import json
|
4 |
-
from collections import OrderedDict
|
5 |
-
from pathlib import Path
|
6 |
-
|
7 |
-
from loguru import logger
|
8 |
-
|
9 |
-
from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
|
10 |
-
|
11 |
-
|
12 |
-
def extract_i18n_strings(node):
|
13 |
-
i18n_strings = []
|
14 |
-
|
15 |
-
if (
|
16 |
-
isinstance(node, ast.Call)
|
17 |
-
and isinstance(node.func, ast.Name)
|
18 |
-
and node.func.id == "i18n"
|
19 |
-
):
|
20 |
-
for arg in node.args:
|
21 |
-
if isinstance(arg, ast.Str):
|
22 |
-
i18n_strings.append(arg.s)
|
23 |
-
|
24 |
-
for child_node in ast.iter_child_nodes(node):
|
25 |
-
i18n_strings.extend(extract_i18n_strings(child_node))
|
26 |
-
|
27 |
-
return i18n_strings
|
28 |
-
|
29 |
-
|
30 |
-
# scan the directory for all .py files (recursively)
|
31 |
-
# for each file, parse the code into an AST
|
32 |
-
# for each AST, extract the i18n strings
|
33 |
-
|
34 |
-
strings = []
|
35 |
-
folders = ["fish_speech", "tools"]
|
36 |
-
# for filename in glob.iglob("**/*.py", recursive=True):
|
37 |
-
for folder in folders:
|
38 |
-
for f in Path(folder).rglob("*.py"):
|
39 |
-
code = f.read_text(encoding="utf-8")
|
40 |
-
if "i18n(" in code:
|
41 |
-
tree = ast.parse(code)
|
42 |
-
i18n_strings = extract_i18n_strings(tree)
|
43 |
-
logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
|
44 |
-
strings.extend(i18n_strings)
|
45 |
-
|
46 |
-
code_keys = set(strings)
|
47 |
-
logger.info(f"Total unique: {len(code_keys)}")
|
48 |
-
|
49 |
-
|
50 |
-
standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
|
51 |
-
with open(standard_file, "r", encoding="utf-8") as f:
|
52 |
-
standard_data = json.load(f, object_pairs_hook=OrderedDict)
|
53 |
-
standard_keys = set(standard_data.keys())
|
54 |
-
|
55 |
-
# Define the standard file name
|
56 |
-
unused_keys = standard_keys - code_keys
|
57 |
-
logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
|
58 |
-
for unused_key in unused_keys:
|
59 |
-
logger.info(f"\t{unused_key}")
|
60 |
-
|
61 |
-
missing_keys = code_keys - standard_keys
|
62 |
-
logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
|
63 |
-
for missing_key in missing_keys:
|
64 |
-
logger.info(f"\t{missing_key}")
|
65 |
-
|
66 |
-
code_keys_dict = OrderedDict()
|
67 |
-
for s in strings:
|
68 |
-
code_keys_dict[s] = s
|
69 |
-
|
70 |
-
# write back
|
71 |
-
with open(standard_file, "w", encoding="utf-8") as f:
|
72 |
-
json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
|
73 |
-
f.write("\n")
|
74 |
-
|
75 |
-
logger.info(f"Updated {standard_file}")
|
76 |
-
|
77 |
-
|
78 |
-
# Define the standard file name
|
79 |
-
standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
|
80 |
-
|
81 |
-
# Find all JSON files in the directory
|
82 |
-
dir_path = I18N_FILE_PATH
|
83 |
-
languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
|
84 |
-
|
85 |
-
# Load the standard file
|
86 |
-
with open(standard_file, "r", encoding="utf-8") as f:
|
87 |
-
standard_data = json.load(f, object_pairs_hook=OrderedDict)
|
88 |
-
|
89 |
-
# Loop through each language file
|
90 |
-
for lang_file in languages:
|
91 |
-
# Load the language file
|
92 |
-
with open(lang_file, "r", encoding="utf-8") as f:
|
93 |
-
lang_data = json.load(f, object_pairs_hook=OrderedDict)
|
94 |
-
|
95 |
-
# Find the difference between the language file and the standard file
|
96 |
-
diff = set(standard_data.keys()) - set(lang_data.keys())
|
97 |
-
|
98 |
-
miss = set(lang_data.keys()) - set(standard_data.keys())
|
99 |
-
|
100 |
-
# Add any missing keys to the language file
|
101 |
-
for key in diff:
|
102 |
-
lang_data[key] = "#!" + key
|
103 |
-
logger.info(f"Added missing key: {key} to {lang_file}")
|
104 |
-
|
105 |
-
# Del any extra keys to the language file
|
106 |
-
for key in miss:
|
107 |
-
del lang_data[key]
|
108 |
-
logger.info(f"Del extra key: {key} from {lang_file}")
|
109 |
-
|
110 |
-
# Sort the keys of the language file to match the order of the standard file
|
111 |
-
lang_data = OrderedDict(
|
112 |
-
sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
|
113 |
-
)
|
114 |
-
|
115 |
-
# Save the updated language file
|
116 |
-
with open(lang_file, "w", encoding="utf-8") as f:
|
117 |
-
json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
|
118 |
-
f.write("\n")
|
119 |
-
|
120 |
-
logger.info(f"Updated {lang_file}")
|
121 |
-
|
122 |
-
logger.info("Done")
|
|
|
1 |
+
import ast
|
2 |
+
import glob
|
3 |
+
import json
|
4 |
+
from collections import OrderedDict
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
|
10 |
+
|
11 |
+
|
12 |
+
def extract_i18n_strings(node):
|
13 |
+
i18n_strings = []
|
14 |
+
|
15 |
+
if (
|
16 |
+
isinstance(node, ast.Call)
|
17 |
+
and isinstance(node.func, ast.Name)
|
18 |
+
and node.func.id == "i18n"
|
19 |
+
):
|
20 |
+
for arg in node.args:
|
21 |
+
if isinstance(arg, ast.Str):
|
22 |
+
i18n_strings.append(arg.s)
|
23 |
+
|
24 |
+
for child_node in ast.iter_child_nodes(node):
|
25 |
+
i18n_strings.extend(extract_i18n_strings(child_node))
|
26 |
+
|
27 |
+
return i18n_strings
|
28 |
+
|
29 |
+
|
30 |
+
# scan the directory for all .py files (recursively)
|
31 |
+
# for each file, parse the code into an AST
|
32 |
+
# for each AST, extract the i18n strings
|
33 |
+
|
34 |
+
strings = []
|
35 |
+
folders = ["fish_speech", "tools"]
|
36 |
+
# for filename in glob.iglob("**/*.py", recursive=True):
|
37 |
+
for folder in folders:
|
38 |
+
for f in Path(folder).rglob("*.py"):
|
39 |
+
code = f.read_text(encoding="utf-8")
|
40 |
+
if "i18n(" in code:
|
41 |
+
tree = ast.parse(code)
|
42 |
+
i18n_strings = extract_i18n_strings(tree)
|
43 |
+
logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
|
44 |
+
strings.extend(i18n_strings)
|
45 |
+
|
46 |
+
code_keys = set(strings)
|
47 |
+
logger.info(f"Total unique: {len(code_keys)}")
|
48 |
+
|
49 |
+
|
50 |
+
standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
|
51 |
+
with open(standard_file, "r", encoding="utf-8") as f:
|
52 |
+
standard_data = json.load(f, object_pairs_hook=OrderedDict)
|
53 |
+
standard_keys = set(standard_data.keys())
|
54 |
+
|
55 |
+
# Define the standard file name
|
56 |
+
unused_keys = standard_keys - code_keys
|
57 |
+
logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
|
58 |
+
for unused_key in unused_keys:
|
59 |
+
logger.info(f"\t{unused_key}")
|
60 |
+
|
61 |
+
missing_keys = code_keys - standard_keys
|
62 |
+
logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
|
63 |
+
for missing_key in missing_keys:
|
64 |
+
logger.info(f"\t{missing_key}")
|
65 |
+
|
66 |
+
code_keys_dict = OrderedDict()
|
67 |
+
for s in strings:
|
68 |
+
code_keys_dict[s] = s
|
69 |
+
|
70 |
+
# write back
|
71 |
+
with open(standard_file, "w", encoding="utf-8") as f:
|
72 |
+
json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
|
73 |
+
f.write("\n")
|
74 |
+
|
75 |
+
logger.info(f"Updated {standard_file}")
|
76 |
+
|
77 |
+
|
78 |
+
# Define the standard file name
|
79 |
+
standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
|
80 |
+
|
81 |
+
# Find all JSON files in the directory
|
82 |
+
dir_path = I18N_FILE_PATH
|
83 |
+
languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
|
84 |
+
|
85 |
+
# Load the standard file
|
86 |
+
with open(standard_file, "r", encoding="utf-8") as f:
|
87 |
+
standard_data = json.load(f, object_pairs_hook=OrderedDict)
|
88 |
+
|
89 |
+
# Loop through each language file
|
90 |
+
for lang_file in languages:
|
91 |
+
# Load the language file
|
92 |
+
with open(lang_file, "r", encoding="utf-8") as f:
|
93 |
+
lang_data = json.load(f, object_pairs_hook=OrderedDict)
|
94 |
+
|
95 |
+
# Find the difference between the language file and the standard file
|
96 |
+
diff = set(standard_data.keys()) - set(lang_data.keys())
|
97 |
+
|
98 |
+
miss = set(lang_data.keys()) - set(standard_data.keys())
|
99 |
+
|
100 |
+
# Add any missing keys to the language file
|
101 |
+
for key in diff:
|
102 |
+
lang_data[key] = "#!" + key
|
103 |
+
logger.info(f"Added missing key: {key} to {lang_file}")
|
104 |
+
|
105 |
+
# Del any extra keys to the language file
|
106 |
+
for key in miss:
|
107 |
+
del lang_data[key]
|
108 |
+
logger.info(f"Del extra key: {key} from {lang_file}")
|
109 |
+
|
110 |
+
# Sort the keys of the language file to match the order of the standard file
|
111 |
+
lang_data = OrderedDict(
|
112 |
+
sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
|
113 |
+
)
|
114 |
+
|
115 |
+
# Save the updated language file
|
116 |
+
with open(lang_file, "w", encoding="utf-8") as f:
|
117 |
+
json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
|
118 |
+
f.write("\n")
|
119 |
+
|
120 |
+
logger.info(f"Updated {lang_file}")
|
121 |
+
|
122 |
+
logger.info("Done")
|
fish_speech/inference_engine/__init__.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import queue
|
3 |
+
from typing import Generator
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
from fish_speech.inference_engine.reference_loader import ReferenceLoader
|
10 |
+
from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header
|
11 |
+
from fish_speech.inference_engine.vq_manager import VQManager
|
12 |
+
from fish_speech.models.dac.modded_dac import DAC
|
13 |
+
from fish_speech.models.text2semantic.inference import (
|
14 |
+
GenerateRequest,
|
15 |
+
GenerateResponse,
|
16 |
+
WrappedGenerateResponse,
|
17 |
+
)
|
18 |
+
from fish_speech.utils import autocast_exclude_mps, set_seed
|
19 |
+
from fish_speech.utils.schema import ServeTTSRequest
|
20 |
+
|
21 |
+
|
22 |
+
class TTSInferenceEngine(ReferenceLoader, VQManager):
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
llama_queue: queue.Queue,
|
27 |
+
decoder_model: DAC,
|
28 |
+
precision: torch.dtype,
|
29 |
+
compile: bool,
|
30 |
+
) -> None:
|
31 |
+
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.llama_queue = llama_queue
|
35 |
+
self.decoder_model = decoder_model
|
36 |
+
self.precision = precision
|
37 |
+
self.compile = compile
|
38 |
+
|
39 |
+
@torch.inference_mode()
|
40 |
+
def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
|
41 |
+
"""
|
42 |
+
Main inference function:
|
43 |
+
- Loads the reference audio and text.
|
44 |
+
- Calls the LLAMA model for inference.
|
45 |
+
- Decodes the VQ tokens to audio.
|
46 |
+
"""
|
47 |
+
|
48 |
+
ref_id: str | None = req.reference_id
|
49 |
+
prompt_tokens, prompt_texts = [], []
|
50 |
+
# Load the reference audio and text based on id or hash
|
51 |
+
if ref_id is not None:
|
52 |
+
prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
|
53 |
+
|
54 |
+
elif req.references:
|
55 |
+
prompt_tokens, prompt_texts = self.load_by_hash(
|
56 |
+
req.references, req.use_memory_cache
|
57 |
+
)
|
58 |
+
|
59 |
+
# Set the random seed if provided
|
60 |
+
if req.seed is not None:
|
61 |
+
set_seed(req.seed)
|
62 |
+
logger.warning(f"set seed: {req.seed}")
|
63 |
+
|
64 |
+
# Get the symbolic tokens from the LLAMA model
|
65 |
+
response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
|
66 |
+
|
67 |
+
# Get the sample rate from the decoder model
|
68 |
+
if hasattr(self.decoder_model, "spec_transform"):
|
69 |
+
sample_rate = self.decoder_model.spec_transform.sample_rate
|
70 |
+
else:
|
71 |
+
sample_rate = self.decoder_model.sample_rate
|
72 |
+
|
73 |
+
# If streaming, send the header
|
74 |
+
if req.streaming:
|
75 |
+
yield InferenceResult(
|
76 |
+
code="header",
|
77 |
+
audio=(
|
78 |
+
sample_rate,
|
79 |
+
np.array(wav_chunk_header(sample_rate=sample_rate)),
|
80 |
+
),
|
81 |
+
error=None,
|
82 |
+
)
|
83 |
+
|
84 |
+
segments = []
|
85 |
+
|
86 |
+
while True:
|
87 |
+
# Get the response from the LLAMA model
|
88 |
+
wrapped_result: WrappedGenerateResponse = response_queue.get()
|
89 |
+
if wrapped_result.status == "error":
|
90 |
+
yield InferenceResult(
|
91 |
+
code="error",
|
92 |
+
audio=None,
|
93 |
+
error=(
|
94 |
+
wrapped_result.response
|
95 |
+
if isinstance(wrapped_result.response, Exception)
|
96 |
+
else Exception("Unknown error")
|
97 |
+
),
|
98 |
+
)
|
99 |
+
break
|
100 |
+
|
101 |
+
# Check the response type
|
102 |
+
if not isinstance(wrapped_result.response, GenerateResponse):
|
103 |
+
raise TypeError(
|
104 |
+
"Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
|
105 |
+
)
|
106 |
+
|
107 |
+
result: GenerateResponse = wrapped_result.response
|
108 |
+
if result.action != "next":
|
109 |
+
segment = self.get_audio_segment(result)
|
110 |
+
|
111 |
+
if req.streaming: # Used only by the API server
|
112 |
+
yield InferenceResult(
|
113 |
+
code="segment",
|
114 |
+
audio=(sample_rate, segment),
|
115 |
+
error=None,
|
116 |
+
)
|
117 |
+
segments.append(segment)
|
118 |
+
else:
|
119 |
+
break
|
120 |
+
|
121 |
+
# Clean up the memory
|
122 |
+
if torch.cuda.is_available():
|
123 |
+
torch.cuda.empty_cache()
|
124 |
+
gc.collect()
|
125 |
+
|
126 |
+
# Edge case: no audio generated
|
127 |
+
if len(segments) == 0:
|
128 |
+
yield InferenceResult(
|
129 |
+
code="error",
|
130 |
+
audio=None,
|
131 |
+
error=RuntimeError("No audio generated, please check the input text."),
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
# Streaming or not, return the final audio
|
135 |
+
audio = np.concatenate(segments, axis=0)
|
136 |
+
yield InferenceResult(
|
137 |
+
code="final",
|
138 |
+
audio=(sample_rate, audio),
|
139 |
+
error=None,
|
140 |
+
)
|
141 |
+
|
142 |
+
return None
|
143 |
+
|
144 |
+
def send_Llama_request(
|
145 |
+
self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
|
146 |
+
) -> queue.Queue:
|
147 |
+
"""
|
148 |
+
Send a request to the LLAMA model to generate the symbolic tokens.
|
149 |
+
"""
|
150 |
+
|
151 |
+
# Prepare the request
|
152 |
+
request = dict(
|
153 |
+
device=self.decoder_model.device,
|
154 |
+
max_new_tokens=req.max_new_tokens,
|
155 |
+
text=req.text,
|
156 |
+
top_p=req.top_p,
|
157 |
+
repetition_penalty=req.repetition_penalty,
|
158 |
+
temperature=req.temperature,
|
159 |
+
compile=self.compile,
|
160 |
+
iterative_prompt=req.chunk_length > 0,
|
161 |
+
chunk_length=req.chunk_length,
|
162 |
+
prompt_tokens=prompt_tokens,
|
163 |
+
prompt_text=prompt_texts,
|
164 |
+
)
|
165 |
+
|
166 |
+
# Create a queue to get the response
|
167 |
+
response_queue = queue.Queue()
|
168 |
+
|
169 |
+
# Send the request to the LLAMA model
|
170 |
+
self.llama_queue.put(
|
171 |
+
GenerateRequest(
|
172 |
+
request=request,
|
173 |
+
response_queue=response_queue,
|
174 |
+
)
|
175 |
+
)
|
176 |
+
|
177 |
+
return response_queue
|
178 |
+
|
179 |
+
def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
|
180 |
+
"""
|
181 |
+
Decode the VQ tokens to audio.
|
182 |
+
"""
|
183 |
+
|
184 |
+
# Don't use autocast on MPS devices
|
185 |
+
with autocast_exclude_mps(
|
186 |
+
device_type=self.decoder_model.device.type, dtype=self.precision
|
187 |
+
):
|
188 |
+
# Decode the symbolic tokens to audio
|
189 |
+
segment = self.decode_vq_tokens(codes=result.codes)
|
190 |
+
|
191 |
+
# Convert the audio to numpy
|
192 |
+
return segment.float().cpu().numpy()
|
fish_speech/inference_engine/reference_loader.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
from hashlib import sha256
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Callable, Literal, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
from loguru import logger
|
9 |
+
|
10 |
+
from fish_speech.models.dac.modded_dac import DAC
|
11 |
+
from fish_speech.utils.file import (
|
12 |
+
AUDIO_EXTENSIONS,
|
13 |
+
audio_to_bytes,
|
14 |
+
list_files,
|
15 |
+
read_ref_text,
|
16 |
+
)
|
17 |
+
from fish_speech.utils.schema import ServeReferenceAudio
|
18 |
+
|
19 |
+
|
20 |
+
class ReferenceLoader:
|
21 |
+
|
22 |
+
def __init__(self) -> None:
|
23 |
+
"""
|
24 |
+
Component of the TTSInferenceEngine class.
|
25 |
+
Loads and manages the cache for the reference audio and text.
|
26 |
+
"""
|
27 |
+
self.ref_by_id: dict = {}
|
28 |
+
self.ref_by_hash: dict = {}
|
29 |
+
|
30 |
+
# Make Pylance happy (attribut/method not defined...)
|
31 |
+
self.decoder_model: DAC
|
32 |
+
self.encode_reference: Callable
|
33 |
+
|
34 |
+
# Define the torchaudio backend
|
35 |
+
backends = torchaudio.list_audio_backends()
|
36 |
+
if "ffmpeg" in backends:
|
37 |
+
self.backend = "ffmpeg"
|
38 |
+
else:
|
39 |
+
self.backend = "soundfile"
|
40 |
+
|
41 |
+
def load_by_id(
|
42 |
+
self,
|
43 |
+
id: str,
|
44 |
+
use_cache: Literal["on", "off"],
|
45 |
+
) -> Tuple:
|
46 |
+
|
47 |
+
# Load the references audio and text by id
|
48 |
+
ref_folder = Path("references") / id
|
49 |
+
ref_folder.mkdir(parents=True, exist_ok=True)
|
50 |
+
ref_audios = list_files(
|
51 |
+
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
52 |
+
)
|
53 |
+
|
54 |
+
if use_cache == "off" or id not in self.ref_by_id:
|
55 |
+
# If the references are not already loaded, encode them
|
56 |
+
prompt_tokens = [
|
57 |
+
self.encode_reference(
|
58 |
+
# decoder_model=self.decoder_model,
|
59 |
+
reference_audio=audio_to_bytes(str(ref_audio)),
|
60 |
+
enable_reference_audio=True,
|
61 |
+
)
|
62 |
+
for ref_audio in ref_audios
|
63 |
+
]
|
64 |
+
prompt_texts = [
|
65 |
+
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
66 |
+
for ref_audio in ref_audios
|
67 |
+
]
|
68 |
+
self.ref_by_id[id] = (prompt_tokens, prompt_texts)
|
69 |
+
|
70 |
+
else:
|
71 |
+
# Reuse already encoded references
|
72 |
+
logger.info("Use same references")
|
73 |
+
prompt_tokens, prompt_texts = self.ref_by_id[id]
|
74 |
+
|
75 |
+
return prompt_tokens, prompt_texts
|
76 |
+
|
77 |
+
def load_by_hash(
|
78 |
+
self,
|
79 |
+
references: list[ServeReferenceAudio],
|
80 |
+
use_cache: Literal["on", "off"],
|
81 |
+
) -> Tuple:
|
82 |
+
|
83 |
+
# Load the references audio and text by hash
|
84 |
+
audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
|
85 |
+
|
86 |
+
cache_used = False
|
87 |
+
prompt_tokens, prompt_texts = [], []
|
88 |
+
for i, ref in enumerate(references):
|
89 |
+
if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
|
90 |
+
# If the references are not already loaded, encode them
|
91 |
+
prompt_tokens.append(
|
92 |
+
self.encode_reference(
|
93 |
+
reference_audio=ref.audio,
|
94 |
+
enable_reference_audio=True,
|
95 |
+
)
|
96 |
+
)
|
97 |
+
prompt_texts.append(ref.text)
|
98 |
+
self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
|
99 |
+
|
100 |
+
else:
|
101 |
+
# Reuse already encoded references
|
102 |
+
prompt_tokens, prompt_texts = self.ref_by_hash[audio_hashes[i]]
|
103 |
+
cache_used = True
|
104 |
+
|
105 |
+
if cache_used:
|
106 |
+
logger.info("Use same references")
|
107 |
+
|
108 |
+
return prompt_tokens, prompt_texts
|
109 |
+
|
110 |
+
def load_audio(self, reference_audio, sr):
|
111 |
+
"""
|
112 |
+
Load the audio data from a file or bytes.
|
113 |
+
"""
|
114 |
+
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
115 |
+
audio_data = reference_audio
|
116 |
+
reference_audio = io.BytesIO(audio_data)
|
117 |
+
|
118 |
+
waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
|
119 |
+
|
120 |
+
if waveform.shape[0] > 1:
|
121 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
122 |
+
|
123 |
+
if original_sr != sr:
|
124 |
+
resampler = torchaudio.transforms.Resample(
|
125 |
+
orig_freq=original_sr, new_freq=sr
|
126 |
+
)
|
127 |
+
waveform = resampler(waveform)
|
128 |
+
|
129 |
+
audio = waveform.squeeze().numpy()
|
130 |
+
return audio
|
fish_speech/inference_engine/utils.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import wave
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Literal, Optional, Tuple
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class InferenceResult:
|
11 |
+
code: Literal["header", "segment", "error", "final"]
|
12 |
+
audio: Optional[Tuple[int, np.ndarray]]
|
13 |
+
error: Optional[Exception]
|
14 |
+
|
15 |
+
|
16 |
+
def wav_chunk_header(
|
17 |
+
sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
|
18 |
+
) -> bytes:
|
19 |
+
buffer = io.BytesIO()
|
20 |
+
|
21 |
+
with wave.open(buffer, "wb") as wav_file:
|
22 |
+
wav_file.setnchannels(channels)
|
23 |
+
wav_file.setsampwidth(bit_depth // 8)
|
24 |
+
wav_file.setframerate(sample_rate)
|
25 |
+
|
26 |
+
wav_header_bytes = buffer.getvalue()
|
27 |
+
buffer.close()
|
28 |
+
|
29 |
+
return wav_header_bytes
|
fish_speech/inference_engine/vq_manager.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
from fish_speech.models.dac.modded_dac import DAC
|
7 |
+
|
8 |
+
|
9 |
+
class VQManager:
|
10 |
+
|
11 |
+
def __init__(self):
|
12 |
+
# Make Pylance happy (attribut/method not defined...)
|
13 |
+
self.decoder_model: DAC
|
14 |
+
self.load_audio: Callable
|
15 |
+
|
16 |
+
def decode_vq_tokens(self, codes):
|
17 |
+
feature_lengths = torch.tensor(
|
18 |
+
[codes.shape[1]], device=self.decoder_model.device
|
19 |
+
)
|
20 |
+
logger.info(f"VQ features: {codes.shape}")
|
21 |
+
|
22 |
+
if isinstance(self.decoder_model, DAC):
|
23 |
+
return self.decoder_model.decode(
|
24 |
+
indices=codes[None],
|
25 |
+
feature_lengths=feature_lengths,
|
26 |
+
)[0].squeeze()
|
27 |
+
|
28 |
+
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
|
29 |
+
|
30 |
+
def encode_reference(self, reference_audio, enable_reference_audio):
|
31 |
+
if enable_reference_audio and reference_audio is not None:
|
32 |
+
# Load audios, and prepare basic info here
|
33 |
+
if hasattr(self.decoder_model, "spec_transform"):
|
34 |
+
sample_rate = self.decoder_model.spec_transform.sample_rate
|
35 |
+
else:
|
36 |
+
sample_rate = self.decoder_model.sample_rate
|
37 |
+
reference_audio_content = self.load_audio(reference_audio, sample_rate)
|
38 |
+
|
39 |
+
audios = torch.from_numpy(reference_audio_content).to(
|
40 |
+
self.decoder_model.device
|
41 |
+
)[None, None, :]
|
42 |
+
audio_lengths = torch.tensor(
|
43 |
+
[audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
|
44 |
+
)
|
45 |
+
logger.info(
|
46 |
+
f"Loaded audio with {audios.shape[2] / sample_rate:.2f} seconds"
|
47 |
+
)
|
48 |
+
|
49 |
+
# VQ Encoder
|
50 |
+
if isinstance(self.decoder_model, DAC):
|
51 |
+
prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
|
52 |
+
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
53 |
+
else:
|
54 |
+
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
|
55 |
+
else:
|
56 |
+
prompt_tokens = None
|
57 |
+
logger.info("No reference audio provided")
|
58 |
+
|
59 |
+
return prompt_tokens
|
fish_speech/models/dac/__init__.py
ADDED
File without changes
|
fish_speech/models/dac/inference.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import click
|
4 |
+
import hydra
|
5 |
+
import numpy as np
|
6 |
+
import pyrootutils
|
7 |
+
import soundfile as sf
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
from hydra import compose, initialize
|
11 |
+
from hydra.utils import instantiate
|
12 |
+
from loguru import logger
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
|
15 |
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
16 |
+
|
17 |
+
from fish_speech.utils.file import AUDIO_EXTENSIONS
|
18 |
+
|
19 |
+
# register eval resolver
|
20 |
+
OmegaConf.register_new_resolver("eval", eval)
|
21 |
+
|
22 |
+
|
23 |
+
def load_model(config_name, checkpoint_path, device="cuda"):
|
24 |
+
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
25 |
+
with initialize(version_base="1.3", config_path="../../configs"):
|
26 |
+
cfg = compose(config_name=config_name)
|
27 |
+
|
28 |
+
model = instantiate(cfg)
|
29 |
+
state_dict = torch.load(
|
30 |
+
checkpoint_path, map_location=device, mmap=True, weights_only=True
|
31 |
+
)
|
32 |
+
if "state_dict" in state_dict:
|
33 |
+
state_dict = state_dict["state_dict"]
|
34 |
+
|
35 |
+
if any("generator" in k for k in state_dict):
|
36 |
+
state_dict = {
|
37 |
+
k.replace("generator.", ""): v
|
38 |
+
for k, v in state_dict.items()
|
39 |
+
if "generator." in k
|
40 |
+
}
|
41 |
+
|
42 |
+
result = model.load_state_dict(state_dict, strict=False, assign=True)
|
43 |
+
model.eval()
|
44 |
+
model.to(device)
|
45 |
+
|
46 |
+
logger.info(f"Loaded model: {result}")
|
47 |
+
return model
|
48 |
+
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
@click.command()
|
52 |
+
@click.option(
|
53 |
+
"--input-path",
|
54 |
+
"-i",
|
55 |
+
default="test.wav",
|
56 |
+
type=click.Path(exists=True, path_type=Path),
|
57 |
+
)
|
58 |
+
@click.option(
|
59 |
+
"--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
|
60 |
+
)
|
61 |
+
@click.option("--config-name", default="modded_dac_vq")
|
62 |
+
@click.option(
|
63 |
+
"--checkpoint-path",
|
64 |
+
default="checkpoints/openaudio-s1-mini/codec.pth",
|
65 |
+
)
|
66 |
+
@click.option(
|
67 |
+
"--device",
|
68 |
+
"-d",
|
69 |
+
default="cuda",
|
70 |
+
)
|
71 |
+
def main(input_path, output_path, config_name, checkpoint_path, device):
|
72 |
+
model = load_model(config_name, checkpoint_path, device=device)
|
73 |
+
|
74 |
+
if input_path.suffix in AUDIO_EXTENSIONS:
|
75 |
+
logger.info(f"Processing in-place reconstruction of {input_path}")
|
76 |
+
|
77 |
+
# Load audio
|
78 |
+
audio, sr = torchaudio.load(str(input_path))
|
79 |
+
if audio.shape[0] > 1:
|
80 |
+
audio = audio.mean(0, keepdim=True)
|
81 |
+
audio = torchaudio.functional.resample(audio, sr, model.sample_rate)
|
82 |
+
|
83 |
+
audios = audio[None].to(device)
|
84 |
+
logger.info(
|
85 |
+
f"Loaded audio with {audios.shape[2] / model.sample_rate:.2f} seconds"
|
86 |
+
)
|
87 |
+
|
88 |
+
# VQ Encoder
|
89 |
+
audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
|
90 |
+
indices, indices_lens = model.encode(audios, audio_lengths)
|
91 |
+
|
92 |
+
if indices.ndim == 3:
|
93 |
+
indices = indices[0]
|
94 |
+
|
95 |
+
logger.info(f"Generated indices of shape {indices.shape}")
|
96 |
+
|
97 |
+
# Save indices
|
98 |
+
np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
|
99 |
+
elif input_path.suffix == ".npy":
|
100 |
+
logger.info(f"Processing precomputed indices from {input_path}")
|
101 |
+
indices = np.load(input_path)
|
102 |
+
indices = torch.from_numpy(indices).to(device).long()
|
103 |
+
assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
|
104 |
+
indices_lens = torch.tensor([indices.shape[1]], device=device, dtype=torch.long)
|
105 |
+
else:
|
106 |
+
raise ValueError(f"Unknown input type: {input_path}")
|
107 |
+
|
108 |
+
# Restore
|
109 |
+
fake_audios, audio_lengths = model.decode(indices, indices_lens)
|
110 |
+
audio_time = fake_audios.shape[-1] / model.sample_rate
|
111 |
+
|
112 |
+
logger.info(
|
113 |
+
f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
|
114 |
+
)
|
115 |
+
|
116 |
+
# Save audio
|
117 |
+
fake_audio = fake_audios[0, 0].float().cpu().numpy()
|
118 |
+
sf.write(output_path, fake_audio, model.sample_rate)
|
119 |
+
logger.info(f"Saved audio to {output_path}")
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
main()
|
fish_speech/models/dac/modded_dac.py
ADDED
@@ -0,0 +1,1024 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import typing as tp
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, Optional, Union
|
5 |
+
|
6 |
+
import hydra
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
import soundfile as sf
|
10 |
+
import torch
|
11 |
+
from audiotools import AudioSignal
|
12 |
+
from audiotools.ml import BaseModel
|
13 |
+
from dac.model.base import CodecMixin
|
14 |
+
from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
from torch import Tensor, nn
|
17 |
+
from torch.nn import functional as F
|
18 |
+
from torch.nn.utils.parametrizations import weight_norm
|
19 |
+
from torch.nn.utils.parametrize import remove_parametrizations
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class VQResult:
|
24 |
+
z: torch.Tensor
|
25 |
+
codes: torch.Tensor
|
26 |
+
latents: torch.Tensor
|
27 |
+
codebook_loss: torch.Tensor
|
28 |
+
commitment_loss: torch.Tensor
|
29 |
+
semantic_distill_z: torch.Tensor | None = None
|
30 |
+
|
31 |
+
|
32 |
+
def find_multiple(n: int, k: int) -> int:
|
33 |
+
if n % k == 0:
|
34 |
+
return n
|
35 |
+
return n + k - (n % k)
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class ModelArgs:
|
40 |
+
block_size: int = 2048
|
41 |
+
n_layer: int = 8
|
42 |
+
n_head: int = 8
|
43 |
+
dim: int = 512
|
44 |
+
intermediate_size: int = 1536
|
45 |
+
n_local_heads: int = -1
|
46 |
+
head_dim: int = 64
|
47 |
+
rope_base: float = 10000
|
48 |
+
norm_eps: float = 1e-5
|
49 |
+
dropout_rate: float = 0.1
|
50 |
+
attn_dropout_rate: float = 0.1
|
51 |
+
channels_first: bool = True # to be compatible with conv1d input/output
|
52 |
+
pos_embed_type: str = "rope" # can be "rope" or "conformer"
|
53 |
+
max_relative_position: int = 128 # for conformer-style relative position embedding
|
54 |
+
|
55 |
+
def __post_init__(self):
|
56 |
+
if self.n_local_heads == -1:
|
57 |
+
self.n_local_heads = self.n_head
|
58 |
+
if self.intermediate_size is None:
|
59 |
+
hidden_dim = 4 * self.dim
|
60 |
+
n_hidden = int(2 * hidden_dim / 3)
|
61 |
+
self.intermediate_size = find_multiple(n_hidden, 256)
|
62 |
+
assert self.pos_embed_type in [
|
63 |
+
"rope",
|
64 |
+
"conformer",
|
65 |
+
], "pos_embed_type must be either 'rope' or 'conformer'"
|
66 |
+
|
67 |
+
|
68 |
+
class KVCache(nn.Module):
|
69 |
+
def __init__(
|
70 |
+
self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
74 |
+
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
|
75 |
+
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
|
76 |
+
|
77 |
+
def update(self, input_pos, k_val, v_val):
|
78 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
79 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
80 |
+
|
81 |
+
k_out = self.k_cache
|
82 |
+
v_out = self.v_cache
|
83 |
+
k_out[:, :, input_pos] = k_val
|
84 |
+
v_out[:, :, input_pos] = v_val
|
85 |
+
|
86 |
+
return (
|
87 |
+
k_out[:, :, : input_pos.max() + 1, :],
|
88 |
+
v_out[:, :, : input_pos.max() + 1, :],
|
89 |
+
)
|
90 |
+
|
91 |
+
def clear_cache(self, prompt_len):
|
92 |
+
self.k_cache[:, :, prompt_len:, :].fill_(0)
|
93 |
+
self.v_cache[:, :, prompt_len:, :].fill_(0)
|
94 |
+
|
95 |
+
|
96 |
+
class Transformer(nn.Module):
|
97 |
+
def __init__(self, config: ModelArgs) -> None:
|
98 |
+
super().__init__()
|
99 |
+
self.config = config
|
100 |
+
|
101 |
+
self.layers = nn.ModuleList(
|
102 |
+
TransformerBlock(config) for _ in range(config.n_layer)
|
103 |
+
)
|
104 |
+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
105 |
+
|
106 |
+
# Only compute RoPE frequencies if using RoPE
|
107 |
+
if config.pos_embed_type == "rope":
|
108 |
+
freqs_cis = precompute_freqs_cis(
|
109 |
+
self.config.block_size, self.config.head_dim, self.config.rope_base
|
110 |
+
)
|
111 |
+
self.register_buffer("freqs_cis", freqs_cis)
|
112 |
+
else:
|
113 |
+
self.register_buffer("freqs_cis", None)
|
114 |
+
|
115 |
+
causal_mask = torch.tril(
|
116 |
+
torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool)
|
117 |
+
)
|
118 |
+
self.register_buffer("causal_mask", causal_mask)
|
119 |
+
|
120 |
+
self.max_batch_size = -1
|
121 |
+
self.max_seq_length = -1
|
122 |
+
self.use_kv_cache = False
|
123 |
+
|
124 |
+
def setup_caches(self, max_batch_size, max_seq_length):
|
125 |
+
"""
|
126 |
+
This method will only be called during inference when using KV cache.
|
127 |
+
"""
|
128 |
+
head_dim = self.config.dim // self.config.n_head
|
129 |
+
max_seq_length = find_multiple(max_seq_length, 8)
|
130 |
+
self.max_seq_length = max_seq_length
|
131 |
+
self.max_batch_size = max_batch_size
|
132 |
+
dtype = self.norm.weight.dtype
|
133 |
+
device = self.norm.weight.device
|
134 |
+
|
135 |
+
for b in self.layers:
|
136 |
+
b.attention.kv_cache = KVCache(
|
137 |
+
max_batch_size,
|
138 |
+
max_seq_length,
|
139 |
+
self.config.n_local_heads,
|
140 |
+
head_dim,
|
141 |
+
dtype,
|
142 |
+
).to(device)
|
143 |
+
|
144 |
+
self.use_kv_cache = True
|
145 |
+
|
146 |
+
def forward(
|
147 |
+
self,
|
148 |
+
x: Tensor,
|
149 |
+
input_pos: Optional[Tensor] = None,
|
150 |
+
mask: Optional[Tensor] = None,
|
151 |
+
) -> Tensor:
|
152 |
+
if self.config.pos_embed_type == "rope":
|
153 |
+
assert (
|
154 |
+
self.freqs_cis is not None
|
155 |
+
), "RoPE frequencies must be initialized for RoPE positional embedding"
|
156 |
+
freqs_cis = self.freqs_cis[input_pos]
|
157 |
+
else:
|
158 |
+
freqs_cis = None
|
159 |
+
|
160 |
+
if mask is None: # in case of non-causal model
|
161 |
+
if not self.training and self.use_kv_cache:
|
162 |
+
mask = self.causal_mask[None, None, input_pos]
|
163 |
+
mask = mask[..., : input_pos.max() + 1]
|
164 |
+
else:
|
165 |
+
mask = self.causal_mask[None, None, input_pos]
|
166 |
+
mask = mask[..., input_pos]
|
167 |
+
|
168 |
+
for i, layer in enumerate(self.layers):
|
169 |
+
x = layer(x, input_pos, freqs_cis, mask)
|
170 |
+
x = self.norm(x)
|
171 |
+
return x
|
172 |
+
|
173 |
+
|
174 |
+
class TransformerBlock(nn.Module):
|
175 |
+
def __init__(self, config: ModelArgs) -> None:
|
176 |
+
super().__init__()
|
177 |
+
self.attention = Attention(config)
|
178 |
+
self.feed_forward = FeedForward(config)
|
179 |
+
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
180 |
+
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
181 |
+
self.attention_layer_scale = LayerScale(config.dim, inplace=True)
|
182 |
+
self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
|
183 |
+
|
184 |
+
def forward(
|
185 |
+
self,
|
186 |
+
x: Tensor,
|
187 |
+
input_pos: Tensor,
|
188 |
+
freqs_cis: Tensor,
|
189 |
+
mask: Tensor,
|
190 |
+
) -> Tensor:
|
191 |
+
h = x + self.attention_layer_scale(
|
192 |
+
self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
|
193 |
+
)
|
194 |
+
out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
|
195 |
+
return out
|
196 |
+
|
197 |
+
|
198 |
+
class Attention(nn.Module):
|
199 |
+
def __init__(self, config: ModelArgs):
|
200 |
+
super().__init__()
|
201 |
+
assert config.dim % config.n_head == 0
|
202 |
+
|
203 |
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
204 |
+
# key, query, value projections for all heads, but in a batch
|
205 |
+
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
206 |
+
self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
|
207 |
+
self.kv_cache = None
|
208 |
+
|
209 |
+
self.n_head = config.n_head
|
210 |
+
self.head_dim = config.head_dim
|
211 |
+
self.n_local_heads = config.n_local_heads
|
212 |
+
self.dim = config.dim
|
213 |
+
self.attn_dropout_rate = config.attn_dropout_rate
|
214 |
+
self.pos_embed_type = config.pos_embed_type
|
215 |
+
|
216 |
+
# Add relative position embedding for conformer-style
|
217 |
+
if self.pos_embed_type == "conformer":
|
218 |
+
self.max_relative_position = config.max_relative_position
|
219 |
+
num_pos_embeddings = 2 * config.max_relative_position + 1
|
220 |
+
self.rel_pos_embeddings = nn.Parameter(
|
221 |
+
torch.zeros(num_pos_embeddings, self.head_dim)
|
222 |
+
)
|
223 |
+
nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
|
224 |
+
|
225 |
+
def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
|
226 |
+
# q: [B, H, S, D]
|
227 |
+
# Returns: [B, H, S, S]
|
228 |
+
positions = torch.arange(seqlen, device=q.device)
|
229 |
+
relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
|
230 |
+
relative_positions = torch.clamp(
|
231 |
+
relative_positions + self.max_relative_position,
|
232 |
+
0,
|
233 |
+
2 * self.max_relative_position,
|
234 |
+
)
|
235 |
+
rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
|
236 |
+
|
237 |
+
# Compute attention scores with relative position embeddings
|
238 |
+
q = q.transpose(1, 2) # [B, S, H, D]
|
239 |
+
rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
|
240 |
+
rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
|
241 |
+
return rel_logits
|
242 |
+
|
243 |
+
def forward(
|
244 |
+
self,
|
245 |
+
x: Tensor,
|
246 |
+
freqs_cis: Tensor,
|
247 |
+
mask: Tensor,
|
248 |
+
input_pos: Optional[Tensor] = None,
|
249 |
+
) -> Tensor:
|
250 |
+
bsz, seqlen, _ = x.shape
|
251 |
+
|
252 |
+
kv_size = self.n_local_heads * self.head_dim
|
253 |
+
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
|
254 |
+
context_seqlen = seqlen
|
255 |
+
|
256 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
257 |
+
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
258 |
+
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
259 |
+
|
260 |
+
if self.pos_embed_type == "rope":
|
261 |
+
q = apply_rotary_emb(q, freqs_cis)
|
262 |
+
k = apply_rotary_emb(k, freqs_cis)
|
263 |
+
|
264 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
265 |
+
|
266 |
+
if self.kv_cache is not None:
|
267 |
+
k, v = self.kv_cache.update(input_pos, k, v)
|
268 |
+
|
269 |
+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
270 |
+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
271 |
+
|
272 |
+
if self.pos_embed_type == "conformer":
|
273 |
+
# Compute attention scores
|
274 |
+
scale = 1.0 / math.sqrt(self.head_dim)
|
275 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
|
276 |
+
|
277 |
+
# Add relative position embeddings for conformer-style
|
278 |
+
rel_scores = self._compute_conformer_pos_scores(q, seqlen)
|
279 |
+
scores = scores + rel_scores
|
280 |
+
|
281 |
+
# Apply attention
|
282 |
+
if mask is not None:
|
283 |
+
scores = scores.masked_fill(~mask, float("-inf"))
|
284 |
+
|
285 |
+
attn = F.softmax(scores, dim=-1)
|
286 |
+
if self.attn_dropout_rate > 0 and self.training:
|
287 |
+
attn = F.dropout(attn, p=self.attn_dropout_rate)
|
288 |
+
|
289 |
+
y = torch.matmul(attn, v)
|
290 |
+
else:
|
291 |
+
y = F.scaled_dot_product_attention(
|
292 |
+
q,
|
293 |
+
k,
|
294 |
+
v,
|
295 |
+
dropout_p=self.attn_dropout_rate if self.training else 0.0,
|
296 |
+
attn_mask=mask,
|
297 |
+
)
|
298 |
+
# is_causal=True)
|
299 |
+
y = (
|
300 |
+
y.transpose(1, 2)
|
301 |
+
.contiguous()
|
302 |
+
.view(bsz, seqlen, self.head_dim * self.n_head)
|
303 |
+
)
|
304 |
+
y = self.wo(y)
|
305 |
+
return y
|
306 |
+
|
307 |
+
|
308 |
+
class FeedForward(nn.Module):
|
309 |
+
def __init__(self, config: ModelArgs) -> None:
|
310 |
+
super().__init__()
|
311 |
+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
312 |
+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
313 |
+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
314 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
315 |
+
|
316 |
+
def forward(self, x: Tensor) -> Tensor:
|
317 |
+
return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
|
318 |
+
|
319 |
+
|
320 |
+
class RMSNorm(nn.Module):
|
321 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
322 |
+
super().__init__()
|
323 |
+
self.eps = eps
|
324 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
325 |
+
|
326 |
+
def _norm(self, x):
|
327 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
328 |
+
|
329 |
+
def forward(self, x: Tensor) -> Tensor:
|
330 |
+
output = self._norm(x.float()).type_as(x)
|
331 |
+
return output * self.weight
|
332 |
+
|
333 |
+
|
334 |
+
class LayerScale(nn.Module):
|
335 |
+
def __init__(
|
336 |
+
self,
|
337 |
+
dim: int,
|
338 |
+
init_values: Union[float, Tensor] = 1e-2,
|
339 |
+
inplace: bool = False,
|
340 |
+
) -> None:
|
341 |
+
super().__init__()
|
342 |
+
self.inplace = inplace
|
343 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
344 |
+
|
345 |
+
def forward(self, x: Tensor) -> Tensor:
|
346 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
347 |
+
|
348 |
+
|
349 |
+
class WindowLimitedTransformer(Transformer):
|
350 |
+
"""
|
351 |
+
Transformer with window limited attention, causal.
|
352 |
+
"""
|
353 |
+
|
354 |
+
def __init__(
|
355 |
+
self,
|
356 |
+
config: ModelArgs,
|
357 |
+
input_dim: int = 512,
|
358 |
+
window_size: Optional[int] = None,
|
359 |
+
causal: bool = True,
|
360 |
+
look_ahead_conv: nn.Module = None,
|
361 |
+
):
|
362 |
+
super().__init__(config)
|
363 |
+
self.window_size = window_size
|
364 |
+
self.causal = causal
|
365 |
+
self.channels_first = config.channels_first
|
366 |
+
self.look_ahead_conv = (
|
367 |
+
look_ahead_conv if look_ahead_conv is not None else nn.Identity()
|
368 |
+
)
|
369 |
+
self.input_proj = (
|
370 |
+
nn.Linear(input_dim, config.dim)
|
371 |
+
if input_dim != config.dim
|
372 |
+
else nn.Identity()
|
373 |
+
)
|
374 |
+
self.output_proj = (
|
375 |
+
nn.Linear(config.dim, input_dim)
|
376 |
+
if input_dim != config.dim
|
377 |
+
else nn.Identity()
|
378 |
+
)
|
379 |
+
|
380 |
+
def make_window_limited_mask(
|
381 |
+
self,
|
382 |
+
max_length: int,
|
383 |
+
x_lens: Optional[Tensor] = None,
|
384 |
+
) -> Tensor:
|
385 |
+
"""
|
386 |
+
Make mask to form window limited attention.
|
387 |
+
"""
|
388 |
+
if self.causal:
|
389 |
+
mask = torch.tril(torch.ones(max_length, max_length))
|
390 |
+
row_indices = torch.arange(max_length).view(-1, 1)
|
391 |
+
window_size = self.window_size or max_length
|
392 |
+
valid_range = (row_indices - window_size + 1).clamp(min=0)
|
393 |
+
column_indices = torch.arange(max_length)
|
394 |
+
mask = (column_indices >= valid_range) & mask.bool()
|
395 |
+
else:
|
396 |
+
raise NotImplementedError
|
397 |
+
mask = mask.bool()[None, None]
|
398 |
+
return mask
|
399 |
+
|
400 |
+
def make_mask(
|
401 |
+
self,
|
402 |
+
max_length: int,
|
403 |
+
x_lens: Optional[Tensor] = None,
|
404 |
+
) -> Tensor:
|
405 |
+
"""
|
406 |
+
Make ordinary mask if window size is not specified.
|
407 |
+
"""
|
408 |
+
if self.causal:
|
409 |
+
mask = torch.tril(torch.ones(max_length, max_length))
|
410 |
+
else:
|
411 |
+
mask = torch.ones(max_length, max_length)
|
412 |
+
mask = mask.bool()[None, None]
|
413 |
+
for i, x_len in enumerate(x_lens):
|
414 |
+
mask[:x_len, i] = 0
|
415 |
+
mask = mask.bool()[None, None]
|
416 |
+
return mask
|
417 |
+
|
418 |
+
def forward(
|
419 |
+
self,
|
420 |
+
x: Tensor,
|
421 |
+
x_lens: Optional[Tensor] = None,
|
422 |
+
) -> Tensor:
|
423 |
+
if self.channels_first:
|
424 |
+
x = x.transpose(1, 2)
|
425 |
+
x = self.input_proj(x) # (B, T, D)
|
426 |
+
x = self.look_ahead_conv(x)
|
427 |
+
input_pos = torch.arange(x.shape[1], device=x.device)
|
428 |
+
# construct mask to form window limited attention
|
429 |
+
max_length = x.shape[1]
|
430 |
+
if self.window_size is not None:
|
431 |
+
mask = self.make_window_limited_mask(max_length, x_lens)
|
432 |
+
else:
|
433 |
+
mask = self.make_mask(max_length, x_lens)
|
434 |
+
mask = mask.to(x.device)
|
435 |
+
x = super().forward(x, input_pos, mask)
|
436 |
+
x = self.output_proj(x) # (B, T, D)
|
437 |
+
if self.channels_first:
|
438 |
+
x = x.transpose(1, 2)
|
439 |
+
return x
|
440 |
+
|
441 |
+
|
442 |
+
def precompute_freqs_cis(
|
443 |
+
seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
|
444 |
+
) -> Tensor:
|
445 |
+
freqs = 1.0 / (
|
446 |
+
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
|
447 |
+
)
|
448 |
+
t = torch.arange(seq_len, device=freqs.device)
|
449 |
+
freqs = torch.outer(t, freqs)
|
450 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
451 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
452 |
+
return cache.to(dtype=dtype)
|
453 |
+
|
454 |
+
|
455 |
+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
456 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
457 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
458 |
+
x_out2 = torch.stack(
|
459 |
+
[
|
460 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
461 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
462 |
+
],
|
463 |
+
-1,
|
464 |
+
)
|
465 |
+
|
466 |
+
x_out2 = x_out2.flatten(3)
|
467 |
+
return x_out2.type_as(x)
|
468 |
+
|
469 |
+
|
470 |
+
def init_weights(m):
|
471 |
+
if isinstance(m, nn.Conv1d):
|
472 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
473 |
+
nn.init.constant_(m.bias, 0)
|
474 |
+
|
475 |
+
|
476 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
477 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
478 |
+
padding_left, padding_right = paddings
|
479 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
480 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
481 |
+
end = x.shape[-1] - padding_right
|
482 |
+
return x[..., padding_left:end]
|
483 |
+
|
484 |
+
|
485 |
+
def get_extra_padding_for_conv1d(
|
486 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
487 |
+
) -> int:
|
488 |
+
"""See `pad_for_conv1d`."""
|
489 |
+
length = x.shape[-1]
|
490 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
491 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
492 |
+
return ideal_length - length
|
493 |
+
|
494 |
+
|
495 |
+
def pad1d(
|
496 |
+
x: torch.Tensor,
|
497 |
+
paddings: tp.Tuple[int, int],
|
498 |
+
mode: str = "zeros",
|
499 |
+
value: float = 0.0,
|
500 |
+
):
|
501 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
502 |
+
If this is the case, we insert extra 0 padding to the right
|
503 |
+
before the reflection happen.
|
504 |
+
"""
|
505 |
+
length = x.shape[-1]
|
506 |
+
padding_left, padding_right = paddings
|
507 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
508 |
+
if mode == "reflect":
|
509 |
+
max_pad = max(padding_left, padding_right)
|
510 |
+
extra_pad = 0
|
511 |
+
if length <= max_pad:
|
512 |
+
extra_pad = max_pad - length + 1
|
513 |
+
x = F.pad(x, (0, extra_pad))
|
514 |
+
padded = F.pad(x, paddings, mode, value)
|
515 |
+
end = padded.shape[-1] - extra_pad
|
516 |
+
return padded[..., :end]
|
517 |
+
else:
|
518 |
+
return F.pad(x, paddings, mode, value)
|
519 |
+
|
520 |
+
|
521 |
+
class CausalConvNet(nn.Module):
|
522 |
+
def __init__(
|
523 |
+
self,
|
524 |
+
in_channels,
|
525 |
+
out_channels,
|
526 |
+
kernel_size,
|
527 |
+
dilation=1,
|
528 |
+
stride=1,
|
529 |
+
groups=1,
|
530 |
+
padding=None,
|
531 |
+
):
|
532 |
+
super(CausalConvNet, self).__init__()
|
533 |
+
self.conv = nn.Conv1d(
|
534 |
+
in_channels,
|
535 |
+
out_channels,
|
536 |
+
kernel_size,
|
537 |
+
stride=stride,
|
538 |
+
dilation=dilation,
|
539 |
+
groups=groups,
|
540 |
+
)
|
541 |
+
self.stride = stride
|
542 |
+
self.kernel_size = (kernel_size - 1) * dilation + 1
|
543 |
+
self.dilation = dilation
|
544 |
+
self.padding = self.kernel_size - self.stride
|
545 |
+
|
546 |
+
def forward(self, x):
|
547 |
+
pad = self.padding
|
548 |
+
extra_padding = get_extra_padding_for_conv1d(
|
549 |
+
x, self.kernel_size, self.stride, pad
|
550 |
+
)
|
551 |
+
x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
|
552 |
+
return self.conv(x).contiguous()
|
553 |
+
|
554 |
+
def weight_norm(self, name="weight", dim=0):
|
555 |
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
556 |
+
return self
|
557 |
+
|
558 |
+
def remove_weight_norm(self):
|
559 |
+
self.conv = remove_parametrizations(self.conv)
|
560 |
+
return self
|
561 |
+
|
562 |
+
|
563 |
+
class CausalTransConvNet(nn.Module):
|
564 |
+
def __init__(
|
565 |
+
self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
|
566 |
+
):
|
567 |
+
super(CausalTransConvNet, self).__init__()
|
568 |
+
self.conv = nn.ConvTranspose1d(
|
569 |
+
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
|
570 |
+
)
|
571 |
+
self.stride = stride
|
572 |
+
self.kernel_size = kernel_size
|
573 |
+
|
574 |
+
def forward(self, x):
|
575 |
+
x = self.conv(x)
|
576 |
+
pad = self.kernel_size - self.stride
|
577 |
+
padding_right = math.ceil(pad)
|
578 |
+
padding_left = pad - padding_right
|
579 |
+
x = unpad1d(x, (padding_left, padding_right))
|
580 |
+
return x.contiguous()
|
581 |
+
|
582 |
+
def weight_norm(self, name="weight", dim=0):
|
583 |
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
584 |
+
return self
|
585 |
+
|
586 |
+
def remove_weight_norm(self):
|
587 |
+
self.conv = remove_parametrizations(self.conv)
|
588 |
+
return self
|
589 |
+
|
590 |
+
|
591 |
+
def CausalWNConv1d(*args, **kwargs):
|
592 |
+
return CausalConvNet(*args, **kwargs).weight_norm()
|
593 |
+
|
594 |
+
|
595 |
+
def CausalWNConvTranspose1d(*args, **kwargs):
|
596 |
+
return CausalTransConvNet(*args, **kwargs).weight_norm()
|
597 |
+
|
598 |
+
|
599 |
+
class ResidualUnit(nn.Module):
|
600 |
+
def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
|
601 |
+
super().__init__()
|
602 |
+
conv_class = CausalWNConv1d if causal else WNConv1d
|
603 |
+
pad = ((7 - 1) * dilation) // 2
|
604 |
+
self.block = nn.Sequential(
|
605 |
+
Snake1d(dim),
|
606 |
+
conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
607 |
+
Snake1d(dim),
|
608 |
+
conv_class(dim, dim, kernel_size=1),
|
609 |
+
)
|
610 |
+
self.causal = causal
|
611 |
+
|
612 |
+
def forward(self, x):
|
613 |
+
y = self.block(x)
|
614 |
+
pad = x.shape[-1] - y.shape[-1]
|
615 |
+
if pad > 0:
|
616 |
+
if self.causal:
|
617 |
+
x = x[..., :-pad]
|
618 |
+
else:
|
619 |
+
x = x[..., pad // 2 : -pad // 2]
|
620 |
+
return x + y
|
621 |
+
|
622 |
+
|
623 |
+
class EncoderBlock(nn.Module):
|
624 |
+
def __init__(
|
625 |
+
self,
|
626 |
+
dim: int = 16,
|
627 |
+
stride: int = 1,
|
628 |
+
causal: bool = False,
|
629 |
+
n_t_layer: int = 0,
|
630 |
+
transformer_general_config=None,
|
631 |
+
):
|
632 |
+
super().__init__()
|
633 |
+
conv_class = CausalWNConv1d if causal else WNConv1d
|
634 |
+
transformer_module = (
|
635 |
+
nn.Identity()
|
636 |
+
if n_t_layer == 0
|
637 |
+
else (
|
638 |
+
WindowLimitedTransformer(
|
639 |
+
causal=causal,
|
640 |
+
input_dim=dim,
|
641 |
+
window_size=512,
|
642 |
+
config=transformer_general_config(
|
643 |
+
n_layer=n_t_layer,
|
644 |
+
n_head=dim // 64,
|
645 |
+
dim=dim,
|
646 |
+
intermediate_size=dim * 3,
|
647 |
+
),
|
648 |
+
)
|
649 |
+
)
|
650 |
+
)
|
651 |
+
self.block = nn.Sequential(
|
652 |
+
ResidualUnit(dim // 2, dilation=1, causal=causal),
|
653 |
+
ResidualUnit(dim // 2, dilation=3, causal=causal),
|
654 |
+
ResidualUnit(dim // 2, dilation=9, causal=causal),
|
655 |
+
Snake1d(dim // 2),
|
656 |
+
conv_class(
|
657 |
+
dim // 2,
|
658 |
+
dim,
|
659 |
+
kernel_size=2 * stride,
|
660 |
+
stride=stride,
|
661 |
+
padding=math.ceil(stride / 2),
|
662 |
+
),
|
663 |
+
transformer_module,
|
664 |
+
)
|
665 |
+
|
666 |
+
def forward(self, x):
|
667 |
+
return self.block(x)
|
668 |
+
|
669 |
+
|
670 |
+
class Encoder(nn.Module):
|
671 |
+
def __init__(
|
672 |
+
self,
|
673 |
+
d_model: int = 64,
|
674 |
+
strides: list = [2, 4, 8, 8],
|
675 |
+
d_latent: int = 64,
|
676 |
+
n_transformer_layers: list = [0, 0, 4, 4],
|
677 |
+
transformer_general_config: ModelArgs = None,
|
678 |
+
causal: bool = False,
|
679 |
+
):
|
680 |
+
super().__init__()
|
681 |
+
conv_class = CausalWNConv1d if causal else WNConv1d
|
682 |
+
# Create first convolution
|
683 |
+
self.block = [conv_class(1, d_model, kernel_size=7, padding=3)]
|
684 |
+
|
685 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
686 |
+
for stride, n_t_layer in zip(strides, n_transformer_layers):
|
687 |
+
d_model *= 2
|
688 |
+
self.block += [
|
689 |
+
EncoderBlock(
|
690 |
+
d_model,
|
691 |
+
stride=stride,
|
692 |
+
causal=causal,
|
693 |
+
n_t_layer=n_t_layer,
|
694 |
+
transformer_general_config=transformer_general_config,
|
695 |
+
)
|
696 |
+
]
|
697 |
+
|
698 |
+
# Create last convolution
|
699 |
+
self.block += [
|
700 |
+
Snake1d(d_model),
|
701 |
+
conv_class(d_model, d_latent, kernel_size=3, padding=1),
|
702 |
+
]
|
703 |
+
|
704 |
+
# Wrap black into nn.Sequential
|
705 |
+
self.block = nn.Sequential(*self.block)
|
706 |
+
self.enc_dim = d_model
|
707 |
+
|
708 |
+
def forward(self, x):
|
709 |
+
return self.block(x)
|
710 |
+
|
711 |
+
|
712 |
+
class DecoderBlock(nn.Module):
|
713 |
+
def __init__(
|
714 |
+
self,
|
715 |
+
input_dim: int = 16,
|
716 |
+
output_dim: int = 8,
|
717 |
+
stride: int = 1,
|
718 |
+
causal: bool = False,
|
719 |
+
n_t_layer: int = 0,
|
720 |
+
transformer_general_config=None,
|
721 |
+
):
|
722 |
+
super().__init__()
|
723 |
+
conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
|
724 |
+
transformer_module = (
|
725 |
+
nn.Identity()
|
726 |
+
if n_t_layer == 0
|
727 |
+
else (
|
728 |
+
WindowLimitedTransformer(
|
729 |
+
causal=causal,
|
730 |
+
input_dim=input_dim,
|
731 |
+
window_size=None,
|
732 |
+
config=transformer_general_config(
|
733 |
+
n_layer=n_t_layer,
|
734 |
+
n_head=input_dim // 64,
|
735 |
+
dim=input_dim,
|
736 |
+
intermediate_size=input_dim * 3,
|
737 |
+
),
|
738 |
+
)
|
739 |
+
)
|
740 |
+
)
|
741 |
+
self.block = nn.Sequential(
|
742 |
+
# transformer_module,
|
743 |
+
Snake1d(input_dim),
|
744 |
+
conv_trans_class(
|
745 |
+
input_dim,
|
746 |
+
output_dim,
|
747 |
+
kernel_size=2 * stride,
|
748 |
+
stride=stride,
|
749 |
+
padding=math.ceil(stride / 2),
|
750 |
+
),
|
751 |
+
ResidualUnit(output_dim, dilation=1, causal=causal),
|
752 |
+
ResidualUnit(output_dim, dilation=3, causal=causal),
|
753 |
+
ResidualUnit(output_dim, dilation=9, causal=causal),
|
754 |
+
)
|
755 |
+
|
756 |
+
def forward(self, x):
|
757 |
+
return self.block(x)
|
758 |
+
|
759 |
+
|
760 |
+
class Decoder(nn.Module):
|
761 |
+
def __init__(
|
762 |
+
self,
|
763 |
+
input_channel,
|
764 |
+
channels,
|
765 |
+
rates,
|
766 |
+
d_out: int = 1,
|
767 |
+
causal: bool = False,
|
768 |
+
n_transformer_layers: list = [0, 0, 0, 0],
|
769 |
+
transformer_general_config=None,
|
770 |
+
):
|
771 |
+
super().__init__()
|
772 |
+
conv_class = CausalWNConv1d if causal else WNConv1d
|
773 |
+
# Add first conv layer
|
774 |
+
layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
|
775 |
+
|
776 |
+
# Add upsampling + MRF blocks
|
777 |
+
for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
|
778 |
+
input_dim = channels // 2**i
|
779 |
+
output_dim = channels // 2 ** (i + 1)
|
780 |
+
layers += [
|
781 |
+
DecoderBlock(
|
782 |
+
input_dim,
|
783 |
+
output_dim,
|
784 |
+
stride,
|
785 |
+
causal=causal,
|
786 |
+
n_t_layer=n_t_layer,
|
787 |
+
transformer_general_config=transformer_general_config,
|
788 |
+
)
|
789 |
+
]
|
790 |
+
|
791 |
+
# Add final conv layer
|
792 |
+
layers += [
|
793 |
+
Snake1d(output_dim),
|
794 |
+
conv_class(output_dim, d_out, kernel_size=7, padding=3),
|
795 |
+
nn.Tanh(),
|
796 |
+
]
|
797 |
+
|
798 |
+
self.model = nn.Sequential(*layers)
|
799 |
+
|
800 |
+
def forward(self, x):
|
801 |
+
return self.model(x)
|
802 |
+
|
803 |
+
|
804 |
+
class DAC(BaseModel, CodecMixin):
|
805 |
+
def __init__(
|
806 |
+
self,
|
807 |
+
encoder_dim: int = 64,
|
808 |
+
encoder_rates: List[int] = [2, 4, 8, 8],
|
809 |
+
latent_dim: int = None,
|
810 |
+
decoder_dim: int = 1536,
|
811 |
+
decoder_rates: List[int] = [8, 8, 4, 2],
|
812 |
+
quantizer: torch.nn.Module = None,
|
813 |
+
sample_rate: int = 44100,
|
814 |
+
causal: bool = True,
|
815 |
+
encoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
816 |
+
decoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
817 |
+
transformer_general_config=None,
|
818 |
+
):
|
819 |
+
super().__init__()
|
820 |
+
|
821 |
+
self.encoder_dim = encoder_dim
|
822 |
+
self.encoder_rates = encoder_rates
|
823 |
+
self.decoder_dim = decoder_dim
|
824 |
+
self.decoder_rates = decoder_rates
|
825 |
+
self.sample_rate = sample_rate
|
826 |
+
|
827 |
+
if latent_dim is None:
|
828 |
+
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
829 |
+
|
830 |
+
self.latent_dim = latent_dim
|
831 |
+
|
832 |
+
self.hop_length = np.prod(encoder_rates)
|
833 |
+
self.encoder = Encoder(
|
834 |
+
encoder_dim,
|
835 |
+
encoder_rates,
|
836 |
+
latent_dim,
|
837 |
+
causal=causal,
|
838 |
+
n_transformer_layers=encoder_transformer_layers,
|
839 |
+
transformer_general_config=transformer_general_config,
|
840 |
+
)
|
841 |
+
|
842 |
+
self.quantizer = quantizer
|
843 |
+
|
844 |
+
self.decoder = Decoder(
|
845 |
+
latent_dim,
|
846 |
+
decoder_dim,
|
847 |
+
decoder_rates,
|
848 |
+
causal=causal,
|
849 |
+
n_transformer_layers=decoder_transformer_layers,
|
850 |
+
transformer_general_config=transformer_general_config,
|
851 |
+
)
|
852 |
+
self.sample_rate = sample_rate
|
853 |
+
self.apply(init_weights)
|
854 |
+
|
855 |
+
self.delay = self.get_delay()
|
856 |
+
|
857 |
+
self.frame_length = self.hop_length * 4
|
858 |
+
|
859 |
+
def preprocess(self, audio_data, sample_rate):
|
860 |
+
if sample_rate is None:
|
861 |
+
sample_rate = self.sample_rate
|
862 |
+
assert sample_rate == self.sample_rate
|
863 |
+
|
864 |
+
length = audio_data.shape[-1]
|
865 |
+
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
|
866 |
+
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
867 |
+
|
868 |
+
return audio_data
|
869 |
+
|
870 |
+
def encode(
|
871 |
+
self,
|
872 |
+
audio_data: torch.Tensor,
|
873 |
+
audio_lengths: torch.Tensor = None,
|
874 |
+
n_quantizers: int = None,
|
875 |
+
**kwargs,
|
876 |
+
):
|
877 |
+
"""Encode given audio data and return quantized latent codes
|
878 |
+
|
879 |
+
Parameters
|
880 |
+
----------
|
881 |
+
audio_data : Tensor[B x T]
|
882 |
+
Audio data to encode
|
883 |
+
n_quantizers : int, optional
|
884 |
+
Number of quantizers to use, by default None
|
885 |
+
If None, all quantizers are used.
|
886 |
+
|
887 |
+
Returns
|
888 |
+
-------
|
889 |
+
dict
|
890 |
+
A dictionary with the following keys:
|
891 |
+
"z" : Tensor[B x D x T]
|
892 |
+
Quantized continuous representation of input
|
893 |
+
"codes" : Tensor[B x N x T]
|
894 |
+
Codebook indices for each codebook
|
895 |
+
(quantized discrete representation of input)
|
896 |
+
"latents" : Tensor[B x N*D x T]
|
897 |
+
Projected latents (continuous representation of input before quantization)
|
898 |
+
"vq/commitment_loss" : Tensor[1]
|
899 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
900 |
+
entries
|
901 |
+
"vq/codebook_loss" : Tensor[1]
|
902 |
+
Codebook loss to update the codebook
|
903 |
+
"length" : int
|
904 |
+
Number of samples in input audio
|
905 |
+
"""
|
906 |
+
# pad to multiple of self.frame_length
|
907 |
+
if audio_data.ndim == 2:
|
908 |
+
audio_data = audio_data.unsqueeze(1)
|
909 |
+
# print(audio_data.shape)
|
910 |
+
length = audio_data.shape[-1]
|
911 |
+
right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
|
912 |
+
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
913 |
+
if audio_lengths is None:
|
914 |
+
audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
|
915 |
+
|
916 |
+
z = self.encoder(audio_data)
|
917 |
+
vq_results = self.quantizer(z, n_quantizers, **kwargs)
|
918 |
+
indices = vq_results.codes
|
919 |
+
indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
|
920 |
+
return indices, indices_lens
|
921 |
+
|
922 |
+
def decode(self, indices: torch.Tensor, feature_lengths):
|
923 |
+
if indices.ndim == 2:
|
924 |
+
indices = indices[None]
|
925 |
+
|
926 |
+
z = self.quantizer.decode(indices)
|
927 |
+
audio_lengths = feature_lengths * self.frame_length
|
928 |
+
return self.decoder(z), audio_lengths
|
929 |
+
|
930 |
+
def forward(
|
931 |
+
self,
|
932 |
+
audio_data: torch.Tensor,
|
933 |
+
template: torch.Tensor = None,
|
934 |
+
mask: torch.Tensor = None,
|
935 |
+
sample_rate: int = None,
|
936 |
+
n_quantizers: int = None,
|
937 |
+
**kwargs,
|
938 |
+
):
|
939 |
+
"""Model forward pass
|
940 |
+
|
941 |
+
Parameters
|
942 |
+
----------
|
943 |
+
audio_data : Tensor[B x 1 x T]
|
944 |
+
Audio data to encode
|
945 |
+
sample_rate : int, optional
|
946 |
+
Sample rate of audio data in Hz, by default None
|
947 |
+
If None, defaults to `self.sample_rate`
|
948 |
+
n_quantizers : int, optional
|
949 |
+
Number of quantizers to use, by default None.
|
950 |
+
If None, all quantizers are used.
|
951 |
+
|
952 |
+
Returns
|
953 |
+
-------
|
954 |
+
dict
|
955 |
+
A dictionary with the following keys:
|
956 |
+
"z" : Tensor[B x D x T]
|
957 |
+
Quantized continuous representation of input
|
958 |
+
"codes" : Tensor[B x N x T]
|
959 |
+
Codebook indices for each codebook
|
960 |
+
(quantized discrete representation of input)
|
961 |
+
"latents" : Tensor[B x N*D x T]
|
962 |
+
Projected latents (continuous representation of input before quantization)
|
963 |
+
"vq/commitment_loss" : Tensor[1]
|
964 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
965 |
+
entries
|
966 |
+
"vq/codebook_loss" : Tensor[1]
|
967 |
+
Codebook loss to update the codebook
|
968 |
+
"length" : int
|
969 |
+
Number of samples in input audio
|
970 |
+
"audio" : Tensor[B x 1 x length]
|
971 |
+
Decoded audio data.
|
972 |
+
"""
|
973 |
+
length = audio_data.shape[-1]
|
974 |
+
audio_data = self.preprocess(audio_data, sample_rate)
|
975 |
+
vq_results = self.encode(audio_data, n_quantizers, **kwargs)
|
976 |
+
z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
|
977 |
+
x = self.decode(z)
|
978 |
+
return x[..., :length], vq_results
|
979 |
+
|
980 |
+
|
981 |
+
if __name__ == "__main__":
|
982 |
+
|
983 |
+
def filter_state_dict_shapes(params, model):
|
984 |
+
model_state_dict = model.state_dict()
|
985 |
+
filtered_state_dict = {
|
986 |
+
k: v
|
987 |
+
for k, v in params.items()
|
988 |
+
if k in model_state_dict and v.shape == model_state_dict[k].shape
|
989 |
+
}
|
990 |
+
skipped_keys = set(params.keys()) - set(filtered_state_dict.keys())
|
991 |
+
if skipped_keys:
|
992 |
+
print(
|
993 |
+
f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
|
994 |
+
)
|
995 |
+
return filtered_state_dict, skipped_keys
|
996 |
+
|
997 |
+
model = hydra.utils.instantiate(
|
998 |
+
OmegaConf.load("fish_speech/configs/modded_dac_vq.yaml")
|
999 |
+
)
|
1000 |
+
sd = torch.load("checkpoints/openaudio-s1-mini/firefly-gan-large.pth")
|
1001 |
+
filtered_sd, skipped_keys = filter_state_dict_shapes(sd, model)
|
1002 |
+
print(f"Skipped keys: {skipped_keys}")
|
1003 |
+
model.load_state_dict(filtered_sd, strict=False)
|
1004 |
+
model.eval()
|
1005 |
+
|
1006 |
+
src_audio_path = "./test.wav"
|
1007 |
+
wave_np, _ = librosa.load(src_audio_path, sr=44100, mono=False)
|
1008 |
+
if len(wave_np.shape) == 1:
|
1009 |
+
wave_np = wave_np[None, :]
|
1010 |
+
wave_tensor = torch.from_numpy(wave_np).unsqueeze(1)
|
1011 |
+
|
1012 |
+
with torch.no_grad():
|
1013 |
+
# encode 返回 (indices, indices_lens)
|
1014 |
+
indices, indices_lens = model.encode(wave_tensor)
|
1015 |
+
print(f"Indices shape: {indices.shape}")
|
1016 |
+
print(f"Indices lengths: {indices_lens}")
|
1017 |
+
|
1018 |
+
# decode 需要 indices 和 feature_lengths 两个参数
|
1019 |
+
fake_audio, audio_lengths = model.decode(indices, indices_lens)
|
1020 |
+
print(f"Decoded audio shape: {fake_audio.shape}")
|
1021 |
+
print(f"Audio lengths: {audio_lengths}")
|
1022 |
+
|
1023 |
+
# 保存重建的音频
|
1024 |
+
sf.write("fake.wav", fake_audio.squeeze(1).cpu().numpy().T, 44100)
|
fish_speech/models/dac/rvq.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import typing as tp
|
3 |
+
from dataclasses import dataclass
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from dac.nn.quantize import ResidualVectorQuantize
|
9 |
+
from torch.nn.utils.parametrizations import weight_norm
|
10 |
+
from torch.nn.utils.parametrize import remove_parametrizations
|
11 |
+
|
12 |
+
|
13 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
14 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
15 |
+
padding_left, padding_right = paddings
|
16 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
17 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
18 |
+
end = x.shape[-1] - padding_right
|
19 |
+
return x[..., padding_left:end]
|
20 |
+
|
21 |
+
|
22 |
+
def get_extra_padding_for_conv1d(
|
23 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
24 |
+
) -> int:
|
25 |
+
"""See `pad_for_conv1d`."""
|
26 |
+
length = x.shape[-1]
|
27 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
28 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
29 |
+
return ideal_length - length
|
30 |
+
|
31 |
+
|
32 |
+
def pad1d(
|
33 |
+
x: torch.Tensor,
|
34 |
+
paddings: tp.Tuple[int, int],
|
35 |
+
mode: str = "zeros",
|
36 |
+
value: float = 0.0,
|
37 |
+
):
|
38 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
39 |
+
If this is the case, we insert extra 0 padding to the right
|
40 |
+
before the reflection happen.
|
41 |
+
"""
|
42 |
+
length = x.shape[-1]
|
43 |
+
padding_left, padding_right = paddings
|
44 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
45 |
+
if mode == "reflect":
|
46 |
+
max_pad = max(padding_left, padding_right)
|
47 |
+
extra_pad = 0
|
48 |
+
if length <= max_pad:
|
49 |
+
extra_pad = max_pad - length + 1
|
50 |
+
x = F.pad(x, (0, extra_pad))
|
51 |
+
padded = F.pad(x, paddings, mode, value)
|
52 |
+
end = padded.shape[-1] - extra_pad
|
53 |
+
return padded[..., :end]
|
54 |
+
else:
|
55 |
+
return F.pad(x, paddings, mode, value)
|
56 |
+
|
57 |
+
|
58 |
+
class CausalConvNet(nn.Module):
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
in_channels,
|
62 |
+
out_channels,
|
63 |
+
kernel_size,
|
64 |
+
dilation=1,
|
65 |
+
stride=1,
|
66 |
+
groups=1,
|
67 |
+
padding=None,
|
68 |
+
):
|
69 |
+
super(CausalConvNet, self).__init__()
|
70 |
+
self.conv = nn.Conv1d(
|
71 |
+
in_channels,
|
72 |
+
out_channels,
|
73 |
+
kernel_size,
|
74 |
+
stride=stride,
|
75 |
+
dilation=dilation,
|
76 |
+
groups=groups,
|
77 |
+
)
|
78 |
+
self.stride = stride
|
79 |
+
self.kernel_size = (kernel_size - 1) * dilation + 1
|
80 |
+
self.dilation = dilation
|
81 |
+
self.padding = self.kernel_size - self.stride
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
pad = self.padding
|
85 |
+
extra_padding = get_extra_padding_for_conv1d(
|
86 |
+
x, self.kernel_size, self.stride, pad
|
87 |
+
)
|
88 |
+
x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
|
89 |
+
return self.conv(x).contiguous()
|
90 |
+
|
91 |
+
def weight_norm(self, name="weight", dim=0):
|
92 |
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
93 |
+
return self
|
94 |
+
|
95 |
+
def remove_weight_norm(self):
|
96 |
+
self.conv = remove_parametrizations(self.conv)
|
97 |
+
return self
|
98 |
+
|
99 |
+
|
100 |
+
class CausalTransConvNet(nn.Module):
|
101 |
+
def __init__(
|
102 |
+
self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
|
103 |
+
):
|
104 |
+
super(CausalTransConvNet, self).__init__()
|
105 |
+
self.conv = nn.ConvTranspose1d(
|
106 |
+
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
|
107 |
+
)
|
108 |
+
self.stride = stride
|
109 |
+
self.kernel_size = kernel_size
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
x = self.conv(x)
|
113 |
+
pad = self.kernel_size - self.stride
|
114 |
+
padding_right = math.ceil(pad)
|
115 |
+
padding_left = pad - padding_right
|
116 |
+
x = unpad1d(x, (padding_left, padding_right))
|
117 |
+
return x.contiguous()
|
118 |
+
|
119 |
+
def weight_norm(self, name="weight", dim=0):
|
120 |
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
121 |
+
return self
|
122 |
+
|
123 |
+
def remove_weight_norm(self):
|
124 |
+
self.conv = remove_parametrizations(self.conv)
|
125 |
+
return self
|
126 |
+
|
127 |
+
|
128 |
+
# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
|
129 |
+
class ConvNeXtBlock(nn.Module):
|
130 |
+
r"""ConvNeXt Block. There are two equivalent implementations:
|
131 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
132 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
133 |
+
We use (2) as we find it slightly faster in PyTorch
|
134 |
+
Args:
|
135 |
+
dim (int): Number of input channels.
|
136 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
137 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
138 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
139 |
+
kernel_size (int): Kernel size for depthwise conv. Default: 7.
|
140 |
+
dilation (int): Dilation for depthwise conv. Default: 1.
|
141 |
+
""" # noqa: E501
|
142 |
+
|
143 |
+
def __init__(
|
144 |
+
self,
|
145 |
+
dim: int,
|
146 |
+
layer_scale_init_value: float = 1e-6,
|
147 |
+
mlp_ratio: float = 4.0,
|
148 |
+
kernel_size: int = 7,
|
149 |
+
dilation: int = 1,
|
150 |
+
):
|
151 |
+
super().__init__()
|
152 |
+
convnet_type = CausalConvNet
|
153 |
+
self.dwconv = convnet_type(
|
154 |
+
dim,
|
155 |
+
dim,
|
156 |
+
kernel_size=kernel_size,
|
157 |
+
# padding=int(dilation * (kernel_size - 1) / 2),
|
158 |
+
groups=dim,
|
159 |
+
dilation=dilation,
|
160 |
+
) # depthwise conv
|
161 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
162 |
+
self.pwconv1 = nn.Linear(
|
163 |
+
dim, int(mlp_ratio * dim)
|
164 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
165 |
+
self.act = nn.GELU()
|
166 |
+
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
|
167 |
+
self.gamma = (
|
168 |
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
169 |
+
if layer_scale_init_value > 0
|
170 |
+
else None
|
171 |
+
)
|
172 |
+
|
173 |
+
def forward(self, x, apply_residual: bool = True):
|
174 |
+
input = x
|
175 |
+
|
176 |
+
x = self.dwconv(x)
|
177 |
+
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
178 |
+
x = self.norm(x)
|
179 |
+
x = self.pwconv1(x)
|
180 |
+
x = self.act(x)
|
181 |
+
x = self.pwconv2(x)
|
182 |
+
|
183 |
+
if self.gamma is not None:
|
184 |
+
x = self.gamma * x
|
185 |
+
|
186 |
+
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
187 |
+
|
188 |
+
if apply_residual:
|
189 |
+
x = input + x
|
190 |
+
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
@dataclass
|
195 |
+
class VQResult:
|
196 |
+
z: torch.Tensor
|
197 |
+
codes: torch.Tensor
|
198 |
+
latents: torch.Tensor
|
199 |
+
codebook_loss: torch.Tensor
|
200 |
+
commitment_loss: torch.Tensor
|
201 |
+
semantic_distill_z: torch.Tensor | None = None
|
202 |
+
|
203 |
+
|
204 |
+
class DownsampleResidualVectorQuantize(nn.Module):
|
205 |
+
def __init__(
|
206 |
+
self,
|
207 |
+
input_dim: int = 1024,
|
208 |
+
n_codebooks: int = 9,
|
209 |
+
codebook_dim: int = 8,
|
210 |
+
quantizer_dropout: float = 0.5,
|
211 |
+
codebook_size: int = 1024,
|
212 |
+
semantic_codebook_size: int = 4096,
|
213 |
+
downsample_factor: tuple[int] = (2, 2),
|
214 |
+
downsample_dims: tuple[int] | None = None,
|
215 |
+
pre_module: nn.Module | None = None,
|
216 |
+
post_module: nn.Module | None = None,
|
217 |
+
semantic_predictor_module: nn.Module | None = None,
|
218 |
+
):
|
219 |
+
super().__init__()
|
220 |
+
|
221 |
+
if downsample_dims is None:
|
222 |
+
downsample_dims = [input_dim for _ in range(len(downsample_factor))]
|
223 |
+
|
224 |
+
all_dims = (input_dim,) + tuple(downsample_dims)
|
225 |
+
|
226 |
+
self.semantic_quantizer = ResidualVectorQuantize(
|
227 |
+
input_dim=input_dim,
|
228 |
+
n_codebooks=1,
|
229 |
+
codebook_size=semantic_codebook_size,
|
230 |
+
codebook_dim=codebook_dim,
|
231 |
+
quantizer_dropout=0.0,
|
232 |
+
)
|
233 |
+
|
234 |
+
self.quantizer = ResidualVectorQuantize(
|
235 |
+
input_dim=input_dim,
|
236 |
+
n_codebooks=n_codebooks,
|
237 |
+
codebook_size=codebook_size,
|
238 |
+
codebook_dim=codebook_dim,
|
239 |
+
quantizer_dropout=quantizer_dropout,
|
240 |
+
)
|
241 |
+
|
242 |
+
self.downsample_factor = downsample_factor
|
243 |
+
self.downsample_dims = downsample_dims
|
244 |
+
|
245 |
+
convnet_type = CausalConvNet
|
246 |
+
transconvnet_type = CausalTransConvNet
|
247 |
+
|
248 |
+
self.downsample = nn.Sequential(
|
249 |
+
*[
|
250 |
+
nn.Sequential(
|
251 |
+
convnet_type(
|
252 |
+
all_dims[idx],
|
253 |
+
all_dims[idx + 1],
|
254 |
+
kernel_size=factor,
|
255 |
+
stride=factor,
|
256 |
+
),
|
257 |
+
ConvNeXtBlock(dim=all_dims[idx + 1]),
|
258 |
+
)
|
259 |
+
for idx, factor in enumerate(downsample_factor)
|
260 |
+
]
|
261 |
+
)
|
262 |
+
|
263 |
+
self.upsample = nn.Sequential(
|
264 |
+
*[
|
265 |
+
nn.Sequential(
|
266 |
+
transconvnet_type(
|
267 |
+
all_dims[idx + 1],
|
268 |
+
all_dims[idx],
|
269 |
+
kernel_size=factor,
|
270 |
+
stride=factor,
|
271 |
+
),
|
272 |
+
ConvNeXtBlock(dim=all_dims[idx]),
|
273 |
+
)
|
274 |
+
for idx, factor in reversed(list(enumerate(downsample_factor)))
|
275 |
+
]
|
276 |
+
)
|
277 |
+
self.apply(self._init_weights)
|
278 |
+
self.pre_module = (
|
279 |
+
pre_module if pre_module is not None else nn.Identity()
|
280 |
+
) # leave for transformer, LSTM or Mamba or something else
|
281 |
+
self.post_module = post_module if post_module is not None else nn.Identity()
|
282 |
+
self.semantic_predictor_module = (
|
283 |
+
semantic_predictor_module
|
284 |
+
if semantic_predictor_module is not None
|
285 |
+
else nn.Identity()
|
286 |
+
)
|
287 |
+
|
288 |
+
def _init_weights(self, m):
|
289 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
290 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
291 |
+
nn.init.constant_(m.bias, 0)
|
292 |
+
|
293 |
+
def forward(
|
294 |
+
self, z, n_quantizers: int = None, semantic_len: torch.Tensor = None, **kwargs
|
295 |
+
):
|
296 |
+
# z: (B, D, T)
|
297 |
+
original_shape = z.shape
|
298 |
+
if semantic_len is None:
|
299 |
+
semantic_len = torch.LongTensor([z.shape[-1]])
|
300 |
+
z = self.downsample(z)
|
301 |
+
z = self.pre_module(z) # B, T, D
|
302 |
+
(
|
303 |
+
semantic_z,
|
304 |
+
semantic_codes,
|
305 |
+
semantic_latents,
|
306 |
+
semantic_commitment_loss,
|
307 |
+
semantic_codebook_loss,
|
308 |
+
) = self.semantic_quantizer(z)
|
309 |
+
residual_z = z - semantic_z
|
310 |
+
residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
|
311 |
+
residual_z, n_quantizers=n_quantizers
|
312 |
+
)
|
313 |
+
z = semantic_z + residual_z
|
314 |
+
commitment_loss = commitment_loss + semantic_commitment_loss
|
315 |
+
codebook_loss = codebook_loss + semantic_codebook_loss
|
316 |
+
codes = torch.cat([semantic_codes, codes], dim=1)
|
317 |
+
latents = torch.cat([semantic_latents, latents], dim=1)
|
318 |
+
z = self.post_module(z)
|
319 |
+
z = self.upsample(z)
|
320 |
+
# z: (B, D, T)
|
321 |
+
|
322 |
+
# semantic distillation (disabled here since only used in training)
|
323 |
+
# semantic_distill_z = self.semantic_predictor_module(semantic_z, semantic_len).mT # wav2vec target is B, T, D
|
324 |
+
|
325 |
+
# Pad or crop z to match original shape
|
326 |
+
diff = original_shape[-1] - z.shape[-1]
|
327 |
+
right = 0
|
328 |
+
left = abs(diff) - right
|
329 |
+
|
330 |
+
if diff > 0:
|
331 |
+
z = F.pad(z, (left, right))
|
332 |
+
elif diff < 0:
|
333 |
+
z = z[..., left:]
|
334 |
+
|
335 |
+
results = VQResult(
|
336 |
+
z=z,
|
337 |
+
codes=codes,
|
338 |
+
latents=latents,
|
339 |
+
commitment_loss=commitment_loss,
|
340 |
+
codebook_loss=codebook_loss,
|
341 |
+
)
|
342 |
+
|
343 |
+
return results
|
344 |
+
|
345 |
+
# def encode(self, z):
|
346 |
+
# z = self.downsample(z)
|
347 |
+
# z = self.pre_module(z)
|
348 |
+
# _, indices, _, _, _ = self.quantizer(z.mT)
|
349 |
+
# indices = rearrange(indices, "g b l r -> b (g r) l")
|
350 |
+
# return indices
|
351 |
+
#
|
352 |
+
def decode(self, indices: torch.Tensor):
|
353 |
+
# indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
|
354 |
+
|
355 |
+
# print(f"indices: {indices.shape}, semantic_quantizer.codebook_size: {self.semantic_quantizer.codebook_size}, quantizer.codebook_size: {self.quantizer.codebook_size}, semantic min: {indices[:, 0].min()}, max: {indices[:, 0].max()}, quantizer min: {indices[:, 1:].min()}, max: {indices[:, 1:].max()}")
|
356 |
+
|
357 |
+
new_indices = torch.zeros_like(indices)
|
358 |
+
new_indices[:, 0] = torch.clamp(
|
359 |
+
indices[:, 0], max=self.semantic_quantizer.codebook_size - 1
|
360 |
+
)
|
361 |
+
new_indices[:, 1:] = torch.clamp(
|
362 |
+
indices[:, 1:], max=self.quantizer.codebook_size - 1
|
363 |
+
)
|
364 |
+
|
365 |
+
z_q_semantic = self.semantic_quantizer.from_codes(new_indices[:, :1])[0]
|
366 |
+
z_q_residual = self.quantizer.from_codes(new_indices[:, 1:])[0]
|
367 |
+
z_q = z_q_semantic + z_q_residual
|
368 |
+
z_q = self.post_module(z_q)
|
369 |
+
z_q = self.upsample(z_q)
|
370 |
+
return z_q
|
371 |
+
|
372 |
+
# def from_latents(self, latents: torch.Tensor):
|
373 |
+
# z_q, z_p, codes = super().from_latents(latents)
|
374 |
+
# z_q = self.upsample(z_q)
|
375 |
+
# return z_q, z_p, codes
|
376 |
+
|
377 |
+
|
378 |
+
if __name__ == "__main__":
|
379 |
+
rvq = DownsampleResidualVectorQuantize(
|
380 |
+
input_dim=512,
|
381 |
+
n_codebooks=8,
|
382 |
+
codebook_dim=8,
|
383 |
+
codebook_size=1024,
|
384 |
+
quantizer_dropout=0.5,
|
385 |
+
downsample_factor=[2, 2],
|
386 |
+
)
|
387 |
+
rvq.eval()
|
388 |
+
x = torch.randn(2, 512, 442)
|
389 |
+
|
390 |
+
result = rvq(x)
|
391 |
+
print(rvq)
|
392 |
+
print(result.latents.shape, result.codes.shape, result.z.shape)
|
393 |
+
|
394 |
+
# y = rvq.from_codes(result.codes)
|
395 |
+
# print(y[0].shape)
|
396 |
+
|
397 |
+
# y = rvq.from_latents(
|
398 |
+
|
399 |
+
result1 = rvq(x[:, :, :40])
|
400 |
+
print(result1.latents.shape, result1.codes.shape, result1.z.shape)
|
401 |
+
|
402 |
+
assert torch.allclose(result.z[:, :, :40], result1.z, atol=1e-8)
|
403 |
+
print("Success")
|
fish_speech/models/text2semantic/inference.py
ADDED
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import queue
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
from contextlib import nullcontext
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Literal, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import click
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch._dynamo.config
|
14 |
+
import torch._inductor.config
|
15 |
+
from loguru import logger
|
16 |
+
from tqdm import tqdm
|
17 |
+
from transformers import AutoTokenizer
|
18 |
+
|
19 |
+
from fish_speech.content_sequence import (
|
20 |
+
ContentSequence,
|
21 |
+
TextPart,
|
22 |
+
VQPart,
|
23 |
+
)
|
24 |
+
from fish_speech.models.text2semantic.llama import BaseModelArgs
|
25 |
+
from fish_speech.text import clean_text, split_text
|
26 |
+
from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
|
27 |
+
|
28 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
29 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
30 |
+
torch._inductor.config.triton.unique_kernel_names = True
|
31 |
+
|
32 |
+
if hasattr(torch._inductor.config, "fx_graph_cache"):
|
33 |
+
# Experimental feature to reduce compilation times, will be on by default in future
|
34 |
+
torch._inductor.config.fx_graph_cache = True
|
35 |
+
|
36 |
+
|
37 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
38 |
+
|
39 |
+
from fish_speech.models.text2semantic.llama import (
|
40 |
+
BaseTransformer,
|
41 |
+
DualARTransformer,
|
42 |
+
NaiveTransformer,
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def multinomial_sample_one_no_sync(
|
47 |
+
probs_sort,
|
48 |
+
): # Does multinomial sampling without a cuda synchronization
|
49 |
+
q = torch.empty_like(probs_sort).exponential_(1)
|
50 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
51 |
+
|
52 |
+
|
53 |
+
def logits_to_probs(
|
54 |
+
logits,
|
55 |
+
previous_tokens: Optional[torch.Tensor] = None,
|
56 |
+
temperature: torch.Tensor = 1.0,
|
57 |
+
top_p: torch.Tensor = 1.0,
|
58 |
+
repetition_penalty: torch.Tensor = 1.0,
|
59 |
+
) -> torch.Tensor:
|
60 |
+
# Apply repetition penalty
|
61 |
+
if previous_tokens is not None:
|
62 |
+
previous_tokens = previous_tokens.long()
|
63 |
+
score = torch.gather(logits, dim=0, index=previous_tokens)
|
64 |
+
score = torch.where(
|
65 |
+
score < 0, score * repetition_penalty, score / repetition_penalty
|
66 |
+
)
|
67 |
+
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
68 |
+
|
69 |
+
# Apply top-p sampling
|
70 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
71 |
+
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
72 |
+
sorted_indices_to_remove = cum_probs > top_p
|
73 |
+
sorted_indices_to_remove[0] = False # keep at least one option
|
74 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
75 |
+
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
76 |
+
)
|
77 |
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
78 |
+
|
79 |
+
logits = logits / max(temperature, 1e-5)
|
80 |
+
|
81 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
82 |
+
return probs
|
83 |
+
|
84 |
+
|
85 |
+
def sample(
|
86 |
+
logits,
|
87 |
+
previous_tokens: Optional[torch.Tensor] = None,
|
88 |
+
**sampling_kwargs,
|
89 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
90 |
+
probs = logits_to_probs(
|
91 |
+
logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
|
92 |
+
)
|
93 |
+
idx_next = multinomial_sample_one_no_sync(probs)
|
94 |
+
return idx_next, probs
|
95 |
+
|
96 |
+
|
97 |
+
def decode_one_token_ar(
|
98 |
+
model: DualARTransformer,
|
99 |
+
x: torch.Tensor,
|
100 |
+
input_pos: torch.Tensor,
|
101 |
+
semantic_ids: list,
|
102 |
+
previous_tokens: torch.Tensor = None,
|
103 |
+
**sampling_kwargs,
|
104 |
+
) -> torch.Tensor:
|
105 |
+
x = model.forward_generate(x, input_pos)
|
106 |
+
|
107 |
+
sampling_kwargs_main = sampling_kwargs.copy()
|
108 |
+
# sampling_kwargs_main["temperature"] = 0.1
|
109 |
+
# sampling_kwargs_main["top_p"] = 0.1
|
110 |
+
# sampling_kwargs_main["repetition_penalty"] = 1.0
|
111 |
+
|
112 |
+
codebooks = [
|
113 |
+
sample(
|
114 |
+
x.logits,
|
115 |
+
previous_tokens=(
|
116 |
+
previous_tokens[0] if previous_tokens is not None else None
|
117 |
+
), # Disable repetition penalty for the token codebook
|
118 |
+
**sampling_kwargs_main,
|
119 |
+
)[0]
|
120 |
+
]
|
121 |
+
|
122 |
+
hidden_states = x.hidden_states
|
123 |
+
|
124 |
+
# Cleanup the cache
|
125 |
+
for layer in model.fast_layers:
|
126 |
+
layer.attention.kv_cache.k_cache.fill_(0)
|
127 |
+
layer.attention.kv_cache.v_cache.fill_(0)
|
128 |
+
|
129 |
+
input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
|
130 |
+
model.forward_generate_fast(hidden_states, input_pos)
|
131 |
+
a = codebooks[0] - model.tokenizer.semantic_begin_id
|
132 |
+
a[a < 0] = 0
|
133 |
+
hidden_states = model.fast_embeddings(a)
|
134 |
+
codebooks.append(a)
|
135 |
+
|
136 |
+
for codebook_idx in range(1, model.config.num_codebooks):
|
137 |
+
input_pos = torch.tensor(
|
138 |
+
[codebook_idx], device=hidden_states.device, dtype=torch.long
|
139 |
+
)
|
140 |
+
logits = model.forward_generate_fast(hidden_states, input_pos)
|
141 |
+
chunked_logits = logits[..., :1024]
|
142 |
+
a = sample(
|
143 |
+
chunked_logits,
|
144 |
+
previous_tokens=(
|
145 |
+
previous_tokens[codebook_idx + 1]
|
146 |
+
if previous_tokens is not None
|
147 |
+
else None
|
148 |
+
),
|
149 |
+
**sampling_kwargs,
|
150 |
+
)[0]
|
151 |
+
hidden_states = model.fast_embeddings(a)
|
152 |
+
codebooks.append(a)
|
153 |
+
|
154 |
+
codebooks = torch.stack(codebooks, dim=0)
|
155 |
+
# semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
|
156 |
+
# codebooks[1:, :] = torch.masked_fill(
|
157 |
+
# codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
|
158 |
+
# )
|
159 |
+
|
160 |
+
# print(codebooks)
|
161 |
+
return codebooks
|
162 |
+
|
163 |
+
|
164 |
+
def decode_n_tokens(
|
165 |
+
model: NaiveTransformer,
|
166 |
+
cur_token: torch.Tensor,
|
167 |
+
input_pos: torch.Tensor,
|
168 |
+
num_new_tokens: int,
|
169 |
+
semantic_ids: list,
|
170 |
+
decode_one_token=decode_one_token_ar,
|
171 |
+
**sampling_kwargs,
|
172 |
+
):
|
173 |
+
previous_tokens = torch.zeros(
|
174 |
+
(model.config.num_codebooks + 1, model.config.max_seq_len),
|
175 |
+
dtype=torch.int,
|
176 |
+
device=cur_token.device,
|
177 |
+
)
|
178 |
+
|
179 |
+
for i in tqdm(range(num_new_tokens)):
|
180 |
+
# We need to get windowed repeat penalty
|
181 |
+
win_size = 16
|
182 |
+
if i < win_size:
|
183 |
+
window = previous_tokens[:, :win_size]
|
184 |
+
else:
|
185 |
+
window = previous_tokens[:, i - win_size : i]
|
186 |
+
|
187 |
+
with (
|
188 |
+
torch.backends.cuda.sdp_kernel(
|
189 |
+
enable_flash=False, enable_mem_efficient=False, enable_math=True
|
190 |
+
)
|
191 |
+
if torch.cuda.is_available()
|
192 |
+
else nullcontext()
|
193 |
+
): # Actually better for Inductor to codegen attention here
|
194 |
+
next_token = decode_one_token(
|
195 |
+
model=model,
|
196 |
+
x=cur_token,
|
197 |
+
input_pos=input_pos,
|
198 |
+
previous_tokens=window,
|
199 |
+
semantic_ids=semantic_ids,
|
200 |
+
**sampling_kwargs,
|
201 |
+
)
|
202 |
+
|
203 |
+
input_pos += 1
|
204 |
+
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
|
205 |
+
previous_tokens[:, i : i + 1] = next_token.view(
|
206 |
+
model.config.num_codebooks + 1, -1
|
207 |
+
)
|
208 |
+
|
209 |
+
if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
|
210 |
+
break
|
211 |
+
|
212 |
+
return previous_tokens[:, : i + 1]
|
213 |
+
|
214 |
+
|
215 |
+
@torch.no_grad()
|
216 |
+
@torch.inference_mode()
|
217 |
+
def generate(
|
218 |
+
*,
|
219 |
+
model: NaiveTransformer,
|
220 |
+
prompt: torch.Tensor,
|
221 |
+
max_new_tokens: int,
|
222 |
+
decode_one_token=decode_one_token_ar,
|
223 |
+
**sampling_kwargs,
|
224 |
+
) -> torch.Tensor:
|
225 |
+
"""
|
226 |
+
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
227 |
+
"""
|
228 |
+
|
229 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
230 |
+
T = prompt.size(1)
|
231 |
+
# semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
232 |
+
semantic_ids = [
|
233 |
+
model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
|
234 |
+
]
|
235 |
+
|
236 |
+
if max_new_tokens:
|
237 |
+
if T + max_new_tokens > model.config.max_seq_len:
|
238 |
+
max_new_tokens = model.config.max_seq_len - T
|
239 |
+
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
|
240 |
+
|
241 |
+
T_new = T + max_new_tokens
|
242 |
+
else:
|
243 |
+
T_new = model.config.max_seq_len
|
244 |
+
max_new_tokens = T_new - T
|
245 |
+
|
246 |
+
device, dtype = prompt.device, prompt.dtype
|
247 |
+
|
248 |
+
codebook_dim = 1 + model.config.num_codebooks
|
249 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
250 |
+
empty = torch.empty(
|
251 |
+
(codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
|
252 |
+
)
|
253 |
+
empty[:, :T] = prompt
|
254 |
+
seq = empty
|
255 |
+
input_pos = torch.arange(0, T, device=device)
|
256 |
+
|
257 |
+
# Use non-accelerated version for now, to avoid compilation overhead
|
258 |
+
prefill_decode = decode_one_token_ar
|
259 |
+
|
260 |
+
next_token = prefill_decode(
|
261 |
+
model,
|
262 |
+
prompt.view(1, codebook_dim, -1),
|
263 |
+
input_pos,
|
264 |
+
semantic_ids=semantic_ids,
|
265 |
+
**sampling_kwargs,
|
266 |
+
)
|
267 |
+
seq[:, T : T + 1] = next_token
|
268 |
+
|
269 |
+
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
270 |
+
x = decode_n_tokens(
|
271 |
+
model,
|
272 |
+
next_token.view(1, codebook_dim, -1),
|
273 |
+
input_pos,
|
274 |
+
max_new_tokens - 1,
|
275 |
+
decode_one_token=decode_one_token,
|
276 |
+
semantic_ids=semantic_ids,
|
277 |
+
**sampling_kwargs,
|
278 |
+
)
|
279 |
+
# x = torch.cat(generated_tokens, dim=1)
|
280 |
+
seq = seq[:, : T + 1 + x.size(1)]
|
281 |
+
seq[:, T + 1 :] = x
|
282 |
+
|
283 |
+
return seq
|
284 |
+
|
285 |
+
|
286 |
+
def load_model(checkpoint_path, device, precision, compile=False):
|
287 |
+
model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
|
288 |
+
|
289 |
+
model = model.to(device=device, dtype=precision)
|
290 |
+
logger.info(f"Restored model from checkpoint")
|
291 |
+
|
292 |
+
if isinstance(model, DualARTransformer):
|
293 |
+
decode_one_token = decode_one_token_ar
|
294 |
+
logger.info("Using DualARTransformer")
|
295 |
+
else:
|
296 |
+
raise ValueError("Model is not a DualARTransformer")
|
297 |
+
|
298 |
+
if compile:
|
299 |
+
logger.info("Compiling function...")
|
300 |
+
decode_one_token = torch.compile(
|
301 |
+
decode_one_token,
|
302 |
+
fullgraph=True,
|
303 |
+
backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
304 |
+
mode="reduce-overhead" if torch.cuda.is_available() else None,
|
305 |
+
)
|
306 |
+
|
307 |
+
return model.eval(), decode_one_token
|
308 |
+
|
309 |
+
|
310 |
+
@dataclass
|
311 |
+
class GenerateResponse:
|
312 |
+
action: Literal["sample", "next"]
|
313 |
+
codes: Optional[torch.Tensor] = None
|
314 |
+
text: Optional[str] = None
|
315 |
+
|
316 |
+
|
317 |
+
def generate_long(
|
318 |
+
*,
|
319 |
+
model,
|
320 |
+
device: str | torch.device,
|
321 |
+
decode_one_token: callable,
|
322 |
+
text: str,
|
323 |
+
num_samples: int = 1,
|
324 |
+
max_new_tokens: int = 0,
|
325 |
+
top_p: int = 0.8,
|
326 |
+
repetition_penalty: float = 1.1,
|
327 |
+
temperature: float = 0.8,
|
328 |
+
compile: bool = False,
|
329 |
+
iterative_prompt: bool = True,
|
330 |
+
chunk_length: int = 150,
|
331 |
+
prompt_text: Optional[str | list[str]] = None,
|
332 |
+
prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
|
333 |
+
):
|
334 |
+
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
335 |
+
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
336 |
+
assert 0 < temperature < 2, "temperature must be in (0, 2)"
|
337 |
+
|
338 |
+
use_prompt = prompt_text is not None and prompt_tokens is not None
|
339 |
+
if use_prompt and isinstance(prompt_text, str):
|
340 |
+
prompt_text = [prompt_text]
|
341 |
+
prompt_tokens = [prompt_tokens]
|
342 |
+
|
343 |
+
assert use_prompt is False or len(prompt_text) == len(
|
344 |
+
prompt_tokens
|
345 |
+
), "Prompt text and tokens must have the same length"
|
346 |
+
|
347 |
+
prompt_tokens = [i.cpu() for i in prompt_tokens]
|
348 |
+
|
349 |
+
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
350 |
+
tokenizer = model.tokenizer
|
351 |
+
base_content_sequence = ContentSequence(modality="interleave")
|
352 |
+
|
353 |
+
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
354 |
+
max_length = model.config.max_seq_len
|
355 |
+
|
356 |
+
if use_prompt:
|
357 |
+
for t, c in zip(prompt_text, prompt_tokens):
|
358 |
+
base_content_sequence.append(
|
359 |
+
[
|
360 |
+
TextPart(text=t),
|
361 |
+
VQPart(codes=c),
|
362 |
+
],
|
363 |
+
add_end=True,
|
364 |
+
)
|
365 |
+
|
366 |
+
encoded_prompts = base_content_sequence.encode_for_inference(
|
367 |
+
tokenizer, num_codebooks=model.config.num_codebooks
|
368 |
+
)
|
369 |
+
if encoded_prompts.size(1) > max_length - 2048:
|
370 |
+
raise ValueError(
|
371 |
+
f"Prompt is too long: {encoded_prompts.size(1)} > {max_length - 2048}"
|
372 |
+
)
|
373 |
+
|
374 |
+
encoded = []
|
375 |
+
for text in texts:
|
376 |
+
content_sequence = ContentSequence(modality=None)
|
377 |
+
content_sequence.append(TextPart(text=text))
|
378 |
+
encoded.append(
|
379 |
+
content_sequence.encode_for_inference(
|
380 |
+
tokenizer, num_codebooks=model.config.num_codebooks
|
381 |
+
)
|
382 |
+
)
|
383 |
+
logger.info(f"Encoded text: {text}")
|
384 |
+
|
385 |
+
# Move temperature, top_p, repetition_penalty to device
|
386 |
+
# This is important so that changing params doesn't trigger recompile
|
387 |
+
temperature = torch.tensor(temperature, device=device, dtype=torch.float)
|
388 |
+
top_p = torch.tensor(top_p, device=device, dtype=torch.float)
|
389 |
+
repetition_penalty = torch.tensor(
|
390 |
+
repetition_penalty, device=device, dtype=torch.float
|
391 |
+
)
|
392 |
+
|
393 |
+
for sample_idx in range(num_samples):
|
394 |
+
if torch.cuda.is_available():
|
395 |
+
torch.cuda.synchronize()
|
396 |
+
|
397 |
+
global_encoded = []
|
398 |
+
seg_idx = 0
|
399 |
+
|
400 |
+
while seg_idx < len(encoded):
|
401 |
+
logger.info(
|
402 |
+
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
403 |
+
)
|
404 |
+
|
405 |
+
seg = encoded[seg_idx]
|
406 |
+
global_encoded.append(seg)
|
407 |
+
|
408 |
+
# Do not use previous segments to generate current segment for now
|
409 |
+
# lengths = reversed([seg.size(1) for seg in global_encoded])
|
410 |
+
|
411 |
+
# # Pick last 2000 tokens
|
412 |
+
# count = 0
|
413 |
+
# for i, length in enumerate(lengths):
|
414 |
+
# count += length
|
415 |
+
# if count + length > max_length - 2048 - encoded_prompts.size(1):
|
416 |
+
# break
|
417 |
+
|
418 |
+
# if i != 0 and i % 2 == 0:
|
419 |
+
# i -= 1
|
420 |
+
|
421 |
+
# # Rotate the list, always make sure first segment is included to avoid drift
|
422 |
+
# if i < len(global_encoded) - 2:
|
423 |
+
# partial_encoded = global_encoded[:2] + global_encoded[-i:]
|
424 |
+
# else:
|
425 |
+
# partial_encoded = global_encoded
|
426 |
+
|
427 |
+
# cat_encoded = torch.cat([encoded_prompts, *partial_encoded], dim=1)
|
428 |
+
if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
|
429 |
+
cat_encoded = torch.cat(
|
430 |
+
[encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
cat_encoded = torch.cat([encoded_prompts, seg], dim=1)
|
434 |
+
|
435 |
+
cat_encoded = cat_encoded.to(device=device)
|
436 |
+
prompt_length = cat_encoded.size(1)
|
437 |
+
|
438 |
+
t0 = time.perf_counter()
|
439 |
+
y = generate(
|
440 |
+
model=model,
|
441 |
+
prompt=cat_encoded,
|
442 |
+
max_new_tokens=max_new_tokens,
|
443 |
+
decode_one_token=decode_one_token,
|
444 |
+
temperature=temperature,
|
445 |
+
top_p=top_p,
|
446 |
+
repetition_penalty=repetition_penalty,
|
447 |
+
)
|
448 |
+
|
449 |
+
if sample_idx == 0 and seg_idx == 0 and compile:
|
450 |
+
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
451 |
+
|
452 |
+
if torch.cuda.is_available():
|
453 |
+
torch.cuda.synchronize()
|
454 |
+
|
455 |
+
t = time.perf_counter() - t0
|
456 |
+
|
457 |
+
tokens_generated = y.size(1) - prompt_length
|
458 |
+
tokens_sec = tokens_generated / t
|
459 |
+
logger.info(
|
460 |
+
f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
|
461 |
+
)
|
462 |
+
logger.info(
|
463 |
+
f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
|
464 |
+
)
|
465 |
+
|
466 |
+
if torch.cuda.is_available():
|
467 |
+
logger.info(
|
468 |
+
f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
|
469 |
+
)
|
470 |
+
|
471 |
+
# Put the generated tokens
|
472 |
+
# since there is <im_end>, we remove last token
|
473 |
+
codes = y[1:, prompt_length:-1].clone()
|
474 |
+
assert (codes >= 0).all(), f"Negative code found"
|
475 |
+
|
476 |
+
decoded = y[:, prompt_length:].clone()
|
477 |
+
# But for global encoding, we should keep the <im_end> token
|
478 |
+
|
479 |
+
global_encoded.append(decoded.cpu())
|
480 |
+
assert (codes >= 0).all(), f"Negative code found: {codes}"
|
481 |
+
yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
|
482 |
+
seg_idx += 1
|
483 |
+
|
484 |
+
# This indicates the end of the current sample
|
485 |
+
yield GenerateResponse(action="next")
|
486 |
+
|
487 |
+
|
488 |
+
@dataclass
|
489 |
+
class WrappedGenerateResponse:
|
490 |
+
status: Literal["success", "error"]
|
491 |
+
response: Optional[GenerateResponse | Exception] = None
|
492 |
+
|
493 |
+
|
494 |
+
@dataclass
|
495 |
+
class GenerateRequest:
|
496 |
+
request: dict
|
497 |
+
response_queue: queue.Queue
|
498 |
+
|
499 |
+
|
500 |
+
def launch_thread_safe_queue(
|
501 |
+
checkpoint_path,
|
502 |
+
device,
|
503 |
+
precision,
|
504 |
+
compile: bool = False,
|
505 |
+
):
|
506 |
+
input_queue = queue.Queue()
|
507 |
+
init_event = threading.Event()
|
508 |
+
|
509 |
+
def worker():
|
510 |
+
model, decode_one_token = load_model(
|
511 |
+
checkpoint_path, device, precision, compile=compile
|
512 |
+
)
|
513 |
+
with torch.device(device):
|
514 |
+
model.setup_caches(
|
515 |
+
max_batch_size=1,
|
516 |
+
max_seq_len=model.config.max_seq_len,
|
517 |
+
dtype=next(model.parameters()).dtype,
|
518 |
+
)
|
519 |
+
init_event.set()
|
520 |
+
|
521 |
+
while True:
|
522 |
+
item: GenerateRequest | None = input_queue.get()
|
523 |
+
if item is None:
|
524 |
+
break
|
525 |
+
|
526 |
+
kwargs = item.request
|
527 |
+
response_queue = item.response_queue
|
528 |
+
|
529 |
+
try:
|
530 |
+
for chunk in generate_long(
|
531 |
+
model=model, decode_one_token=decode_one_token, **kwargs
|
532 |
+
):
|
533 |
+
response_queue.put(
|
534 |
+
WrappedGenerateResponse(status="success", response=chunk)
|
535 |
+
)
|
536 |
+
except Exception as e:
|
537 |
+
response_queue.put(WrappedGenerateResponse(status="error", response=e))
|
538 |
+
|
539 |
+
threading.Thread(target=worker, daemon=True).start()
|
540 |
+
init_event.wait()
|
541 |
+
|
542 |
+
return input_queue
|
543 |
+
|
544 |
+
|
545 |
+
def launch_thread_safe_queue_agent(
|
546 |
+
checkpoint_path,
|
547 |
+
device,
|
548 |
+
precision,
|
549 |
+
compile: bool = False,
|
550 |
+
):
|
551 |
+
input_queue = queue.Queue()
|
552 |
+
init_event = threading.Event()
|
553 |
+
|
554 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
555 |
+
config = BaseModelArgs.from_pretrained(checkpoint_path)
|
556 |
+
|
557 |
+
def worker():
|
558 |
+
model, decode_one_token = load_model(
|
559 |
+
checkpoint_path, device, precision, compile=compile, is_agent=True
|
560 |
+
)
|
561 |
+
|
562 |
+
with torch.device(device):
|
563 |
+
model.setup_caches(
|
564 |
+
max_batch_size=1,
|
565 |
+
max_seq_len=model.config.max_seq_len,
|
566 |
+
dtype=next(model.parameters()).dtype,
|
567 |
+
)
|
568 |
+
init_event.set()
|
569 |
+
|
570 |
+
while True:
|
571 |
+
item: GenerateRequest | None = input_queue.get()
|
572 |
+
if item is None:
|
573 |
+
break
|
574 |
+
|
575 |
+
kwargs = item.request
|
576 |
+
response_queue = item.response_queue
|
577 |
+
|
578 |
+
try:
|
579 |
+
for token in generate_agent(
|
580 |
+
model=model,
|
581 |
+
decode_one_token=decode_one_token,
|
582 |
+
**kwargs,
|
583 |
+
):
|
584 |
+
response_queue.put(token)
|
585 |
+
|
586 |
+
response_queue.put("stop")
|
587 |
+
except Exception as e:
|
588 |
+
import traceback
|
589 |
+
|
590 |
+
logger.exception(f"Error in worker: {traceback.format_exc()}")
|
591 |
+
response_queue.put("error")
|
592 |
+
|
593 |
+
threading.Thread(target=worker, daemon=True).start()
|
594 |
+
init_event.wait()
|
595 |
+
|
596 |
+
return input_queue, tokenizer, config
|
597 |
+
|
598 |
+
|
599 |
+
@click.command()
|
600 |
+
@click.option(
|
601 |
+
"--text",
|
602 |
+
type=str,
|
603 |
+
default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
604 |
+
)
|
605 |
+
@click.option("--prompt-text", type=str, default=None, multiple=True)
|
606 |
+
@click.option(
|
607 |
+
"--prompt-tokens",
|
608 |
+
type=click.Path(path_type=Path, exists=True),
|
609 |
+
default=None,
|
610 |
+
multiple=True,
|
611 |
+
)
|
612 |
+
@click.option("--num-samples", type=int, default=1)
|
613 |
+
@click.option("--max-new-tokens", type=int, default=0)
|
614 |
+
@click.option("--top-p", type=float, default=0.8)
|
615 |
+
@click.option("--repetition-penalty", type=float, default=1.1)
|
616 |
+
@click.option("--temperature", type=float, default=0.8)
|
617 |
+
@click.option(
|
618 |
+
"--checkpoint-path",
|
619 |
+
type=click.Path(path_type=Path, exists=True),
|
620 |
+
default="checkpoints/openaudio-s1-mini",
|
621 |
+
)
|
622 |
+
@click.option("--device", type=str, default="cuda")
|
623 |
+
@click.option("--compile/--no-compile", default=False)
|
624 |
+
@click.option("--seed", type=int, default=42)
|
625 |
+
@click.option("--half/--no-half", default=False)
|
626 |
+
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
|
627 |
+
@click.option("--chunk-length", type=int, default=300)
|
628 |
+
@click.option("--output-dir", type=Path, default="temp")
|
629 |
+
def main(
|
630 |
+
text: str,
|
631 |
+
prompt_text: Optional[list[str]],
|
632 |
+
prompt_tokens: Optional[list[Path]],
|
633 |
+
num_samples: int,
|
634 |
+
max_new_tokens: int,
|
635 |
+
top_p: int,
|
636 |
+
repetition_penalty: float,
|
637 |
+
temperature: float,
|
638 |
+
checkpoint_path: Path,
|
639 |
+
device: str,
|
640 |
+
compile: bool,
|
641 |
+
seed: int,
|
642 |
+
half: bool,
|
643 |
+
iterative_prompt: bool,
|
644 |
+
chunk_length: int,
|
645 |
+
output_dir: Path,
|
646 |
+
) -> None:
|
647 |
+
os.makedirs(output_dir, exist_ok=True)
|
648 |
+
precision = torch.half if half else torch.bfloat16
|
649 |
+
|
650 |
+
if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
|
651 |
+
raise ValueError(
|
652 |
+
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
|
653 |
+
)
|
654 |
+
|
655 |
+
logger.info("Loading model ...")
|
656 |
+
t0 = time.time()
|
657 |
+
model, decode_one_token = load_model(
|
658 |
+
checkpoint_path, device, precision, compile=compile
|
659 |
+
)
|
660 |
+
with torch.device(device):
|
661 |
+
model.setup_caches(
|
662 |
+
max_batch_size=1,
|
663 |
+
max_seq_len=model.config.max_seq_len,
|
664 |
+
dtype=next(model.parameters()).dtype,
|
665 |
+
)
|
666 |
+
if torch.cuda.is_available():
|
667 |
+
torch.cuda.synchronize()
|
668 |
+
|
669 |
+
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
670 |
+
|
671 |
+
if prompt_tokens is not None:
|
672 |
+
prompt_tokens = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
|
673 |
+
|
674 |
+
torch.manual_seed(seed)
|
675 |
+
|
676 |
+
if torch.cuda.is_available():
|
677 |
+
torch.cuda.manual_seed(seed)
|
678 |
+
|
679 |
+
generator = generate_long(
|
680 |
+
model=model,
|
681 |
+
device=device,
|
682 |
+
decode_one_token=decode_one_token,
|
683 |
+
text=text,
|
684 |
+
num_samples=num_samples,
|
685 |
+
max_new_tokens=max_new_tokens,
|
686 |
+
top_p=top_p,
|
687 |
+
repetition_penalty=repetition_penalty,
|
688 |
+
temperature=temperature,
|
689 |
+
compile=compile,
|
690 |
+
iterative_prompt=iterative_prompt,
|
691 |
+
chunk_length=chunk_length,
|
692 |
+
prompt_text=prompt_text,
|
693 |
+
prompt_tokens=prompt_tokens,
|
694 |
+
)
|
695 |
+
|
696 |
+
idx = 0
|
697 |
+
codes = []
|
698 |
+
|
699 |
+
for response in generator:
|
700 |
+
if response.action == "sample":
|
701 |
+
codes.append(response.codes)
|
702 |
+
logger.info(f"Sampled text: {response.text}")
|
703 |
+
elif response.action == "next":
|
704 |
+
if codes:
|
705 |
+
codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
|
706 |
+
np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
|
707 |
+
logger.info(f"Saved codes to {codes_npy_path}")
|
708 |
+
logger.info(f"Next sample")
|
709 |
+
codes = []
|
710 |
+
idx += 1
|
711 |
+
else:
|
712 |
+
logger.error(f"Error: {response}")
|
713 |
+
|
714 |
+
|
715 |
+
if __name__ == "__main__":
|
716 |
+
main()
|
fish_speech/models/text2semantic/lit_module.py
CHANGED
@@ -1,202 +1,202 @@
|
|
1 |
-
from typing import Any, Optional
|
2 |
-
|
3 |
-
import lightning as L
|
4 |
-
import torch
|
5 |
-
import torch.nn.functional as F
|
6 |
-
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
7 |
-
|
8 |
-
import fish_speech.utils as utils
|
9 |
-
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
10 |
-
from fish_speech.models.text2semantic.llama import NaiveTransformer
|
11 |
-
|
12 |
-
log = utils.RankedLogger(__name__, rank_zero_only=True)
|
13 |
-
|
14 |
-
|
15 |
-
class TextToSemantic(L.LightningModule):
|
16 |
-
def __init__(
|
17 |
-
self,
|
18 |
-
model: NaiveTransformer,
|
19 |
-
optimizer: Any,
|
20 |
-
lr_scheduler: Any,
|
21 |
-
):
|
22 |
-
super().__init__()
|
23 |
-
|
24 |
-
self.model = model
|
25 |
-
self.optimizer_builder = optimizer
|
26 |
-
self.lr_scheduler_builder = lr_scheduler
|
27 |
-
|
28 |
-
def forward(self, x):
|
29 |
-
return self.model(x)
|
30 |
-
|
31 |
-
def on_save_checkpoint(self, checkpoint):
|
32 |
-
# Save only LoRA parameters
|
33 |
-
state_dict = checkpoint["state_dict"]
|
34 |
-
use_lora = any("lora" in name for name in state_dict.keys())
|
35 |
-
if not use_lora:
|
36 |
-
return
|
37 |
-
|
38 |
-
for name in list(state_dict.keys()):
|
39 |
-
if "lora" not in name:
|
40 |
-
state_dict.pop(name)
|
41 |
-
|
42 |
-
def configure_optimizers(self) -> OptimizerLRScheduler:
|
43 |
-
# Get weight decay parameters
|
44 |
-
weight_decay_parameters, other_parameters = [], []
|
45 |
-
for name, param in self.named_parameters():
|
46 |
-
if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
|
47 |
-
other_parameters.append(param)
|
48 |
-
else:
|
49 |
-
weight_decay_parameters.append(param)
|
50 |
-
|
51 |
-
optimizer = self.optimizer_builder(
|
52 |
-
[
|
53 |
-
{"params": weight_decay_parameters},
|
54 |
-
{"params": other_parameters, "weight_decay": 0.0},
|
55 |
-
]
|
56 |
-
)
|
57 |
-
|
58 |
-
# Print the parameters and their weight decay
|
59 |
-
for i in optimizer.param_groups:
|
60 |
-
log.info(
|
61 |
-
f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
|
62 |
-
)
|
63 |
-
|
64 |
-
lr_scheduler = self.lr_scheduler_builder(optimizer)
|
65 |
-
|
66 |
-
return {
|
67 |
-
"optimizer": optimizer,
|
68 |
-
"lr_scheduler": {
|
69 |
-
"scheduler": lr_scheduler,
|
70 |
-
"interval": "step",
|
71 |
-
},
|
72 |
-
}
|
73 |
-
|
74 |
-
# Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
|
75 |
-
def get_batch_logps(
|
76 |
-
self,
|
77 |
-
logits: torch.FloatTensor,
|
78 |
-
labels: torch.LongTensor,
|
79 |
-
average_log_prob: bool = False,
|
80 |
-
) -> torch.FloatTensor:
|
81 |
-
"""Compute the log probabilities of the given labels under the given logits.
|
82 |
-
|
83 |
-
Args:
|
84 |
-
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
|
85 |
-
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
|
86 |
-
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
87 |
-
|
88 |
-
Returns:
|
89 |
-
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
90 |
-
"""
|
91 |
-
assert logits.shape[:-1] == labels.shape
|
92 |
-
|
93 |
-
labels = labels.clone()
|
94 |
-
loss_mask = labels != -100
|
95 |
-
|
96 |
-
# dummy token; we'll ignore the losses on these tokens later
|
97 |
-
labels[labels == -100] = 0
|
98 |
-
|
99 |
-
per_token_logps = torch.gather(
|
100 |
-
logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
|
101 |
-
).squeeze(-1)
|
102 |
-
|
103 |
-
if average_log_prob:
|
104 |
-
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
105 |
-
else:
|
106 |
-
return (per_token_logps * loss_mask).sum(-1)
|
107 |
-
|
108 |
-
def _step(self, batch, batch_idx, stage: str):
|
109 |
-
is_train = stage == "train"
|
110 |
-
|
111 |
-
if is_train:
|
112 |
-
# Key part to make lora work
|
113 |
-
# Otherwise the parameters are merged, which lead to incorrect gradients
|
114 |
-
self.model.train()
|
115 |
-
|
116 |
-
# Do positive and negative samples in the same batch to speed up training
|
117 |
-
labels = batch["labels"]
|
118 |
-
outputs = self.model(
|
119 |
-
inp=batch["inputs"],
|
120 |
-
key_padding_mask=batch["attention_masks"],
|
121 |
-
)
|
122 |
-
token_logits = outputs.token_logits
|
123 |
-
codebook_logits = outputs.codebook_logits
|
124 |
-
|
125 |
-
# Generate labels
|
126 |
-
base_loss = F.cross_entropy(
|
127 |
-
token_logits.view(-1, token_logits.size(-1)),
|
128 |
-
labels[:, 0].reshape(-1),
|
129 |
-
ignore_index=-100,
|
130 |
-
)
|
131 |
-
|
132 |
-
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
|
133 |
-
semantic_loss = F.cross_entropy(
|
134 |
-
codebook_logits.view(-1, codebook_logits.size(-1)),
|
135 |
-
codebook_labels.reshape(-1),
|
136 |
-
ignore_index=-100,
|
137 |
-
)
|
138 |
-
|
139 |
-
loss = base_loss + semantic_loss
|
140 |
-
|
141 |
-
self.log(
|
142 |
-
f"{stage}/loss",
|
143 |
-
loss,
|
144 |
-
on_step=is_train,
|
145 |
-
on_epoch=not is_train,
|
146 |
-
prog_bar=True,
|
147 |
-
logger=True,
|
148 |
-
sync_dist=not is_train,
|
149 |
-
)
|
150 |
-
|
151 |
-
self.log(
|
152 |
-
f"{stage}/base_loss",
|
153 |
-
base_loss,
|
154 |
-
on_step=is_train,
|
155 |
-
on_epoch=not is_train,
|
156 |
-
prog_bar=False,
|
157 |
-
logger=True,
|
158 |
-
sync_dist=not is_train,
|
159 |
-
)
|
160 |
-
|
161 |
-
self.log(
|
162 |
-
f"{stage}/semantic_loss",
|
163 |
-
semantic_loss,
|
164 |
-
on_step=is_train,
|
165 |
-
on_epoch=not is_train,
|
166 |
-
prog_bar=False,
|
167 |
-
logger=True,
|
168 |
-
sync_dist=not is_train,
|
169 |
-
)
|
170 |
-
|
171 |
-
# Top-5 accuracy
|
172 |
-
accuracy = self.get_accuracy(codebook_logits, codebook_labels)
|
173 |
-
self.log(
|
174 |
-
f"{stage}/top_5_accuracy",
|
175 |
-
accuracy,
|
176 |
-
on_step=is_train,
|
177 |
-
on_epoch=not is_train,
|
178 |
-
prog_bar=True,
|
179 |
-
logger=True,
|
180 |
-
sync_dist=not is_train,
|
181 |
-
)
|
182 |
-
|
183 |
-
return loss
|
184 |
-
|
185 |
-
def get_accuracy(self, logits, labels):
|
186 |
-
mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
|
187 |
-
if mask.sum() == 0:
|
188 |
-
return torch.tensor(0.0, device=logits.device)
|
189 |
-
|
190 |
-
_, indices = logits.topk(5, dim=-1)
|
191 |
-
correct = indices.eq(labels.unsqueeze(-1))
|
192 |
-
correct[~mask] = 0
|
193 |
-
correct = correct.sum()
|
194 |
-
accuracy = correct / mask.sum()
|
195 |
-
|
196 |
-
return accuracy
|
197 |
-
|
198 |
-
def training_step(self, batch, batch_idx):
|
199 |
-
return self._step(batch, batch_idx, "train")
|
200 |
-
|
201 |
-
def validation_step(self, batch, batch_idx):
|
202 |
-
return self._step(batch, batch_idx, "val")
|
|
|
1 |
+
from typing import Any, Optional
|
2 |
+
|
3 |
+
import lightning as L
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
7 |
+
|
8 |
+
import fish_speech.utils as utils
|
9 |
+
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
10 |
+
from fish_speech.models.text2semantic.llama import NaiveTransformer
|
11 |
+
|
12 |
+
log = utils.RankedLogger(__name__, rank_zero_only=True)
|
13 |
+
|
14 |
+
|
15 |
+
class TextToSemantic(L.LightningModule):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
model: NaiveTransformer,
|
19 |
+
optimizer: Any,
|
20 |
+
lr_scheduler: Any,
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.model = model
|
25 |
+
self.optimizer_builder = optimizer
|
26 |
+
self.lr_scheduler_builder = lr_scheduler
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return self.model(x)
|
30 |
+
|
31 |
+
def on_save_checkpoint(self, checkpoint):
|
32 |
+
# Save only LoRA parameters
|
33 |
+
state_dict = checkpoint["state_dict"]
|
34 |
+
use_lora = any("lora" in name for name in state_dict.keys())
|
35 |
+
if not use_lora:
|
36 |
+
return
|
37 |
+
|
38 |
+
for name in list(state_dict.keys()):
|
39 |
+
if "lora" not in name:
|
40 |
+
state_dict.pop(name)
|
41 |
+
|
42 |
+
def configure_optimizers(self) -> OptimizerLRScheduler:
|
43 |
+
# Get weight decay parameters
|
44 |
+
weight_decay_parameters, other_parameters = [], []
|
45 |
+
for name, param in self.named_parameters():
|
46 |
+
if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
|
47 |
+
other_parameters.append(param)
|
48 |
+
else:
|
49 |
+
weight_decay_parameters.append(param)
|
50 |
+
|
51 |
+
optimizer = self.optimizer_builder(
|
52 |
+
[
|
53 |
+
{"params": weight_decay_parameters},
|
54 |
+
{"params": other_parameters, "weight_decay": 0.0},
|
55 |
+
]
|
56 |
+
)
|
57 |
+
|
58 |
+
# Print the parameters and their weight decay
|
59 |
+
for i in optimizer.param_groups:
|
60 |
+
log.info(
|
61 |
+
f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
|
62 |
+
)
|
63 |
+
|
64 |
+
lr_scheduler = self.lr_scheduler_builder(optimizer)
|
65 |
+
|
66 |
+
return {
|
67 |
+
"optimizer": optimizer,
|
68 |
+
"lr_scheduler": {
|
69 |
+
"scheduler": lr_scheduler,
|
70 |
+
"interval": "step",
|
71 |
+
},
|
72 |
+
}
|
73 |
+
|
74 |
+
# Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
|
75 |
+
def get_batch_logps(
|
76 |
+
self,
|
77 |
+
logits: torch.FloatTensor,
|
78 |
+
labels: torch.LongTensor,
|
79 |
+
average_log_prob: bool = False,
|
80 |
+
) -> torch.FloatTensor:
|
81 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
|
85 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
|
86 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
90 |
+
"""
|
91 |
+
assert logits.shape[:-1] == labels.shape
|
92 |
+
|
93 |
+
labels = labels.clone()
|
94 |
+
loss_mask = labels != -100
|
95 |
+
|
96 |
+
# dummy token; we'll ignore the losses on these tokens later
|
97 |
+
labels[labels == -100] = 0
|
98 |
+
|
99 |
+
per_token_logps = torch.gather(
|
100 |
+
logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
|
101 |
+
).squeeze(-1)
|
102 |
+
|
103 |
+
if average_log_prob:
|
104 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
105 |
+
else:
|
106 |
+
return (per_token_logps * loss_mask).sum(-1)
|
107 |
+
|
108 |
+
def _step(self, batch, batch_idx, stage: str):
|
109 |
+
is_train = stage == "train"
|
110 |
+
|
111 |
+
if is_train:
|
112 |
+
# Key part to make lora work
|
113 |
+
# Otherwise the parameters are merged, which lead to incorrect gradients
|
114 |
+
self.model.train()
|
115 |
+
|
116 |
+
# Do positive and negative samples in the same batch to speed up training
|
117 |
+
labels = batch["labels"]
|
118 |
+
outputs = self.model(
|
119 |
+
inp=batch["inputs"],
|
120 |
+
key_padding_mask=batch["attention_masks"],
|
121 |
+
)
|
122 |
+
token_logits = outputs.token_logits
|
123 |
+
codebook_logits = outputs.codebook_logits
|
124 |
+
|
125 |
+
# Generate labels
|
126 |
+
base_loss = F.cross_entropy(
|
127 |
+
token_logits.view(-1, token_logits.size(-1)),
|
128 |
+
labels[:, 0].reshape(-1),
|
129 |
+
ignore_index=-100,
|
130 |
+
)
|
131 |
+
|
132 |
+
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
|
133 |
+
semantic_loss = F.cross_entropy(
|
134 |
+
codebook_logits.view(-1, codebook_logits.size(-1)),
|
135 |
+
codebook_labels.reshape(-1),
|
136 |
+
ignore_index=-100,
|
137 |
+
)
|
138 |
+
|
139 |
+
loss = base_loss + semantic_loss
|
140 |
+
|
141 |
+
self.log(
|
142 |
+
f"{stage}/loss",
|
143 |
+
loss,
|
144 |
+
on_step=is_train,
|
145 |
+
on_epoch=not is_train,
|
146 |
+
prog_bar=True,
|
147 |
+
logger=True,
|
148 |
+
sync_dist=not is_train,
|
149 |
+
)
|
150 |
+
|
151 |
+
self.log(
|
152 |
+
f"{stage}/base_loss",
|
153 |
+
base_loss,
|
154 |
+
on_step=is_train,
|
155 |
+
on_epoch=not is_train,
|
156 |
+
prog_bar=False,
|
157 |
+
logger=True,
|
158 |
+
sync_dist=not is_train,
|
159 |
+
)
|
160 |
+
|
161 |
+
self.log(
|
162 |
+
f"{stage}/semantic_loss",
|
163 |
+
semantic_loss,
|
164 |
+
on_step=is_train,
|
165 |
+
on_epoch=not is_train,
|
166 |
+
prog_bar=False,
|
167 |
+
logger=True,
|
168 |
+
sync_dist=not is_train,
|
169 |
+
)
|
170 |
+
|
171 |
+
# Top-5 accuracy
|
172 |
+
accuracy = self.get_accuracy(codebook_logits, codebook_labels)
|
173 |
+
self.log(
|
174 |
+
f"{stage}/top_5_accuracy",
|
175 |
+
accuracy,
|
176 |
+
on_step=is_train,
|
177 |
+
on_epoch=not is_train,
|
178 |
+
prog_bar=True,
|
179 |
+
logger=True,
|
180 |
+
sync_dist=not is_train,
|
181 |
+
)
|
182 |
+
|
183 |
+
return loss
|
184 |
+
|
185 |
+
def get_accuracy(self, logits, labels):
|
186 |
+
mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
|
187 |
+
if mask.sum() == 0:
|
188 |
+
return torch.tensor(0.0, device=logits.device)
|
189 |
+
|
190 |
+
_, indices = logits.topk(5, dim=-1)
|
191 |
+
correct = indices.eq(labels.unsqueeze(-1))
|
192 |
+
correct[~mask] = 0
|
193 |
+
correct = correct.sum()
|
194 |
+
accuracy = correct / mask.sum()
|
195 |
+
|
196 |
+
return accuracy
|
197 |
+
|
198 |
+
def training_step(self, batch, batch_idx):
|
199 |
+
return self._step(batch, batch_idx, "train")
|
200 |
+
|
201 |
+
def validation_step(self, batch, batch_idx):
|
202 |
+
return self._step(batch, batch_idx, "val")
|
fish_speech/models/text2semantic/llama.py
CHANGED
@@ -1,887 +1,903 @@
|
|
1 |
-
import dataclasses
|
2 |
-
import json
|
3 |
-
import math
|
4 |
-
from collections import OrderedDict
|
5 |
-
from dataclasses import dataclass
|
6 |
-
from pathlib import Path
|
7 |
-
from typing import Optional
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from einops import rearrange
|
12 |
-
from loguru import logger
|
13 |
-
from torch import Tensor
|
14 |
-
from torch.nn import functional as F
|
15 |
-
from torch.nn.attention import SDPBackend, sdpa_kernel
|
16 |
-
from torch.utils.checkpoint import checkpoint
|
17 |
-
from transformers import AutoTokenizer
|
18 |
-
|
19 |
-
from fish_speech.
|
20 |
-
from fish_speech.
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
path
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
case "
|
87 |
-
cls =
|
88 |
-
case
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
def __post_init__(self):
|
117 |
-
super().__post_init__()
|
118 |
-
|
119 |
-
self.fast_dim = self.fast_dim or self.dim
|
120 |
-
self.fast_n_head = self.fast_n_head or self.n_head
|
121 |
-
self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
|
122 |
-
self.fast_head_dim = self.fast_head_dim or self.head_dim
|
123 |
-
self.fast_intermediate_size = (
|
124 |
-
self.fast_intermediate_size or self.intermediate_size
|
125 |
-
)
|
126 |
-
self.fast_attention_qkv_bias = (
|
127 |
-
self.fast_attention_qkv_bias
|
128 |
-
if self.fast_attention_qkv_bias is not None
|
129 |
-
else self.attention_qkv_bias
|
130 |
-
)
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
self.
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
self.
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
config.
|
205 |
-
config.
|
206 |
-
|
207 |
-
)
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
self
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
weights
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
)
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
) -> TransformerForwardResult:
|
527 |
-
result = super().
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
#
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
self.
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
codebook_logits
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
codebook_logits,
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
def
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
config.
|
732 |
-
|
733 |
-
self.
|
734 |
-
self.
|
735 |
-
|
736 |
-
self.
|
737 |
-
self.
|
738 |
-
self.
|
739 |
-
self.
|
740 |
-
self.
|
741 |
-
|
742 |
-
self._register_load_state_dict_pre_hook(self.load_hook)
|
743 |
-
|
744 |
-
def load_hook(self, state_dict, prefix, *args):
|
745 |
-
if prefix + "wq.weight" in state_dict:
|
746 |
-
wq = state_dict.pop(prefix + "wq.weight")
|
747 |
-
wk = state_dict.pop(prefix + "wk.weight")
|
748 |
-
wv = state_dict.pop(prefix + "wv.weight")
|
749 |
-
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
750 |
-
|
751 |
-
def forward(
|
752 |
-
self,
|
753 |
-
x: Tensor,
|
754 |
-
freqs_cis: Tensor,
|
755 |
-
mask: Tensor,
|
756 |
-
input_pos: Optional[Tensor] = None,
|
757 |
-
) -> Tensor:
|
758 |
-
bsz, seqlen, _ = x.shape
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
k =
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
def
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
def
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
def
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
]
|
883 |
-
|
884 |
-
)
|
885 |
-
|
886 |
-
|
887 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
from collections import OrderedDict
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange
|
12 |
+
from loguru import logger
|
13 |
+
from torch import Tensor
|
14 |
+
from torch.nn import functional as F
|
15 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
16 |
+
from torch.utils.checkpoint import checkpoint
|
17 |
+
from transformers import AutoTokenizer
|
18 |
+
|
19 |
+
from fish_speech.models.text2semantic.lora import LoraConfig, setup_lora
|
20 |
+
from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
|
21 |
+
|
22 |
+
|
23 |
+
def find_multiple(n: int, k: int) -> int:
|
24 |
+
if n % k == 0:
|
25 |
+
return n
|
26 |
+
return n + k - (n % k)
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class BaseModelArgs:
|
31 |
+
model_type: str = "base"
|
32 |
+
|
33 |
+
vocab_size: int = 32000
|
34 |
+
n_layer: int = 32
|
35 |
+
n_head: int = 32
|
36 |
+
dim: int = 4096
|
37 |
+
intermediate_size: int = None
|
38 |
+
n_local_heads: int = -1
|
39 |
+
head_dim: int = 64
|
40 |
+
rope_base: float = 10000
|
41 |
+
norm_eps: float = 1e-5
|
42 |
+
max_seq_len: int = 2048
|
43 |
+
dropout: float = 0.0
|
44 |
+
tie_word_embeddings: bool = True
|
45 |
+
attention_qkv_bias: bool = False
|
46 |
+
attention_o_bias: bool = False
|
47 |
+
attention_qk_norm: bool = False
|
48 |
+
|
49 |
+
# Codebook configs
|
50 |
+
codebook_size: int = 160
|
51 |
+
num_codebooks: int = 4
|
52 |
+
|
53 |
+
# Gradient checkpointing
|
54 |
+
use_gradient_checkpointing: bool = True
|
55 |
+
|
56 |
+
# Initialize the model
|
57 |
+
initializer_range: float = 0.02
|
58 |
+
|
59 |
+
# Dummy vars
|
60 |
+
is_reward_model: bool = False
|
61 |
+
scale_codebook_embeddings: bool = False
|
62 |
+
|
63 |
+
def __post_init__(self):
|
64 |
+
if self.n_local_heads == -1:
|
65 |
+
self.n_local_heads = self.n_head
|
66 |
+
if self.intermediate_size is None:
|
67 |
+
hidden_dim = 4 * self.dim
|
68 |
+
n_hidden = int(2 * hidden_dim / 3)
|
69 |
+
self.intermediate_size = find_multiple(n_hidden, 256)
|
70 |
+
if self.head_dim is None:
|
71 |
+
self.head_dim = self.dim // self.n_head
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def from_pretrained(path: str):
|
75 |
+
path = Path(path)
|
76 |
+
|
77 |
+
if path.is_dir():
|
78 |
+
path = path / "config.json"
|
79 |
+
|
80 |
+
with open(path, "r", encoding="utf-8") as f:
|
81 |
+
data = json.load(f)
|
82 |
+
|
83 |
+
match data["model_type"]:
|
84 |
+
case "naive":
|
85 |
+
cls = NaiveModelArgs
|
86 |
+
case "dual_ar":
|
87 |
+
cls = DualARModelArgs
|
88 |
+
case _:
|
89 |
+
raise ValueError(f"Unknown model type: {data['model_type']}")
|
90 |
+
|
91 |
+
return cls(**data)
|
92 |
+
|
93 |
+
def save(self, path: str):
|
94 |
+
with open(path, "w") as f:
|
95 |
+
json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
|
96 |
+
|
97 |
+
|
98 |
+
@dataclass
|
99 |
+
class NaiveModelArgs(BaseModelArgs):
|
100 |
+
model_type: str = "naive"
|
101 |
+
|
102 |
+
|
103 |
+
@dataclass
|
104 |
+
class DualARModelArgs(BaseModelArgs):
|
105 |
+
model_type: str = "dual_ar"
|
106 |
+
n_fast_layer: int = 4
|
107 |
+
fast_dim: int | None = None
|
108 |
+
fast_n_head: int | None = None
|
109 |
+
fast_n_local_heads: int | None = None
|
110 |
+
fast_head_dim: int | None = None
|
111 |
+
fast_intermediate_size: int | None = None
|
112 |
+
fast_attention_qkv_bias: bool | None = None
|
113 |
+
fast_attention_qk_norm: bool | None = None
|
114 |
+
fast_attention_o_bias: bool | None = None
|
115 |
+
|
116 |
+
def __post_init__(self):
|
117 |
+
super().__post_init__()
|
118 |
+
|
119 |
+
self.fast_dim = self.fast_dim or self.dim
|
120 |
+
self.fast_n_head = self.fast_n_head or self.n_head
|
121 |
+
self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
|
122 |
+
self.fast_head_dim = self.fast_head_dim or self.head_dim
|
123 |
+
self.fast_intermediate_size = (
|
124 |
+
self.fast_intermediate_size or self.intermediate_size
|
125 |
+
)
|
126 |
+
self.fast_attention_qkv_bias = (
|
127 |
+
self.fast_attention_qkv_bias
|
128 |
+
if self.fast_attention_qkv_bias is not None
|
129 |
+
else self.attention_qkv_bias
|
130 |
+
)
|
131 |
+
self.fast_attention_qk_norm = (
|
132 |
+
self.fast_attention_qk_norm
|
133 |
+
if self.fast_attention_qk_norm is not None
|
134 |
+
else self.attention_qk_norm
|
135 |
+
)
|
136 |
+
self.fast_attention_o_bias = (
|
137 |
+
self.fast_attention_o_bias
|
138 |
+
if self.fast_attention_o_bias is not None
|
139 |
+
else self.attention_o_bias
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
class KVCache(nn.Module):
|
144 |
+
def __init__(
|
145 |
+
self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
|
146 |
+
):
|
147 |
+
super().__init__()
|
148 |
+
cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
|
149 |
+
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
|
150 |
+
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
|
151 |
+
|
152 |
+
def update(self, input_pos, k_val, v_val):
|
153 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
154 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
155 |
+
|
156 |
+
k_out = self.k_cache
|
157 |
+
v_out = self.v_cache
|
158 |
+
k_out[:, :, input_pos] = k_val
|
159 |
+
v_out[:, :, input_pos] = v_val
|
160 |
+
|
161 |
+
return k_out, v_out
|
162 |
+
|
163 |
+
|
164 |
+
@dataclass
|
165 |
+
class TransformerForwardResult:
|
166 |
+
token_logits: Tensor
|
167 |
+
codebook_logits: Tensor
|
168 |
+
|
169 |
+
|
170 |
+
@dataclass
|
171 |
+
class BaseTransformerForwardResult:
|
172 |
+
logits: Tensor
|
173 |
+
hidden_states: Tensor
|
174 |
+
|
175 |
+
|
176 |
+
class BaseTransformer(nn.Module):
|
177 |
+
def __init__(
|
178 |
+
self,
|
179 |
+
config: BaseModelArgs,
|
180 |
+
tokenizer: FishTokenizer,
|
181 |
+
init_weights: bool = True,
|
182 |
+
) -> None:
|
183 |
+
super().__init__()
|
184 |
+
self.config = config
|
185 |
+
self.tokenizer = tokenizer
|
186 |
+
self.semantic_token_ids = list(tokenizer.semantic_id_to_token_id.values())
|
187 |
+
|
188 |
+
# Slow transformer
|
189 |
+
self.embeddings = nn.Embedding(
|
190 |
+
config.vocab_size,
|
191 |
+
config.dim,
|
192 |
+
)
|
193 |
+
self.codebook_embeddings = nn.Embedding(
|
194 |
+
config.codebook_size * config.num_codebooks,
|
195 |
+
config.dim,
|
196 |
+
)
|
197 |
+
self.layers = nn.ModuleList(
|
198 |
+
TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
|
199 |
+
)
|
200 |
+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
201 |
+
|
202 |
+
if self.config.tie_word_embeddings is False:
|
203 |
+
self.output = nn.Linear(
|
204 |
+
config.dim,
|
205 |
+
config.vocab_size,
|
206 |
+
bias=False,
|
207 |
+
)
|
208 |
+
|
209 |
+
self.register_buffer(
|
210 |
+
"freqs_cis",
|
211 |
+
precompute_freqs_cis(
|
212 |
+
config.max_seq_len,
|
213 |
+
config.head_dim,
|
214 |
+
config.rope_base,
|
215 |
+
),
|
216 |
+
persistent=False,
|
217 |
+
)
|
218 |
+
self.register_buffer(
|
219 |
+
"causal_mask",
|
220 |
+
torch.tril(
|
221 |
+
torch.ones(
|
222 |
+
config.max_seq_len,
|
223 |
+
config.max_seq_len,
|
224 |
+
dtype=torch.bool,
|
225 |
+
)
|
226 |
+
),
|
227 |
+
persistent=False,
|
228 |
+
)
|
229 |
+
|
230 |
+
# For kv cache
|
231 |
+
self.max_batch_size = -1
|
232 |
+
self.max_seq_len = -1
|
233 |
+
|
234 |
+
if init_weights:
|
235 |
+
self.apply(self._init_weights)
|
236 |
+
|
237 |
+
def setup_caches(
|
238 |
+
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
|
239 |
+
):
|
240 |
+
if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
|
241 |
+
return
|
242 |
+
|
243 |
+
max_seq_len = find_multiple(max_seq_len, 8)
|
244 |
+
self.max_seq_len = max_seq_len
|
245 |
+
self.max_batch_size = max_batch_size
|
246 |
+
|
247 |
+
for b in self.layers:
|
248 |
+
b.attention.kv_cache = KVCache(
|
249 |
+
max_batch_size,
|
250 |
+
max_seq_len,
|
251 |
+
self.config.n_local_heads,
|
252 |
+
self.config.head_dim,
|
253 |
+
dtype=dtype,
|
254 |
+
)
|
255 |
+
|
256 |
+
def embed(self, inp: Tensor) -> Tensor:
|
257 |
+
embeds = []
|
258 |
+
semantic_token_ids_tensor = torch.tensor(
|
259 |
+
self.semantic_token_ids, device=inp.device, dtype=inp.dtype
|
260 |
+
)
|
261 |
+
|
262 |
+
for i in range(self.config.num_codebooks):
|
263 |
+
emb = self.codebook_embeddings(
|
264 |
+
inp[:, i + 1] + i * self.config.codebook_size
|
265 |
+
)
|
266 |
+
embeds.append(emb)
|
267 |
+
|
268 |
+
vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
|
269 |
+
vq_embeds_sum[~torch.isin(inp[:, 0], semantic_token_ids_tensor)] = 0
|
270 |
+
x = self.embeddings(inp[:, 0]) + vq_embeds_sum
|
271 |
+
|
272 |
+
return x
|
273 |
+
|
274 |
+
def forward(
|
275 |
+
self,
|
276 |
+
inp: Tensor,
|
277 |
+
key_padding_mask: Optional[Tensor] = None,
|
278 |
+
) -> BaseTransformerForwardResult:
|
279 |
+
seq_len = inp.size(2)
|
280 |
+
|
281 |
+
# Here we want to merge the embeddings of the codebooks
|
282 |
+
x = self.embed(inp)
|
283 |
+
|
284 |
+
freqs_cis = self.freqs_cis[:seq_len]
|
285 |
+
|
286 |
+
# Not that the causal mask here follows the definition of scaled_dot_product_attention
|
287 |
+
# That is, FALSE means masked out
|
288 |
+
# To maintain consistency, key_padding_mask use TRUE to mask out
|
289 |
+
mask = None
|
290 |
+
if key_padding_mask is not None:
|
291 |
+
causal = self.causal_mask[:seq_len, :seq_len]
|
292 |
+
causal = rearrange(causal, "q k -> 1 1 q k")
|
293 |
+
|
294 |
+
atten_mask = rearrange(key_padding_mask, "b s -> b 1 1 s")
|
295 |
+
atten_mask = atten_mask.logical_not()
|
296 |
+
mask = causal & atten_mask
|
297 |
+
|
298 |
+
# return freqs_cis, mask
|
299 |
+
|
300 |
+
for layer in self.layers:
|
301 |
+
if self.config.use_gradient_checkpointing and self.training:
|
302 |
+
x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
|
303 |
+
else:
|
304 |
+
x = layer(x, freqs_cis, mask)
|
305 |
+
|
306 |
+
# We got slow_out here
|
307 |
+
slow_out = self.norm(x)
|
308 |
+
|
309 |
+
if self.config.tie_word_embeddings:
|
310 |
+
token_logits = F.linear(slow_out, self.embeddings.weight)
|
311 |
+
else:
|
312 |
+
token_logits = self.output(slow_out)
|
313 |
+
|
314 |
+
return BaseTransformerForwardResult(
|
315 |
+
logits=token_logits,
|
316 |
+
hidden_states=x,
|
317 |
+
)
|
318 |
+
|
319 |
+
def forward_generate(
|
320 |
+
self,
|
321 |
+
inp: Tensor,
|
322 |
+
input_pos: Optional[Tensor] = None,
|
323 |
+
return_all: bool = False,
|
324 |
+
) -> BaseTransformerForwardResult:
|
325 |
+
x = self.embed(inp)
|
326 |
+
|
327 |
+
if input_pos is None:
|
328 |
+
input_pos = torch.arange(inp.shape[-1], device=x.device)
|
329 |
+
max_seq_len = inp.shape[-1]
|
330 |
+
else:
|
331 |
+
max_seq_len = self.max_seq_len
|
332 |
+
|
333 |
+
mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
|
334 |
+
freqs_cis = self.freqs_cis[input_pos]
|
335 |
+
|
336 |
+
for layer in self.layers:
|
337 |
+
x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
338 |
+
|
339 |
+
# If prefill, we only calculate the logits of last token
|
340 |
+
if x.size(1) > 1 and not return_all:
|
341 |
+
x = x[:, -1:]
|
342 |
+
|
343 |
+
# We got slow_out here
|
344 |
+
slow_out = self.norm(x)
|
345 |
+
|
346 |
+
if self.config.is_reward_model:
|
347 |
+
token_logits = self.score_output(slow_out)
|
348 |
+
elif self.config.tie_word_embeddings:
|
349 |
+
token_logits = F.linear(slow_out, self.embeddings.weight)
|
350 |
+
else:
|
351 |
+
token_logits = self.output(slow_out)
|
352 |
+
|
353 |
+
return BaseTransformerForwardResult(
|
354 |
+
logits=token_logits,
|
355 |
+
hidden_states=x,
|
356 |
+
)
|
357 |
+
|
358 |
+
def _init_weights(self, module):
|
359 |
+
std = self.config.initializer_range
|
360 |
+
if isinstance(module, nn.Linear):
|
361 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
362 |
+
if module.bias is not None:
|
363 |
+
module.bias.data.zero_()
|
364 |
+
elif isinstance(module, nn.Embedding):
|
365 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
366 |
+
if module.padding_idx is not None:
|
367 |
+
module.weight.data[module.padding_idx].zero_()
|
368 |
+
|
369 |
+
@staticmethod
|
370 |
+
def from_pretrained(
|
371 |
+
path: str,
|
372 |
+
load_weights: bool = False,
|
373 |
+
max_length: int | None = None,
|
374 |
+
lora_config: LoraConfig | None = None,
|
375 |
+
rope_base: int | None = None,
|
376 |
+
) -> "BaseTransformer":
|
377 |
+
config = BaseModelArgs.from_pretrained(str(path))
|
378 |
+
if max_length is not None:
|
379 |
+
config.max_seq_len = max_length
|
380 |
+
logger.info(f"Override max_seq_len to {max_length}")
|
381 |
+
|
382 |
+
if rope_base is not None:
|
383 |
+
config.rope_base = rope_base
|
384 |
+
logger.info(f"Override rope_base to {rope_base}")
|
385 |
+
|
386 |
+
match config.model_type:
|
387 |
+
case "naive":
|
388 |
+
model_cls = NaiveTransformer
|
389 |
+
case "dual_ar":
|
390 |
+
model_cls = DualARTransformer
|
391 |
+
case _:
|
392 |
+
raise ValueError(f"Unknown model type: {config.model_type}")
|
393 |
+
|
394 |
+
tokenizer = FishTokenizer.from_pretrained(path)
|
395 |
+
|
396 |
+
logger.info(f"Loading model from {path}, config: {config}")
|
397 |
+
model = model_cls(config, tokenizer=tokenizer)
|
398 |
+
|
399 |
+
if lora_config is not None:
|
400 |
+
setup_lora(model, lora_config)
|
401 |
+
logger.info(f"LoRA setup: {lora_config}")
|
402 |
+
|
403 |
+
if load_weights is False:
|
404 |
+
logger.info("Randomly initialized model")
|
405 |
+
else:
|
406 |
+
|
407 |
+
if "int8" in str(Path(path)):
|
408 |
+
logger.info("Using int8 weight-only quantization!")
|
409 |
+
from tools.llama.quantize import WeightOnlyInt8QuantHandler
|
410 |
+
|
411 |
+
simple_quantizer = WeightOnlyInt8QuantHandler(model)
|
412 |
+
model = simple_quantizer.convert_for_runtime()
|
413 |
+
|
414 |
+
if "int4" in str(Path(path)):
|
415 |
+
logger.info("Using int4 quantization!")
|
416 |
+
path_comps = path.name.split("-")
|
417 |
+
assert path_comps[-2].startswith("g")
|
418 |
+
groupsize = int(path_comps[-2][1:])
|
419 |
+
from tools.llama.quantize import WeightOnlyInt4QuantHandler
|
420 |
+
|
421 |
+
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
|
422 |
+
model = simple_quantizer.convert_for_runtime()
|
423 |
+
|
424 |
+
weights = torch.load(
|
425 |
+
Path(path) / "model.pth",
|
426 |
+
map_location="cpu",
|
427 |
+
mmap=True,
|
428 |
+
weights_only=True,
|
429 |
+
)
|
430 |
+
|
431 |
+
if "state_dict" in weights:
|
432 |
+
logger.warning(
|
433 |
+
"Using a TextToSemantic LightningModule checkpoint, "
|
434 |
+
"please make sure it is a full model, not a LoRA model."
|
435 |
+
)
|
436 |
+
weights = weights["state_dict"]
|
437 |
+
|
438 |
+
if next(iter(weights.keys())).startswith("model."):
|
439 |
+
logger.info(
|
440 |
+
f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
|
441 |
+
)
|
442 |
+
new_weights = OrderedDict()
|
443 |
+
for k, v in weights.items():
|
444 |
+
new_weights[k.replace("model.", "")] = v
|
445 |
+
weights = new_weights
|
446 |
+
|
447 |
+
# Remove audio related weights
|
448 |
+
for k in list(weights.keys()):
|
449 |
+
if "audio_" in k:
|
450 |
+
weights.pop(k)
|
451 |
+
|
452 |
+
# Verify the name and shape of parameters since strict=False in load_state_dict.
|
453 |
+
for k, v in model.named_parameters():
|
454 |
+
if k not in weights:
|
455 |
+
logger.warning(f"No weight for {k}")
|
456 |
+
elif v.shape != weights[k].shape:
|
457 |
+
logger.warning(
|
458 |
+
f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
|
459 |
+
)
|
460 |
+
|
461 |
+
err = model.load_state_dict(weights, strict=False, assign=True)
|
462 |
+
logger.info(f"Loaded weights with error: {err}")
|
463 |
+
|
464 |
+
return model
|
465 |
+
|
466 |
+
def save_pretrained(self, path: str, drop_lora: bool = False):
|
467 |
+
path = Path(path)
|
468 |
+
path.mkdir(parents=True, exist_ok=True)
|
469 |
+
|
470 |
+
self.config.save(path / "config.json")
|
471 |
+
state_dict = self.state_dict()
|
472 |
+
|
473 |
+
if drop_lora:
|
474 |
+
for key in list(state_dict.keys()):
|
475 |
+
if "lora" not in key:
|
476 |
+
continue
|
477 |
+
|
478 |
+
state_dict.pop(key)
|
479 |
+
logger.info(f"Drop LoRA parameter: {key}")
|
480 |
+
|
481 |
+
torch.save(state_dict, path / "model.pth")
|
482 |
+
self.tokenizer.save_pretrained(path)
|
483 |
+
|
484 |
+
|
485 |
+
class NaiveTransformer(BaseTransformer):
|
486 |
+
def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
|
487 |
+
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
488 |
+
|
489 |
+
self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
490 |
+
self.codebook_output = nn.Linear(
|
491 |
+
config.dim,
|
492 |
+
config.codebook_size * config.num_codebooks,
|
493 |
+
bias=False,
|
494 |
+
)
|
495 |
+
|
496 |
+
self.apply(self._init_weights)
|
497 |
+
|
498 |
+
def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
|
499 |
+
token_logits = result.logits
|
500 |
+
x = result.hidden_states
|
501 |
+
|
502 |
+
# Codebook
|
503 |
+
codebook_logits = self.codebook_output(self.codebook_norm(x))
|
504 |
+
codebook_logits = rearrange(
|
505 |
+
codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
|
506 |
+
)
|
507 |
+
|
508 |
+
return TransformerForwardResult(
|
509 |
+
token_logits=token_logits,
|
510 |
+
codebook_logits=codebook_logits,
|
511 |
+
)
|
512 |
+
|
513 |
+
def forward(
|
514 |
+
self,
|
515 |
+
inp: Tensor,
|
516 |
+
key_padding_mask: Optional[Tensor] = None,
|
517 |
+
) -> TransformerForwardResult:
|
518 |
+
result = super().forward(
|
519 |
+
inp=inp,
|
520 |
+
key_padding_mask=key_padding_mask,
|
521 |
+
)
|
522 |
+
return self.decode(result)
|
523 |
+
|
524 |
+
def forward_generate(
|
525 |
+
self, x: Tensor, input_pos: Optional[Tensor] = None
|
526 |
+
) -> TransformerForwardResult:
|
527 |
+
result = super().forward_generate(x, input_pos)
|
528 |
+
return self.decode(result)
|
529 |
+
|
530 |
+
|
531 |
+
class DualARTransformer(BaseTransformer):
|
532 |
+
def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
|
533 |
+
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
534 |
+
|
535 |
+
# Project to fast dim if needed
|
536 |
+
if config.fast_dim is not None and config.fast_dim != config.dim:
|
537 |
+
self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
|
538 |
+
else:
|
539 |
+
self.fast_project_in = nn.Identity()
|
540 |
+
|
541 |
+
# Fast transformer
|
542 |
+
self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
|
543 |
+
|
544 |
+
# The equivalent bs is so large that sdpa doesn't work
|
545 |
+
override_config = dataclasses.replace(
|
546 |
+
config,
|
547 |
+
dim=config.fast_dim,
|
548 |
+
n_head=config.fast_n_head,
|
549 |
+
n_local_heads=config.fast_n_local_heads,
|
550 |
+
head_dim=config.fast_head_dim,
|
551 |
+
intermediate_size=config.fast_intermediate_size,
|
552 |
+
attention_qkv_bias=config.fast_attention_qkv_bias,
|
553 |
+
attention_qk_norm=config.fast_attention_qk_norm,
|
554 |
+
attention_o_bias=config.fast_attention_o_bias,
|
555 |
+
)
|
556 |
+
|
557 |
+
self.fast_layers = nn.ModuleList(
|
558 |
+
TransformerBlock(override_config, use_sdpa=False)
|
559 |
+
for _ in range(config.n_fast_layer)
|
560 |
+
)
|
561 |
+
self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
|
562 |
+
self.fast_output = nn.Linear(
|
563 |
+
config.fast_dim,
|
564 |
+
config.codebook_size,
|
565 |
+
bias=False,
|
566 |
+
)
|
567 |
+
|
568 |
+
self.register_buffer(
|
569 |
+
"fast_freqs_cis",
|
570 |
+
precompute_freqs_cis(
|
571 |
+
config.num_codebooks,
|
572 |
+
config.fast_head_dim,
|
573 |
+
config.rope_base,
|
574 |
+
),
|
575 |
+
persistent=False,
|
576 |
+
)
|
577 |
+
self.apply(self._init_weights)
|
578 |
+
|
579 |
+
def setup_caches(
|
580 |
+
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
|
581 |
+
):
|
582 |
+
super().setup_caches(max_batch_size, max_seq_len, dtype)
|
583 |
+
|
584 |
+
# Fast transformer
|
585 |
+
# The max seq len here is the number of codebooks
|
586 |
+
for b in self.fast_layers:
|
587 |
+
b.attention.kv_cache = KVCache(
|
588 |
+
max_batch_size,
|
589 |
+
self.config.num_codebooks,
|
590 |
+
self.config.fast_n_local_heads,
|
591 |
+
self.config.fast_head_dim,
|
592 |
+
dtype=dtype,
|
593 |
+
)
|
594 |
+
|
595 |
+
def forward(
|
596 |
+
self,
|
597 |
+
inp: Tensor,
|
598 |
+
key_padding_mask: Optional[Tensor] = None,
|
599 |
+
) -> TransformerForwardResult:
|
600 |
+
parent_result = super().forward(inp, key_padding_mask)
|
601 |
+
token_logits = parent_result.logits
|
602 |
+
x = parent_result.hidden_states
|
603 |
+
x = self.fast_project_in(x)
|
604 |
+
|
605 |
+
# Fast transformer
|
606 |
+
fast_seq_len = self.config.num_codebooks
|
607 |
+
fast_mask = self.causal_mask[
|
608 |
+
None, None, :fast_seq_len, :fast_seq_len
|
609 |
+
] # (B, N, Q, K)
|
610 |
+
|
611 |
+
# Drop the last token and rotate left
|
612 |
+
codebooks = inp[:, 1:-1, 1:]
|
613 |
+
codebooks = F.pad(codebooks, (0, 1), value=0)
|
614 |
+
codebook_embeddings = self.fast_embeddings(codebooks)
|
615 |
+
x = torch.cat([x[:, None], codebook_embeddings], dim=1)
|
616 |
+
b, s = x.size(0), x.size(2)
|
617 |
+
x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
|
618 |
+
|
619 |
+
# Remove padded part
|
620 |
+
codebooks = rearrange(codebooks, "b n s -> (b s) n")
|
621 |
+
codebook_mask = (codebooks == 0).all(dim=-1)
|
622 |
+
|
623 |
+
if torch.all(codebook_mask):
|
624 |
+
# If all codebooks are padded, we keep first 8 to make sure the model runs
|
625 |
+
codebook_mask[:8] = False
|
626 |
+
|
627 |
+
x_bs, x_len = x.size(0), x.size(1)
|
628 |
+
x = x[~codebook_mask]
|
629 |
+
|
630 |
+
for layer in self.fast_layers:
|
631 |
+
if self.config.use_gradient_checkpointing and self.training:
|
632 |
+
x = checkpoint(
|
633 |
+
layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
|
634 |
+
)
|
635 |
+
else:
|
636 |
+
x = layer(x, self.fast_freqs_cis, fast_mask)
|
637 |
+
|
638 |
+
# unflatten the batch and num_codebooks
|
639 |
+
fast_out = self.fast_norm(x)
|
640 |
+
codebook_logits = self.fast_output(fast_out)
|
641 |
+
|
642 |
+
# Re-pad the codebook_logits
|
643 |
+
buffer = torch.zeros(
|
644 |
+
x_bs,
|
645 |
+
x_len,
|
646 |
+
codebook_logits.size(-1),
|
647 |
+
device=codebook_logits.device,
|
648 |
+
dtype=codebook_logits.dtype,
|
649 |
+
)
|
650 |
+
buffer[~codebook_mask] = codebook_logits
|
651 |
+
codebook_logits = buffer
|
652 |
+
|
653 |
+
assert codebook_logits.shape[1] == self.config.num_codebooks
|
654 |
+
codebook_logits = rearrange(
|
655 |
+
codebook_logits,
|
656 |
+
"(b s) n d -> b s n d",
|
657 |
+
b=b,
|
658 |
+
s=s,
|
659 |
+
n=self.config.num_codebooks,
|
660 |
+
)
|
661 |
+
|
662 |
+
return TransformerForwardResult(
|
663 |
+
token_logits=token_logits,
|
664 |
+
codebook_logits=codebook_logits,
|
665 |
+
)
|
666 |
+
|
667 |
+
def forward_generate_fast(
|
668 |
+
self, x: Tensor, input_pos: Optional[Tensor] = None
|
669 |
+
) -> Tensor:
|
670 |
+
# Fast transformer
|
671 |
+
x = x.view(1, 1, -1)
|
672 |
+
|
673 |
+
fast_mask = self.causal_mask[
|
674 |
+
None, None, input_pos, : self.config.num_codebooks
|
675 |
+
] # (B, N, Q, K)
|
676 |
+
fast_freqs_cis = self.fast_freqs_cis[input_pos]
|
677 |
+
|
678 |
+
for layer in self.fast_layers:
|
679 |
+
x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
|
680 |
+
|
681 |
+
# unflatten the batch and num_codebooks
|
682 |
+
fast_out = self.fast_norm(x) # only take the last token
|
683 |
+
codebook_logits = self.fast_output(fast_out)
|
684 |
+
|
685 |
+
return codebook_logits
|
686 |
+
|
687 |
+
def forward_generate(
|
688 |
+
self,
|
689 |
+
x: Tensor,
|
690 |
+
input_pos: Optional[Tensor] = None,
|
691 |
+
vq_masks: Optional[Tensor] = None,
|
692 |
+
) -> TransformerForwardResult:
|
693 |
+
x = super().forward_generate(x, input_pos, vq_masks)
|
694 |
+
x.hidden_states = self.fast_project_in(x.hidden_states)
|
695 |
+
return x
|
696 |
+
|
697 |
+
|
698 |
+
class TransformerBlock(nn.Module):
|
699 |
+
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
|
700 |
+
super().__init__()
|
701 |
+
self.attention = Attention(config, use_sdpa=use_sdpa)
|
702 |
+
self.feed_forward = FeedForward(config)
|
703 |
+
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
|
704 |
+
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
|
705 |
+
|
706 |
+
def forward(
|
707 |
+
self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
|
708 |
+
) -> Tensor:
|
709 |
+
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
|
710 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
711 |
+
return out
|
712 |
+
|
713 |
+
|
714 |
+
class Attention(nn.Module):
|
715 |
+
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
|
716 |
+
super().__init__()
|
717 |
+
assert config.dim % config.n_head == 0
|
718 |
+
|
719 |
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
720 |
+
# key, query, value projections for all heads, but in a batch
|
721 |
+
self.wqkv = nn.Linear(
|
722 |
+
config.dim, total_head_dim, bias=config.attention_qkv_bias
|
723 |
+
)
|
724 |
+
self.wo = nn.Linear(
|
725 |
+
config.n_head * config.head_dim, config.dim, bias=config.attention_o_bias
|
726 |
+
)
|
727 |
+
self.kv_cache = None
|
728 |
+
|
729 |
+
if config.attention_qk_norm:
|
730 |
+
self.q_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
|
731 |
+
self.k_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
|
732 |
+
|
733 |
+
self.dropout = config.dropout
|
734 |
+
self.n_head = config.n_head
|
735 |
+
self.head_dim = config.head_dim
|
736 |
+
self.n_local_heads = config.n_local_heads
|
737 |
+
self.dim = config.dim
|
738 |
+
self.use_sdpa = use_sdpa
|
739 |
+
self.attention_qk_norm = config.attention_qk_norm
|
740 |
+
self.config = config
|
741 |
+
|
742 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
743 |
+
|
744 |
+
def load_hook(self, state_dict, prefix, *args):
|
745 |
+
if prefix + "wq.weight" in state_dict:
|
746 |
+
wq = state_dict.pop(prefix + "wq.weight")
|
747 |
+
wk = state_dict.pop(prefix + "wk.weight")
|
748 |
+
wv = state_dict.pop(prefix + "wv.weight")
|
749 |
+
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
750 |
+
|
751 |
+
def forward(
|
752 |
+
self,
|
753 |
+
x: Tensor,
|
754 |
+
freqs_cis: Tensor,
|
755 |
+
mask: Tensor,
|
756 |
+
input_pos: Optional[Tensor] = None,
|
757 |
+
) -> Tensor:
|
758 |
+
bsz, seqlen, _ = x.shape
|
759 |
+
|
760 |
+
q_size = self.n_head * self.head_dim
|
761 |
+
kv_size = self.n_local_heads * self.head_dim
|
762 |
+
q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
|
763 |
+
|
764 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
765 |
+
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
766 |
+
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
767 |
+
|
768 |
+
if self.attention_qk_norm:
|
769 |
+
q = self.q_norm(q)
|
770 |
+
k = self.k_norm(k)
|
771 |
+
|
772 |
+
q = apply_rotary_emb(q, freqs_cis)
|
773 |
+
k = apply_rotary_emb(k, freqs_cis)
|
774 |
+
|
775 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
776 |
+
|
777 |
+
if self.kv_cache is not None:
|
778 |
+
k, v = self.kv_cache.update(input_pos, k, v)
|
779 |
+
|
780 |
+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
781 |
+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
782 |
+
|
783 |
+
if self.use_sdpa:
|
784 |
+
if mask is None:
|
785 |
+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
786 |
+
y = F.scaled_dot_product_attention(
|
787 |
+
q,
|
788 |
+
k,
|
789 |
+
v,
|
790 |
+
dropout_p=self.dropout if self.training else 0.0,
|
791 |
+
is_causal=True,
|
792 |
+
# No third party attn_mask here to use flash_attention
|
793 |
+
)
|
794 |
+
else:
|
795 |
+
y = F.scaled_dot_product_attention(
|
796 |
+
q,
|
797 |
+
k,
|
798 |
+
v,
|
799 |
+
attn_mask=mask,
|
800 |
+
dropout_p=self.dropout if self.training else 0.0,
|
801 |
+
)
|
802 |
+
else:
|
803 |
+
y = self.eq_scaled_dot_product_attention(
|
804 |
+
q,
|
805 |
+
k,
|
806 |
+
v,
|
807 |
+
attn_mask=mask,
|
808 |
+
dropout_p=self.dropout if self.training else 0.0,
|
809 |
+
)
|
810 |
+
|
811 |
+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
|
812 |
+
|
813 |
+
return self.wo(y)
|
814 |
+
|
815 |
+
def eq_scaled_dot_product_attention(
|
816 |
+
self,
|
817 |
+
query,
|
818 |
+
key,
|
819 |
+
value,
|
820 |
+
attn_mask=None,
|
821 |
+
dropout_p=0.0,
|
822 |
+
) -> torch.Tensor:
|
823 |
+
# This is a standard scaled dot product attention
|
824 |
+
# It's low efficient, but it doesn't raise cuda error
|
825 |
+
|
826 |
+
L, S = query.size(-2), key.size(-2)
|
827 |
+
scale_factor = 1 / math.sqrt(query.size(-1))
|
828 |
+
attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
|
829 |
+
|
830 |
+
if attn_mask is not None:
|
831 |
+
if attn_mask.dtype == torch.bool:
|
832 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
833 |
+
else:
|
834 |
+
attn_bias += attn_mask
|
835 |
+
|
836 |
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
837 |
+
attn_weight += attn_bias
|
838 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
839 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
840 |
+
|
841 |
+
return attn_weight @ value
|
842 |
+
|
843 |
+
|
844 |
+
class FeedForward(nn.Module):
|
845 |
+
def __init__(self, config: BaseModelArgs) -> None:
|
846 |
+
super().__init__()
|
847 |
+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
848 |
+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
849 |
+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
850 |
+
|
851 |
+
def forward(self, x: Tensor) -> Tensor:
|
852 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
853 |
+
|
854 |
+
|
855 |
+
class RMSNorm(nn.Module):
|
856 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
857 |
+
super().__init__()
|
858 |
+
self.eps = eps
|
859 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
860 |
+
|
861 |
+
def _norm(self, x):
|
862 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
863 |
+
|
864 |
+
def forward(self, x: Tensor) -> Tensor:
|
865 |
+
output = self._norm(x.float()).type_as(x)
|
866 |
+
return output * self.weight
|
867 |
+
|
868 |
+
|
869 |
+
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
|
870 |
+
"""
|
871 |
+
Precomputes frequency tensors for complex exponentials (cis)
|
872 |
+
|
873 |
+
Args:
|
874 |
+
seq_len: Length of the sequence for which positional embeddings are needed.
|
875 |
+
n_elem: Number of elements in the frequency tensor.
|
876 |
+
base: Base value for the frequency scaling (default: 10000).
|
877 |
+
|
878 |
+
Returns:
|
879 |
+
A tensor containing the precomputed frequencies in real and imaginary parts (bfloat16).
|
880 |
+
"""
|
881 |
+
freqs = 1.0 / (
|
882 |
+
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
|
883 |
+
)
|
884 |
+
t = torch.arange(seq_len, device=freqs.device)
|
885 |
+
freqs = torch.outer(t, freqs)
|
886 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
887 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
888 |
+
return cache.to(dtype=torch.bfloat16)
|
889 |
+
|
890 |
+
|
891 |
+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
892 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
893 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
894 |
+
x_out2 = torch.stack(
|
895 |
+
[
|
896 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
897 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
898 |
+
],
|
899 |
+
-1,
|
900 |
+
)
|
901 |
+
|
902 |
+
x_out2 = x_out2.flatten(3)
|
903 |
+
return x_out2.type_as(x)
|
fish_speech/models/text2semantic/lora.py
CHANGED
@@ -1,92 +1,92 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
|
3 |
-
import loralib as lora
|
4 |
-
|
5 |
-
|
6 |
-
@dataclass
|
7 |
-
class LoraConfig:
|
8 |
-
r: int
|
9 |
-
lora_alpha: float
|
10 |
-
lora_dropout: float = 0.0
|
11 |
-
|
12 |
-
|
13 |
-
def setup_lora(model, lora_config):
|
14 |
-
# Replace the embedding layer with a LoRA layer
|
15 |
-
model.embeddings = lora.Embedding(
|
16 |
-
num_embeddings=model.embeddings.num_embeddings,
|
17 |
-
embedding_dim=model.embeddings.embedding_dim,
|
18 |
-
padding_idx=model.embeddings.padding_idx,
|
19 |
-
r=lora_config.r,
|
20 |
-
lora_alpha=lora_config.lora_alpha,
|
21 |
-
)
|
22 |
-
|
23 |
-
model.codebook_embeddings = lora.Embedding(
|
24 |
-
num_embeddings=model.codebook_embeddings.num_embeddings,
|
25 |
-
embedding_dim=model.codebook_embeddings.embedding_dim,
|
26 |
-
padding_idx=model.codebook_embeddings.padding_idx,
|
27 |
-
r=lora_config.r,
|
28 |
-
lora_alpha=lora_config.lora_alpha,
|
29 |
-
)
|
30 |
-
|
31 |
-
# Replace output layer with a LoRA layer
|
32 |
-
linears = [(model, "output")]
|
33 |
-
|
34 |
-
# Replace all linear layers with LoRA layers
|
35 |
-
for layer in model.layers:
|
36 |
-
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
37 |
-
linears.extend(
|
38 |
-
[
|
39 |
-
(layer.feed_forward, "w1"),
|
40 |
-
(layer.feed_forward, "w2"),
|
41 |
-
(layer.feed_forward, "w3"),
|
42 |
-
]
|
43 |
-
)
|
44 |
-
|
45 |
-
if hasattr(model, "fast_layers"):
|
46 |
-
model.fast_embeddings = lora.Embedding(
|
47 |
-
num_embeddings=model.fast_embeddings.num_embeddings,
|
48 |
-
embedding_dim=model.fast_embeddings.embedding_dim,
|
49 |
-
padding_idx=model.fast_embeddings.padding_idx,
|
50 |
-
r=lora_config.r,
|
51 |
-
lora_alpha=lora_config.lora_alpha,
|
52 |
-
)
|
53 |
-
|
54 |
-
# Dual-AR model
|
55 |
-
linears.append((model, "fast_output"))
|
56 |
-
|
57 |
-
for layer in model.fast_layers:
|
58 |
-
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
59 |
-
linears.extend(
|
60 |
-
[
|
61 |
-
(layer.feed_forward, "w1"),
|
62 |
-
(layer.feed_forward, "w2"),
|
63 |
-
(layer.feed_forward, "w3"),
|
64 |
-
]
|
65 |
-
)
|
66 |
-
|
67 |
-
for module, layer in linears:
|
68 |
-
updated_linear = lora.Linear(
|
69 |
-
in_features=getattr(module, layer).in_features,
|
70 |
-
out_features=getattr(module, layer).out_features,
|
71 |
-
bias=getattr(module, layer).bias,
|
72 |
-
r=lora_config.r,
|
73 |
-
lora_alpha=lora_config.lora_alpha,
|
74 |
-
lora_dropout=lora_config.lora_dropout,
|
75 |
-
)
|
76 |
-
setattr(module, layer, updated_linear)
|
77 |
-
|
78 |
-
# Mark only the LoRA layers as trainable
|
79 |
-
lora.mark_only_lora_as_trainable(model, bias="none")
|
80 |
-
|
81 |
-
|
82 |
-
def get_merged_state_dict(model):
|
83 |
-
# This line will merge the state dict of the model and the LoRA parameters
|
84 |
-
model.eval()
|
85 |
-
|
86 |
-
# Then we need to remove the LoRA parameters from the state dict
|
87 |
-
state_dict = model.state_dict()
|
88 |
-
for name in list(state_dict.keys()):
|
89 |
-
if "lora" in name:
|
90 |
-
state_dict.pop(name)
|
91 |
-
|
92 |
-
return state_dict
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import loralib as lora
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class LoraConfig:
|
8 |
+
r: int
|
9 |
+
lora_alpha: float
|
10 |
+
lora_dropout: float = 0.0
|
11 |
+
|
12 |
+
|
13 |
+
def setup_lora(model, lora_config):
|
14 |
+
# Replace the embedding layer with a LoRA layer
|
15 |
+
model.embeddings = lora.Embedding(
|
16 |
+
num_embeddings=model.embeddings.num_embeddings,
|
17 |
+
embedding_dim=model.embeddings.embedding_dim,
|
18 |
+
padding_idx=model.embeddings.padding_idx,
|
19 |
+
r=lora_config.r,
|
20 |
+
lora_alpha=lora_config.lora_alpha,
|
21 |
+
)
|
22 |
+
|
23 |
+
model.codebook_embeddings = lora.Embedding(
|
24 |
+
num_embeddings=model.codebook_embeddings.num_embeddings,
|
25 |
+
embedding_dim=model.codebook_embeddings.embedding_dim,
|
26 |
+
padding_idx=model.codebook_embeddings.padding_idx,
|
27 |
+
r=lora_config.r,
|
28 |
+
lora_alpha=lora_config.lora_alpha,
|
29 |
+
)
|
30 |
+
|
31 |
+
# Replace output layer with a LoRA layer
|
32 |
+
linears = [(model, "output")]
|
33 |
+
|
34 |
+
# Replace all linear layers with LoRA layers
|
35 |
+
for layer in model.layers:
|
36 |
+
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
37 |
+
linears.extend(
|
38 |
+
[
|
39 |
+
(layer.feed_forward, "w1"),
|
40 |
+
(layer.feed_forward, "w2"),
|
41 |
+
(layer.feed_forward, "w3"),
|
42 |
+
]
|
43 |
+
)
|
44 |
+
|
45 |
+
if hasattr(model, "fast_layers"):
|
46 |
+
model.fast_embeddings = lora.Embedding(
|
47 |
+
num_embeddings=model.fast_embeddings.num_embeddings,
|
48 |
+
embedding_dim=model.fast_embeddings.embedding_dim,
|
49 |
+
padding_idx=model.fast_embeddings.padding_idx,
|
50 |
+
r=lora_config.r,
|
51 |
+
lora_alpha=lora_config.lora_alpha,
|
52 |
+
)
|
53 |
+
|
54 |
+
# Dual-AR model
|
55 |
+
linears.append((model, "fast_output"))
|
56 |
+
|
57 |
+
for layer in model.fast_layers:
|
58 |
+
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
59 |
+
linears.extend(
|
60 |
+
[
|
61 |
+
(layer.feed_forward, "w1"),
|
62 |
+
(layer.feed_forward, "w2"),
|
63 |
+
(layer.feed_forward, "w3"),
|
64 |
+
]
|
65 |
+
)
|
66 |
+
|
67 |
+
for module, layer in linears:
|
68 |
+
updated_linear = lora.Linear(
|
69 |
+
in_features=getattr(module, layer).in_features,
|
70 |
+
out_features=getattr(module, layer).out_features,
|
71 |
+
bias=getattr(module, layer).bias,
|
72 |
+
r=lora_config.r,
|
73 |
+
lora_alpha=lora_config.lora_alpha,
|
74 |
+
lora_dropout=lora_config.lora_dropout,
|
75 |
+
)
|
76 |
+
setattr(module, layer, updated_linear)
|
77 |
+
|
78 |
+
# Mark only the LoRA layers as trainable
|
79 |
+
lora.mark_only_lora_as_trainable(model, bias="none")
|
80 |
+
|
81 |
+
|
82 |
+
def get_merged_state_dict(model):
|
83 |
+
# This line will merge the state dict of the model and the LoRA parameters
|
84 |
+
model.eval()
|
85 |
+
|
86 |
+
# Then we need to remove the LoRA parameters from the state dict
|
87 |
+
state_dict = model.state_dict()
|
88 |
+
for name in list(state_dict.keys()):
|
89 |
+
if "lora" in name:
|
90 |
+
state_dict.pop(name)
|
91 |
+
|
92 |
+
return state_dict
|
fish_speech/text/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from .clean import clean_text
|
2 |
-
from .spliter import split_text
|
3 |
-
|
4 |
-
__all__ = ["clean_text", "split_text"]
|
|
|
1 |
+
from .clean import clean_text
|
2 |
+
from .spliter import split_text
|
3 |
+
|
4 |
+
__all__ = ["clean_text", "split_text"]
|
fish_speech/text/clean.py
CHANGED
@@ -1,37 +1,37 @@
|
|
1 |
-
import re
|
2 |
-
|
3 |
-
SYMBOLS_MAPPING = {
|
4 |
-
"‘": "'",
|
5 |
-
"’": "'",
|
6 |
-
}
|
7 |
-
|
8 |
-
REPLACE_SYMBOL_REGEX = re.compile(
|
9 |
-
"|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
|
10 |
-
)
|
11 |
-
|
12 |
-
|
13 |
-
EMOJI_REGEX = re.compile(
|
14 |
-
"["
|
15 |
-
"\
|
16 |
-
"\
|
17 |
-
"\
|
18 |
-
"\
|
19 |
-
"]+",
|
20 |
-
flags=re.UNICODE,
|
21 |
-
)
|
22 |
-
|
23 |
-
|
24 |
-
def clean_text(text):
|
25 |
-
# Clean the text
|
26 |
-
text = text.strip()
|
27 |
-
|
28 |
-
# Replace all chinese symbols with their english counterparts
|
29 |
-
text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
|
30 |
-
|
31 |
-
# Remove emojis
|
32 |
-
text = EMOJI_REGEX.sub(r"", text)
|
33 |
-
|
34 |
-
# Remove continuous periods (...) and commas (,,,)
|
35 |
-
text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
|
36 |
-
|
37 |
-
return text
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
SYMBOLS_MAPPING = {
|
4 |
+
"‘": "'",
|
5 |
+
"’": "'",
|
6 |
+
}
|
7 |
+
|
8 |
+
REPLACE_SYMBOL_REGEX = re.compile(
|
9 |
+
"|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
EMOJI_REGEX = re.compile(
|
14 |
+
"["
|
15 |
+
"\U0001f600-\U0001f64f" # emoticons
|
16 |
+
"\U0001f300-\U0001f5ff" # symbols & pictographs
|
17 |
+
"\U0001f680-\U0001f6ff" # transport & map symbols
|
18 |
+
"\U0001f1e0-\U0001f1ff" # flags (iOS)
|
19 |
+
"]+",
|
20 |
+
flags=re.UNICODE,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
def clean_text(text):
|
25 |
+
# Clean the text
|
26 |
+
text = text.strip()
|
27 |
+
|
28 |
+
# Replace all chinese symbols with their english counterparts
|
29 |
+
text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
|
30 |
+
|
31 |
+
# Remove emojis
|
32 |
+
text = EMOJI_REGEX.sub(r"", text)
|
33 |
+
|
34 |
+
# Remove continuous periods (...) and commas (,,,)
|
35 |
+
text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
|
36 |
+
|
37 |
+
return text
|
fish_speech/text/spliter.py
CHANGED
@@ -1,130 +1,130 @@
|
|
1 |
-
import re
|
2 |
-
import string
|
3 |
-
|
4 |
-
from fish_speech.text.clean import clean_text
|
5 |
-
|
6 |
-
|
7 |
-
def utf_8_len(text: str):
|
8 |
-
return len(text.encode("utf-8"))
|
9 |
-
|
10 |
-
|
11 |
-
def break_text(texts, length, splits: set):
|
12 |
-
for text in texts:
|
13 |
-
if utf_8_len(text) <= length:
|
14 |
-
yield text
|
15 |
-
continue
|
16 |
-
|
17 |
-
curr = ""
|
18 |
-
for char in text:
|
19 |
-
curr += char
|
20 |
-
|
21 |
-
if char in splits:
|
22 |
-
yield curr
|
23 |
-
curr = ""
|
24 |
-
|
25 |
-
if curr:
|
26 |
-
yield curr
|
27 |
-
|
28 |
-
|
29 |
-
def break_text_by_length(texts, length):
|
30 |
-
for text in texts:
|
31 |
-
if utf_8_len(text) <= length:
|
32 |
-
yield text
|
33 |
-
continue
|
34 |
-
|
35 |
-
curr = ""
|
36 |
-
for char in text:
|
37 |
-
curr += char
|
38 |
-
|
39 |
-
if utf_8_len(curr) >= length:
|
40 |
-
yield curr
|
41 |
-
curr = ""
|
42 |
-
|
43 |
-
if curr:
|
44 |
-
yield curr
|
45 |
-
|
46 |
-
|
47 |
-
def add_cleaned(curr, segments):
|
48 |
-
curr = curr.strip()
|
49 |
-
if curr and not all(c.isspace() or c in string.punctuation for c in curr):
|
50 |
-
segments.append(curr)
|
51 |
-
|
52 |
-
|
53 |
-
def protect_float(text):
|
54 |
-
# Turns 3.14 into <3_f_14> to prevent splitting
|
55 |
-
return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
|
56 |
-
|
57 |
-
|
58 |
-
def unprotect_float(text):
|
59 |
-
# Turns <3_f_14> into 3.14
|
60 |
-
return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
|
61 |
-
|
62 |
-
|
63 |
-
def split_text(text, length):
|
64 |
-
text = clean_text(text)
|
65 |
-
|
66 |
-
# Break the text into pieces with following rules:
|
67 |
-
# 1. Split the text at ".", "!", "?" if text is NOT a float
|
68 |
-
# 2. If the text is longer than length, split at ","
|
69 |
-
# 3. If the text is still longer than length, split at " "
|
70 |
-
# 4. If the text is still longer than length, split at any character to length
|
71 |
-
|
72 |
-
texts = [text]
|
73 |
-
texts = map(protect_float, texts)
|
74 |
-
texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
|
75 |
-
texts = map(unprotect_float, texts)
|
76 |
-
texts = break_text(texts, length, {",", ","})
|
77 |
-
texts = break_text(texts, length, {" "})
|
78 |
-
texts = list(break_text_by_length(texts, length))
|
79 |
-
|
80 |
-
# Then, merge the texts into segments with length <= length
|
81 |
-
segments = []
|
82 |
-
curr = ""
|
83 |
-
|
84 |
-
for text in texts:
|
85 |
-
if utf_8_len(curr) + utf_8_len(text) <= length:
|
86 |
-
curr += text
|
87 |
-
else:
|
88 |
-
add_cleaned(curr, segments)
|
89 |
-
curr = text
|
90 |
-
|
91 |
-
if curr:
|
92 |
-
add_cleaned(curr, segments)
|
93 |
-
|
94 |
-
return segments
|
95 |
-
|
96 |
-
|
97 |
-
if __name__ == "__main__":
|
98 |
-
# Test the split_text function
|
99 |
-
|
100 |
-
text = "This is a test sentence. This is another test sentence. And a third one."
|
101 |
-
|
102 |
-
assert split_text(text, 50) == [
|
103 |
-
"This is a test sentence.",
|
104 |
-
"This is another test sentence. And a third one.",
|
105 |
-
]
|
106 |
-
assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
|
107 |
-
assert split_text(" ", 10) == []
|
108 |
-
assert split_text("a", 10) == ["a"]
|
109 |
-
|
110 |
-
text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
|
111 |
-
assert split_text(text, 50) == [
|
112 |
-
"This is a test sentence with only commas,",
|
113 |
-
"and no dots, and no exclamation marks,",
|
114 |
-
"and no question marks, and no newlines.",
|
115 |
-
]
|
116 |
-
|
117 |
-
text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
|
118 |
-
# First half split at " ", second half split at ","
|
119 |
-
assert split_text(text, 50) == [
|
120 |
-
"This is a test sentence This is a test sentence",
|
121 |
-
"This is a test sentence. This is a test sentence,",
|
122 |
-
"This is a test sentence, This is a test sentence.",
|
123 |
-
]
|
124 |
-
|
125 |
-
text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
|
126 |
-
assert split_text(text, 50) == [
|
127 |
-
"这是一段很长的中文文本,",
|
128 |
-
"而且没有句号,也没有感叹号,",
|
129 |
-
"也没有问号,也没有换行符.",
|
130 |
-
]
|
|
|
1 |
+
import re
|
2 |
+
import string
|
3 |
+
|
4 |
+
from fish_speech.text.clean import clean_text
|
5 |
+
|
6 |
+
|
7 |
+
def utf_8_len(text: str):
|
8 |
+
return len(text.encode("utf-8"))
|
9 |
+
|
10 |
+
|
11 |
+
def break_text(texts, length, splits: set):
|
12 |
+
for text in texts:
|
13 |
+
if utf_8_len(text) <= length:
|
14 |
+
yield text
|
15 |
+
continue
|
16 |
+
|
17 |
+
curr = ""
|
18 |
+
for char in text:
|
19 |
+
curr += char
|
20 |
+
|
21 |
+
if char in splits:
|
22 |
+
yield curr
|
23 |
+
curr = ""
|
24 |
+
|
25 |
+
if curr:
|
26 |
+
yield curr
|
27 |
+
|
28 |
+
|
29 |
+
def break_text_by_length(texts, length):
|
30 |
+
for text in texts:
|
31 |
+
if utf_8_len(text) <= length:
|
32 |
+
yield text
|
33 |
+
continue
|
34 |
+
|
35 |
+
curr = ""
|
36 |
+
for char in text:
|
37 |
+
curr += char
|
38 |
+
|
39 |
+
if utf_8_len(curr) >= length:
|
40 |
+
yield curr
|
41 |
+
curr = ""
|
42 |
+
|
43 |
+
if curr:
|
44 |
+
yield curr
|
45 |
+
|
46 |
+
|
47 |
+
def add_cleaned(curr, segments):
|
48 |
+
curr = curr.strip()
|
49 |
+
if curr and not all(c.isspace() or c in string.punctuation for c in curr):
|
50 |
+
segments.append(curr)
|
51 |
+
|
52 |
+
|
53 |
+
def protect_float(text):
|
54 |
+
# Turns 3.14 into <3_f_14> to prevent splitting
|
55 |
+
return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
|
56 |
+
|
57 |
+
|
58 |
+
def unprotect_float(text):
|
59 |
+
# Turns <3_f_14> into 3.14
|
60 |
+
return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
|
61 |
+
|
62 |
+
|
63 |
+
def split_text(text, length):
|
64 |
+
text = clean_text(text)
|
65 |
+
|
66 |
+
# Break the text into pieces with following rules:
|
67 |
+
# 1. Split the text at ".", "!", "?" if text is NOT a float
|
68 |
+
# 2. If the text is longer than length, split at ","
|
69 |
+
# 3. If the text is still longer than length, split at " "
|
70 |
+
# 4. If the text is still longer than length, split at any character to length
|
71 |
+
|
72 |
+
texts = [text]
|
73 |
+
texts = map(protect_float, texts)
|
74 |
+
texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
|
75 |
+
texts = map(unprotect_float, texts)
|
76 |
+
texts = break_text(texts, length, {",", ","})
|
77 |
+
texts = break_text(texts, length, {" "})
|
78 |
+
texts = list(break_text_by_length(texts, length))
|
79 |
+
|
80 |
+
# Then, merge the texts into segments with length <= length
|
81 |
+
segments = []
|
82 |
+
curr = ""
|
83 |
+
|
84 |
+
for text in texts:
|
85 |
+
if utf_8_len(curr) + utf_8_len(text) <= length:
|
86 |
+
curr += text
|
87 |
+
else:
|
88 |
+
add_cleaned(curr, segments)
|
89 |
+
curr = text
|
90 |
+
|
91 |
+
if curr:
|
92 |
+
add_cleaned(curr, segments)
|
93 |
+
|
94 |
+
return segments
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
# Test the split_text function
|
99 |
+
|
100 |
+
text = "This is a test sentence. This is another test sentence. And a third one."
|
101 |
+
|
102 |
+
assert split_text(text, 50) == [
|
103 |
+
"This is a test sentence.",
|
104 |
+
"This is another test sentence. And a third one.",
|
105 |
+
]
|
106 |
+
assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
|
107 |
+
assert split_text(" ", 10) == []
|
108 |
+
assert split_text("a", 10) == ["a"]
|
109 |
+
|
110 |
+
text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
|
111 |
+
assert split_text(text, 50) == [
|
112 |
+
"This is a test sentence with only commas,",
|
113 |
+
"and no dots, and no exclamation marks,",
|
114 |
+
"and no question marks, and no newlines.",
|
115 |
+
]
|
116 |
+
|
117 |
+
text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
|
118 |
+
# First half split at " ", second half split at ","
|
119 |
+
assert split_text(text, 50) == [
|
120 |
+
"This is a test sentence This is a test sentence",
|
121 |
+
"This is a test sentence. This is a test sentence,",
|
122 |
+
"This is a test sentence, This is a test sentence.",
|
123 |
+
]
|
124 |
+
|
125 |
+
text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
|
126 |
+
assert split_text(text, 50) == [
|
127 |
+
"这是一段很长的中文文本,",
|
128 |
+
"而且没有句号,也没有感叹号,",
|
129 |
+
"也没有问号,也没有换行符.",
|
130 |
+
]
|
fish_speech/tokenizer.py
CHANGED
@@ -1,152 +1,179 @@
|
|
1 |
-
import base64
|
2 |
-
import json
|
3 |
-
import logging
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
r"
|
15 |
-
r"
|
16 |
-
r"\p{N}",
|
17 |
-
r"
|
18 |
-
r"\
|
19 |
-
r"\s+
|
20 |
-
r"\s+",
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
return
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import re
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import tiktoken
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
# This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
|
12 |
+
FISH_TIKTOKEN_PATTERN = "|".join(
|
13 |
+
[
|
14 |
+
r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
|
15 |
+
r"\p{P}",
|
16 |
+
r"[^\r\n\p{L}\p{N}]?\p{L}+",
|
17 |
+
r"\p{N}",
|
18 |
+
r" ?[^\s\p{L}\p{N}]+[\r\n]*",
|
19 |
+
r"\s*[\r\n]+",
|
20 |
+
r"\s+(\?!\S)",
|
21 |
+
r"\s+",
|
22 |
+
]
|
23 |
+
)
|
24 |
+
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
25 |
+
|
26 |
+
BOS_TOKEN = "<|begin_of_text|>"
|
27 |
+
EOS_TOKEN = "<|end_of_text|>"
|
28 |
+
PAD_TOKEN = "<|pad|>"
|
29 |
+
IM_START_TOKEN = "<|im_start|>"
|
30 |
+
IM_END_TOKEN = "<|im_end|>"
|
31 |
+
PHONEME_START_TOKEN = "<|phoneme_start|>"
|
32 |
+
PHONEME_END_TOKEN = "<|phoneme_end|>"
|
33 |
+
TOOL_CALL_START_TOKEN = "<|tool_call_start|>"
|
34 |
+
TOOL_CALL_END_TOKEN = "<|tool_call_end|>"
|
35 |
+
|
36 |
+
MODALITY_TEXT_TOKEN = "<|text|>"
|
37 |
+
MODALITY_VOICE_TOKEN = "<|voice|>"
|
38 |
+
MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
|
39 |
+
AUDIO_START_TOKEN = "<|audio_start|>"
|
40 |
+
AUDIO_END_TOKEN = "<|audio_end|>"
|
41 |
+
AUDIO_EMBED_TOKEN = "<|audio|>"
|
42 |
+
MODALITY_TOKENS = {
|
43 |
+
"text": MODALITY_TEXT_TOKEN,
|
44 |
+
"voice": MODALITY_VOICE_TOKEN,
|
45 |
+
"interleave": MODALITY_INTERLEAVE_TOKEN,
|
46 |
+
}
|
47 |
+
|
48 |
+
SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
|
49 |
+
SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
|
50 |
+
|
51 |
+
# Warning: when you add a new special token, you should only add it to the end of the list.
|
52 |
+
ALL_SPECIAL_TOKENS = [
|
53 |
+
BOS_TOKEN,
|
54 |
+
EOS_TOKEN,
|
55 |
+
PAD_TOKEN,
|
56 |
+
IM_START_TOKEN,
|
57 |
+
IM_END_TOKEN,
|
58 |
+
PHONEME_START_TOKEN,
|
59 |
+
PHONEME_END_TOKEN,
|
60 |
+
TOOL_CALL_START_TOKEN,
|
61 |
+
TOOL_CALL_END_TOKEN,
|
62 |
+
MODALITY_TEXT_TOKEN,
|
63 |
+
MODALITY_VOICE_TOKEN,
|
64 |
+
MODALITY_INTERLEAVE_TOKEN,
|
65 |
+
AUDIO_START_TOKEN,
|
66 |
+
AUDIO_END_TOKEN,
|
67 |
+
AUDIO_EMBED_TOKEN,
|
68 |
+
*SEMANTIC_TOKENS,
|
69 |
+
]
|
70 |
+
|
71 |
+
|
72 |
+
class FishTokenizer:
|
73 |
+
def __init__(
|
74 |
+
self, model_path: str, special_tokens: list[str] = ALL_SPECIAL_TOKENS
|
75 |
+
) -> None:
|
76 |
+
mergeable_ranks = self.load_tiktoken_bpe(model_path)
|
77 |
+
special_token_begin = len(mergeable_ranks)
|
78 |
+
self.all_special_tokens_with_ids = {
|
79 |
+
token: special_token_begin + i for i, token in enumerate(special_tokens)
|
80 |
+
}
|
81 |
+
|
82 |
+
self.semantic_id_to_token_id = {}
|
83 |
+
end_idx = 0
|
84 |
+
for token in special_tokens:
|
85 |
+
if token.startswith("<|semantic:"):
|
86 |
+
idx = int(re.match(r"<\|semantic:(\d+)\|>", token).group(1))
|
87 |
+
self.semantic_id_to_token_id[idx] = self.all_special_tokens_with_ids[
|
88 |
+
token
|
89 |
+
]
|
90 |
+
|
91 |
+
if idx > end_idx:
|
92 |
+
end_idx = idx
|
93 |
+
|
94 |
+
self.semantic_begin_id = self.semantic_id_to_token_id[0]
|
95 |
+
self.semantic_end_id = self.semantic_id_to_token_id[end_idx]
|
96 |
+
|
97 |
+
self.tkt_model = tiktoken.core.Encoding(
|
98 |
+
name=Path(model_path).stem,
|
99 |
+
pat_str=FISH_TIKTOKEN_PATTERN,
|
100 |
+
mergeable_ranks=mergeable_ranks,
|
101 |
+
special_tokens=self.all_special_tokens_with_ids,
|
102 |
+
)
|
103 |
+
|
104 |
+
@property
|
105 |
+
def vocab_size(self):
|
106 |
+
return len(self.tkt_model._mergeable_ranks)
|
107 |
+
|
108 |
+
@property
|
109 |
+
def num_special_tokens(self):
|
110 |
+
return len(self.all_special_tokens_with_ids)
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
|
114 |
+
data = {}
|
115 |
+
for line in open(tiktoken_bpe_file).read().splitlines():
|
116 |
+
if not line:
|
117 |
+
continue
|
118 |
+
token, rank = line.split()
|
119 |
+
if token == "=":
|
120 |
+
continue
|
121 |
+
data[base64.b64decode(token)] = int(rank)
|
122 |
+
return data
|
123 |
+
|
124 |
+
def get_token_id(self, token: str) -> int:
|
125 |
+
return self.all_special_tokens_with_ids[token]
|
126 |
+
|
127 |
+
def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
|
128 |
+
assert isinstance(s, str)
|
129 |
+
|
130 |
+
subs = []
|
131 |
+
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
|
132 |
+
subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
|
133 |
+
|
134 |
+
if allowed_special is True:
|
135 |
+
allowed_special = self.tkt_model.special_tokens_set
|
136 |
+
elif allowed_special is False:
|
137 |
+
allowed_special = set()
|
138 |
+
|
139 |
+
return sum(
|
140 |
+
self.tkt_model.encode_batch(
|
141 |
+
subs, allowed_special=allowed_special, disallowed_special=set()
|
142 |
+
),
|
143 |
+
start=[],
|
144 |
+
)
|
145 |
+
|
146 |
+
def decode(self, tokens: list[int]) -> str:
|
147 |
+
return self.tkt_model.decode(tokens)
|
148 |
+
|
149 |
+
def save_pretrained(self, path: str):
|
150 |
+
path = Path(path)
|
151 |
+
path.mkdir(parents=True, exist_ok=True)
|
152 |
+
|
153 |
+
with open(path / "tokenizer.tiktoken", "w") as f:
|
154 |
+
for token, rank in self.tkt_model._mergeable_ranks.items():
|
155 |
+
a = base64.b64encode(token).decode()
|
156 |
+
if a == "":
|
157 |
+
a = "="
|
158 |
+
f.write(f"{a} {rank}\n")
|
159 |
+
|
160 |
+
with open(path / "special_tokens.json", "w") as f:
|
161 |
+
json.dump(
|
162 |
+
self.all_special_tokens_with_ids,
|
163 |
+
f,
|
164 |
+
indent=2,
|
165 |
+
ensure_ascii=False,
|
166 |
+
)
|
167 |
+
|
168 |
+
@staticmethod
|
169 |
+
def from_pretrained(path: str):
|
170 |
+
special_tokens_path = Path(path) / "special_tokens.json"
|
171 |
+
if special_tokens_path.exists():
|
172 |
+
with open(special_tokens_path) as f:
|
173 |
+
all_special_tokens_with_ids = json.load(f)
|
174 |
+
else:
|
175 |
+
all_special_tokens_with_ids = ALL_SPECIAL_TOKENS
|
176 |
+
|
177 |
+
return FishTokenizer(
|
178 |
+
Path(path) / "tokenizer.tiktoken", all_special_tokens_with_ids
|
179 |
+
)
|
fish_speech/utils/__init__.py
CHANGED
@@ -1,24 +1,24 @@
|
|
1 |
-
from .braceexpand import braceexpand
|
2 |
-
from .context import autocast_exclude_mps
|
3 |
-
from .file import get_latest_checkpoint
|
4 |
-
from .instantiators import instantiate_callbacks, instantiate_loggers
|
5 |
-
from .logger import RankedLogger
|
6 |
-
from .logging_utils import log_hyperparameters
|
7 |
-
from .rich_utils import enforce_tags, print_config_tree
|
8 |
-
from .utils import extras, get_metric_value, set_seed, task_wrapper
|
9 |
-
|
10 |
-
__all__ = [
|
11 |
-
"enforce_tags",
|
12 |
-
"extras",
|
13 |
-
"get_metric_value",
|
14 |
-
"RankedLogger",
|
15 |
-
"instantiate_callbacks",
|
16 |
-
"instantiate_loggers",
|
17 |
-
"log_hyperparameters",
|
18 |
-
"print_config_tree",
|
19 |
-
"task_wrapper",
|
20 |
-
"braceexpand",
|
21 |
-
"get_latest_checkpoint",
|
22 |
-
"autocast_exclude_mps",
|
23 |
-
"set_seed",
|
24 |
-
]
|
|
|
1 |
+
from .braceexpand import braceexpand
|
2 |
+
from .context import autocast_exclude_mps
|
3 |
+
from .file import get_latest_checkpoint
|
4 |
+
from .instantiators import instantiate_callbacks, instantiate_loggers
|
5 |
+
from .logger import RankedLogger
|
6 |
+
from .logging_utils import log_hyperparameters
|
7 |
+
from .rich_utils import enforce_tags, print_config_tree
|
8 |
+
from .utils import extras, get_metric_value, set_seed, task_wrapper
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"enforce_tags",
|
12 |
+
"extras",
|
13 |
+
"get_metric_value",
|
14 |
+
"RankedLogger",
|
15 |
+
"instantiate_callbacks",
|
16 |
+
"instantiate_loggers",
|
17 |
+
"log_hyperparameters",
|
18 |
+
"print_config_tree",
|
19 |
+
"task_wrapper",
|
20 |
+
"braceexpand",
|
21 |
+
"get_latest_checkpoint",
|
22 |
+
"autocast_exclude_mps",
|
23 |
+
"set_seed",
|
24 |
+
]
|
fish_speech/utils/braceexpand.py
CHANGED
@@ -1,217 +1,217 @@
|
|
1 |
-
"""
|
2 |
-
Bash-style brace expansion
|
3 |
-
Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
|
4 |
-
License: MIT
|
5 |
-
"""
|
6 |
-
|
7 |
-
import re
|
8 |
-
import string
|
9 |
-
from itertools import chain, product
|
10 |
-
from typing import Iterable, Iterator, Optional
|
11 |
-
|
12 |
-
__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
|
13 |
-
|
14 |
-
|
15 |
-
class UnbalancedBracesError(ValueError):
|
16 |
-
pass
|
17 |
-
|
18 |
-
|
19 |
-
alphabet = string.ascii_uppercase + string.ascii_lowercase
|
20 |
-
|
21 |
-
int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
|
22 |
-
char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
|
23 |
-
escape_re = re.compile(r"\\(.)")
|
24 |
-
|
25 |
-
|
26 |
-
def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
|
27 |
-
"""braceexpand(pattern) -> iterator over generated strings
|
28 |
-
|
29 |
-
Returns an iterator over the strings resulting from brace expansion
|
30 |
-
of pattern. This function implements Brace Expansion as described in
|
31 |
-
bash(1), with the following limitations:
|
32 |
-
|
33 |
-
* A pattern containing unbalanced braces will raise an
|
34 |
-
UnbalancedBracesError exception. In bash, unbalanced braces will either
|
35 |
-
be partly expanded or ignored.
|
36 |
-
|
37 |
-
* A mixed-case character range like '{Z..a}' or '{a..Z}' will not
|
38 |
-
include the characters '[]^_`' between 'Z' and 'a'.
|
39 |
-
|
40 |
-
When escape is True (the default), characters in pattern can be
|
41 |
-
prefixed with a backslash to cause them not to be interpreted as
|
42 |
-
special characters for brace expansion (such as '{', '}', ',').
|
43 |
-
To pass through a a literal backslash, double it ('\\\\').
|
44 |
-
|
45 |
-
When escape is False, backslashes in pattern have no special
|
46 |
-
meaning and will be preserved in the output.
|
47 |
-
|
48 |
-
Examples:
|
49 |
-
|
50 |
-
>>> from braceexpand import braceexpand
|
51 |
-
|
52 |
-
# Integer range
|
53 |
-
>>> list(braceexpand('item{1..3}'))
|
54 |
-
['item1', 'item2', 'item3']
|
55 |
-
|
56 |
-
# Character range
|
57 |
-
>>> list(braceexpand('{a..c}'))
|
58 |
-
['a', 'b', 'c']
|
59 |
-
|
60 |
-
# Sequence
|
61 |
-
>>> list(braceexpand('index.html{,.backup}'))
|
62 |
-
['index.html', 'index.html.backup']
|
63 |
-
|
64 |
-
# Nested patterns
|
65 |
-
>>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
|
66 |
-
['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
|
67 |
-
|
68 |
-
# Prefixing an integer with zero causes all numbers to be padded to
|
69 |
-
# the same width.
|
70 |
-
>>> list(braceexpand('{07..10}'))
|
71 |
-
['07', '08', '09', '10']
|
72 |
-
|
73 |
-
# An optional increment can be specified for ranges.
|
74 |
-
>>> list(braceexpand('{a..g..2}'))
|
75 |
-
['a', 'c', 'e', 'g']
|
76 |
-
|
77 |
-
# Ranges can go in both directions.
|
78 |
-
>>> list(braceexpand('{4..1}'))
|
79 |
-
['4', '3', '2', '1']
|
80 |
-
|
81 |
-
# Numbers can be negative
|
82 |
-
>>> list(braceexpand('{2..-1}'))
|
83 |
-
['2', '1', '0', '-1']
|
84 |
-
|
85 |
-
# Unbalanced braces raise an exception.
|
86 |
-
>>> list(braceexpand('{1{2,3}'))
|
87 |
-
Traceback (most recent call last):
|
88 |
-
...
|
89 |
-
UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
|
90 |
-
|
91 |
-
# By default, the backslash is the escape character.
|
92 |
-
>>> list(braceexpand(r'{1\\{2,3}'))
|
93 |
-
['1{2', '3']
|
94 |
-
|
95 |
-
# Setting 'escape' to False disables backslash escaping.
|
96 |
-
>>> list(braceexpand(r'\\{1,2}', escape=False))
|
97 |
-
['\\\\1', '\\\\2']
|
98 |
-
|
99 |
-
"""
|
100 |
-
return (
|
101 |
-
escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
|
102 |
-
)
|
103 |
-
|
104 |
-
|
105 |
-
def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
|
106 |
-
start = 0
|
107 |
-
pos = 0
|
108 |
-
bracketdepth = 0
|
109 |
-
items: list[Iterable[str]] = []
|
110 |
-
|
111 |
-
# print 'pattern:', pattern
|
112 |
-
while pos < len(pattern):
|
113 |
-
if escape and pattern[pos] == "\\":
|
114 |
-
pos += 2
|
115 |
-
continue
|
116 |
-
elif pattern[pos] == "{":
|
117 |
-
if bracketdepth == 0 and pos > start:
|
118 |
-
# print 'literal:', pattern[start:pos]
|
119 |
-
items.append([pattern[start:pos]])
|
120 |
-
start = pos
|
121 |
-
bracketdepth += 1
|
122 |
-
elif pattern[pos] == "}":
|
123 |
-
bracketdepth -= 1
|
124 |
-
if bracketdepth == 0:
|
125 |
-
# print 'expression:', pattern[start+1:pos]
|
126 |
-
expr = pattern[start + 1 : pos]
|
127 |
-
item = parse_expression(expr, escape)
|
128 |
-
if item is None: # not a range or sequence
|
129 |
-
items.extend([["{"], parse_pattern(expr, escape), ["}"]])
|
130 |
-
else:
|
131 |
-
items.append(item)
|
132 |
-
start = pos + 1 # skip the closing brace
|
133 |
-
pos += 1
|
134 |
-
|
135 |
-
if bracketdepth != 0: # unbalanced braces
|
136 |
-
raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
|
137 |
-
|
138 |
-
if start < pos:
|
139 |
-
items.append([pattern[start:]])
|
140 |
-
|
141 |
-
return ("".join(item) for item in product(*items))
|
142 |
-
|
143 |
-
|
144 |
-
def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
|
145 |
-
int_range_match = int_range_re.match(expr)
|
146 |
-
if int_range_match:
|
147 |
-
return make_int_range(*int_range_match.groups())
|
148 |
-
|
149 |
-
char_range_match = char_range_re.match(expr)
|
150 |
-
if char_range_match:
|
151 |
-
return make_char_range(*char_range_match.groups())
|
152 |
-
|
153 |
-
return parse_sequence(expr, escape)
|
154 |
-
|
155 |
-
|
156 |
-
def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
|
157 |
-
# sequence -> chain(*sequence_items)
|
158 |
-
start = 0
|
159 |
-
pos = 0
|
160 |
-
bracketdepth = 0
|
161 |
-
items: list[Iterable[str]] = []
|
162 |
-
|
163 |
-
# print 'sequence:', seq
|
164 |
-
while pos < len(seq):
|
165 |
-
if escape and seq[pos] == "\\":
|
166 |
-
pos += 2
|
167 |
-
continue
|
168 |
-
elif seq[pos] == "{":
|
169 |
-
bracketdepth += 1
|
170 |
-
elif seq[pos] == "}":
|
171 |
-
bracketdepth -= 1
|
172 |
-
elif seq[pos] == "," and bracketdepth == 0:
|
173 |
-
items.append(parse_pattern(seq[start:pos], escape))
|
174 |
-
start = pos + 1 # skip the comma
|
175 |
-
pos += 1
|
176 |
-
|
177 |
-
if bracketdepth != 0:
|
178 |
-
raise UnbalancedBracesError
|
179 |
-
if not items:
|
180 |
-
return None
|
181 |
-
|
182 |
-
# part after the last comma (may be the empty string)
|
183 |
-
items.append(parse_pattern(seq[start:], escape))
|
184 |
-
return chain(*items)
|
185 |
-
|
186 |
-
|
187 |
-
def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
|
188 |
-
if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
|
189 |
-
padding = max(len(left), len(right))
|
190 |
-
else:
|
191 |
-
padding = 0
|
192 |
-
step = (int(incr) or 1) if incr else 1
|
193 |
-
start = int(left)
|
194 |
-
end = int(right)
|
195 |
-
r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
|
196 |
-
fmt = "%0{}d".format(padding)
|
197 |
-
return (fmt % i for i in r)
|
198 |
-
|
199 |
-
|
200 |
-
def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
|
201 |
-
step = (int(incr) or 1) if incr else 1
|
202 |
-
start = alphabet.index(left)
|
203 |
-
end = alphabet.index(right)
|
204 |
-
if start < end:
|
205 |
-
return alphabet[start : end + 1 : step]
|
206 |
-
else:
|
207 |
-
end = end or -len(alphabet)
|
208 |
-
return alphabet[start : end - 1 : -step]
|
209 |
-
|
210 |
-
|
211 |
-
if __name__ == "__main__":
|
212 |
-
import doctest
|
213 |
-
import sys
|
214 |
-
|
215 |
-
failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
|
216 |
-
if failed:
|
217 |
-
sys.exit(1)
|
|
|
1 |
+
"""
|
2 |
+
Bash-style brace expansion
|
3 |
+
Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
|
4 |
+
License: MIT
|
5 |
+
"""
|
6 |
+
|
7 |
+
import re
|
8 |
+
import string
|
9 |
+
from itertools import chain, product
|
10 |
+
from typing import Iterable, Iterator, Optional
|
11 |
+
|
12 |
+
__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
|
13 |
+
|
14 |
+
|
15 |
+
class UnbalancedBracesError(ValueError):
|
16 |
+
pass
|
17 |
+
|
18 |
+
|
19 |
+
alphabet = string.ascii_uppercase + string.ascii_lowercase
|
20 |
+
|
21 |
+
int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
|
22 |
+
char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
|
23 |
+
escape_re = re.compile(r"\\(.)")
|
24 |
+
|
25 |
+
|
26 |
+
def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
|
27 |
+
"""braceexpand(pattern) -> iterator over generated strings
|
28 |
+
|
29 |
+
Returns an iterator over the strings resulting from brace expansion
|
30 |
+
of pattern. This function implements Brace Expansion as described in
|
31 |
+
bash(1), with the following limitations:
|
32 |
+
|
33 |
+
* A pattern containing unbalanced braces will raise an
|
34 |
+
UnbalancedBracesError exception. In bash, unbalanced braces will either
|
35 |
+
be partly expanded or ignored.
|
36 |
+
|
37 |
+
* A mixed-case character range like '{Z..a}' or '{a..Z}' will not
|
38 |
+
include the characters '[]^_`' between 'Z' and 'a'.
|
39 |
+
|
40 |
+
When escape is True (the default), characters in pattern can be
|
41 |
+
prefixed with a backslash to cause them not to be interpreted as
|
42 |
+
special characters for brace expansion (such as '{', '}', ',').
|
43 |
+
To pass through a a literal backslash, double it ('\\\\').
|
44 |
+
|
45 |
+
When escape is False, backslashes in pattern have no special
|
46 |
+
meaning and will be preserved in the output.
|
47 |
+
|
48 |
+
Examples:
|
49 |
+
|
50 |
+
>>> from braceexpand import braceexpand
|
51 |
+
|
52 |
+
# Integer range
|
53 |
+
>>> list(braceexpand('item{1..3}'))
|
54 |
+
['item1', 'item2', 'item3']
|
55 |
+
|
56 |
+
# Character range
|
57 |
+
>>> list(braceexpand('{a..c}'))
|
58 |
+
['a', 'b', 'c']
|
59 |
+
|
60 |
+
# Sequence
|
61 |
+
>>> list(braceexpand('index.html{,.backup}'))
|
62 |
+
['index.html', 'index.html.backup']
|
63 |
+
|
64 |
+
# Nested patterns
|
65 |
+
>>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
|
66 |
+
['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
|
67 |
+
|
68 |
+
# Prefixing an integer with zero causes all numbers to be padded to
|
69 |
+
# the same width.
|
70 |
+
>>> list(braceexpand('{07..10}'))
|
71 |
+
['07', '08', '09', '10']
|
72 |
+
|
73 |
+
# An optional increment can be specified for ranges.
|
74 |
+
>>> list(braceexpand('{a..g..2}'))
|
75 |
+
['a', 'c', 'e', 'g']
|
76 |
+
|
77 |
+
# Ranges can go in both directions.
|
78 |
+
>>> list(braceexpand('{4..1}'))
|
79 |
+
['4', '3', '2', '1']
|
80 |
+
|
81 |
+
# Numbers can be negative
|
82 |
+
>>> list(braceexpand('{2..-1}'))
|
83 |
+
['2', '1', '0', '-1']
|
84 |
+
|
85 |
+
# Unbalanced braces raise an exception.
|
86 |
+
>>> list(braceexpand('{1{2,3}'))
|
87 |
+
Traceback (most recent call last):
|
88 |
+
...
|
89 |
+
UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
|
90 |
+
|
91 |
+
# By default, the backslash is the escape character.
|
92 |
+
>>> list(braceexpand(r'{1\\{2,3}'))
|
93 |
+
['1{2', '3']
|
94 |
+
|
95 |
+
# Setting 'escape' to False disables backslash escaping.
|
96 |
+
>>> list(braceexpand(r'\\{1,2}', escape=False))
|
97 |
+
['\\\\1', '\\\\2']
|
98 |
+
|
99 |
+
"""
|
100 |
+
return (
|
101 |
+
escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
|
106 |
+
start = 0
|
107 |
+
pos = 0
|
108 |
+
bracketdepth = 0
|
109 |
+
items: list[Iterable[str]] = []
|
110 |
+
|
111 |
+
# print 'pattern:', pattern
|
112 |
+
while pos < len(pattern):
|
113 |
+
if escape and pattern[pos] == "\\":
|
114 |
+
pos += 2
|
115 |
+
continue
|
116 |
+
elif pattern[pos] == "{":
|
117 |
+
if bracketdepth == 0 and pos > start:
|
118 |
+
# print 'literal:', pattern[start:pos]
|
119 |
+
items.append([pattern[start:pos]])
|
120 |
+
start = pos
|
121 |
+
bracketdepth += 1
|
122 |
+
elif pattern[pos] == "}":
|
123 |
+
bracketdepth -= 1
|
124 |
+
if bracketdepth == 0:
|
125 |
+
# print 'expression:', pattern[start+1:pos]
|
126 |
+
expr = pattern[start + 1 : pos]
|
127 |
+
item = parse_expression(expr, escape)
|
128 |
+
if item is None: # not a range or sequence
|
129 |
+
items.extend([["{"], parse_pattern(expr, escape), ["}"]])
|
130 |
+
else:
|
131 |
+
items.append(item)
|
132 |
+
start = pos + 1 # skip the closing brace
|
133 |
+
pos += 1
|
134 |
+
|
135 |
+
if bracketdepth != 0: # unbalanced braces
|
136 |
+
raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
|
137 |
+
|
138 |
+
if start < pos:
|
139 |
+
items.append([pattern[start:]])
|
140 |
+
|
141 |
+
return ("".join(item) for item in product(*items))
|
142 |
+
|
143 |
+
|
144 |
+
def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
|
145 |
+
int_range_match = int_range_re.match(expr)
|
146 |
+
if int_range_match:
|
147 |
+
return make_int_range(*int_range_match.groups())
|
148 |
+
|
149 |
+
char_range_match = char_range_re.match(expr)
|
150 |
+
if char_range_match:
|
151 |
+
return make_char_range(*char_range_match.groups())
|
152 |
+
|
153 |
+
return parse_sequence(expr, escape)
|
154 |
+
|
155 |
+
|
156 |
+
def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
|
157 |
+
# sequence -> chain(*sequence_items)
|
158 |
+
start = 0
|
159 |
+
pos = 0
|
160 |
+
bracketdepth = 0
|
161 |
+
items: list[Iterable[str]] = []
|
162 |
+
|
163 |
+
# print 'sequence:', seq
|
164 |
+
while pos < len(seq):
|
165 |
+
if escape and seq[pos] == "\\":
|
166 |
+
pos += 2
|
167 |
+
continue
|
168 |
+
elif seq[pos] == "{":
|
169 |
+
bracketdepth += 1
|
170 |
+
elif seq[pos] == "}":
|
171 |
+
bracketdepth -= 1
|
172 |
+
elif seq[pos] == "," and bracketdepth == 0:
|
173 |
+
items.append(parse_pattern(seq[start:pos], escape))
|
174 |
+
start = pos + 1 # skip the comma
|
175 |
+
pos += 1
|
176 |
+
|
177 |
+
if bracketdepth != 0:
|
178 |
+
raise UnbalancedBracesError
|
179 |
+
if not items:
|
180 |
+
return None
|
181 |
+
|
182 |
+
# part after the last comma (may be the empty string)
|
183 |
+
items.append(parse_pattern(seq[start:], escape))
|
184 |
+
return chain(*items)
|
185 |
+
|
186 |
+
|
187 |
+
def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
|
188 |
+
if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
|
189 |
+
padding = max(len(left), len(right))
|
190 |
+
else:
|
191 |
+
padding = 0
|
192 |
+
step = (int(incr) or 1) if incr else 1
|
193 |
+
start = int(left)
|
194 |
+
end = int(right)
|
195 |
+
r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
|
196 |
+
fmt = "%0{}d".format(padding)
|
197 |
+
return (fmt % i for i in r)
|
198 |
+
|
199 |
+
|
200 |
+
def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
|
201 |
+
step = (int(incr) or 1) if incr else 1
|
202 |
+
start = alphabet.index(left)
|
203 |
+
end = alphabet.index(right)
|
204 |
+
if start < end:
|
205 |
+
return alphabet[start : end + 1 : step]
|
206 |
+
else:
|
207 |
+
end = end or -len(alphabet)
|
208 |
+
return alphabet[start : end - 1 : -step]
|
209 |
+
|
210 |
+
|
211 |
+
if __name__ == "__main__":
|
212 |
+
import doctest
|
213 |
+
import sys
|
214 |
+
|
215 |
+
failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
|
216 |
+
if failed:
|
217 |
+
sys.exit(1)
|
fish_speech/utils/context.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
from contextlib import nullcontext
|
2 |
-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
def autocast_exclude_mps(
|
7 |
-
device_type: str, dtype: torch.dtype
|
8 |
-
) -> nullcontext | torch.autocast:
|
9 |
-
return (
|
10 |
-
nullcontext()
|
11 |
-
if torch.backends.mps.is_available()
|
12 |
-
else torch.autocast(device_type, dtype)
|
13 |
-
)
|
|
|
1 |
+
from contextlib import nullcontext
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def autocast_exclude_mps(
|
7 |
+
device_type: str, dtype: torch.dtype
|
8 |
+
) -> nullcontext | torch.autocast:
|
9 |
+
return (
|
10 |
+
nullcontext()
|
11 |
+
if torch.backends.mps.is_available()
|
12 |
+
else torch.autocast(device_type, dtype)
|
13 |
+
)
|
fish_speech/utils/file.py
CHANGED
@@ -1,16 +1,139 @@
|
|
1 |
-
import os
|
2 |
-
from pathlib import Path
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
from loguru import logger
|
6 |
+
from natsort import natsorted
|
7 |
+
|
8 |
+
AUDIO_EXTENSIONS = {
|
9 |
+
".mp3",
|
10 |
+
".wav",
|
11 |
+
".flac",
|
12 |
+
".ogg",
|
13 |
+
".m4a",
|
14 |
+
".wma",
|
15 |
+
".aac",
|
16 |
+
".aiff",
|
17 |
+
".aif",
|
18 |
+
".aifc",
|
19 |
+
}
|
20 |
+
|
21 |
+
VIDEO_EXTENSIONS = {
|
22 |
+
".mp4",
|
23 |
+
".avi",
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def get_latest_checkpoint(path: Path | str) -> Path | None:
|
28 |
+
# Find the latest checkpoint
|
29 |
+
ckpt_dir = Path(path)
|
30 |
+
|
31 |
+
if ckpt_dir.exists() is False:
|
32 |
+
return None
|
33 |
+
|
34 |
+
ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
|
35 |
+
if len(ckpts) == 0:
|
36 |
+
return None
|
37 |
+
|
38 |
+
return ckpts[-1]
|
39 |
+
|
40 |
+
|
41 |
+
def audio_to_bytes(file_path):
|
42 |
+
if not file_path or not Path(file_path).exists():
|
43 |
+
return None
|
44 |
+
with open(file_path, "rb") as wav_file:
|
45 |
+
wav = wav_file.read()
|
46 |
+
return wav
|
47 |
+
|
48 |
+
|
49 |
+
def read_ref_text(ref_text):
|
50 |
+
path = Path(ref_text)
|
51 |
+
if path.exists() and path.is_file():
|
52 |
+
with path.open("r", encoding="utf-8") as file:
|
53 |
+
return file.read()
|
54 |
+
return ref_text
|
55 |
+
|
56 |
+
|
57 |
+
def list_files(
|
58 |
+
path: Union[Path, str],
|
59 |
+
extensions: set[str] = set(),
|
60 |
+
recursive: bool = False,
|
61 |
+
sort: bool = True,
|
62 |
+
) -> list[Path]:
|
63 |
+
"""List files in a directory.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
path (Path): Path to the directory.
|
67 |
+
extensions (set, optional): Extensions to filter. Defaults to None.
|
68 |
+
recursive (bool, optional): Whether to search recursively. Defaults to False.
|
69 |
+
sort (bool, optional): Whether to sort the files. Defaults to True.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
list: List of files.
|
73 |
+
"""
|
74 |
+
|
75 |
+
if isinstance(path, str):
|
76 |
+
path = Path(path)
|
77 |
+
|
78 |
+
if not path.exists():
|
79 |
+
raise FileNotFoundError(f"Directory {path} does not exist.")
|
80 |
+
|
81 |
+
files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
|
82 |
+
|
83 |
+
if sort:
|
84 |
+
files = natsorted(files)
|
85 |
+
|
86 |
+
return files
|
87 |
+
|
88 |
+
|
89 |
+
def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
|
90 |
+
"""
|
91 |
+
Load a Bert-VITS2 style filelist.
|
92 |
+
"""
|
93 |
+
|
94 |
+
files = set()
|
95 |
+
results = []
|
96 |
+
count_duplicated, count_not_found = 0, 0
|
97 |
+
|
98 |
+
LANGUAGE_TO_LANGUAGES = {
|
99 |
+
"zh": ["zh", "en"],
|
100 |
+
"jp": ["jp", "en"],
|
101 |
+
"en": ["en"],
|
102 |
+
}
|
103 |
+
|
104 |
+
with open(path, "r", encoding="utf-8") as f:
|
105 |
+
for line in f.readlines():
|
106 |
+
splits = line.strip().split("|", maxsplit=3)
|
107 |
+
if len(splits) != 4:
|
108 |
+
logger.warning(f"Invalid line: {line}")
|
109 |
+
continue
|
110 |
+
|
111 |
+
filename, speaker, language, text = splits
|
112 |
+
file = Path(filename)
|
113 |
+
language = language.strip().lower()
|
114 |
+
|
115 |
+
if language == "ja":
|
116 |
+
language = "jp"
|
117 |
+
|
118 |
+
assert language in ["zh", "jp", "en"], f"Invalid language {language}"
|
119 |
+
languages = LANGUAGE_TO_LANGUAGES[language]
|
120 |
+
|
121 |
+
if file in files:
|
122 |
+
logger.warning(f"Duplicated file: {file}")
|
123 |
+
count_duplicated += 1
|
124 |
+
continue
|
125 |
+
|
126 |
+
if not file.exists():
|
127 |
+
logger.warning(f"File not found: {file}")
|
128 |
+
count_not_found += 1
|
129 |
+
continue
|
130 |
+
|
131 |
+
results.append((file, speaker, languages, text))
|
132 |
+
|
133 |
+
if count_duplicated > 0:
|
134 |
+
logger.warning(f"Total duplicated files: {count_duplicated}")
|
135 |
+
|
136 |
+
if count_not_found > 0:
|
137 |
+
logger.warning(f"Total files not found: {count_not_found}")
|
138 |
+
|
139 |
+
return results
|
fish_speech/utils/instantiators.py
CHANGED
@@ -1,50 +1,50 @@
|
|
1 |
-
from typing import List
|
2 |
-
|
3 |
-
import hydra
|
4 |
-
from omegaconf import DictConfig
|
5 |
-
from pytorch_lightning import Callback
|
6 |
-
from pytorch_lightning.loggers import Logger
|
7 |
-
|
8 |
-
from .logger import RankedLogger
|
9 |
-
|
10 |
-
log = RankedLogger(__name__, rank_zero_only=True)
|
11 |
-
|
12 |
-
|
13 |
-
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
14 |
-
"""Instantiates callbacks from config."""
|
15 |
-
|
16 |
-
callbacks: List[Callback] = []
|
17 |
-
|
18 |
-
if not callbacks_cfg:
|
19 |
-
log.warning("No callback configs found! Skipping..")
|
20 |
-
return callbacks
|
21 |
-
|
22 |
-
if not isinstance(callbacks_cfg, DictConfig):
|
23 |
-
raise TypeError("Callbacks config must be a DictConfig!")
|
24 |
-
|
25 |
-
for _, cb_conf in callbacks_cfg.items():
|
26 |
-
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
27 |
-
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
28 |
-
callbacks.append(hydra.utils.instantiate(cb_conf))
|
29 |
-
|
30 |
-
return callbacks
|
31 |
-
|
32 |
-
|
33 |
-
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
34 |
-
"""Instantiates loggers from config."""
|
35 |
-
|
36 |
-
logger: List[Logger] = []
|
37 |
-
|
38 |
-
if not logger_cfg:
|
39 |
-
log.warning("No logger configs found! Skipping...")
|
40 |
-
return logger
|
41 |
-
|
42 |
-
if not isinstance(logger_cfg, DictConfig):
|
43 |
-
raise TypeError("Logger config must be a DictConfig!")
|
44 |
-
|
45 |
-
for _, lg_conf in logger_cfg.items():
|
46 |
-
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
47 |
-
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
48 |
-
logger.append(hydra.utils.instantiate(lg_conf))
|
49 |
-
|
50 |
-
return logger
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import hydra
|
4 |
+
from omegaconf import DictConfig
|
5 |
+
from pytorch_lightning import Callback
|
6 |
+
from pytorch_lightning.loggers import Logger
|
7 |
+
|
8 |
+
from .logger import RankedLogger
|
9 |
+
|
10 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
11 |
+
|
12 |
+
|
13 |
+
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
14 |
+
"""Instantiates callbacks from config."""
|
15 |
+
|
16 |
+
callbacks: List[Callback] = []
|
17 |
+
|
18 |
+
if not callbacks_cfg:
|
19 |
+
log.warning("No callback configs found! Skipping..")
|
20 |
+
return callbacks
|
21 |
+
|
22 |
+
if not isinstance(callbacks_cfg, DictConfig):
|
23 |
+
raise TypeError("Callbacks config must be a DictConfig!")
|
24 |
+
|
25 |
+
for _, cb_conf in callbacks_cfg.items():
|
26 |
+
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
27 |
+
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
28 |
+
callbacks.append(hydra.utils.instantiate(cb_conf))
|
29 |
+
|
30 |
+
return callbacks
|
31 |
+
|
32 |
+
|
33 |
+
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
34 |
+
"""Instantiates loggers from config."""
|
35 |
+
|
36 |
+
logger: List[Logger] = []
|
37 |
+
|
38 |
+
if not logger_cfg:
|
39 |
+
log.warning("No logger configs found! Skipping...")
|
40 |
+
return logger
|
41 |
+
|
42 |
+
if not isinstance(logger_cfg, DictConfig):
|
43 |
+
raise TypeError("Logger config must be a DictConfig!")
|
44 |
+
|
45 |
+
for _, lg_conf in logger_cfg.items():
|
46 |
+
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
47 |
+
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
48 |
+
logger.append(hydra.utils.instantiate(lg_conf))
|
49 |
+
|
50 |
+
return logger
|
fish_speech/utils/logger.py
CHANGED
@@ -1,55 +1,55 @@
|
|
1 |
-
import logging
|
2 |
-
from typing import Mapping, Optional
|
3 |
-
|
4 |
-
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
5 |
-
|
6 |
-
|
7 |
-
class RankedLogger(logging.LoggerAdapter):
|
8 |
-
"""A multi-GPU-friendly python command line logger."""
|
9 |
-
|
10 |
-
def __init__(
|
11 |
-
self,
|
12 |
-
name: str = __name__,
|
13 |
-
rank_zero_only: bool = True,
|
14 |
-
extra: Optional[Mapping[str, object]] = None,
|
15 |
-
) -> None:
|
16 |
-
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
|
17 |
-
with their rank prefixed in the log message.
|
18 |
-
|
19 |
-
:param name: The name of the logger. Default is ``__name__``.
|
20 |
-
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
|
21 |
-
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
|
22 |
-
"""
|
23 |
-
logger = logging.getLogger(name)
|
24 |
-
super().__init__(logger=logger, extra=extra)
|
25 |
-
self.rank_zero_only = rank_zero_only
|
26 |
-
|
27 |
-
def log(
|
28 |
-
self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
|
29 |
-
) -> None:
|
30 |
-
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
|
31 |
-
of the process it's being logged from. If `'rank'` is provided, then the log will only
|
32 |
-
occur on that rank/process.
|
33 |
-
|
34 |
-
:param level: The level to log at. Look at `logging.__init__.py` for more information.
|
35 |
-
:param msg: The message to log.
|
36 |
-
:param rank: The rank to log at.
|
37 |
-
:param args: Additional args to pass to the underlying logging function.
|
38 |
-
:param kwargs: Any additional keyword args to pass to the underlying logging function.
|
39 |
-
"""
|
40 |
-
if self.isEnabledFor(level):
|
41 |
-
msg, kwargs = self.process(msg, kwargs)
|
42 |
-
current_rank = getattr(rank_zero_only, "rank", None)
|
43 |
-
if current_rank is None:
|
44 |
-
raise RuntimeError(
|
45 |
-
"The `rank_zero_only.rank` needs to be set before use"
|
46 |
-
)
|
47 |
-
msg = rank_prefixed_message(msg, current_rank)
|
48 |
-
if self.rank_zero_only:
|
49 |
-
if current_rank == 0:
|
50 |
-
self.logger.log(level, msg, *args, **kwargs)
|
51 |
-
else:
|
52 |
-
if rank is None:
|
53 |
-
self.logger.log(level, msg, *args, **kwargs)
|
54 |
-
elif current_rank == rank:
|
55 |
-
self.logger.log(level, msg, *args, **kwargs)
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Mapping, Optional
|
3 |
+
|
4 |
+
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
5 |
+
|
6 |
+
|
7 |
+
class RankedLogger(logging.LoggerAdapter):
|
8 |
+
"""A multi-GPU-friendly python command line logger."""
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
name: str = __name__,
|
13 |
+
rank_zero_only: bool = True,
|
14 |
+
extra: Optional[Mapping[str, object]] = None,
|
15 |
+
) -> None:
|
16 |
+
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
|
17 |
+
with their rank prefixed in the log message.
|
18 |
+
|
19 |
+
:param name: The name of the logger. Default is ``__name__``.
|
20 |
+
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
|
21 |
+
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
|
22 |
+
"""
|
23 |
+
logger = logging.getLogger(name)
|
24 |
+
super().__init__(logger=logger, extra=extra)
|
25 |
+
self.rank_zero_only = rank_zero_only
|
26 |
+
|
27 |
+
def log(
|
28 |
+
self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
|
29 |
+
) -> None:
|
30 |
+
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
|
31 |
+
of the process it's being logged from. If `'rank'` is provided, then the log will only
|
32 |
+
occur on that rank/process.
|
33 |
+
|
34 |
+
:param level: The level to log at. Look at `logging.__init__.py` for more information.
|
35 |
+
:param msg: The message to log.
|
36 |
+
:param rank: The rank to log at.
|
37 |
+
:param args: Additional args to pass to the underlying logging function.
|
38 |
+
:param kwargs: Any additional keyword args to pass to the underlying logging function.
|
39 |
+
"""
|
40 |
+
if self.isEnabledFor(level):
|
41 |
+
msg, kwargs = self.process(msg, kwargs)
|
42 |
+
current_rank = getattr(rank_zero_only, "rank", None)
|
43 |
+
if current_rank is None:
|
44 |
+
raise RuntimeError(
|
45 |
+
"The `rank_zero_only.rank` needs to be set before use"
|
46 |
+
)
|
47 |
+
msg = rank_prefixed_message(msg, current_rank)
|
48 |
+
if self.rank_zero_only:
|
49 |
+
if current_rank == 0:
|
50 |
+
self.logger.log(level, msg, *args, **kwargs)
|
51 |
+
else:
|
52 |
+
if rank is None:
|
53 |
+
self.logger.log(level, msg, *args, **kwargs)
|
54 |
+
elif current_rank == rank:
|
55 |
+
self.logger.log(level, msg, *args, **kwargs)
|
fish_speech/utils/logging_utils.py
CHANGED
@@ -1,48 +1,48 @@
|
|
1 |
-
from lightning.pytorch.utilities import rank_zero_only
|
2 |
-
|
3 |
-
from fish_speech.utils import logger as log
|
4 |
-
|
5 |
-
|
6 |
-
@rank_zero_only
|
7 |
-
def log_hyperparameters(object_dict: dict) -> None:
|
8 |
-
"""Controls which config parts are saved by lightning loggers.
|
9 |
-
|
10 |
-
Additionally saves:
|
11 |
-
- Number of model parameters
|
12 |
-
"""
|
13 |
-
|
14 |
-
hparams = {}
|
15 |
-
|
16 |
-
cfg = object_dict["cfg"]
|
17 |
-
model = object_dict["model"]
|
18 |
-
trainer = object_dict["trainer"]
|
19 |
-
|
20 |
-
if not trainer.logger:
|
21 |
-
log.warning("Logger not found! Skipping hyperparameter logging...")
|
22 |
-
return
|
23 |
-
|
24 |
-
hparams["model"] = cfg["model"]
|
25 |
-
|
26 |
-
# save number of model parameters
|
27 |
-
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
28 |
-
hparams["model/params/trainable"] = sum(
|
29 |
-
p.numel() for p in model.parameters() if p.requires_grad
|
30 |
-
)
|
31 |
-
hparams["model/params/non_trainable"] = sum(
|
32 |
-
p.numel() for p in model.parameters() if not p.requires_grad
|
33 |
-
)
|
34 |
-
|
35 |
-
hparams["data"] = cfg["data"]
|
36 |
-
hparams["trainer"] = cfg["trainer"]
|
37 |
-
|
38 |
-
hparams["callbacks"] = cfg.get("callbacks")
|
39 |
-
hparams["extras"] = cfg.get("extras")
|
40 |
-
|
41 |
-
hparams["task_name"] = cfg.get("task_name")
|
42 |
-
hparams["tags"] = cfg.get("tags")
|
43 |
-
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
44 |
-
hparams["seed"] = cfg.get("seed")
|
45 |
-
|
46 |
-
# send hparams to all loggers
|
47 |
-
for logger in trainer.loggers:
|
48 |
-
logger.log_hyperparams(hparams)
|
|
|
1 |
+
from lightning.pytorch.utilities import rank_zero_only
|
2 |
+
|
3 |
+
from fish_speech.utils import logger as log
|
4 |
+
|
5 |
+
|
6 |
+
@rank_zero_only
|
7 |
+
def log_hyperparameters(object_dict: dict) -> None:
|
8 |
+
"""Controls which config parts are saved by lightning loggers.
|
9 |
+
|
10 |
+
Additionally saves:
|
11 |
+
- Number of model parameters
|
12 |
+
"""
|
13 |
+
|
14 |
+
hparams = {}
|
15 |
+
|
16 |
+
cfg = object_dict["cfg"]
|
17 |
+
model = object_dict["model"]
|
18 |
+
trainer = object_dict["trainer"]
|
19 |
+
|
20 |
+
if not trainer.logger:
|
21 |
+
log.warning("Logger not found! Skipping hyperparameter logging...")
|
22 |
+
return
|
23 |
+
|
24 |
+
hparams["model"] = cfg["model"]
|
25 |
+
|
26 |
+
# save number of model parameters
|
27 |
+
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
28 |
+
hparams["model/params/trainable"] = sum(
|
29 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
30 |
+
)
|
31 |
+
hparams["model/params/non_trainable"] = sum(
|
32 |
+
p.numel() for p in model.parameters() if not p.requires_grad
|
33 |
+
)
|
34 |
+
|
35 |
+
hparams["data"] = cfg["data"]
|
36 |
+
hparams["trainer"] = cfg["trainer"]
|
37 |
+
|
38 |
+
hparams["callbacks"] = cfg.get("callbacks")
|
39 |
+
hparams["extras"] = cfg.get("extras")
|
40 |
+
|
41 |
+
hparams["task_name"] = cfg.get("task_name")
|
42 |
+
hparams["tags"] = cfg.get("tags")
|
43 |
+
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
44 |
+
hparams["seed"] = cfg.get("seed")
|
45 |
+
|
46 |
+
# send hparams to all loggers
|
47 |
+
for logger in trainer.loggers:
|
48 |
+
logger.log_hyperparams(hparams)
|
fish_speech/utils/rich_utils.py
CHANGED
@@ -1,100 +1,100 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
from typing import Sequence
|
3 |
-
|
4 |
-
import rich
|
5 |
-
import rich.syntax
|
6 |
-
import rich.tree
|
7 |
-
from hydra.core.hydra_config import HydraConfig
|
8 |
-
from lightning.pytorch.utilities import rank_zero_only
|
9 |
-
from omegaconf import DictConfig, OmegaConf, open_dict
|
10 |
-
from rich.prompt import Prompt
|
11 |
-
|
12 |
-
from fish_speech.utils import logger as log
|
13 |
-
|
14 |
-
|
15 |
-
@rank_zero_only
|
16 |
-
def print_config_tree(
|
17 |
-
cfg: DictConfig,
|
18 |
-
print_order: Sequence[str] = (
|
19 |
-
"data",
|
20 |
-
"model",
|
21 |
-
"callbacks",
|
22 |
-
"logger",
|
23 |
-
"trainer",
|
24 |
-
"paths",
|
25 |
-
"extras",
|
26 |
-
),
|
27 |
-
resolve: bool = False,
|
28 |
-
save_to_file: bool = False,
|
29 |
-
) -> None:
|
30 |
-
"""Prints content of DictConfig using Rich library and its tree structure.
|
31 |
-
|
32 |
-
Args:
|
33 |
-
cfg (DictConfig): Configuration composed by Hydra.
|
34 |
-
print_order (Sequence[str], optional): Determines in what order config components are printed.
|
35 |
-
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
|
36 |
-
save_to_file (bool, optional): Whether to export config to the hydra output folder.
|
37 |
-
""" # noqa: E501
|
38 |
-
|
39 |
-
style = "dim"
|
40 |
-
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
41 |
-
|
42 |
-
queue = []
|
43 |
-
|
44 |
-
# add fields from `print_order` to queue
|
45 |
-
for field in print_order:
|
46 |
-
(
|
47 |
-
queue.append(field)
|
48 |
-
if field in cfg
|
49 |
-
else log.warning(
|
50 |
-
f"Field '{field}' not found in config. "
|
51 |
-
+ f"Skipping '{field}' config printing..."
|
52 |
-
)
|
53 |
-
)
|
54 |
-
|
55 |
-
# add all the other fields to queue (not specified in `print_order`)
|
56 |
-
for field in cfg:
|
57 |
-
if field not in queue:
|
58 |
-
queue.append(field)
|
59 |
-
|
60 |
-
# generate config tree from queue
|
61 |
-
for field in queue:
|
62 |
-
branch = tree.add(field, style=style, guide_style=style)
|
63 |
-
|
64 |
-
config_group = cfg[field]
|
65 |
-
if isinstance(config_group, DictConfig):
|
66 |
-
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
67 |
-
else:
|
68 |
-
branch_content = str(config_group)
|
69 |
-
|
70 |
-
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
71 |
-
|
72 |
-
# print config tree
|
73 |
-
rich.print(tree)
|
74 |
-
|
75 |
-
# save config tree to file
|
76 |
-
if save_to_file:
|
77 |
-
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
78 |
-
rich.print(tree, file=file)
|
79 |
-
|
80 |
-
|
81 |
-
@rank_zero_only
|
82 |
-
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
83 |
-
"""Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
|
84 |
-
|
85 |
-
if not cfg.get("tags"):
|
86 |
-
if "id" in HydraConfig().cfg.hydra.job:
|
87 |
-
raise ValueError("Specify tags before launching a multirun!")
|
88 |
-
|
89 |
-
log.warning("No tags provided in config. Prompting user to input tags...")
|
90 |
-
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
|
91 |
-
tags = [t.strip() for t in tags.split(",") if t != ""]
|
92 |
-
|
93 |
-
with open_dict(cfg):
|
94 |
-
cfg.tags = tags
|
95 |
-
|
96 |
-
log.info(f"Tags: {cfg.tags}")
|
97 |
-
|
98 |
-
if save_to_file:
|
99 |
-
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
100 |
-
rich.print(cfg.tags, file=file)
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Sequence
|
3 |
+
|
4 |
+
import rich
|
5 |
+
import rich.syntax
|
6 |
+
import rich.tree
|
7 |
+
from hydra.core.hydra_config import HydraConfig
|
8 |
+
from lightning.pytorch.utilities import rank_zero_only
|
9 |
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
10 |
+
from rich.prompt import Prompt
|
11 |
+
|
12 |
+
from fish_speech.utils import logger as log
|
13 |
+
|
14 |
+
|
15 |
+
@rank_zero_only
|
16 |
+
def print_config_tree(
|
17 |
+
cfg: DictConfig,
|
18 |
+
print_order: Sequence[str] = (
|
19 |
+
"data",
|
20 |
+
"model",
|
21 |
+
"callbacks",
|
22 |
+
"logger",
|
23 |
+
"trainer",
|
24 |
+
"paths",
|
25 |
+
"extras",
|
26 |
+
),
|
27 |
+
resolve: bool = False,
|
28 |
+
save_to_file: bool = False,
|
29 |
+
) -> None:
|
30 |
+
"""Prints content of DictConfig using Rich library and its tree structure.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
cfg (DictConfig): Configuration composed by Hydra.
|
34 |
+
print_order (Sequence[str], optional): Determines in what order config components are printed.
|
35 |
+
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
|
36 |
+
save_to_file (bool, optional): Whether to export config to the hydra output folder.
|
37 |
+
""" # noqa: E501
|
38 |
+
|
39 |
+
style = "dim"
|
40 |
+
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
41 |
+
|
42 |
+
queue = []
|
43 |
+
|
44 |
+
# add fields from `print_order` to queue
|
45 |
+
for field in print_order:
|
46 |
+
(
|
47 |
+
queue.append(field)
|
48 |
+
if field in cfg
|
49 |
+
else log.warning(
|
50 |
+
f"Field '{field}' not found in config. "
|
51 |
+
+ f"Skipping '{field}' config printing..."
|
52 |
+
)
|
53 |
+
)
|
54 |
+
|
55 |
+
# add all the other fields to queue (not specified in `print_order`)
|
56 |
+
for field in cfg:
|
57 |
+
if field not in queue:
|
58 |
+
queue.append(field)
|
59 |
+
|
60 |
+
# generate config tree from queue
|
61 |
+
for field in queue:
|
62 |
+
branch = tree.add(field, style=style, guide_style=style)
|
63 |
+
|
64 |
+
config_group = cfg[field]
|
65 |
+
if isinstance(config_group, DictConfig):
|
66 |
+
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
67 |
+
else:
|
68 |
+
branch_content = str(config_group)
|
69 |
+
|
70 |
+
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
71 |
+
|
72 |
+
# print config tree
|
73 |
+
rich.print(tree)
|
74 |
+
|
75 |
+
# save config tree to file
|
76 |
+
if save_to_file:
|
77 |
+
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
78 |
+
rich.print(tree, file=file)
|
79 |
+
|
80 |
+
|
81 |
+
@rank_zero_only
|
82 |
+
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
83 |
+
"""Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
|
84 |
+
|
85 |
+
if not cfg.get("tags"):
|
86 |
+
if "id" in HydraConfig().cfg.hydra.job:
|
87 |
+
raise ValueError("Specify tags before launching a multirun!")
|
88 |
+
|
89 |
+
log.warning("No tags provided in config. Prompting user to input tags...")
|
90 |
+
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
|
91 |
+
tags = [t.strip() for t in tags.split(",") if t != ""]
|
92 |
+
|
93 |
+
with open_dict(cfg):
|
94 |
+
cfg.tags = tags
|
95 |
+
|
96 |
+
log.info(f"Tags: {cfg.tags}")
|
97 |
+
|
98 |
+
if save_to_file:
|
99 |
+
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
100 |
+
rich.print(cfg.tags, file=file)
|