diff --git a/.gitattributes b/.gitattributes
index a56018dcef8a9d2bbdf0bae70eb19e573133741a..d0958581de2691a0397f245b105e854a8660a466 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -36,3 +36,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
examples/Japanese.wav filter=lfs diff=lfs merge=lfs -text
examples/Korean.wav filter=lfs diff=lfs merge=lfs -text
examples/Nice[[:space:]]English[[:space:]]Ref.wav filter=lfs diff=lfs merge=lfs -text
+examples/Arabic.wav filter=lfs diff=lfs merge=lfs -text
+examples/English.wav filter=lfs diff=lfs merge=lfs -text
+examples/French.wav filter=lfs diff=lfs merge=lfs -text
+examples/German.wav filter=lfs diff=lfs merge=lfs -text
+examples/Spanish.wav filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index e87fb1fa90473c1ffce53ab39f9a33993ee861b6..91c4a63cb1a38b778237397a69aac079c5ac8ab6 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
---
-title: Fish Speech 1
+title: OpenAudio S1
emoji: 🏆
colorFrom: purple
colorTo: gray
diff --git a/app.py b/app.py
index 7919d4335ed4a6daf739c752e117e539c3a987c8..85c9d20a70fe709751b53c43af3316f9b92fb850 100644
--- a/app.py
+++ b/app.py
@@ -1,86 +1,51 @@
import os
import queue
from huggingface_hub import snapshot_download
-import hydra
import numpy as np
import wave
import io
-import pyrootutils
import gc
+from typing import Callable
# Download if not exists
os.makedirs("checkpoints", exist_ok=True)
-snapshot_download(repo_id="fishaudio/fish-speech-1.5", local_dir="./checkpoints/fish-speech-1.5")
+snapshot_download(repo_id="fishaudio/openaudio-s1-mini", local_dir="./checkpoints/openaudio-s1-mini")
print("All checkpoints downloaded")
import html
import os
-import threading
from argparse import ArgumentParser
from pathlib import Path
-from functools import partial
import gradio as gr
-import librosa
import torch
import torchaudio
torchaudio.set_audio_backend("soundfile")
from loguru import logger
-from transformers import AutoTokenizer
-
from fish_speech.i18n import i18n
-from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
-from fish_speech.utils import autocast_exclude_mps, set_seed
-from tools.api import decode_vq_tokens, encode_reference
-from tools.file import AUDIO_EXTENSIONS, list_files
-from tools.llama.generate import (
- GenerateRequest,
- GenerateResponse,
- WrappedGenerateResponse,
- launch_thread_safe_queue,
-)
-from tools.vqgan.inference import load_model as load_decoder_model
-
-from tools.schema import (
- GLOBAL_NUM_SAMPLES,
- ASRPackRequest,
- ServeASRRequest,
- ServeASRResponse,
- ServeASRSegment,
- ServeAudioPart,
- ServeForwardMessage,
- ServeMessage,
- ServeRequest,
- ServeResponse,
- ServeStreamDelta,
- ServeStreamResponse,
- ServeTextPart,
- ServeTimedASRResponse,
- ServeTTSRequest,
- ServeVQGANDecodeRequest,
- ServeVQGANDecodeResponse,
- ServeVQGANEncodeRequest,
- ServeVQGANEncodeResponse,
- ServeVQPart,
- ServeReferenceAudio
-)
+from fish_speech.inference_engine import TTSInferenceEngine
+from fish_speech.models.dac.inference import load_model as load_decoder_model
+from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
+from tools.webui.inference import get_inference_wrapper
+from fish_speech.utils.schema import ServeTTSRequest
+
# Make einx happy
os.environ["EINX_FILTER_TRACEBACK"] = "false"
-HEADER_MD = """# Fish Speech
+HEADER_MD = """# OpenAudio S1
-## The demo in this space is version 1.5, Please check [Fish Audio](https://fish.audio) for the best model.
-## 该 Demo 为 Fish Speech 1.5 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
+## The demo in this space is OpenAudio S1, Please check [Fish Audio](https://fish.audio) for the best model.
+## 该 Demo 为 OpenAudio S1 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
-A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
-由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
+A text-to-speech model based on DAC and Qwen3 developed by [Fish Audio](https://fish.audio).
+由 [Fish Audio](https://fish.audio) 研发的基于 DAC 和 Qwen3 的多语种语音合成.
-You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).
-你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1.5) 找到模型.
+You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/openaudio-s1-mini).
+你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/openaudio-s1-mini) 找到模型.
Related code and weights are released under CC BY-NC-SA 4.0 License.
相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
@@ -88,8 +53,8 @@ Related code and weights are released under CC BY-NC-SA 4.0 License.
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
-The model running in this WebUI is Fish Speech V1.5 Medium.
-在此 WebUI 中运行的模型是 Fish Speech V1.5 Medium.
+The model running in this WebUI is OpenAudio S1 Mini.
+在此 WebUI 中运行的模型是 OpenAudio S1 Mini.
"""
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
@@ -106,7 +71,6 @@ except ImportError:
return wrapper
-
def build_html_error_message(error):
return f"""
0,
- chunk_length=req.chunk_length,
- max_length=4096,
- prompt_tokens=prompt_tokens,
- prompt_text=prompt_texts,
- )
-
- response_queue = queue.Queue()
- llama_queue.put(
- GenerateRequest(
- request=request,
- response_queue=response_queue,
- )
- )
-
- segments = []
-
- while True:
- result: WrappedGenerateResponse = response_queue.get()
- if result.status == "error":
- yield None, None, build_html_error_message(result.response)
- break
-
- result: GenerateResponse = result.response
- if result.action == "next":
- break
-
- with autocast_exclude_mps(
- device_type=decoder_model.device.type, dtype=args.precision
- ):
- fake_audios = decode_vq_tokens(
- decoder_model=decoder_model,
- codes=result.codes,
- )
-
- fake_audios = fake_audios.float().cpu().numpy()
- segments.append(fake_audios)
-
- if len(segments) == 0:
- return (
- None,
- None,
- build_html_error_message(
- i18n("No audio generated, please check the input text.")
- ),
- )
-
- # No matter streaming or not, we need to return the final audio
- audio = np.concatenate(segments, axis=0)
- yield None, (decoder_model.spec_transform.sample_rate, audio), None
-
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- gc.collect()
-
- except Exception as e:
- er = "CUDA error: device-side assert triggered"
- if er in str(e):
- app.close()
- else:
- raise Exception(e)
-
-n_audios = 4
-
-global_audio_list = []
-global_error_list = []
-
-
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
buffer = io.BytesIO()
@@ -230,13 +91,8 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
buffer.close()
return wav_header_bytes
-def normalize_text(user_input, use_normalization):
- if use_normalization:
- return ChnNormedText(raw_text=user_input).normalize()
- else:
- return user_input
-def build_app():
+def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Base()) as app:
gr.Markdown(HEADER_MD)
@@ -245,7 +101,7 @@ def build_app():
None,
None,
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
- % args.theme,
+ % theme,
)
# Inference
@@ -254,20 +110,6 @@ def build_app():
text = gr.Textbox(
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
)
- refined_text = gr.Textbox(
- label=i18n("Realtime Transform Text"),
- placeholder=i18n(
- "Normalization Result Preview (Currently Only Chinese)"
- ),
- lines=5,
- interactive=False,
- )
-
- with gr.Row():
- normalize = gr.Checkbox(
- label=i18n("Text Normalization"),
- value=False,
- )
with gr.Row():
with gr.Column():
@@ -275,45 +117,45 @@ def build_app():
with gr.Row():
chunk_length = gr.Slider(
label=i18n("Iterative Prompt Length, 0 means off"),
- minimum=0,
- maximum=300,
- value=200,
+ minimum=100,
+ maximum=400,
+ value=300,
step=8,
)
max_new_tokens = gr.Slider(
label=i18n(
- "Maximum tokens per batch"
+ "Maximum tokens per batch, 0 means no limit"
),
- minimum=512,
+ minimum=0,
maximum=2048,
- value=1024,
- step=64,
+ value=0,
+ step=8,
)
with gr.Row():
top_p = gr.Slider(
label="Top-P",
- minimum=0.6,
- maximum=0.9,
- value=0.7,
+ minimum=0.7,
+ maximum=0.95,
+ value=0.8,
step=0.01,
)
repetition_penalty = gr.Slider(
label=i18n("Repetition Penalty"),
minimum=1,
- maximum=1.5,
- value=1.2,
+ maximum=1.2,
+ value=1.1,
step=0.01,
)
with gr.Row():
temperature = gr.Slider(
label="Temperature",
- minimum=0.6,
- maximum=0.9,
- value=0.7,
+ minimum=0.7,
+ maximum=1.0,
+ value=0.8,
step=0.01,
)
seed = gr.Number(
@@ -326,24 +168,20 @@ def build_app():
with gr.Row():
gr.Markdown(
i18n(
- "15 to 60 seconds of reference audio, useful for specifying speaker."
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
)
)
-
with gr.Row():
- # Add dropdown for selecting example audio files
- example_audio_files = [f for f in os.listdir("examples") if f.endswith(".wav")]
- example_audio_dropdown = gr.Dropdown(
- label="Select Example Audio",
- choices=[""] + example_audio_files,
- value=""
+ reference_id = gr.Textbox(
+ label=i18n("Reference ID"),
+ placeholder="Leave empty to use uploaded references",
)
with gr.Row():
use_memory_cache = gr.Radio(
label=i18n("Use Memory Cache"),
- choices=["never"],
- value="never",
+ choices=["on", "off"],
+ value="on",
)
with gr.Row():
@@ -351,7 +189,6 @@ def build_app():
label=i18n("Reference Audio"),
type="filepath",
)
-
with gr.Row():
reference_text = gr.Textbox(
label=i18n("Reference Text"),
@@ -377,101 +214,16 @@ def build_app():
with gr.Row():
with gr.Column(scale=3):
generate = gr.Button(
- value="\U0001F3A7 " + i18n("Generate"), variant="primary"
+ value="\U0001f3a7 " + i18n("Generate"),
+ variant="primary",
)
- text.input(
- fn=normalize_text, inputs=[text, normalize], outputs=[refined_text]
- )
-
- def inference_wrapper(
- text,
- normalize,
- reference_audio,
- reference_text,
- max_new_tokens,
- chunk_length,
- top_p,
- repetition_penalty,
- temperature,
- seed,
- use_memory_cache,
- ):
- print(
- "call inference wrapper",
- text,
- normalize,
- reference_audio,
- reference_text,
- max_new_tokens,
- chunk_length,
- top_p,
- repetition_penalty,
- temperature,
- seed,
- use_memory_cache
- )
-
- references = []
- if reference_audio:
- # 将文件路径转换为字节
- with open(reference_audio, 'rb') as audio_file:
- audio_bytes = audio_file.read()
-
- references = [
- ServeReferenceAudio(audio=audio_bytes, text=reference_text)
- ]
-
- req = ServeTTSRequest(
- text=text,
- normalize=normalize,
- reference_id=None,
- references=references,
- max_new_tokens=max_new_tokens,
- chunk_length=chunk_length,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- temperature=temperature,
- seed=int(seed) if seed else None,
- use_memory_cache=use_memory_cache,
- )
-
- for result in inference(req):
- if result[2]: # Error message
- return None, result[2]
- elif result[1]: # Audio data
- return result[1], None
-
- return None, i18n("No audio generated")
-
- def select_example_audio(audio_file):
- if audio_file:
- audio_path = os.path.join("examples", audio_file)
- lab_file = os.path.splitext(audio_file)[0] + ".lab"
- lab_path = os.path.join("examples", lab_file)
-
- if os.path.exists(lab_path):
- with open(lab_path, "r", encoding="utf-8") as f:
- lab_content = f.read().strip()
- else:
- lab_content = ""
-
- return audio_path, lab_content
- return None, ""
-
- # Connect the dropdown to update reference audio and text
- example_audio_dropdown.change(
- fn=select_example_audio,
- inputs=[example_audio_dropdown],
- outputs=[reference_audio, reference_text]
- )
-
# Submit
generate.click(
- inference_wrapper,
+ inference_fct,
[
- refined_text,
- normalize,
+ text,
+ reference_id,
reference_audio,
reference_text,
max_new_tokens,
@@ -488,26 +240,24 @@ def build_app():
return app
-
-
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--llama-checkpoint-path",
type=Path,
- default="checkpoints/fish-speech-1.5",
+ default="checkpoints/openaudio-s1-mini",
)
parser.add_argument(
"--decoder-checkpoint-path",
type=Path,
- default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ default="checkpoints/openaudio-s1-mini/codec.pth",
)
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+ parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true",default=True)
parser.add_argument("--max-gradio-length", type=int, default=0)
- parser.add_argument("--theme", type=str, default="light")
+ parser.add_argument("--theme", type=str, default="dark")
return parser.parse_args()
@@ -533,25 +283,34 @@ if __name__ == "__main__":
logger.info("Decoder model loaded, warming up...")
+ # Create the inference engine
+ inference_engine = TTSInferenceEngine(
+ llama_queue=llama_queue,
+ decoder_model=decoder_model,
+ compile=args.compile,
+ precision=args.precision,
+ )
+
# Dry run to check if the model is loaded correctly and avoid the first-time latency
list(
- inference(
- ServeTTSRequest(
- text="Hello world.",
- references=[],
- reference_id=None,
- max_new_tokens=0,
- chunk_length=200,
- top_p=0.7,
- repetition_penalty=1.5,
- temperature=0.7,
- emotion=None,
- format="wav",
- )
+ inference_engine.inference(
+ ServeTTSRequest(
+ text="Hello world.",
+ references=[],
+ reference_id=None,
+ max_new_tokens=1024,
+ chunk_length=200,
+ top_p=0.7,
+ repetition_penalty=1.5,
+ temperature=0.7,
+ format="wav",
)
+ )
)
logger.info("Warming up done, launching the web UI...")
- app = build_app()
- app.queue(api_open=True).launch(show_error=True, show_api=True)
+ inference_fct = get_inference_wrapper(inference_engine)
+
+ app = build_app(inference_fct, args.theme)
+ app.queue(api_open=True).launch(show_error=True, show_api=True, server_name="0.0.0.0", server_port=18888)
diff --git a/examples/Arabic.wav b/examples/Arabic.wav
index 0e5cecd8112b805841714e01550b63b98a8b92be..2e2acdf221d68c1ff00d4c50e86959669dcd6b94 100644
Binary files a/examples/Arabic.wav and b/examples/Arabic.wav differ
diff --git a/examples/English.wav b/examples/English.wav
index c41d923cb56b33ac51f63ab1ceab1643976bf6d0..5061c0e2658b9238b883c054d6c3b08f90501e99 100644
Binary files a/examples/English.wav and b/examples/English.wav differ
diff --git a/examples/French.wav b/examples/French.wav
index a36ca2193b9b391701178e378423375897191e91..5292499755d86e57263c4ee00160864b5988197a 100644
Binary files a/examples/French.wav and b/examples/French.wav differ
diff --git a/examples/German.wav b/examples/German.wav
index 13ef8df7a782da5a6f863f57f72abaecaf7a9180..1c3a10598b0c6345dfdd7e3989aaa155f72ea114 100644
Binary files a/examples/German.wav and b/examples/German.wav differ
diff --git a/examples/Japanese.wav b/examples/Japanese.wav
index d601bc9d57b518e6b4646de0c0c9d063fb7df23c..8e41057b85229d69a196945d146109edd1eea2b2 100644
--- a/examples/Japanese.wav
+++ b/examples/Japanese.wav
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:a23cffeac70f42e1cc69e2a0505e4c1fda50884dd34c509128d432aaf44565e5
-size 1148682
+oid sha256:3034a38260884be854cb4a3f6cb648db85ebdeeb8cab74cfae2a578dc7aaedc2
+size 132
diff --git a/examples/Korean.wav b/examples/Korean.wav
index 66c9661830a819c99d88eef921a351969150651e..78b1faf981694a3353b842c2d66a912c17eb504f 100644
--- a/examples/Korean.wav
+++ b/examples/Korean.wav
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:c4234f119c741782e2c9c0ede4b5b864a560a355c28a23b2332e79420b69961a
-size 1632522
+oid sha256:5767663f0c26f4dc94f45227f385c2be568aac065272466915d65eaa64fdda0f
+size 132
diff --git a/examples/Nice English Ref.wav b/examples/Nice English Ref.wav
index 018def3da6421b1f1ec4a57870bb3a2d235efeb4..828534a0c3042d87cfeeb3a81a1adf367a8b45c0 100644
--- a/examples/Nice English Ref.wav
+++ b/examples/Nice English Ref.wav
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:d00ad9768c62f9821fc01ecab3e02669581ca75c18af6549690e19ce90a09f53
-size 5254482
+oid sha256:4b707de0cfc5d2eee59dcc3fea495603fe28d95ca64d8202bcdb31537d588782
+size 132
diff --git a/examples/Spanish.wav b/examples/Spanish.wav
index 31f7fe5875c9dcc81db9c87d98c8c68a4df22382..98c26d0c4173f2f067af00e882915a2a9123e9ed 100644
Binary files a/examples/Spanish.wav and b/examples/Spanish.wav differ
diff --git a/fish_speech/configs/base.yaml b/fish_speech/configs/base.yaml
index b6bf1c8cf5a79e7ddb58ebc725f2c002c26c9486..99e6dab54d3f57bce4f6d29a9129a19a523cad75 100644
--- a/fish_speech/configs/base.yaml
+++ b/fish_speech/configs/base.yaml
@@ -1,87 +1,87 @@
-# Base configuration for training a model
-paths:
- run_dir: results/${project}
- ckpt_dir: ${paths.run_dir}/checkpoints
-
-hydra:
- run:
- dir: ${paths.run_dir}
-
-# Lightning Trainer
-trainer:
- _target_: lightning.pytorch.trainer.Trainer
-
- default_root_dir: ${paths.run_dir}
- accelerator: gpu
- num_nodes: 1
- devices: auto
- strategy:
- _target_: lightning.pytorch.strategies.DDPStrategy
- process_group_backend: nccl # This should be override when training on windows
-
- precision: bf16-mixed
-
- # disable validation by epoch end
- check_val_every_n_epoch: null
- val_check_interval: 5000
- max_steps: 100_000
-
- # Use torch.backends.cudnn.benchmark to speed up training
- benchmark: true
-
-# Callbacks
-callbacks:
- model_checkpoint:
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
- dirpath: ${paths.ckpt_dir}
- filename: "step_{step:09d}"
- save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
- save_top_k: 5 # save 5 latest checkpoints
- monitor: step # use step to monitor checkpoints
- mode: max # save the latest checkpoint with the highest global_step
- every_n_epochs: null # don't save checkpoints by epoch end
- every_n_train_steps: 5000 # save checkpoints every 5000 steps
- auto_insert_metric_name: false
-
- model_summary:
- _target_: lightning.pytorch.callbacks.ModelSummary
- max_depth: 2 # the maximum depth of layer nesting that the summary will include
-
- learning_rate_monitor:
- _target_: lightning.pytorch.callbacks.LearningRateMonitor
- logging_interval: step
- log_momentum: false
-
- grad_norm_monitor:
- _target_: fish_speech.callbacks.GradNormMonitor
- norm_type: 2
- logging_interval: step
-
-# Logger
-logger:
- tensorboard:
- _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
- save_dir: "${paths.run_dir}/tensorboard/"
- name: null
- log_graph: false
- default_hp_metric: true
- prefix: ""
-
- # wandb:
- # _target_: lightning.pytorch.loggers.wandb.WandbLogger
- # # name: "" # name of the run (normally generated by wandb)
- # save_dir: "${paths.run_dir}"
- # offline: False
- # id: null # pass correct id to resume experiment!
- # anonymous: null # enable anonymous logging
- # project: "fish-speech"
- # log_model: False # upload lightning ckpts
- # prefix: "" # a string to put at the beginning of metric keys
- # # entity: "" # set to name of your wandb team
- # group: ""
- # tags: ["vq", "hq", "finetune"]
- # job_type: ""
-
-# Loop
-train: true
-test: false
+# Base configuration for training a model
+paths:
+ run_dir: results/${project}
+ ckpt_dir: ${paths.run_dir}/checkpoints
+
+hydra:
+ run:
+ dir: ${paths.run_dir}
+
+# Lightning Trainer
+trainer:
+ _target_: lightning.pytorch.trainer.Trainer
+
+ default_root_dir: ${paths.run_dir}
+ accelerator: gpu
+ num_nodes: 1
+ devices: auto
+ strategy:
+ _target_: lightning.pytorch.strategies.DDPStrategy
+ process_group_backend: nccl # This should be override when training on windows
+
+ precision: bf16-mixed
+
+ # disable validation by epoch end
+ check_val_every_n_epoch: null
+ val_check_interval: 5000
+ max_steps: 100_000
+
+ # Use torch.backends.cudnn.benchmark to speed up training
+ benchmark: true
+
+# Callbacks
+callbacks:
+ model_checkpoint:
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
+ dirpath: ${paths.ckpt_dir}
+ filename: "step_{step:09d}"
+ save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
+ save_top_k: 5 # save 5 latest checkpoints
+ monitor: step # use step to monitor checkpoints
+ mode: max # save the latest checkpoint with the highest global_step
+ every_n_epochs: null # don't save checkpoints by epoch end
+ every_n_train_steps: 5000 # save checkpoints every 5000 steps
+ auto_insert_metric_name: false
+
+ model_summary:
+ _target_: lightning.pytorch.callbacks.ModelSummary
+ max_depth: 2 # the maximum depth of layer nesting that the summary will include
+
+ learning_rate_monitor:
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
+ logging_interval: step
+ log_momentum: false
+
+ grad_norm_monitor:
+ _target_: fish_speech.callbacks.GradNormMonitor
+ norm_type: 2
+ logging_interval: step
+
+# Logger
+logger:
+ tensorboard:
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
+ save_dir: "${paths.run_dir}/tensorboard/"
+ name: null
+ log_graph: false
+ default_hp_metric: true
+ prefix: ""
+
+ # wandb:
+ # _target_: lightning.pytorch.loggers.wandb.WandbLogger
+ # # name: "" # name of the run (normally generated by wandb)
+ # save_dir: "${paths.run_dir}"
+ # offline: False
+ # id: null # pass correct id to resume experiment!
+ # anonymous: null # enable anonymous logging
+ # project: "fish-speech"
+ # log_model: False # upload lightning ckpts
+ # prefix: "" # a string to put at the beginning of metric keys
+ # # entity: "" # set to name of your wandb team
+ # group: ""
+ # tags: ["vq", "hq", "finetune"]
+ # job_type: ""
+
+# Loop
+train: true
+test: false
diff --git a/fish_speech/configs/lora/r_8_alpha_16.yaml b/fish_speech/configs/lora/r_8_alpha_16.yaml
index 0bb13622fc9dcf3ea59fd9db215536946fb346fa..aecc4d9766a18fe31c55941e01b1f590c95e77c9 100644
--- a/fish_speech/configs/lora/r_8_alpha_16.yaml
+++ b/fish_speech/configs/lora/r_8_alpha_16.yaml
@@ -1,4 +1,4 @@
-_target_: fish_speech.models.text2semantic.lora.LoraConfig
-r: 8
-lora_alpha: 16
-lora_dropout: 0.01
+_target_: fish_speech.models.text2semantic.lora.LoraConfig
+r: 8
+lora_alpha: 16
+lora_dropout: 0.01
diff --git a/fish_speech/configs/modded_dac_vq.yaml b/fish_speech/configs/modded_dac_vq.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..68a926deff44228fe757a22bc27f4f5b76f34408
--- /dev/null
+++ b/fish_speech/configs/modded_dac_vq.yaml
@@ -0,0 +1,50 @@
+_target_: fish_speech.models.dac.modded_dac.DAC
+# Model setup
+sample_rate: 44100
+encoder_dim: 64
+encoder_rates: [2, 4, 8, 8]
+decoder_dim: 1536
+decoder_rates: [8, 8, 4, 2]
+encoder_transformer_layers: [0, 0, 0, 4]
+decoder_transformer_layers: [4, 0, 0, 0]
+transformer_general_config:
+ _target_: fish_speech.models.dac.modded_dac.ModelArgs
+ _partial_: true
+ block_size: 16384
+ n_local_heads: -1
+ head_dim: 64
+ rope_base: 10000
+ norm_eps: 1e-5
+ dropout_rate: 0.1
+ attn_dropout_rate: 0.1
+ channels_first: true
+# Quantization
+quantizer:
+ _target_: fish_speech.models.dac.rvq.DownsampleResidualVectorQuantize
+ input_dim: 1024
+ n_codebooks: 9
+ codebook_size: 1024
+ codebook_dim: 8
+ quantizer_dropout: 0.5
+ downsample_factor: [2, 2]
+ post_module: &transformer_module
+ _target_: fish_speech.models.dac.modded_dac.WindowLimitedTransformer
+ causal: true
+ window_size: 128 # empirically this does not seem to matter
+ input_dim: 1024
+ config: &transformer_config
+ _target_: fish_speech.models.dac.modded_dac.ModelArgs
+ block_size: 4096
+ n_layer: 8
+ n_head: 16
+ dim: 1024
+ intermediate_size: 3072
+ n_local_heads: -1
+ head_dim: 64
+ rope_base: 10000
+ norm_eps: 1e-5
+ dropout_rate: 0.1
+ attn_dropout_rate: 0.1
+ channels_first: true
+ pre_module: *transformer_module
+ semantic_codebook_size: 4096
diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml
index 8c999411703e85c7ad5972134f7a33f42b279571..00f69051c05cbe428c8ef51fa8c467f7fc708bef 100644
--- a/fish_speech/configs/text2semantic_finetune.yaml
+++ b/fish_speech/configs/text2semantic_finetune.yaml
@@ -1,83 +1,86 @@
-defaults:
- - base
- - _self_
-
-project: text2semantic_finetune_dual_ar
-max_length: 4096
-pretrained_ckpt_path: checkpoints/fish-speech-1.4
-
-# Lightning Trainer
-trainer:
- accumulate_grad_batches: 1
- gradient_clip_val: 1.0
- gradient_clip_algorithm: "norm"
- max_steps: 1000
- precision: bf16-true
- limit_val_batches: 10
- val_check_interval: 100
-
-# Dataset Configuration
-tokenizer:
- _target_: transformers.AutoTokenizer.from_pretrained
- pretrained_model_name_or_path: ${pretrained_ckpt_path}
-
-# Dataset Configuration
-train_dataset:
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
- proto_files:
- - data/protos
- tokenizer: ${tokenizer}
- causal: true
- max_length: ${max_length}
- use_speaker: false
- interactive_prob: 0.7
-
-val_dataset:
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
- proto_files:
- - data/protos
- tokenizer: ${tokenizer}
- causal: true
- max_length: ${max_length}
- use_speaker: false
- interactive_prob: 0.7
-
-data:
- _target_: fish_speech.datasets.semantic.SemanticDataModule
- train_dataset: ${train_dataset}
- val_dataset: ${val_dataset}
- num_workers: 4
- batch_size: 8
- tokenizer: ${tokenizer}
- max_length: ${max_length}
-
-# Model Configuration
-model:
- _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
- model:
- _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
- path: ${pretrained_ckpt_path}
- load_weights: true
- max_length: ${max_length}
- lora_config: null
-
- optimizer:
- _target_: torch.optim.AdamW
- _partial_: true
- lr: 1e-4
- weight_decay: 0
- betas: [0.9, 0.95]
- eps: 1e-5
-
- lr_scheduler:
- _target_: torch.optim.lr_scheduler.LambdaLR
- _partial_: true
- lr_lambda:
- _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
- _partial_: true
- num_warmup_steps: 10
-
-# Callbacks
-callbacks:
- model_checkpoint:
- every_n_train_steps: ${trainer.val_check_interval}
+defaults:
+ - base
+ - _self_
+
+project: text2semantic_finetune_dual_ar
+max_length: 4096
+pretrained_ckpt_path: checkpoints/openaudio-s1-mini
+
+# Lightning Trainer
+trainer:
+ accumulate_grad_batches: 1
+ gradient_clip_val: 1.0
+ gradient_clip_algorithm: "norm"
+ max_steps: 10000
+ precision: bf16-true
+ limit_val_batches: 10
+ val_check_interval: 100
+ # strategy:
+ # find_unused_parameters: true
+ # static_graph: true
+
+# Dataset Configuration
+tokenizer:
+ _target_: fish_speech.tokenizer.FishTokenizer
+ model_path: ${pretrained_ckpt_path}/tokenizer.tiktoken
+
+# Dataset Configuration
+train_dataset:
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
+ proto_files:
+ - data/protos
+ tokenizer: ${tokenizer}
+ causal: true
+ max_length: ${max_length}
+ use_speaker: false
+ interactive_prob: 0.7
+
+val_dataset:
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
+ proto_files:
+ - data/protos
+ tokenizer: ${tokenizer}
+ causal: true
+ max_length: ${max_length}
+ use_speaker: false
+ interactive_prob: 0.7
+
+data:
+ _target_: fish_speech.datasets.semantic.SemanticDataModule
+ train_dataset: ${train_dataset}
+ val_dataset: ${val_dataset}
+ num_workers: 4
+ batch_size: 4
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+
+# Model Configuration
+model:
+ _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
+ model:
+ _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
+ path: ${pretrained_ckpt_path}
+ load_weights: true
+ max_length: ${max_length}
+ lora_config: null
+
+ optimizer:
+ _target_: torch.optim.AdamW
+ _partial_: true
+ lr: 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.95]
+ eps: 1e-5
+
+ lr_scheduler:
+ _target_: torch.optim.lr_scheduler.LambdaLR
+ _partial_: true
+ lr_lambda:
+ _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
+ _partial_: true
+ num_warmup_steps: 10
+
+# Callbacks
+callbacks:
+ model_checkpoint:
+ every_n_train_steps: ${trainer.val_check_interval}
diff --git a/fish_speech/content_sequence.py b/fish_speech/content_sequence.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5885c65e5e103d2df0f5b9c27adefa5cdb637c2
--- /dev/null
+++ b/fish_speech/content_sequence.py
@@ -0,0 +1,367 @@
+from dataclasses import dataclass, field
+from typing import List, Literal, Union
+
+import numpy as np
+import torch
+
+from fish_speech.tokenizer import (
+ IM_END_TOKEN,
+ MODALITY_TOKENS,
+ FishTokenizer,
+)
+
+
+def restore_ndarray(obj, to_tensor: bool = False):
+ if isinstance(obj, dict) and "__ndarray__" in obj:
+ obj = np.frombuffer(obj["data"], dtype=obj["dtype"]).reshape(obj["shape"])
+
+ if to_tensor and isinstance(obj, np.ndarray):
+ obj = torch.from_numpy(obj.copy())
+
+ return obj
+
+
+@dataclass
+class BasePart:
+ type: Literal["text", "vq", "audio"] | None = None
+ cal_loss: bool = False
+
+
+@dataclass(kw_only=True)
+class VQPart(BasePart):
+ type = "vq"
+ codes: torch.Tensor
+
+ def __post_init__(self: "VQPart"):
+ self.type = "vq"
+ self.codes = restore_ndarray(self.codes, to_tensor=True)
+
+
+@dataclass(kw_only=True)
+class TextPart(BasePart):
+ type = "text"
+ text: str | None = None
+ tokens: list[int] | None = None
+
+ def __post_init__(self: "TextPart"):
+ self.type = "text"
+ if self.text is None and self.tokens is None:
+ raise ValueError("Either text or tokens must be provided")
+
+
+@dataclass(kw_only=True)
+class AudioPart(BasePart):
+ type = "audio"
+ features: torch.Tensor
+
+ def __post_init__(self: "AudioPart"):
+ self.type = "audio"
+ self.features = restore_ndarray(self.features, to_tensor=True)
+
+
+@dataclass(kw_only=True)
+class EncodedMessage:
+ tokens: torch.Tensor
+ labels: torch.Tensor
+ vq_mask_tokens: torch.Tensor | None = None
+ vq_mask_labels: torch.Tensor | None = None
+ vq_parts: list[torch.Tensor]
+ vq_require_losses: torch.Tensor | None = None
+ audio_parts: list[torch.Tensor]
+ audio_masks: torch.Tensor | None = None
+ metadata: dict | None = None
+
+
+@dataclass
+class ContentSequence:
+ """
+ Flexible sequence of content parts that supports interleaved multimodal format.
+ Example format: <|interleave|><|speaker:1|> TEXT AUDIO <|im_end|><|speaker:2|> TEXT AUDIO <|im_end|>
+ """
+
+ parts: list[BasePart] = field(default_factory=list)
+ modality: Literal["text", "voice", "interleave"] | None = None
+ metadata: dict | None = None
+
+ def __init__(
+ self: "ContentSequence",
+ parts: list[BasePart | dict] | None = None,
+ modality: Literal["text", "voice", "interleave"] | None = None,
+ metadata: dict | None = None,
+ ):
+ self.modality = modality
+ self.metadata = metadata or {}
+
+ fixed_parts = []
+ for part in parts or []:
+ if isinstance(part, dict):
+ if part["type"] == "vq":
+ part = VQPart(**part)
+ elif part["type"] == "audio":
+ part = AudioPart(**part)
+ elif part["type"] == "text":
+ part = TextPart(**part)
+ else:
+ raise ValueError(f"Unsupported part type: {part['type']}")
+ fixed_parts.append(part)
+
+ self.parts = fixed_parts
+
+ # If modality is specified, add it at the beginning if it's not already there
+ if self.modality and not (
+ len(self.parts) > 0
+ and isinstance(self.parts[0], dict) is False
+ and isinstance(self.parts[0], TextPart)
+ and self.parts[0].text is not None
+ and self.parts[0].text.startswith(MODALITY_TOKENS[self.modality])
+ ):
+ modality_token = MODALITY_TOKENS[self.modality]
+ self.parts.insert(0, TextPart(text=modality_token))
+
+ def append(
+ self: "ContentSequence",
+ part_or_parts: Union[BasePart, List[BasePart]],
+ add_end: bool = False,
+ speaker: Union[str, int] | None = None,
+ ):
+ """
+ Append a part or list of parts to the sequence.
+
+ Args:
+ part_or_parts: A single part or list of parts to add
+ add_end: Whether to add the IM_END_TOKEN after these parts
+ speaker: Optional speaker identifier (name or ID) to add before the parts
+ """
+ # Convert single part to list
+ parts_to_add = (
+ [part_or_parts] if not isinstance(part_or_parts, list) else part_or_parts
+ )
+
+ # Add speaker token if specified
+ if speaker is not None:
+ speaker_token = f"<|speaker:{speaker}|>"
+ self.parts.append(TextPart(text=speaker_token))
+
+ # Add all the parts
+ self.parts.extend(parts_to_add)
+
+ # Add end token if requested
+ if add_end:
+ self.parts.append(
+ TextPart(text=IM_END_TOKEN, cal_loss=self.parts[-1].cal_loss)
+ )
+
+ def encode(
+ self: "ContentSequence",
+ tokenizer: FishTokenizer,
+ add_shift: bool = True,
+ ignore_loss_tokens: list[str] = [],
+ ) -> EncodedMessage:
+ """
+ Encode the sequence parts into tokens for the model.
+
+ Args:
+ tokenizer: The tokenizer to use
+ add_shift: Whether to shift tokens for next-token prediction
+ ignore_loss_tokens: List of token strings to ignore when calculating loss
+
+ Returns:
+ EncodedMessage with tensors ready for the model
+ """
+ all_tokens = []
+ all_labels = []
+
+ # Multi-modal elements
+ vq_parts = []
+ vq_masks = []
+ vq_require_losses = []
+
+ audio_parts = []
+ audio_masks = []
+
+ ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
+
+ for part in self.parts:
+ if isinstance(part, TextPart):
+ if part.tokens is None:
+ assert part.text is not None
+ tokens = tokenizer.encode(part.text)
+ else:
+ tokens = part.tokens
+
+ tokens = torch.tensor(tokens, dtype=torch.int)
+ elif isinstance(part, VQPart):
+ curr_codes = part.codes.clone().to(torch.int)
+ tokens = torch.tensor(
+ [
+ tokenizer.semantic_id_to_token_id[int(i.item())]
+ for i in curr_codes[0].int()
+ ],
+ dtype=torch.int,
+ )
+ vq_parts.append(curr_codes)
+ vq_require_losses.append(part.cal_loss)
+ else:
+ raise ValueError(f"Unsupported part type: {type(part)}")
+
+ all_tokens.append(tokens)
+
+ # Set masks for different part types
+ if isinstance(part, VQPart):
+ vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
+ audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
+ elif isinstance(part, AudioPart):
+ vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
+ audio_mask = torch.ones_like(tokens, dtype=torch.bool)
+ audio_mask[0] = False # Skip start token
+ audio_mask[-1] = False # Skip end token
+ audio_masks.append(audio_mask)
+ else:
+ vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
+ audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
+
+ # Set labels based on whether we want to calculate loss for this part
+ if part.cal_loss and not isinstance(part, AudioPart):
+ all_labels.append(tokens.clone())
+ else:
+ all_labels.append(torch.full_like(tokens, -100))
+
+ # Concatenate all tensors
+ tokens = torch.cat(all_tokens, dim=0)
+ labels = torch.cat(all_labels, dim=0)
+ vq_masks = torch.cat(vq_masks, dim=0)
+ audio_masks = torch.cat(audio_masks, dim=0)
+ vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
+
+ # Apply shift if needed for next-token prediction
+ vq_mask_tokens = vq_masks
+ vq_mask_labels = vq_masks
+
+ if add_shift:
+ tokens = tokens[:-1]
+ labels = labels[1:]
+ vq_masks = vq_masks[:-1]
+ vq_mask_tokens = vq_mask_tokens[:-1]
+ vq_mask_labels = vq_mask_labels[1:]
+ audio_masks = audio_masks[:-1]
+
+ # Ignore specified tokens
+ for i in ignore_loss_token_ids:
+ assert i != -100 and i is not None
+ labels[labels == i] = -100
+
+ assert tokens.dtype in [
+ torch.int,
+ torch.long,
+ ], f"Invalid dtype: {tokens.dtype}"
+
+ return EncodedMessage(
+ tokens=tokens,
+ labels=labels,
+ vq_parts=vq_parts,
+ vq_mask_tokens=vq_mask_tokens,
+ vq_mask_labels=vq_mask_labels,
+ vq_require_losses=vq_require_losses,
+ audio_parts=audio_parts,
+ audio_masks=audio_masks,
+ metadata=self.metadata,
+ )
+
+ def encode_for_inference(
+ self: "ContentSequence",
+ tokenizer: FishTokenizer,
+ num_codebooks: int,
+ ) -> torch.Tensor:
+ encoded = self.encode(tokenizer, add_shift=False)
+ tokens = encoded.tokens
+ values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
+ values[0] = tokens
+
+ if (encoded.vq_parts is None or len(encoded.vq_parts) == 0) and (
+ encoded.audio_parts is None or len(encoded.audio_parts) == 0
+ ):
+ return values
+
+ if encoded.vq_parts is not None and len(encoded.vq_parts) > 0:
+ vq_parts = encoded.vq_parts
+ vq_parts = torch.cat(vq_parts, dim=1)
+ values[0, encoded.vq_mask_tokens] = (
+ vq_parts[0] + tokenizer.semantic_begin_id
+ )
+ values[1:, encoded.vq_mask_tokens] = vq_parts
+
+ return values
+
+ def visualize(
+ self: "ContentSequence",
+ tokenizer: FishTokenizer,
+ ignore_loss_tokens: list[str] = [],
+ merge_semantic_tokens: bool = False,
+ ):
+ """
+ Visualize the encoded sequence with color-coded tokens.
+ Blue/cyan tokens contribute to loss, green tokens do not.
+ """
+ encoded = self.encode(
+ tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
+ )
+
+ # Colors for alternating tokens
+ colors = {
+ "blue": "\033[94m", # Light blue
+ "cyan": "\033[96m", # Cyan
+ "green": "\033[92m", # Light green
+ "dark_green": "\033[32m", # Dark green
+ }
+ blue_idx = 0
+ green_idx = 0
+
+ def print_in_blue(x):
+ nonlocal blue_idx
+ color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
+ print(f"{color}{x}\033[0m", end="")
+ blue_idx += 1
+
+ def print_in_green(x):
+ nonlocal green_idx
+ color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
+ print(f"{color}{x}\033[0m", end="")
+ green_idx += 1
+
+ def print_semantic_token(x, count):
+ val = f"[<|semantic|>x{count}]"
+ if x == -100:
+ print_in_green(val)
+ else:
+ print_in_blue(val)
+
+ count_semantic_tokens = 0
+ semantic_label = None
+
+ for tok, lab in zip(encoded.tokens, encoded.labels):
+ token_id = int(tok.item())
+
+ if merge_semantic_tokens:
+ if (
+ tokenizer.semantic_begin_id <= token_id <= tokenizer.semantic_end_id
+ and (semantic_label is None or semantic_label == lab)
+ ):
+ count_semantic_tokens += 1
+ semantic_label = lab
+ continue
+ elif count_semantic_tokens > 0:
+ print_semantic_token(semantic_label, count_semantic_tokens)
+ count_semantic_tokens = 0
+ semantic_label = None
+
+ val = tokenizer.decode([int(tok.item())])
+
+ if lab == -100:
+ print_in_green(val)
+ else:
+ print_in_blue(val)
+
+ if merge_semantic_tokens and count_semantic_tokens > 0:
+ print_semantic_token(semantic_label, count_semantic_tokens)
+
+ print()
diff --git a/fish_speech/i18n/README.md b/fish_speech/i18n/README.md
index 7d4612883379a16fd2d0945c431d9fb8b04b249a..700902b09db20911ef1ad678cbdce5644b84aea2 100644
--- a/fish_speech/i18n/README.md
+++ b/fish_speech/i18n/README.md
@@ -1,27 +1,27 @@
-## i18n Folder Attribution
-
-The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
-
-### fish_speech/i18n/core.py
-
-**Related code from RVC:**
-[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
-
-**Initial commit:**
-add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
-
-**Initial author:**
-[@L4Ph](https://github.com/L4Ph)
-
-### fish_speech/i18n/scan.py
-
-**Related code from RVC:**
-[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
-
-**Initial commit:**
-File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
-
-**Initial author:**
-[@towzeur](https://github.com/towzeur)
-
-We appreciate the contributions of the RVC project and its authors.
+## i18n Folder Attribution
+
+The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
+
+### fish_speech/i18n/core.py
+
+**Related code from RVC:**
+[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
+
+**Initial commit:**
+add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
+
+**Initial author:**
+[@L4Ph](https://github.com/L4Ph)
+
+### fish_speech/i18n/scan.py
+
+**Related code from RVC:**
+[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
+
+**Initial commit:**
+File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
+
+**Initial author:**
+[@towzeur](https://github.com/towzeur)
+
+We appreciate the contributions of the RVC project and its authors.
diff --git a/fish_speech/i18n/__init__.py b/fish_speech/i18n/__init__.py
index 0ac9702a707223257997e283ebc259b78996ad5c..981dbb3b3ecf28043ec9ff5757f947182821a246 100644
--- a/fish_speech/i18n/__init__.py
+++ b/fish_speech/i18n/__init__.py
@@ -1,3 +1,3 @@
-from .core import i18n
-
-__all__ = ["i18n"]
+from .core import i18n
+
+__all__ = ["i18n"]
diff --git a/fish_speech/i18n/core.py b/fish_speech/i18n/core.py
index 8375d9ddb4e3c2b3ec25c426d2786f2a7506a0ab..9f793ec95669228f7f4e8f9a7a5fe38da85c74bd 100644
--- a/fish_speech/i18n/core.py
+++ b/fish_speech/i18n/core.py
@@ -1,40 +1,40 @@
-import json
-import locale
-from pathlib import Path
-
-I18N_FILE_PATH = Path(__file__).parent / "locale"
-DEFAULT_LANGUAGE = "en_US"
-
-
-def load_language_list(language):
- with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
- language_list = json.load(f)
-
- return language_list
-
-
-class I18nAuto:
- def __init__(self):
- i18n_file = Path(".locale")
-
- if i18n_file.exists():
- with open(i18n_file, "r", encoding="utf-8") as f:
- language = f.read().strip()
- else:
- # getlocale can't identify the system's language ((None, None))
- language = locale.getdefaultlocale()[0]
-
- if (I18N_FILE_PATH / f"{language}.json").exists() is False:
- language = DEFAULT_LANGUAGE
-
- self.language = language
- self.language_map = load_language_list(language)
-
- def __call__(self, key):
- return self.language_map.get(key, key)
-
- def __repr__(self):
- return "Use Language: " + self.language
-
-
-i18n = I18nAuto()
+import json
+import locale
+from pathlib import Path
+
+I18N_FILE_PATH = Path(__file__).parent / "locale"
+DEFAULT_LANGUAGE = "en_US"
+
+
+def load_language_list(language):
+ with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
+ language_list = json.load(f)
+
+ return language_list
+
+
+class I18nAuto:
+ def __init__(self):
+ i18n_file = Path(".locale")
+
+ if i18n_file.exists():
+ with open(i18n_file, "r", encoding="utf-8") as f:
+ language = f.read().strip()
+ else:
+ # getlocale can't identify the system's language ((None, None))
+ language = locale.getdefaultlocale()[0]
+
+ if (I18N_FILE_PATH / f"{language}.json").exists() is False:
+ language = DEFAULT_LANGUAGE
+
+ self.language = language
+ self.language_map = load_language_list(language)
+
+ def __call__(self, key):
+ return self.language_map.get(key, key)
+
+ def __repr__(self):
+ return "Use Language: " + self.language
+
+
+i18n = I18nAuto()
diff --git a/fish_speech/i18n/locale/en_US.json b/fish_speech/i18n/locale/en_US.json
index 32d58983a5df52d032411fd50ef1a9e3ecdeb859..d36c774313628fe9d4ee60e816f404c09935e655 100644
--- a/fish_speech/i18n/locale/en_US.json
+++ b/fish_speech/i18n/locale/en_US.json
@@ -1,123 +1,123 @@
-{
- "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
- "Accumulate Gradient Batches": "Accumulate Gradient Batches",
- "Add to Processing Area": "Add to Processing Area",
- "Added path successfully!": "Added path successfully!",
- "Advanced Config": "Advanced Config",
- "Base LLAMA Model": "Base LLAMA Model",
- "Batch Inference": "Batch Inference",
- "Batch Size": "Batch Size",
- "Changing with the Model Path": "Changing with the Model Path",
- "Chinese": "Chinese",
- "Compile Model": "Compile Model",
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
- "Copy": "Copy",
- "Data Preprocessing": "Data Preprocessing",
- "Data Preprocessing Path": "Data Preprocessing Path",
- "Data Source": "Data Source",
- "Decoder Model Config": "Decoder Model Config",
- "Decoder Model Path": "Decoder Model Path",
- "Disabled": "Disabled",
- "Enable Reference Audio": "Enable Reference Audio",
- "English": "English",
- "Error Message": "Error Message",
- "File Preprocessing": "File Preprocessing",
- "Generate": "Generate",
- "Generated Audio": "Generated Audio",
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
- "Infer interface is closed": "Infer interface is closed",
- "Inference Configuration": "Inference Configuration",
- "Inference Server Configuration": "Inference Server Configuration",
- "Inference Server Error": "Inference Server Error",
- "Inferring interface is launched at {}": "Inferring interface is launched at {}",
- "Initial Learning Rate": "Initial Learning Rate",
- "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
- "Input Text": "Input Text",
- "Invalid path: {}": "Invalid path: {}",
- "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
- "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
- "Japanese": "Japanese",
- "LLAMA Configuration": "LLAMA Configuration",
- "LLAMA Model Config": "LLAMA Model Config",
- "LLAMA Model Path": "LLAMA Model Path",
- "Labeling Device": "Labeling Device",
- "LoRA Model to be merged": "LoRA Model to be merged",
- "Maximum Audio Duration": "Maximum Audio Duration",
- "Maximum Length per Sample": "Maximum Length per Sample",
- "Maximum Training Steps": "Maximum Training Steps",
- "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
- "Merge": "Merge",
- "Merge LoRA": "Merge LoRA",
- "Merge successfully": "Merge successfully",
- "Minimum Audio Duration": "Minimum Audio Duration",
- "Model Output Path": "Model Output Path",
- "Model Size": "Model Size",
- "Move": "Move",
- "Move files successfully": "Move files successfully",
- "No audio generated, please check the input text.": "No audio generated, please check the input text.",
- "No selected options": "No selected options",
- "Number of Workers": "Number of Workers",
- "Open Inference Server": "Open Inference Server",
- "Open Labeler WebUI": "Open Labeler WebUI",
- "Open Tensorboard": "Open Tensorboard",
- "Opened labeler in browser": "Opened labeler in browser",
- "Optional Label Language": "Optional Label Language",
- "Optional online ver": "Optional online ver",
- "Output Path": "Output Path",
- "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
- "Precision": "Precision",
- "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
- "Put your text here.": "Put your text here.",
- "Reference Audio": "Reference Audio",
- "Reference Text": "Reference Text",
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
- "Remove Selected Data": "Remove Selected Data",
- "Removed path successfully!": "Removed path successfully!",
- "Repetition Penalty": "Repetition Penalty",
- "Save model every n steps": "Save model every n steps",
- "Select LLAMA ckpt": "Select LLAMA ckpt",
- "Select VITS ckpt": "Select VITS ckpt",
- "Select VQGAN ckpt": "Select VQGAN ckpt",
- "Select source file processing method": "Select source file processing method",
- "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
- "Selected: {}": "Selected: {}",
- "Speaker": "Speaker",
- "Speaker is identified by the folder name": "Speaker is identified by the folder name",
- "Start Training": "Start Training",
- "Streaming Audio": "Streaming Audio",
- "Streaming Generate": "Streaming Generate",
- "Tensorboard Host": "Tensorboard Host",
- "Tensorboard Log Path": "Tensorboard Log Path",
- "Tensorboard Port": "Tensorboard Port",
- "Tensorboard interface is closed": "Tensorboard interface is closed",
- "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
- "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
- "Training Configuration": "Training Configuration",
- "Training Error": "Training Error",
- "Training stopped": "Training stopped",
- "Type name of the speaker": "Type name of the speaker",
- "Type the path or select from the dropdown": "Type the path or select from the dropdown",
- "Use LoRA": "Use LoRA",
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
- "Use filelist": "Use filelist",
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
- "VITS Configuration": "VITS Configuration",
- "VQGAN Configuration": "VQGAN Configuration",
- "Validation Batch Size": "Validation Batch Size",
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
- "WebUI Host": "WebUI Host",
- "WebUI Port": "WebUI Port",
- "Whisper Model": "Whisper Model",
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
- "latest": "latest",
- "new": "new",
- "Realtime Transform Text": "Realtime Transform Text",
- "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
- "Text Normalization": "Text Normalization",
- "Select Example Audio": "Select Example Audio"
-}
+{
+ "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Accumulate Gradient Batches",
+ "Add to Processing Area": "Add to Processing Area",
+ "Added path successfully!": "Added path successfully!",
+ "Advanced Config": "Advanced Config",
+ "Base LLAMA Model": "Base LLAMA Model",
+ "Batch Inference": "Batch Inference",
+ "Batch Size": "Batch Size",
+ "Changing with the Model Path": "Changing with the Model Path",
+ "Chinese": "Chinese",
+ "Compile Model": "Compile Model",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
+ "Copy": "Copy",
+ "Data Preprocessing": "Data Preprocessing",
+ "Data Preprocessing Path": "Data Preprocessing Path",
+ "Data Source": "Data Source",
+ "Decoder Model Config": "Decoder Model Config",
+ "Decoder Model Path": "Decoder Model Path",
+ "Disabled": "Disabled",
+ "Enable Reference Audio": "Enable Reference Audio",
+ "English": "English",
+ "Error Message": "Error Message",
+ "File Preprocessing": "File Preprocessing",
+ "Generate": "Generate",
+ "Generated Audio": "Generated Audio",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
+ "Infer interface is closed": "Infer interface is closed",
+ "Inference Configuration": "Inference Configuration",
+ "Inference Server Configuration": "Inference Server Configuration",
+ "Inference Server Error": "Inference Server Error",
+ "Inferring interface is launched at {}": "Inferring interface is launched at {}",
+ "Initial Learning Rate": "Initial Learning Rate",
+ "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
+ "Input Text": "Input Text",
+ "Invalid path: {}": "Invalid path: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
+ "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
+ "Japanese": "Japanese",
+ "LLAMA Configuration": "LLAMA Configuration",
+ "LLAMA Model Config": "LLAMA Model Config",
+ "LLAMA Model Path": "LLAMA Model Path",
+ "Labeling Device": "Labeling Device",
+ "LoRA Model to be merged": "LoRA Model to be merged",
+ "Maximum Audio Duration": "Maximum Audio Duration",
+ "Maximum Length per Sample": "Maximum Length per Sample",
+ "Maximum Training Steps": "Maximum Training Steps",
+ "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
+ "Merge": "Merge",
+ "Merge LoRA": "Merge LoRA",
+ "Merge successfully": "Merge successfully",
+ "Minimum Audio Duration": "Minimum Audio Duration",
+ "Model Output Path": "Model Output Path",
+ "Model Size": "Model Size",
+ "Move": "Move",
+ "Move files successfully": "Move files successfully",
+ "No audio generated, please check the input text.": "No audio generated, please check the input text.",
+ "No selected options": "No selected options",
+ "Number of Workers": "Number of Workers",
+ "Open Inference Server": "Open Inference Server",
+ "Open Labeler WebUI": "Open Labeler WebUI",
+ "Open Tensorboard": "Open Tensorboard",
+ "Opened labeler in browser": "Opened labeler in browser",
+ "Optional Label Language": "Optional Label Language",
+ "Optional online ver": "Optional online ver",
+ "Output Path": "Output Path",
+ "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
+ "Precision": "Precision",
+ "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
+ "Put your text here.": "Put your text here.",
+ "Reference Audio": "Reference Audio",
+ "Reference Text": "Reference Text",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
+ "Remove Selected Data": "Remove Selected Data",
+ "Removed path successfully!": "Removed path successfully!",
+ "Repetition Penalty": "Repetition Penalty",
+ "Save model every n steps": "Save model every n steps",
+ "Select LLAMA ckpt": "Select LLAMA ckpt",
+ "Select VITS ckpt": "Select VITS ckpt",
+ "Select VQGAN ckpt": "Select VQGAN ckpt",
+ "Select source file processing method": "Select source file processing method",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
+ "Selected: {}": "Selected: {}",
+ "Speaker": "Speaker",
+ "Speaker is identified by the folder name": "Speaker is identified by the folder name",
+ "Start Training": "Start Training",
+ "Streaming Audio": "Streaming Audio",
+ "Streaming Generate": "Streaming Generate",
+ "Tensorboard Host": "Tensorboard Host",
+ "Tensorboard Log Path": "Tensorboard Log Path",
+ "Tensorboard Port": "Tensorboard Port",
+ "Tensorboard interface is closed": "Tensorboard interface is closed",
+ "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
+ "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
+ "Training Configuration": "Training Configuration",
+ "Training Error": "Training Error",
+ "Training stopped": "Training stopped",
+ "Type name of the speaker": "Type name of the speaker",
+ "Type the path or select from the dropdown": "Type the path or select from the dropdown",
+ "Use LoRA": "Use LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
+ "Use filelist": "Use filelist",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
+ "VITS Configuration": "VITS Configuration",
+ "VQGAN Configuration": "VQGAN Configuration",
+ "Validation Batch Size": "Validation Batch Size",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
+ "WebUI Host": "WebUI Host",
+ "WebUI Port": "WebUI Port",
+ "Whisper Model": "Whisper Model",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
+ "latest": "latest",
+ "new": "new",
+ "Realtime Transform Text": "Realtime Transform Text",
+ "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
+ "Text Normalization": "Text Normalization",
+ "Select Example Audio": "Select Example Audio"
+}
diff --git a/fish_speech/i18n/locale/es_ES.json b/fish_speech/i18n/locale/es_ES.json
index 0bde8404cdaff94e7c49be580cbba99b8f41ce29..7a4757967dd0fe3807ba4d354e75ad7a88eb510e 100644
--- a/fish_speech/i18n/locale/es_ES.json
+++ b/fish_speech/i18n/locale/es_ES.json
@@ -1,123 +1,123 @@
-{
- "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
- "Accumulate Gradient Batches": "Acumular lotes de gradientes",
- "Add to Processing Area": "Agregar al Área de Procesamiento",
- "Added path successfully!": "¡Ruta agregada exitosamente!",
- "Advanced Config": "Configuración Avanzada",
- "Base LLAMA Model": "Modelo Base LLAMA",
- "Batch Inference": "Inferencia por Lote",
- "Batch Size": "Tamaño del Lote",
- "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
- "Chinese": "Chino",
- "Compile Model": "Compilar Modelo",
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
- "Copy": "Copiar",
- "Data Preprocessing": "Preprocesamiento de Datos",
- "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
- "Data Source": "Fuente de Datos",
- "Decoder Model Config": "Configuración del modelo decodificador",
- "Decoder Model Path": "Ruta del modelo decodificador",
- "Disabled": "Desactivado",
- "Enable Reference Audio": "Habilitar Audio de Referencia",
- "English": "Inglés",
- "Error Message": "Mensaje de Error",
- "File Preprocessing": "Preprocesamiento de Archivos",
- "Generate": "Generar",
- "Generated Audio": "Audio Generado",
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
- "Infer interface is closed": "La interfaz de inferencia está cerrada",
- "Inference Configuration": "Configuración de Inferencia",
- "Inference Server Configuration": "Configuración del Servidor de Inferencia",
- "Inference Server Error": "Error del Servidor de Inferencia",
- "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
- "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
- "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
- "Input Text": "Texto de Entrada",
- "Invalid path: {}": "Ruta inválida: {}",
- "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
- "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
- "Japanese": "Japonés",
- "LLAMA Configuration": "Configuración de LLAMA",
- "LLAMA Model Config": "Configuración del Modelo LLAMA",
- "LLAMA Model Path": "Ruta del Modelo LLAMA",
- "Labeling Device": "Dispositivo de Etiquetado",
- "LoRA Model to be merged": "Modelo LoRA a fusionar",
- "Maximum Audio Duration": "Duración máxima de audio",
- "Maximum Length per Sample": "Longitud Máxima por Muestra",
- "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
- "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
- "Merge": "Fusionar",
- "Merge LoRA": "Fusionar LoRA",
- "Merge successfully": "Fusionado exitosamente",
- "Minimum Audio Duration": "Duración mínima de audio",
- "Model Output Path": "Ruta de Salida del Modelo",
- "Model Size": "Tamaño del Modelo",
- "Move": "Mover",
- "Move files successfully": "Archivos movidos exitosamente",
- "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
- "No selected options": "No hay opciones seleccionadas",
- "Number of Workers": "Número de Trabajadores",
- "Open Inference Server": "Abrir Servidor de Inferencia",
- "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
- "Open Tensorboard": "Abrir Tensorboard",
- "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
- "Optional Label Language": "Idioma de Etiquetado Opcional",
- "Optional online ver": "Ver en línea opcional",
- "Output Path": "Ruta de Salida",
- "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
- "Precision": "Precisión",
- "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
- "Put your text here.": "Ponga su texto aquí.",
- "Reference Audio": "Audio de Referencia",
- "Reference Text": "Texto de Referencia",
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
- "Remove Selected Data": "Eliminar Datos Seleccionados",
- "Removed path successfully!": "¡Ruta eliminada exitosamente!",
- "Repetition Penalty": "Penalización por Repetición",
- "Save model every n steps": "Guardar modelo cada n pasos",
- "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
- "Select VITS ckpt": "Seleccionar punto de control VITS",
- "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
- "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
- "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
- "Selected: {}": "Seleccionado: {}",
- "Speaker": "Hablante",
- "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
- "Start Training": "Iniciar Entrenamiento",
- "Streaming Audio": "transmisión de audio",
- "Streaming Generate": "síntesis en flujo",
- "Tensorboard Host": "Host de Tensorboard",
- "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
- "Tensorboard Port": "Puerto de Tensorboard",
- "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
- "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
- "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
- "Training Configuration": "Configuración de Entrenamiento",
- "Training Error": "Error de Entrenamiento",
- "Training stopped": "Entrenamiento detenido",
- "Type name of the speaker": "Escriba el nombre del hablante",
- "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
- "Use LoRA": "Usar LoRA",
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
- "Use filelist": "Usar lista de archivos",
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
- "VITS Configuration": "Configuración de VITS",
- "VQGAN Configuration": "Configuración de VQGAN",
- "Validation Batch Size": "Tamaño del Lote de Validación",
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
- "WebUI Host": "Host de WebUI",
- "WebUI Port": "Puerto de WebUI",
- "Whisper Model": "Modelo Whisper",
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
- "latest": "más reciente",
- "new": "nuevo",
- "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
- "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
- "Text Normalization": "Normalización de Texto",
- "Select Example Audio": "Selecionar áudio de exemplo"
-}
+{
+ "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Acumular lotes de gradientes",
+ "Add to Processing Area": "Agregar al Área de Procesamiento",
+ "Added path successfully!": "¡Ruta agregada exitosamente!",
+ "Advanced Config": "Configuración Avanzada",
+ "Base LLAMA Model": "Modelo Base LLAMA",
+ "Batch Inference": "Inferencia por Lote",
+ "Batch Size": "Tamaño del Lote",
+ "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
+ "Chinese": "Chino",
+ "Compile Model": "Compilar Modelo",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
+ "Copy": "Copiar",
+ "Data Preprocessing": "Preprocesamiento de Datos",
+ "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
+ "Data Source": "Fuente de Datos",
+ "Decoder Model Config": "Configuración del modelo decodificador",
+ "Decoder Model Path": "Ruta del modelo decodificador",
+ "Disabled": "Desactivado",
+ "Enable Reference Audio": "Habilitar Audio de Referencia",
+ "English": "Inglés",
+ "Error Message": "Mensaje de Error",
+ "File Preprocessing": "Preprocesamiento de Archivos",
+ "Generate": "Generar",
+ "Generated Audio": "Audio Generado",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
+ "Infer interface is closed": "La interfaz de inferencia está cerrada",
+ "Inference Configuration": "Configuración de Inferencia",
+ "Inference Server Configuration": "Configuración del Servidor de Inferencia",
+ "Inference Server Error": "Error del Servidor de Inferencia",
+ "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
+ "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
+ "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
+ "Input Text": "Texto de Entrada",
+ "Invalid path: {}": "Ruta inválida: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
+ "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
+ "Japanese": "Japonés",
+ "LLAMA Configuration": "Configuración de LLAMA",
+ "LLAMA Model Config": "Configuración del Modelo LLAMA",
+ "LLAMA Model Path": "Ruta del Modelo LLAMA",
+ "Labeling Device": "Dispositivo de Etiquetado",
+ "LoRA Model to be merged": "Modelo LoRA a fusionar",
+ "Maximum Audio Duration": "Duración máxima de audio",
+ "Maximum Length per Sample": "Longitud Máxima por Muestra",
+ "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
+ "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
+ "Merge": "Fusionar",
+ "Merge LoRA": "Fusionar LoRA",
+ "Merge successfully": "Fusionado exitosamente",
+ "Minimum Audio Duration": "Duración mínima de audio",
+ "Model Output Path": "Ruta de Salida del Modelo",
+ "Model Size": "Tamaño del Modelo",
+ "Move": "Mover",
+ "Move files successfully": "Archivos movidos exitosamente",
+ "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
+ "No selected options": "No hay opciones seleccionadas",
+ "Number of Workers": "Número de Trabajadores",
+ "Open Inference Server": "Abrir Servidor de Inferencia",
+ "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
+ "Open Tensorboard": "Abrir Tensorboard",
+ "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
+ "Optional Label Language": "Idioma de Etiquetado Opcional",
+ "Optional online ver": "Ver en línea opcional",
+ "Output Path": "Ruta de Salida",
+ "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
+ "Precision": "Precisión",
+ "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
+ "Put your text here.": "Ponga su texto aquí.",
+ "Reference Audio": "Audio de Referencia",
+ "Reference Text": "Texto de Referencia",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
+ "Remove Selected Data": "Eliminar Datos Seleccionados",
+ "Removed path successfully!": "¡Ruta eliminada exitosamente!",
+ "Repetition Penalty": "Penalización por Repetición",
+ "Save model every n steps": "Guardar modelo cada n pasos",
+ "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
+ "Select VITS ckpt": "Seleccionar punto de control VITS",
+ "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
+ "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
+ "Selected: {}": "Seleccionado: {}",
+ "Speaker": "Hablante",
+ "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
+ "Start Training": "Iniciar Entrenamiento",
+ "Streaming Audio": "transmisión de audio",
+ "Streaming Generate": "síntesis en flujo",
+ "Tensorboard Host": "Host de Tensorboard",
+ "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
+ "Tensorboard Port": "Puerto de Tensorboard",
+ "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
+ "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
+ "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
+ "Training Configuration": "Configuración de Entrenamiento",
+ "Training Error": "Error de Entrenamiento",
+ "Training stopped": "Entrenamiento detenido",
+ "Type name of the speaker": "Escriba el nombre del hablante",
+ "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
+ "Use LoRA": "Usar LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
+ "Use filelist": "Usar lista de archivos",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
+ "VITS Configuration": "Configuración de VITS",
+ "VQGAN Configuration": "Configuración de VQGAN",
+ "Validation Batch Size": "Tamaño del Lote de Validación",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
+ "WebUI Host": "Host de WebUI",
+ "WebUI Port": "Puerto de WebUI",
+ "Whisper Model": "Modelo Whisper",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
+ "latest": "más reciente",
+ "new": "nuevo",
+ "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
+ "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
+ "Text Normalization": "Normalización de Texto",
+ "Select Example Audio": "Selecionar áudio de exemplo"
+}
diff --git a/fish_speech/i18n/locale/ja_JP.json b/fish_speech/i18n/locale/ja_JP.json
index 9d0baeb73ffd9cc1af7570ef0ac7e6018ce9527b..863b8b0b41da7e504ac0dcc4abf707f1f71a53fa 100644
--- a/fish_speech/i18n/locale/ja_JP.json
+++ b/fish_speech/i18n/locale/ja_JP.json
@@ -1,123 +1,123 @@
-{
- "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
- "Accumulate Gradient Batches": "勾配バッチの累積",
- "Add to Processing Area": "処理エリアに追加",
- "Added path successfully!": "パスの追加に成功しました!",
- "Advanced Config": "詳細設定",
- "Base LLAMA Model": "基本LLAMAモデル",
- "Batch Inference": "バッチ推論",
- "Batch Size": "バッチサイズ",
- "Changing with the Model Path": "モデルのパスに伴って変化する",
- "Chinese": "中国語",
- "Compile Model": "モデルのコンパイル",
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
- "Copy": "コピー",
- "Data Preprocessing": "データ前処理",
- "Data Preprocessing Path": "データ前処理パス",
- "Data Source": "データソース",
- "Decoder Model Config": "デコーダーモデルの構成",
- "Decoder Model Path": "デコーダーモデルのパス",
- "Disabled": "無効",
- "Enable Reference Audio": "リファレンスオーディオを有効にする",
- "English": "英語",
- "Error Message": "エラーメッセージ",
- "File Preprocessing": "文書前处理",
- "Generate": "生成",
- "Generated Audio": "生成されたオーディオ",
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
- "Infer interface is closed": "推論インターフェースが閉じられています",
- "Inference Configuration": "推論設定",
- "Inference Server Configuration": "推論サーバー設定",
- "Inference Server Error": "推論サーバーエラー",
- "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
- "Initial Learning Rate": "初期学習率",
- "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
- "Input Text": "入力テキスト",
- "Invalid path: {}": "無効なパス: {}",
- "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
- "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
- "Japanese": "日本語",
- "LLAMA Configuration": "LLAMA設定",
- "LLAMA Model Config": "LLAMAモデル設定",
- "LLAMA Model Path": "LLAMAモデルパス",
- "Labeling Device": "ラベリングデバイス",
- "LoRA Model to be merged": "マージするLoRAモデル",
- "Maximum Audio Duration": "最大オーディオの長さ",
- "Maximum Length per Sample": "サンプルあたりの最大長",
- "Maximum Training Steps": "最大トレーニングステップ数",
- "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
- "Merge": "マージ",
- "Merge LoRA": "LoRAのマージ",
- "Merge successfully": "マージに成功しました",
- "Minimum Audio Duration": "最小オーディオの長さ",
- "Model Output Path": "モデル出力パス",
- "Model Size": "モデルサイズ",
- "Move": "移動",
- "Move files successfully": "ファイルの移動に成功しました",
- "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
- "No selected options": "選択されたオプションはありません",
- "Number of Workers": "ワーカー数",
- "Open Inference Server": "推論サーバーを開く",
- "Open Labeler WebUI": "ラベラーWebUIを開く",
- "Open Tensorboard": "Tensorboardを開く",
- "Opened labeler in browser": "ブラウザでラベラーを開きました",
- "Optional Label Language": "オプションのラベル言語",
- "Optional online ver": "オプションのオンラインバージョン",
- "Output Path": "出力パス",
- "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
- "Precision": "精度",
- "Probability of applying Speaker Condition": "話者条件を適用する確率",
- "Put your text here.": "ここにテキストを入力してください。",
- "Reference Audio": "リファレンスオーディオ",
- "Reference Text": "リファレンステキスト",
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
- "Remove Selected Data": "選択したデータを削除",
- "Removed path successfully!": "パスの削除に成功しました!",
- "Repetition Penalty": "反復ペナルティ",
- "Save model every n steps": "nステップごとにモデルを保存",
- "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
- "Select VITS ckpt": "VITS チェックポイントを選択",
- "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
- "Select source file processing method": "ソースファイルの処理方法を選択",
- "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
- "Selected: {}": "選択済み: {}",
- "Speaker": "話者",
- "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
- "Start Training": "トレーニング開始",
- "Streaming Audio": "ストリーミングオーディオ",
- "Streaming Generate": "ストリーミング合成",
- "Tensorboard Host": "Tensorboardホスト",
- "Tensorboard Log Path": "Tensorboardログパス",
- "Tensorboard Port": "Tensorboardポート",
- "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
- "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
- "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
- "Training Configuration": "トレーニング設定",
- "Training Error": "トレーニングエラー",
- "Training stopped": "トレーニングが停止しました",
- "Type name of the speaker": "話者の名前を入力",
- "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
- "Use LoRA": "LoRAを使用",
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
- "Use filelist": "ファイルリストを使用",
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
- "VITS Configuration": "VITS の構成",
- "VQGAN Configuration": "VQGAN の構成",
- "Validation Batch Size": "検証バッチサイズ",
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
- "WebUI Host": "WebUIホスト",
- "WebUI Port": "WebUIポート",
- "Whisper Model": "Whisperモデル",
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
- "latest": "最新",
- "new": "新規",
- "Realtime Transform Text": "リアルタイム変換テキスト",
- "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
- "Text Normalization": "テキスト正規化",
- "Select Example Audio": "サンプル音声を選択"
-}
+{
+ "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
+ "Accumulate Gradient Batches": "勾配バッチの累積",
+ "Add to Processing Area": "処理エリアに追加",
+ "Added path successfully!": "パスの追加に成功しました!",
+ "Advanced Config": "詳細設定",
+ "Base LLAMA Model": "基本LLAMAモデル",
+ "Batch Inference": "バッチ推論",
+ "Batch Size": "バッチサイズ",
+ "Changing with the Model Path": "モデルのパスに伴って変化する",
+ "Chinese": "中国語",
+ "Compile Model": "モデルのコンパイル",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
+ "Copy": "コピー",
+ "Data Preprocessing": "データ前処理",
+ "Data Preprocessing Path": "データ前処理パス",
+ "Data Source": "データソース",
+ "Decoder Model Config": "デコーダーモデルの構成",
+ "Decoder Model Path": "デコーダーモデルのパス",
+ "Disabled": "無効",
+ "Enable Reference Audio": "リファレンスオーディオを有効にする",
+ "English": "英語",
+ "Error Message": "エラーメッセージ",
+ "File Preprocessing": "文書前处理",
+ "Generate": "生成",
+ "Generated Audio": "生成されたオーディオ",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
+ "Infer interface is closed": "推論インターフェースが閉じられています",
+ "Inference Configuration": "推論設定",
+ "Inference Server Configuration": "推論サーバー設定",
+ "Inference Server Error": "推論サーバーエラー",
+ "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
+ "Initial Learning Rate": "初期学習率",
+ "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
+ "Input Text": "入力テキスト",
+ "Invalid path: {}": "無効なパス: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
+ "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
+ "Japanese": "日本語",
+ "LLAMA Configuration": "LLAMA設定",
+ "LLAMA Model Config": "LLAMAモデル設定",
+ "LLAMA Model Path": "LLAMAモデルパス",
+ "Labeling Device": "ラベリングデバイス",
+ "LoRA Model to be merged": "マージするLoRAモデル",
+ "Maximum Audio Duration": "最大オーディオの長さ",
+ "Maximum Length per Sample": "サンプルあたりの最大長",
+ "Maximum Training Steps": "最大トレーニングステップ数",
+ "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
+ "Merge": "マージ",
+ "Merge LoRA": "LoRAのマージ",
+ "Merge successfully": "マージに成功しました",
+ "Minimum Audio Duration": "最小オーディオの長さ",
+ "Model Output Path": "モデル出力パス",
+ "Model Size": "モデルサイズ",
+ "Move": "移動",
+ "Move files successfully": "ファイルの移動に成功しました",
+ "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
+ "No selected options": "選択されたオプションはありません",
+ "Number of Workers": "ワーカー数",
+ "Open Inference Server": "推論サーバーを開く",
+ "Open Labeler WebUI": "ラベラーWebUIを開く",
+ "Open Tensorboard": "Tensorboardを開く",
+ "Opened labeler in browser": "ブラウザでラベラーを開きました",
+ "Optional Label Language": "オプションのラベル言語",
+ "Optional online ver": "オプションのオンラインバージョン",
+ "Output Path": "出力パス",
+ "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
+ "Precision": "精度",
+ "Probability of applying Speaker Condition": "話者条件を適用する確率",
+ "Put your text here.": "ここにテキストを入力してください。",
+ "Reference Audio": "リファレンスオーディオ",
+ "Reference Text": "リファレンステキスト",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
+ "Remove Selected Data": "選択したデータを削除",
+ "Removed path successfully!": "パスの削除に成功しました!",
+ "Repetition Penalty": "反復ペナルティ",
+ "Save model every n steps": "nステップごとにモデルを保存",
+ "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
+ "Select VITS ckpt": "VITS チェックポイントを選択",
+ "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
+ "Select source file processing method": "ソースファイルの処理方法を選択",
+ "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
+ "Selected: {}": "選択済み: {}",
+ "Speaker": "話者",
+ "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
+ "Start Training": "トレーニング開始",
+ "Streaming Audio": "ストリーミングオーディオ",
+ "Streaming Generate": "ストリーミング合成",
+ "Tensorboard Host": "Tensorboardホスト",
+ "Tensorboard Log Path": "Tensorboardログパス",
+ "Tensorboard Port": "Tensorboardポート",
+ "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
+ "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
+ "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
+ "Training Configuration": "トレーニング設定",
+ "Training Error": "トレーニングエラー",
+ "Training stopped": "トレーニングが停止しました",
+ "Type name of the speaker": "話者の名前を入力",
+ "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
+ "Use LoRA": "LoRAを使用",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
+ "Use filelist": "ファイルリストを使用",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
+ "VITS Configuration": "VITS の構成",
+ "VQGAN Configuration": "VQGAN の構成",
+ "Validation Batch Size": "検証バッチサイズ",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
+ "WebUI Host": "WebUIホスト",
+ "WebUI Port": "WebUIポート",
+ "Whisper Model": "Whisperモデル",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
+ "latest": "最新",
+ "new": "新規",
+ "Realtime Transform Text": "リアルタイム変換テキスト",
+ "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
+ "Text Normalization": "テキスト正規化",
+ "Select Example Audio": "サンプル音声を選択"
+}
diff --git a/fish_speech/i18n/locale/ko_KR.json b/fish_speech/i18n/locale/ko_KR.json
index f4bf1841b7c847993707ec3b8e32f5174de77214..180263874b476059870035d4c2b74ce5fa553a8a 100644
--- a/fish_speech/i18n/locale/ko_KR.json
+++ b/fish_speech/i18n/locale/ko_KR.json
@@ -1,123 +1,123 @@
-{
- "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
- "Accumulate Gradient Batches": "그라디언트 배치 누적",
- "Add to Processing Area": "처리 영역에 추가",
- "Added path successfully!": "경로가 성공적으로 추가되었습니다!",
- "Advanced Config": "고급 설정",
- "Base LLAMA Model": "기본 LLAMA 모델",
- "Batch Inference": "배치 추론",
- "Batch Size": "배치 크기",
- "Changing with the Model Path": "모델 경로에 따라 변경 중",
- "Chinese": "중국어",
- "Compile Model": "모델 컴파일",
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
- "Copy": "복사",
- "Data Preprocessing": "데이터 전처리",
- "Data Preprocessing Path": "데이터 전처리 경로",
- "Data Source": "데이터 소스",
- "Decoder Model Config": "디코더 모델 설정",
- "Decoder Model Path": "디코더 모델 경로",
- "Disabled": "비활성화 됨",
- "Enable Reference Audio": "참고 음성 활성화",
- "English": "영어",
- "Error Message": "오류 메시지",
- "File Preprocessing": "파일 전처리",
- "Generate": "생성",
- "Generated Audio": "생성된 오디오",
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
- "Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
- "Inference Configuration": "추론 설정",
- "Inference Server Configuration": "추론 서버 설정",
- "Inference Server Error": "추론 서버 오류",
- "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
- "Initial Learning Rate": "초기 학습률",
- "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
- "Input Text": "입력 텍스트",
- "Invalid path: {}": "유효하지 않은 경로: {}",
- "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
- "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
- "Japanese": "일본어",
- "LLAMA Configuration": "LLAMA 설정",
- "LLAMA Model Config": "LLAMA 모델 설정",
- "LLAMA Model Path": "LLAMA 모델 경로",
- "Labeling Device": "라벨링 장치",
- "LoRA Model to be merged": "병합할 LoRA 모델",
- "Maximum Audio Duration": "최대 오디오 길이",
- "Maximum Length per Sample": "샘플당 최대 길이",
- "Maximum Training Steps": "최대 학습 단계",
- "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
- "Merge": "병합",
- "Merge LoRA": "LoRA 병합",
- "Merge successfully": "성공적으로 병합 되었습니다.",
- "Minimum Audio Duration": "최소 오디오 길이",
- "Model Output Path": "모델 출력 경로",
- "Model Size": "모델 크기",
- "Move": "이동",
- "Move files successfully": "파일이 성공적으로 이동되었습니다.",
- "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
- "No selected options": "옵션이 선택되지 않았습니다.",
- "Number of Workers": "작업자 수",
- "Open Inference Server": "추론 서버 열기",
- "Open Labeler WebUI": "라벨러 WebUI 열기",
- "Open Tensorboard": "Tensorboard 열기",
- "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
- "Optional Label Language": "선택적 라벨 언어",
- "Optional online ver": "온라인 버전 선택",
- "Output Path": "출력 경로",
- "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
- "Precision": "정밀도",
- "Probability of applying Speaker Condition": "화자 조건 적용 확률",
- "Put your text here.": "여기에 텍스트를 입력하세요.",
- "Reference Audio": "참고 오디오",
- "Reference Text": "참고 텍스트",
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
- "Remove Selected Data": "선택한 데이터 제거",
- "Removed path successfully!": "경로가 성공적으로 제거되었습니다!",
- "Repetition Penalty": "반복 패널티",
- "Save model every n steps": "n 단계마다 모델 저장",
- "Select LLAMA ckpt": "LLAMA ckpt 선택",
- "Select VITS ckpt": "VITS ckpt 선택",
- "Select VQGAN ckpt": "VQGAN ckpt 선택",
- "Select source file processing method": "소스 파일 처리 방법 선택",
- "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
- "Selected: {}": "선택됨: {}",
- "Speaker": "화자",
- "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
- "Start Training": "학습 시작",
- "Streaming Audio": "스트리밍 오디오",
- "Streaming Generate": "스트리밍 생성",
- "Tensorboard Host": "Tensorboard 호스트",
- "Tensorboard Log Path": "Tensorboard 로그 경로",
- "Tensorboard Port": "Tensorboard 포트",
- "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
- "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
- "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
- "Training Configuration": "학습 설정",
- "Training Error": "학습 오류",
- "Training stopped": "학습이 중지되었습니다.",
- "Type name of the speaker": "화자의 이름을 입력하세요.",
- "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
- "Use LoRA": "LoRA 사용",
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
- "Use filelist": "파일 목록 사용",
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
- "VITS Configuration": "VITS 설정",
- "VQGAN Configuration": "VQGAN 설정",
- "Validation Batch Size": "검증 배치 크기",
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
- "WebUI Host": "WebUI 호스트",
- "WebUI Port": "WebUI 포트",
- "Whisper Model": "Whisper 모델",
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
- "latest": "최신",
- "new": "새로운",
- "Realtime Transform Text": "실시간 텍스트 변환",
- "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
- "Text Normalization": "텍스트 정규화",
- "Select Example Audio": "예시 오디오 선택"
-}
+{
+ "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
+ "Accumulate Gradient Batches": "그라디언트 배치 누적",
+ "Add to Processing Area": "처리 영역에 추가",
+ "Added path successfully!": "경로가 성공적으로 추가되었습니다!",
+ "Advanced Config": "고급 설정",
+ "Base LLAMA Model": "기본 LLAMA 모델",
+ "Batch Inference": "배치 추론",
+ "Batch Size": "배치 크기",
+ "Changing with the Model Path": "모델 경로에 따라 변경 중",
+ "Chinese": "중국어",
+ "Compile Model": "모델 컴파일",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
+ "Copy": "복사",
+ "Data Preprocessing": "데이터 전처리",
+ "Data Preprocessing Path": "데이터 전처리 경로",
+ "Data Source": "데이터 소스",
+ "Decoder Model Config": "디코더 모델 설정",
+ "Decoder Model Path": "디코더 모델 경로",
+ "Disabled": "비활성화 됨",
+ "Enable Reference Audio": "참고 음성 활성화",
+ "English": "영어",
+ "Error Message": "오류 메시지",
+ "File Preprocessing": "파일 전처리",
+ "Generate": "생성",
+ "Generated Audio": "생성된 오디오",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
+ "Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
+ "Inference Configuration": "추론 설정",
+ "Inference Server Configuration": "추론 서버 설정",
+ "Inference Server Error": "추론 서버 오류",
+ "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
+ "Initial Learning Rate": "초기 학습률",
+ "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
+ "Input Text": "입력 텍스트",
+ "Invalid path: {}": "유효하지 않은 경로: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
+ "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
+ "Japanese": "일본어",
+ "LLAMA Configuration": "LLAMA 설정",
+ "LLAMA Model Config": "LLAMA 모델 설정",
+ "LLAMA Model Path": "LLAMA 모델 경로",
+ "Labeling Device": "라벨링 장치",
+ "LoRA Model to be merged": "병합할 LoRA 모델",
+ "Maximum Audio Duration": "최대 오디오 길이",
+ "Maximum Length per Sample": "샘플당 최대 길이",
+ "Maximum Training Steps": "최대 학습 단계",
+ "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
+ "Merge": "병합",
+ "Merge LoRA": "LoRA 병합",
+ "Merge successfully": "성공적으로 병합 되었습니다.",
+ "Minimum Audio Duration": "최소 오디오 길이",
+ "Model Output Path": "모델 출력 경로",
+ "Model Size": "모델 크기",
+ "Move": "이동",
+ "Move files successfully": "파일이 성공적으로 이동되었습니다.",
+ "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
+ "No selected options": "옵션이 선택되지 않았습니다.",
+ "Number of Workers": "작업자 수",
+ "Open Inference Server": "추론 서버 열기",
+ "Open Labeler WebUI": "라벨러 WebUI 열기",
+ "Open Tensorboard": "Tensorboard 열기",
+ "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
+ "Optional Label Language": "선택적 라벨 언어",
+ "Optional online ver": "온라인 버전 선택",
+ "Output Path": "출력 경로",
+ "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
+ "Precision": "정밀도",
+ "Probability of applying Speaker Condition": "화자 조건 적용 확률",
+ "Put your text here.": "여기에 텍스트를 입력하세요.",
+ "Reference Audio": "참고 오디오",
+ "Reference Text": "참고 텍스트",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
+ "Remove Selected Data": "선택한 데이터 제거",
+ "Removed path successfully!": "경로가 성공적으로 제거되었습니다!",
+ "Repetition Penalty": "반복 패널티",
+ "Save model every n steps": "n 단계마다 모델 저장",
+ "Select LLAMA ckpt": "LLAMA ckpt 선택",
+ "Select VITS ckpt": "VITS ckpt 선택",
+ "Select VQGAN ckpt": "VQGAN ckpt 선택",
+ "Select source file processing method": "소스 파일 처리 방법 선택",
+ "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
+ "Selected: {}": "선택됨: {}",
+ "Speaker": "화자",
+ "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
+ "Start Training": "학습 시작",
+ "Streaming Audio": "스트리밍 오디오",
+ "Streaming Generate": "스트리밍 생성",
+ "Tensorboard Host": "Tensorboard 호스트",
+ "Tensorboard Log Path": "Tensorboard 로그 경로",
+ "Tensorboard Port": "Tensorboard 포트",
+ "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
+ "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
+ "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
+ "Training Configuration": "학습 설정",
+ "Training Error": "학습 오류",
+ "Training stopped": "학습이 중지되었습니다.",
+ "Type name of the speaker": "화자의 이름을 입력하세요.",
+ "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
+ "Use LoRA": "LoRA 사용",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
+ "Use filelist": "파일 목록 사용",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
+ "VITS Configuration": "VITS 설정",
+ "VQGAN Configuration": "VQGAN 설정",
+ "Validation Batch Size": "검증 배치 크기",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
+ "WebUI Host": "WebUI 호스트",
+ "WebUI Port": "WebUI 포트",
+ "Whisper Model": "Whisper 모델",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
+ "latest": "최신",
+ "new": "새로운",
+ "Realtime Transform Text": "실시간 텍스트 변환",
+ "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
+ "Text Normalization": "텍스트 정규화",
+ "Select Example Audio": "예시 오디오 선택"
+}
diff --git a/fish_speech/i18n/locale/pt_BR.json b/fish_speech/i18n/locale/pt_BR.json
index a5278e29fac737d9fd3da3f8e82e49ac22a96ac3..385f20272e19053ab9b6cf6463a84c8ece768c68 100644
--- a/fish_speech/i18n/locale/pt_BR.json
+++ b/fish_speech/i18n/locale/pt_BR.json
@@ -1,133 +1,133 @@
-{
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
- "Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
- "Add to Processing Area": "Adicionar à Área de Processamento",
- "Added path successfully!": "Caminho adicionado com sucesso!",
- "Advanced Config": "Configuração Avançada",
- "Base LLAMA Model": "Modelo LLAMA Base",
- "Batch Inference": "Inferência em Lote",
- "Batch Size": "Tamanho do Lote",
- "Changing with the Model Path": "Alterando com o Caminho do Modelo",
-
- "Compile Model": "Compilar Modelo",
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
- "Copy": "Copiar",
- "Data Preprocessing": "Pré-processamento de Dados",
- "Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
- "Data Source": "Fonte de Dados",
- "Decoder Model Config": "Configuração do Modelo Decodificador",
- "Decoder Model Path": "Caminho do Modelo Decodificador",
- "Disabled": "Desativado",
- "Enable Initial Prompt": "Habilitar Prompt Inicial",
- "Enable Reference Audio": "Habilitar Áudio de Referência",
- "English": "Inglês",
- "Japanese": "Japonês",
- "Chinese": "Chinês",
- "Portuguese": "Português",
- "Spanish": "Espanhol",
- "Error Message": "Mensagem de Erro",
- "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
- "File Preprocessing": "Pré-processamento de Arquivos",
- "Generate": "Gerar",
- "Generated Audio": "Áudio Gerado",
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
- "Infer interface is closed": "A interface de inferência foi fechada",
- "Inference Configuration": "Configuração de Inferência",
- "Inference Server Configuration": "Configuração do Servidor de Inferência",
- "Inference Server Error": "Erro do Servidor de Inferência",
- "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
- "Initial Learning Rate": "Taxa de Aprendizagem Inicial",
- "Initial Prompt": "Prompt Inicial",
- "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
- "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
- "Input Text": "Texto de Entrada",
- "Invalid path: {}": "Caminho inválido: {}",
- "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
- "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
- "LLAMA Configuration": "Configuração do LLAMA",
- "LLAMA Model Config": "Configuração do Modelo LLAMA",
- "LLAMA Model Path": "Caminho do Modelo LLAMA",
- "Labeling Device": "Dispositivo de Rotulagem",
- "LoRA Model to be merged": "Modelo LoRA para mesclagem",
- "Maximum Length per Sample": "Comprimento Máximo por Amostra",
- "Maximum Training Steps": "Etapas Máximas de Treinamento",
- "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
- "Merge": "Mesclar",
- "Merge LoRA": "Mesclar LoRA",
- "Merge successfully": "Mesclado com sucesso",
- "Model Output Path": "Caminho de Saída do Modelo",
- "Model Quantization": "Quantização do Modelo",
- "Model Size": "Tamanho do Modelo",
- "Move": "Mover",
- "Move files successfully": "Arquivos movidos com sucesso",
- "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
- "No selected options": "Nenhuma opção selecionada",
- "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
- "Number of Workers": "Número de Processos",
- "Open Inference Server": "Abrir Servidor de Inferência",
- "Open Labeler WebUI": "Abrir WebUI de Rotulagem",
- "Open Tensorboard": "Abrir Tensorboard",
- "Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
- "Optional Label Language": "Idioma do Rótulo (Opcional)",
- "Optional online ver": "Versão online (opcional)",
- "Output Path": "Caminho de Saída",
- "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
- "Post-quantification Precision": "Precisão Pós-quantização",
- "Precision": "Precisão",
- "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
- "Put your text here.": "Insira seu texto aqui.",
- "Quantify": "Quantizar",
- "Quantify successfully": "Quantizado com sucesso",
- "Realtime Transform Text": "Transformar Texto em Tempo Real",
- "Reference Audio": "Áudio de Referência",
- "Reference Text": "Texto de Referência",
- "warning": "Aviso",
- "Pre-processing begins...": "O pré-processamento começou!",
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
- "Remove Selected Data": "Remover Dados Selecionados",
- "Removed path successfully!": "Caminho removido com sucesso!",
- "Repetition Penalty": "Penalidade de Repetição",
- "Save model every n steps": "Salvar modelo a cada n etapas",
- "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
- "Select source file processing method": "Escolha como processar o arquivo de origem",
- "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
- "Selected: {}": "Selecionado: {}",
- "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
- "Start Training": "Iniciar Treinamento",
- "Streaming Audio": "Áudio em Streaming",
- "Streaming Generate": "Geração em Streaming",
- "Tensorboard Host": "Host do Tensorboard",
- "Tensorboard Log Path": "Caminho de Log do Tensorboard",
- "Tensorboard Port": "Porta do Tensorboard",
- "Tensorboard interface is closed": "A interface do Tensorboard está fechada",
- "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
- "Text Normalization": "Normalização de Texto",
- "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
- "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
- "Training Configuration": "Configuração de Treinamento",
- "Training Error": "Erro de Treinamento",
- "Training stopped": "Treinamento interrompido!",
- "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
- "Use LoRA": "Usar LoRA",
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
- "Use filelist": "Usar lista de arquivos",
- "VQGAN Configuration": "Configuração do VQGAN",
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
- "WebUI Host": "Host da WebUI",
- "WebUI Port": "Porta da WebUI",
- "Whisper Model": "Modelo Whisper",
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
- "auto": "automático",
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
- "latest": "mais recente",
- "new": "novo",
- "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
- "You don't need to train this model!": "Não é necessário treinar este modelo!",
- "Yes": "Sim",
- "No": "Não",
- "version:": "versão:",
- "author:": "autor:"
-}
+{
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
+ "Add to Processing Area": "Adicionar à Área de Processamento",
+ "Added path successfully!": "Caminho adicionado com sucesso!",
+ "Advanced Config": "Configuração Avançada",
+ "Base LLAMA Model": "Modelo LLAMA Base",
+ "Batch Inference": "Inferência em Lote",
+ "Batch Size": "Tamanho do Lote",
+ "Changing with the Model Path": "Alterando com o Caminho do Modelo",
+
+ "Compile Model": "Compilar Modelo",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
+ "Copy": "Copiar",
+ "Data Preprocessing": "Pré-processamento de Dados",
+ "Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
+ "Data Source": "Fonte de Dados",
+ "Decoder Model Config": "Configuração do Modelo Decodificador",
+ "Decoder Model Path": "Caminho do Modelo Decodificador",
+ "Disabled": "Desativado",
+ "Enable Initial Prompt": "Habilitar Prompt Inicial",
+ "Enable Reference Audio": "Habilitar Áudio de Referência",
+ "English": "Inglês",
+ "Japanese": "Japonês",
+ "Chinese": "Chinês",
+ "Portuguese": "Português",
+ "Spanish": "Espanhol",
+ "Error Message": "Mensagem de Erro",
+ "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
+ "File Preprocessing": "Pré-processamento de Arquivos",
+ "Generate": "Gerar",
+ "Generated Audio": "Áudio Gerado",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
+ "Infer interface is closed": "A interface de inferência foi fechada",
+ "Inference Configuration": "Configuração de Inferência",
+ "Inference Server Configuration": "Configuração do Servidor de Inferência",
+ "Inference Server Error": "Erro do Servidor de Inferência",
+ "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
+ "Initial Learning Rate": "Taxa de Aprendizagem Inicial",
+ "Initial Prompt": "Prompt Inicial",
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
+ "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
+ "Input Text": "Texto de Entrada",
+ "Invalid path: {}": "Caminho inválido: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
+ "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
+ "LLAMA Configuration": "Configuração do LLAMA",
+ "LLAMA Model Config": "Configuração do Modelo LLAMA",
+ "LLAMA Model Path": "Caminho do Modelo LLAMA",
+ "Labeling Device": "Dispositivo de Rotulagem",
+ "LoRA Model to be merged": "Modelo LoRA para mesclagem",
+ "Maximum Length per Sample": "Comprimento Máximo por Amostra",
+ "Maximum Training Steps": "Etapas Máximas de Treinamento",
+ "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
+ "Merge": "Mesclar",
+ "Merge LoRA": "Mesclar LoRA",
+ "Merge successfully": "Mesclado com sucesso",
+ "Model Output Path": "Caminho de Saída do Modelo",
+ "Model Quantization": "Quantização do Modelo",
+ "Model Size": "Tamanho do Modelo",
+ "Move": "Mover",
+ "Move files successfully": "Arquivos movidos com sucesso",
+ "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
+ "No selected options": "Nenhuma opção selecionada",
+ "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
+ "Number of Workers": "Número de Processos",
+ "Open Inference Server": "Abrir Servidor de Inferência",
+ "Open Labeler WebUI": "Abrir WebUI de Rotulagem",
+ "Open Tensorboard": "Abrir Tensorboard",
+ "Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
+ "Optional Label Language": "Idioma do Rótulo (Opcional)",
+ "Optional online ver": "Versão online (opcional)",
+ "Output Path": "Caminho de Saída",
+ "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
+ "Post-quantification Precision": "Precisão Pós-quantização",
+ "Precision": "Precisão",
+ "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
+ "Put your text here.": "Insira seu texto aqui.",
+ "Quantify": "Quantizar",
+ "Quantify successfully": "Quantizado com sucesso",
+ "Realtime Transform Text": "Transformar Texto em Tempo Real",
+ "Reference Audio": "Áudio de Referência",
+ "Reference Text": "Texto de Referência",
+ "warning": "Aviso",
+ "Pre-processing begins...": "O pré-processamento começou!",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
+ "Remove Selected Data": "Remover Dados Selecionados",
+ "Removed path successfully!": "Caminho removido com sucesso!",
+ "Repetition Penalty": "Penalidade de Repetição",
+ "Save model every n steps": "Salvar modelo a cada n etapas",
+ "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
+ "Select source file processing method": "Escolha como processar o arquivo de origem",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
+ "Selected: {}": "Selecionado: {}",
+ "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
+ "Start Training": "Iniciar Treinamento",
+ "Streaming Audio": "Áudio em Streaming",
+ "Streaming Generate": "Geração em Streaming",
+ "Tensorboard Host": "Host do Tensorboard",
+ "Tensorboard Log Path": "Caminho de Log do Tensorboard",
+ "Tensorboard Port": "Porta do Tensorboard",
+ "Tensorboard interface is closed": "A interface do Tensorboard está fechada",
+ "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
+ "Text Normalization": "Normalização de Texto",
+ "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
+ "Training Configuration": "Configuração de Treinamento",
+ "Training Error": "Erro de Treinamento",
+ "Training stopped": "Treinamento interrompido!",
+ "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
+ "Use LoRA": "Usar LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
+ "Use filelist": "Usar lista de arquivos",
+ "VQGAN Configuration": "Configuração do VQGAN",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
+ "WebUI Host": "Host da WebUI",
+ "WebUI Port": "Porta da WebUI",
+ "Whisper Model": "Modelo Whisper",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
+ "auto": "automático",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
+ "latest": "mais recente",
+ "new": "novo",
+ "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
+ "You don't need to train this model!": "Não é necessário treinar este modelo!",
+ "Yes": "Sim",
+ "No": "Não",
+ "version:": "versão:",
+ "author:": "autor:"
+}
diff --git a/fish_speech/i18n/locale/zh_CN.json b/fish_speech/i18n/locale/zh_CN.json
index df7cd5477bfa035ea66ae7322dad46f5d054d9b0..9068ef0b9a41b9941b37644c6a4c96ec6a5d836e 100644
--- a/fish_speech/i18n/locale/zh_CN.json
+++ b/fish_speech/i18n/locale/zh_CN.json
@@ -1,123 +1,123 @@
-{
- "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
- "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
- "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
- "Accumulate Gradient Batches": "梯度累积批次",
- "Add to Processing Area": "加入处理区",
- "Added path successfully!": "添加路径成功!",
- "Advanced Config": "高级参数",
- "Base LLAMA Model": "基础 LLAMA 模型",
- "Batch Inference": "批量推理",
- "Batch Size": "批次大小",
- "Changing with the Model Path": "随模型路径变化",
- "Chinese": "中文",
- "Compile Model": "编译模型",
- "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
- "Copy": "复制",
- "Data Preprocessing": "数据预处理",
- "Data Preprocessing Path": "数据预处理路径",
- "Data Source": "数据源",
- "Decoder Model Config": "解码器模型配置",
- "Decoder Model Path": "解码器模型路径",
- "Disabled": "禁用",
- "Enable Reference Audio": "启用参考音频",
- "English": "英文",
- "Error Message": "错误信息",
- "File Preprocessing": "文件预处理",
- "Generate": "生成",
- "Generated Audio": "音频",
- "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
- "Infer interface is closed": "推理界面已关闭",
- "Inference Configuration": "推理配置",
- "Inference Server Configuration": "推理服务器配置",
- "Inference Server Error": "推理服务器错误",
- "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
- "Initial Learning Rate": "初始学习率",
- "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
- "Input Text": "输入文本",
- "Invalid path: {}": "无效路径: {}",
- "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
- "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
- "Japanese": "日文",
- "LLAMA Configuration": "LLAMA 配置",
- "LLAMA Model Config": "LLAMA 模型配置",
- "LLAMA Model Path": "LLAMA 模型路径",
- "Labeling Device": "标注加速设备",
- "LoRA Model to be merged": "要合并的 LoRA 模型",
- "Maximum Audio Duration": "最大音频时长",
- "Maximum Length per Sample": "每个样本的最大长度",
- "Maximum Training Steps": "最大训练步数",
- "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
- "Merge": "合并",
- "Merge LoRA": "合并 LoRA",
- "Merge successfully": "合并成功",
- "Minimum Audio Duration": "最小音频时长",
- "Model Output Path": "模型输出路径",
- "Model Size": "模型规模",
- "Move": "移动",
- "Move files successfully": "移动文件成功",
- "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
- "No selected options": "没有选择的选项",
- "Number of Workers": "数据加载进程数",
- "Open Inference Server": "打开推理服务器",
- "Open Labeler WebUI": "打开标注工具",
- "Open Tensorboard": "打开 Tensorboard",
- "Opened labeler in browser": "在浏览器中打开标注工具",
- "Optional Label Language": "[可选] 标注语言",
- "Optional online ver": "[可选] 使用在线版",
- "Output Path": "输出路径",
- "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
- "Precision": "精度",
- "Probability of applying Speaker Condition": "应用说话人条件的概率",
- "Put your text here.": "在此处输入文本.",
- "Reference Audio": "参考音频",
- "Reference Text": "参考文本",
- "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
- "Remove Selected Data": "移除选中数据",
- "Removed path successfully!": "移除路径成功!",
- "Repetition Penalty": "重复惩罚",
- "Save model every n steps": "每 n 步保存模型",
- "Select LLAMA ckpt": "选择 LLAMA 检查点",
- "Select VITS ckpt": "选择 VITS 检查点",
- "Select VQGAN ckpt": "选择 VQGAN 检查点",
- "Select source file processing method": "选择源文件处理方法",
- "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
- "Selected: {}": "已选择: {}",
- "Speaker": "说话人",
- "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
- "Start Training": "开始训练",
- "Streaming Audio": "流式音频",
- "Streaming Generate": "流式合成",
- "Tensorboard Host": "Tensorboard 监听地址",
- "Tensorboard Log Path": "Tensorboard 日志路径",
- "Tensorboard Port": "Tensorboard 端口",
- "Tensorboard interface is closed": "Tensorboard 界面已关闭",
- "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
- "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
- "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
- "Training Configuration": "训练配置",
- "Training Error": "训练错误",
- "Training stopped": "训练已停止",
- "Type name of the speaker": "输入说话人的名称",
- "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
- "Use LoRA": "使用 LoRA",
- "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
- "Use filelist": "使用文件列表",
- "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
- "VITS Configuration": "VITS 配置",
- "VQGAN Configuration": "VQGAN 配置",
- "Validation Batch Size": "验证批次大小",
- "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
- "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
- "WebUI Host": "WebUI 监听地址",
- "WebUI Port": "WebUI 端口",
- "Whisper Model": "Whisper 模型",
- "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
- "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
- "latest": "最近的检查点",
- "new": "创建新的检查点",
- "Realtime Transform Text": "实时规范化文本",
- "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
- "Text Normalization": "文本规范化",
- "Select Example Audio": "选择参考音频"
-}
+{
+ "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
+ "Accumulate Gradient Batches": "梯度累积批次",
+ "Add to Processing Area": "加入处理区",
+ "Added path successfully!": "添加路径成功!",
+ "Advanced Config": "高级参数",
+ "Base LLAMA Model": "基础 LLAMA 模型",
+ "Batch Inference": "批量推理",
+ "Batch Size": "批次大小",
+ "Changing with the Model Path": "随模型路径变化",
+ "Chinese": "中文",
+ "Compile Model": "编译模型",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
+ "Copy": "复制",
+ "Data Preprocessing": "数据预处理",
+ "Data Preprocessing Path": "数据预处理路径",
+ "Data Source": "数据源",
+ "Decoder Model Config": "解码器模型配置",
+ "Decoder Model Path": "解码器模型路径",
+ "Disabled": "禁用",
+ "Enable Reference Audio": "启用参考音频",
+ "English": "英文",
+ "Error Message": "错误信息",
+ "File Preprocessing": "文件预处理",
+ "Generate": "生成",
+ "Generated Audio": "音频",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
+ "Infer interface is closed": "推理界面已关闭",
+ "Inference Configuration": "推理配置",
+ "Inference Server Configuration": "推理服务器配置",
+ "Inference Server Error": "推理服务器错误",
+ "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
+ "Initial Learning Rate": "初始学习率",
+ "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
+ "Input Text": "输入文本",
+ "Invalid path: {}": "无效路径: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
+ "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
+ "Japanese": "日文",
+ "LLAMA Configuration": "LLAMA 配置",
+ "LLAMA Model Config": "LLAMA 模型配置",
+ "LLAMA Model Path": "LLAMA 模型路径",
+ "Labeling Device": "标注加速设备",
+ "LoRA Model to be merged": "要合并的 LoRA 模型",
+ "Maximum Audio Duration": "最大音频时长",
+ "Maximum Length per Sample": "每个样本的最大长度",
+ "Maximum Training Steps": "最大训练步数",
+ "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
+ "Merge": "合并",
+ "Merge LoRA": "合并 LoRA",
+ "Merge successfully": "合并成功",
+ "Minimum Audio Duration": "最小音频时长",
+ "Model Output Path": "模型输出路径",
+ "Model Size": "模型规模",
+ "Move": "移动",
+ "Move files successfully": "移动文件成功",
+ "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
+ "No selected options": "没有选择的选项",
+ "Number of Workers": "数据加载进程数",
+ "Open Inference Server": "打开推理服务器",
+ "Open Labeler WebUI": "打开标注工具",
+ "Open Tensorboard": "打开 Tensorboard",
+ "Opened labeler in browser": "在浏览器中打开标注工具",
+ "Optional Label Language": "[可选] 标注语言",
+ "Optional online ver": "[可选] 使用在线版",
+ "Output Path": "输出路径",
+ "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
+ "Precision": "精度",
+ "Probability of applying Speaker Condition": "应用说话人条件的概率",
+ "Put your text here.": "在此处输入文本.",
+ "Reference Audio": "参考音频",
+ "Reference Text": "参考文本",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
+ "Remove Selected Data": "移除选中数据",
+ "Removed path successfully!": "移除路径成功!",
+ "Repetition Penalty": "重复惩罚",
+ "Save model every n steps": "每 n 步保存模型",
+ "Select LLAMA ckpt": "选择 LLAMA 检查点",
+ "Select VITS ckpt": "选择 VITS 检查点",
+ "Select VQGAN ckpt": "选择 VQGAN 检查点",
+ "Select source file processing method": "选择源文件处理方法",
+ "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
+ "Selected: {}": "已选择: {}",
+ "Speaker": "说话人",
+ "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
+ "Start Training": "开始训练",
+ "Streaming Audio": "流式音频",
+ "Streaming Generate": "流式合成",
+ "Tensorboard Host": "Tensorboard 监听地址",
+ "Tensorboard Log Path": "Tensorboard 日志路径",
+ "Tensorboard Port": "Tensorboard 端口",
+ "Tensorboard interface is closed": "Tensorboard 界面已关闭",
+ "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
+ "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
+ "Training Configuration": "训练配置",
+ "Training Error": "训练错误",
+ "Training stopped": "训练已停止",
+ "Type name of the speaker": "输入说话人的名称",
+ "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
+ "Use LoRA": "使用 LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
+ "Use filelist": "使用文件列表",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
+ "VITS Configuration": "VITS 配置",
+ "VQGAN Configuration": "VQGAN 配置",
+ "Validation Batch Size": "验证批次大小",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
+ "WebUI Host": "WebUI 监听地址",
+ "WebUI Port": "WebUI 端口",
+ "Whisper Model": "Whisper 模型",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
+ "latest": "最近的检查点",
+ "new": "创建新的检查点",
+ "Realtime Transform Text": "实时规范化文本",
+ "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
+ "Text Normalization": "文本规范化",
+ "Select Example Audio": "选择参考音频"
+}
diff --git a/fish_speech/i18n/scan.py b/fish_speech/i18n/scan.py
index 00a39a4d08b8a19c91a8518e90bafe6ceea2231c..d0194c0f1a31dc95309c64626d13f04751a44ba1 100644
--- a/fish_speech/i18n/scan.py
+++ b/fish_speech/i18n/scan.py
@@ -1,122 +1,122 @@
-import ast
-import glob
-import json
-from collections import OrderedDict
-from pathlib import Path
-
-from loguru import logger
-
-from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
-
-
-def extract_i18n_strings(node):
- i18n_strings = []
-
- if (
- isinstance(node, ast.Call)
- and isinstance(node.func, ast.Name)
- and node.func.id == "i18n"
- ):
- for arg in node.args:
- if isinstance(arg, ast.Str):
- i18n_strings.append(arg.s)
-
- for child_node in ast.iter_child_nodes(node):
- i18n_strings.extend(extract_i18n_strings(child_node))
-
- return i18n_strings
-
-
-# scan the directory for all .py files (recursively)
-# for each file, parse the code into an AST
-# for each AST, extract the i18n strings
-
-strings = []
-folders = ["fish_speech", "tools"]
-# for filename in glob.iglob("**/*.py", recursive=True):
-for folder in folders:
- for f in Path(folder).rglob("*.py"):
- code = f.read_text(encoding="utf-8")
- if "i18n(" in code:
- tree = ast.parse(code)
- i18n_strings = extract_i18n_strings(tree)
- logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
- strings.extend(i18n_strings)
-
-code_keys = set(strings)
-logger.info(f"Total unique: {len(code_keys)}")
-
-
-standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
-with open(standard_file, "r", encoding="utf-8") as f:
- standard_data = json.load(f, object_pairs_hook=OrderedDict)
-standard_keys = set(standard_data.keys())
-
-# Define the standard file name
-unused_keys = standard_keys - code_keys
-logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
-for unused_key in unused_keys:
- logger.info(f"\t{unused_key}")
-
-missing_keys = code_keys - standard_keys
-logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
-for missing_key in missing_keys:
- logger.info(f"\t{missing_key}")
-
-code_keys_dict = OrderedDict()
-for s in strings:
- code_keys_dict[s] = s
-
-# write back
-with open(standard_file, "w", encoding="utf-8") as f:
- json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
- f.write("\n")
-
-logger.info(f"Updated {standard_file}")
-
-
-# Define the standard file name
-standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
-
-# Find all JSON files in the directory
-dir_path = I18N_FILE_PATH
-languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
-
-# Load the standard file
-with open(standard_file, "r", encoding="utf-8") as f:
- standard_data = json.load(f, object_pairs_hook=OrderedDict)
-
-# Loop through each language file
-for lang_file in languages:
- # Load the language file
- with open(lang_file, "r", encoding="utf-8") as f:
- lang_data = json.load(f, object_pairs_hook=OrderedDict)
-
- # Find the difference between the language file and the standard file
- diff = set(standard_data.keys()) - set(lang_data.keys())
-
- miss = set(lang_data.keys()) - set(standard_data.keys())
-
- # Add any missing keys to the language file
- for key in diff:
- lang_data[key] = "#!" + key
- logger.info(f"Added missing key: {key} to {lang_file}")
-
- # Del any extra keys to the language file
- for key in miss:
- del lang_data[key]
- logger.info(f"Del extra key: {key} from {lang_file}")
-
- # Sort the keys of the language file to match the order of the standard file
- lang_data = OrderedDict(
- sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
- )
-
- # Save the updated language file
- with open(lang_file, "w", encoding="utf-8") as f:
- json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
- f.write("\n")
-
- logger.info(f"Updated {lang_file}")
-
-logger.info("Done")
+import ast
+import glob
+import json
+from collections import OrderedDict
+from pathlib import Path
+
+from loguru import logger
+
+from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
+
+
+def extract_i18n_strings(node):
+ i18n_strings = []
+
+ if (
+ isinstance(node, ast.Call)
+ and isinstance(node.func, ast.Name)
+ and node.func.id == "i18n"
+ ):
+ for arg in node.args:
+ if isinstance(arg, ast.Str):
+ i18n_strings.append(arg.s)
+
+ for child_node in ast.iter_child_nodes(node):
+ i18n_strings.extend(extract_i18n_strings(child_node))
+
+ return i18n_strings
+
+
+# scan the directory for all .py files (recursively)
+# for each file, parse the code into an AST
+# for each AST, extract the i18n strings
+
+strings = []
+folders = ["fish_speech", "tools"]
+# for filename in glob.iglob("**/*.py", recursive=True):
+for folder in folders:
+ for f in Path(folder).rglob("*.py"):
+ code = f.read_text(encoding="utf-8")
+ if "i18n(" in code:
+ tree = ast.parse(code)
+ i18n_strings = extract_i18n_strings(tree)
+ logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
+ strings.extend(i18n_strings)
+
+code_keys = set(strings)
+logger.info(f"Total unique: {len(code_keys)}")
+
+
+standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
+with open(standard_file, "r", encoding="utf-8") as f:
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
+standard_keys = set(standard_data.keys())
+
+# Define the standard file name
+unused_keys = standard_keys - code_keys
+logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
+for unused_key in unused_keys:
+ logger.info(f"\t{unused_key}")
+
+missing_keys = code_keys - standard_keys
+logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
+for missing_key in missing_keys:
+ logger.info(f"\t{missing_key}")
+
+code_keys_dict = OrderedDict()
+for s in strings:
+ code_keys_dict[s] = s
+
+# write back
+with open(standard_file, "w", encoding="utf-8") as f:
+ json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
+ f.write("\n")
+
+logger.info(f"Updated {standard_file}")
+
+
+# Define the standard file name
+standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
+
+# Find all JSON files in the directory
+dir_path = I18N_FILE_PATH
+languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
+
+# Load the standard file
+with open(standard_file, "r", encoding="utf-8") as f:
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
+
+# Loop through each language file
+for lang_file in languages:
+ # Load the language file
+ with open(lang_file, "r", encoding="utf-8") as f:
+ lang_data = json.load(f, object_pairs_hook=OrderedDict)
+
+ # Find the difference between the language file and the standard file
+ diff = set(standard_data.keys()) - set(lang_data.keys())
+
+ miss = set(lang_data.keys()) - set(standard_data.keys())
+
+ # Add any missing keys to the language file
+ for key in diff:
+ lang_data[key] = "#!" + key
+ logger.info(f"Added missing key: {key} to {lang_file}")
+
+ # Del any extra keys to the language file
+ for key in miss:
+ del lang_data[key]
+ logger.info(f"Del extra key: {key} from {lang_file}")
+
+ # Sort the keys of the language file to match the order of the standard file
+ lang_data = OrderedDict(
+ sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
+ )
+
+ # Save the updated language file
+ with open(lang_file, "w", encoding="utf-8") as f:
+ json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
+ f.write("\n")
+
+ logger.info(f"Updated {lang_file}")
+
+logger.info("Done")
diff --git a/fish_speech/inference_engine/__init__.py b/fish_speech/inference_engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f54de5bf7f4d70dbfcfaff10423ac75d4c7d07b0
--- /dev/null
+++ b/fish_speech/inference_engine/__init__.py
@@ -0,0 +1,192 @@
+import gc
+import queue
+from typing import Generator
+
+import numpy as np
+import torch
+from loguru import logger
+
+from fish_speech.inference_engine.reference_loader import ReferenceLoader
+from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header
+from fish_speech.inference_engine.vq_manager import VQManager
+from fish_speech.models.dac.modded_dac import DAC
+from fish_speech.models.text2semantic.inference import (
+ GenerateRequest,
+ GenerateResponse,
+ WrappedGenerateResponse,
+)
+from fish_speech.utils import autocast_exclude_mps, set_seed
+from fish_speech.utils.schema import ServeTTSRequest
+
+
+class TTSInferenceEngine(ReferenceLoader, VQManager):
+
+ def __init__(
+ self,
+ llama_queue: queue.Queue,
+ decoder_model: DAC,
+ precision: torch.dtype,
+ compile: bool,
+ ) -> None:
+
+ super().__init__()
+
+ self.llama_queue = llama_queue
+ self.decoder_model = decoder_model
+ self.precision = precision
+ self.compile = compile
+
+ @torch.inference_mode()
+ def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
+ """
+ Main inference function:
+ - Loads the reference audio and text.
+ - Calls the LLAMA model for inference.
+ - Decodes the VQ tokens to audio.
+ """
+
+ ref_id: str | None = req.reference_id
+ prompt_tokens, prompt_texts = [], []
+ # Load the reference audio and text based on id or hash
+ if ref_id is not None:
+ prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
+
+ elif req.references:
+ prompt_tokens, prompt_texts = self.load_by_hash(
+ req.references, req.use_memory_cache
+ )
+
+ # Set the random seed if provided
+ if req.seed is not None:
+ set_seed(req.seed)
+ logger.warning(f"set seed: {req.seed}")
+
+ # Get the symbolic tokens from the LLAMA model
+ response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
+
+ # Get the sample rate from the decoder model
+ if hasattr(self.decoder_model, "spec_transform"):
+ sample_rate = self.decoder_model.spec_transform.sample_rate
+ else:
+ sample_rate = self.decoder_model.sample_rate
+
+ # If streaming, send the header
+ if req.streaming:
+ yield InferenceResult(
+ code="header",
+ audio=(
+ sample_rate,
+ np.array(wav_chunk_header(sample_rate=sample_rate)),
+ ),
+ error=None,
+ )
+
+ segments = []
+
+ while True:
+ # Get the response from the LLAMA model
+ wrapped_result: WrappedGenerateResponse = response_queue.get()
+ if wrapped_result.status == "error":
+ yield InferenceResult(
+ code="error",
+ audio=None,
+ error=(
+ wrapped_result.response
+ if isinstance(wrapped_result.response, Exception)
+ else Exception("Unknown error")
+ ),
+ )
+ break
+
+ # Check the response type
+ if not isinstance(wrapped_result.response, GenerateResponse):
+ raise TypeError(
+ "Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
+ )
+
+ result: GenerateResponse = wrapped_result.response
+ if result.action != "next":
+ segment = self.get_audio_segment(result)
+
+ if req.streaming: # Used only by the API server
+ yield InferenceResult(
+ code="segment",
+ audio=(sample_rate, segment),
+ error=None,
+ )
+ segments.append(segment)
+ else:
+ break
+
+ # Clean up the memory
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ # Edge case: no audio generated
+ if len(segments) == 0:
+ yield InferenceResult(
+ code="error",
+ audio=None,
+ error=RuntimeError("No audio generated, please check the input text."),
+ )
+ else:
+ # Streaming or not, return the final audio
+ audio = np.concatenate(segments, axis=0)
+ yield InferenceResult(
+ code="final",
+ audio=(sample_rate, audio),
+ error=None,
+ )
+
+ return None
+
+ def send_Llama_request(
+ self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
+ ) -> queue.Queue:
+ """
+ Send a request to the LLAMA model to generate the symbolic tokens.
+ """
+
+ # Prepare the request
+ request = dict(
+ device=self.decoder_model.device,
+ max_new_tokens=req.max_new_tokens,
+ text=req.text,
+ top_p=req.top_p,
+ repetition_penalty=req.repetition_penalty,
+ temperature=req.temperature,
+ compile=self.compile,
+ iterative_prompt=req.chunk_length > 0,
+ chunk_length=req.chunk_length,
+ prompt_tokens=prompt_tokens,
+ prompt_text=prompt_texts,
+ )
+
+ # Create a queue to get the response
+ response_queue = queue.Queue()
+
+ # Send the request to the LLAMA model
+ self.llama_queue.put(
+ GenerateRequest(
+ request=request,
+ response_queue=response_queue,
+ )
+ )
+
+ return response_queue
+
+ def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
+ """
+ Decode the VQ tokens to audio.
+ """
+
+ # Don't use autocast on MPS devices
+ with autocast_exclude_mps(
+ device_type=self.decoder_model.device.type, dtype=self.precision
+ ):
+ # Decode the symbolic tokens to audio
+ segment = self.decode_vq_tokens(codes=result.codes)
+
+ # Convert the audio to numpy
+ return segment.float().cpu().numpy()
diff --git a/fish_speech/inference_engine/reference_loader.py b/fish_speech/inference_engine/reference_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dc3a89bb4010efb659cbf6b09e301bb0aac45d1
--- /dev/null
+++ b/fish_speech/inference_engine/reference_loader.py
@@ -0,0 +1,130 @@
+import io
+from hashlib import sha256
+from pathlib import Path
+from typing import Callable, Literal, Tuple
+
+import torch
+import torchaudio
+from loguru import logger
+
+from fish_speech.models.dac.modded_dac import DAC
+from fish_speech.utils.file import (
+ AUDIO_EXTENSIONS,
+ audio_to_bytes,
+ list_files,
+ read_ref_text,
+)
+from fish_speech.utils.schema import ServeReferenceAudio
+
+
+class ReferenceLoader:
+
+ def __init__(self) -> None:
+ """
+ Component of the TTSInferenceEngine class.
+ Loads and manages the cache for the reference audio and text.
+ """
+ self.ref_by_id: dict = {}
+ self.ref_by_hash: dict = {}
+
+ # Make Pylance happy (attribut/method not defined...)
+ self.decoder_model: DAC
+ self.encode_reference: Callable
+
+ # Define the torchaudio backend
+ backends = torchaudio.list_audio_backends()
+ if "ffmpeg" in backends:
+ self.backend = "ffmpeg"
+ else:
+ self.backend = "soundfile"
+
+ def load_by_id(
+ self,
+ id: str,
+ use_cache: Literal["on", "off"],
+ ) -> Tuple:
+
+ # Load the references audio and text by id
+ ref_folder = Path("references") / id
+ ref_folder.mkdir(parents=True, exist_ok=True)
+ ref_audios = list_files(
+ ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
+ )
+
+ if use_cache == "off" or id not in self.ref_by_id:
+ # If the references are not already loaded, encode them
+ prompt_tokens = [
+ self.encode_reference(
+ # decoder_model=self.decoder_model,
+ reference_audio=audio_to_bytes(str(ref_audio)),
+ enable_reference_audio=True,
+ )
+ for ref_audio in ref_audios
+ ]
+ prompt_texts = [
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
+ for ref_audio in ref_audios
+ ]
+ self.ref_by_id[id] = (prompt_tokens, prompt_texts)
+
+ else:
+ # Reuse already encoded references
+ logger.info("Use same references")
+ prompt_tokens, prompt_texts = self.ref_by_id[id]
+
+ return prompt_tokens, prompt_texts
+
+ def load_by_hash(
+ self,
+ references: list[ServeReferenceAudio],
+ use_cache: Literal["on", "off"],
+ ) -> Tuple:
+
+ # Load the references audio and text by hash
+ audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
+
+ cache_used = False
+ prompt_tokens, prompt_texts = [], []
+ for i, ref in enumerate(references):
+ if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
+ # If the references are not already loaded, encode them
+ prompt_tokens.append(
+ self.encode_reference(
+ reference_audio=ref.audio,
+ enable_reference_audio=True,
+ )
+ )
+ prompt_texts.append(ref.text)
+ self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
+
+ else:
+ # Reuse already encoded references
+ prompt_tokens, prompt_texts = self.ref_by_hash[audio_hashes[i]]
+ cache_used = True
+
+ if cache_used:
+ logger.info("Use same references")
+
+ return prompt_tokens, prompt_texts
+
+ def load_audio(self, reference_audio, sr):
+ """
+ Load the audio data from a file or bytes.
+ """
+ if len(reference_audio) > 255 or not Path(reference_audio).exists():
+ audio_data = reference_audio
+ reference_audio = io.BytesIO(audio_data)
+
+ waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
+
+ if waveform.shape[0] > 1:
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
+
+ if original_sr != sr:
+ resampler = torchaudio.transforms.Resample(
+ orig_freq=original_sr, new_freq=sr
+ )
+ waveform = resampler(waveform)
+
+ audio = waveform.squeeze().numpy()
+ return audio
diff --git a/fish_speech/inference_engine/utils.py b/fish_speech/inference_engine/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c11af150dc4a0985b11ba7675d93cf165490f35
--- /dev/null
+++ b/fish_speech/inference_engine/utils.py
@@ -0,0 +1,29 @@
+import io
+import wave
+from dataclasses import dataclass
+from typing import Literal, Optional, Tuple
+
+import numpy as np
+
+
+@dataclass
+class InferenceResult:
+ code: Literal["header", "segment", "error", "final"]
+ audio: Optional[Tuple[int, np.ndarray]]
+ error: Optional[Exception]
+
+
+def wav_chunk_header(
+ sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
+) -> bytes:
+ buffer = io.BytesIO()
+
+ with wave.open(buffer, "wb") as wav_file:
+ wav_file.setnchannels(channels)
+ wav_file.setsampwidth(bit_depth // 8)
+ wav_file.setframerate(sample_rate)
+
+ wav_header_bytes = buffer.getvalue()
+ buffer.close()
+
+ return wav_header_bytes
diff --git a/fish_speech/inference_engine/vq_manager.py b/fish_speech/inference_engine/vq_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ef2b55271016e87f7f3d3f0b6d79177fdb27dd4
--- /dev/null
+++ b/fish_speech/inference_engine/vq_manager.py
@@ -0,0 +1,59 @@
+from typing import Callable
+
+import torch
+from loguru import logger
+
+from fish_speech.models.dac.modded_dac import DAC
+
+
+class VQManager:
+
+ def __init__(self):
+ # Make Pylance happy (attribut/method not defined...)
+ self.decoder_model: DAC
+ self.load_audio: Callable
+
+ def decode_vq_tokens(self, codes):
+ feature_lengths = torch.tensor(
+ [codes.shape[1]], device=self.decoder_model.device
+ )
+ logger.info(f"VQ features: {codes.shape}")
+
+ if isinstance(self.decoder_model, DAC):
+ return self.decoder_model.decode(
+ indices=codes[None],
+ feature_lengths=feature_lengths,
+ )[0].squeeze()
+
+ raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
+
+ def encode_reference(self, reference_audio, enable_reference_audio):
+ if enable_reference_audio and reference_audio is not None:
+ # Load audios, and prepare basic info here
+ if hasattr(self.decoder_model, "spec_transform"):
+ sample_rate = self.decoder_model.spec_transform.sample_rate
+ else:
+ sample_rate = self.decoder_model.sample_rate
+ reference_audio_content = self.load_audio(reference_audio, sample_rate)
+
+ audios = torch.from_numpy(reference_audio_content).to(
+ self.decoder_model.device
+ )[None, None, :]
+ audio_lengths = torch.tensor(
+ [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
+ )
+ logger.info(
+ f"Loaded audio with {audios.shape[2] / sample_rate:.2f} seconds"
+ )
+
+ # VQ Encoder
+ if isinstance(self.decoder_model, DAC):
+ prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
+ logger.info(f"Encoded prompt: {prompt_tokens.shape}")
+ else:
+ raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
+ else:
+ prompt_tokens = None
+ logger.info("No reference audio provided")
+
+ return prompt_tokens
diff --git a/fish_speech/models/dac/__init__.py b/fish_speech/models/dac/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fish_speech/models/dac/inference.py b/fish_speech/models/dac/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cd0452757ec50c8a89934f24ea9cc196a45de31
--- /dev/null
+++ b/fish_speech/models/dac/inference.py
@@ -0,0 +1,123 @@
+from pathlib import Path
+
+import click
+import hydra
+import numpy as np
+import pyrootutils
+import soundfile as sf
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+from omegaconf import OmegaConf
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+
+def load_model(config_name, checkpoint_path, device="cuda"):
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
+ with initialize(version_base="1.3", config_path="../../configs"):
+ cfg = compose(config_name=config_name)
+
+ model = instantiate(cfg)
+ state_dict = torch.load(
+ checkpoint_path, map_location=device, mmap=True, weights_only=True
+ )
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator" in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ result = model.load_state_dict(state_dict, strict=False, assign=True)
+ model.eval()
+ model.to(device)
+
+ logger.info(f"Loaded model: {result}")
+ return model
+
+
+@torch.no_grad()
+@click.command()
+@click.option(
+ "--input-path",
+ "-i",
+ default="test.wav",
+ type=click.Path(exists=True, path_type=Path),
+)
+@click.option(
+ "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
+)
+@click.option("--config-name", default="modded_dac_vq")
+@click.option(
+ "--checkpoint-path",
+ default="checkpoints/openaudio-s1-mini/codec.pth",
+)
+@click.option(
+ "--device",
+ "-d",
+ default="cuda",
+)
+def main(input_path, output_path, config_name, checkpoint_path, device):
+ model = load_model(config_name, checkpoint_path, device=device)
+
+ if input_path.suffix in AUDIO_EXTENSIONS:
+ logger.info(f"Processing in-place reconstruction of {input_path}")
+
+ # Load audio
+ audio, sr = torchaudio.load(str(input_path))
+ if audio.shape[0] > 1:
+ audio = audio.mean(0, keepdim=True)
+ audio = torchaudio.functional.resample(audio, sr, model.sample_rate)
+
+ audios = audio[None].to(device)
+ logger.info(
+ f"Loaded audio with {audios.shape[2] / model.sample_rate:.2f} seconds"
+ )
+
+ # VQ Encoder
+ audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
+ indices, indices_lens = model.encode(audios, audio_lengths)
+
+ if indices.ndim == 3:
+ indices = indices[0]
+
+ logger.info(f"Generated indices of shape {indices.shape}")
+
+ # Save indices
+ np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
+ elif input_path.suffix == ".npy":
+ logger.info(f"Processing precomputed indices from {input_path}")
+ indices = np.load(input_path)
+ indices = torch.from_numpy(indices).to(device).long()
+ assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
+ indices_lens = torch.tensor([indices.shape[1]], device=device, dtype=torch.long)
+ else:
+ raise ValueError(f"Unknown input type: {input_path}")
+
+ # Restore
+ fake_audios, audio_lengths = model.decode(indices, indices_lens)
+ audio_time = fake_audios.shape[-1] / model.sample_rate
+
+ logger.info(
+ f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
+ )
+
+ # Save audio
+ fake_audio = fake_audios[0, 0].float().cpu().numpy()
+ sf.write(output_path, fake_audio, model.sample_rate)
+ logger.info(f"Saved audio to {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fish_speech/models/dac/modded_dac.py b/fish_speech/models/dac/modded_dac.py
new file mode 100644
index 0000000000000000000000000000000000000000..db0a7a0ccece45c2b175a6ef1d06e7193bd15a07
--- /dev/null
+++ b/fish_speech/models/dac/modded_dac.py
@@ -0,0 +1,1024 @@
+import math
+import typing as tp
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import hydra
+import librosa
+import numpy as np
+import soundfile as sf
+import torch
+from audiotools import AudioSignal
+from audiotools.ml import BaseModel
+from dac.model.base import CodecMixin
+from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
+from omegaconf import OmegaConf
+from torch import Tensor, nn
+from torch.nn import functional as F
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+
+
+@dataclass
+class VQResult:
+ z: torch.Tensor
+ codes: torch.Tensor
+ latents: torch.Tensor
+ codebook_loss: torch.Tensor
+ commitment_loss: torch.Tensor
+ semantic_distill_z: torch.Tensor | None = None
+
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+
+@dataclass
+class ModelArgs:
+ block_size: int = 2048
+ n_layer: int = 8
+ n_head: int = 8
+ dim: int = 512
+ intermediate_size: int = 1536
+ n_local_heads: int = -1
+ head_dim: int = 64
+ rope_base: float = 10000
+ norm_eps: float = 1e-5
+ dropout_rate: float = 0.1
+ attn_dropout_rate: float = 0.1
+ channels_first: bool = True # to be compatible with conv1d input/output
+ pos_embed_type: str = "rope" # can be "rope" or "conformer"
+ max_relative_position: int = 128 # for conformer-style relative position embedding
+
+ def __post_init__(self):
+ if self.n_local_heads == -1:
+ self.n_local_heads = self.n_head
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ assert self.pos_embed_type in [
+ "rope",
+ "conformer",
+ ], "pos_embed_type must be either 'rope' or 'conformer'"
+
+
+class KVCache(nn.Module):
+ def __init__(
+ self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
+ ):
+ super().__init__()
+ cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
+
+ def update(self, input_pos, k_val, v_val):
+ # input_pos: [S], k_val: [B, H, S, D]
+ assert input_pos.shape[0] == k_val.shape[2]
+
+ k_out = self.k_cache
+ v_out = self.v_cache
+ k_out[:, :, input_pos] = k_val
+ v_out[:, :, input_pos] = v_val
+
+ return (
+ k_out[:, :, : input_pos.max() + 1, :],
+ v_out[:, :, : input_pos.max() + 1, :],
+ )
+
+ def clear_cache(self, prompt_len):
+ self.k_cache[:, :, prompt_len:, :].fill_(0)
+ self.v_cache[:, :, prompt_len:, :].fill_(0)
+
+
+class Transformer(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.config = config
+
+ self.layers = nn.ModuleList(
+ TransformerBlock(config) for _ in range(config.n_layer)
+ )
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
+
+ # Only compute RoPE frequencies if using RoPE
+ if config.pos_embed_type == "rope":
+ freqs_cis = precompute_freqs_cis(
+ self.config.block_size, self.config.head_dim, self.config.rope_base
+ )
+ self.register_buffer("freqs_cis", freqs_cis)
+ else:
+ self.register_buffer("freqs_cis", None)
+
+ causal_mask = torch.tril(
+ torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool)
+ )
+ self.register_buffer("causal_mask", causal_mask)
+
+ self.max_batch_size = -1
+ self.max_seq_length = -1
+ self.use_kv_cache = False
+
+ def setup_caches(self, max_batch_size, max_seq_length):
+ """
+ This method will only be called during inference when using KV cache.
+ """
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_length = find_multiple(max_seq_length, 8)
+ self.max_seq_length = max_seq_length
+ self.max_batch_size = max_batch_size
+ dtype = self.norm.weight.dtype
+ device = self.norm.weight.device
+
+ for b in self.layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ max_seq_length,
+ self.config.n_local_heads,
+ head_dim,
+ dtype,
+ ).to(device)
+
+ self.use_kv_cache = True
+
+ def forward(
+ self,
+ x: Tensor,
+ input_pos: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ if self.config.pos_embed_type == "rope":
+ assert (
+ self.freqs_cis is not None
+ ), "RoPE frequencies must be initialized for RoPE positional embedding"
+ freqs_cis = self.freqs_cis[input_pos]
+ else:
+ freqs_cis = None
+
+ if mask is None: # in case of non-causal model
+ if not self.training and self.use_kv_cache:
+ mask = self.causal_mask[None, None, input_pos]
+ mask = mask[..., : input_pos.max() + 1]
+ else:
+ mask = self.causal_mask[None, None, input_pos]
+ mask = mask[..., input_pos]
+
+ for i, layer in enumerate(self.layers):
+ x = layer(x, input_pos, freqs_cis, mask)
+ x = self.norm(x)
+ return x
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.attention = Attention(config)
+ self.feed_forward = FeedForward(config)
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.attention_layer_scale = LayerScale(config.dim, inplace=True)
+ self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
+
+ def forward(
+ self,
+ x: Tensor,
+ input_pos: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ ) -> Tensor:
+ h = x + self.attention_layer_scale(
+ self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
+ )
+ out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+ # key, query, value projections for all heads, but in a batch
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
+ self.kv_cache = None
+
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.n_local_heads = config.n_local_heads
+ self.dim = config.dim
+ self.attn_dropout_rate = config.attn_dropout_rate
+ self.pos_embed_type = config.pos_embed_type
+
+ # Add relative position embedding for conformer-style
+ if self.pos_embed_type == "conformer":
+ self.max_relative_position = config.max_relative_position
+ num_pos_embeddings = 2 * config.max_relative_position + 1
+ self.rel_pos_embeddings = nn.Parameter(
+ torch.zeros(num_pos_embeddings, self.head_dim)
+ )
+ nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
+
+ def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
+ # q: [B, H, S, D]
+ # Returns: [B, H, S, S]
+ positions = torch.arange(seqlen, device=q.device)
+ relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
+ relative_positions = torch.clamp(
+ relative_positions + self.max_relative_position,
+ 0,
+ 2 * self.max_relative_position,
+ )
+ rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
+
+ # Compute attention scores with relative position embeddings
+ q = q.transpose(1, 2) # [B, S, H, D]
+ rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
+ rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
+ return rel_logits
+
+ def forward(
+ self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_local_heads * self.head_dim
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
+ context_seqlen = seqlen
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+
+ if self.pos_embed_type == "rope":
+ q = apply_rotary_emb(q, freqs_cis)
+ k = apply_rotary_emb(k, freqs_cis)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ if self.kv_cache is not None:
+ k, v = self.kv_cache.update(input_pos, k, v)
+
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+
+ if self.pos_embed_type == "conformer":
+ # Compute attention scores
+ scale = 1.0 / math.sqrt(self.head_dim)
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
+
+ # Add relative position embeddings for conformer-style
+ rel_scores = self._compute_conformer_pos_scores(q, seqlen)
+ scores = scores + rel_scores
+
+ # Apply attention
+ if mask is not None:
+ scores = scores.masked_fill(~mask, float("-inf"))
+
+ attn = F.softmax(scores, dim=-1)
+ if self.attn_dropout_rate > 0 and self.training:
+ attn = F.dropout(attn, p=self.attn_dropout_rate)
+
+ y = torch.matmul(attn, v)
+ else:
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.attn_dropout_rate if self.training else 0.0,
+ attn_mask=mask,
+ )
+ # is_causal=True)
+ y = (
+ y.transpose(1, 2)
+ .contiguous()
+ .view(bsz, seqlen, self.head_dim * self.n_head)
+ )
+ y = self.wo(y)
+ return y
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-2,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class WindowLimitedTransformer(Transformer):
+ """
+ Transformer with window limited attention, causal.
+ """
+
+ def __init__(
+ self,
+ config: ModelArgs,
+ input_dim: int = 512,
+ window_size: Optional[int] = None,
+ causal: bool = True,
+ look_ahead_conv: nn.Module = None,
+ ):
+ super().__init__(config)
+ self.window_size = window_size
+ self.causal = causal
+ self.channels_first = config.channels_first
+ self.look_ahead_conv = (
+ look_ahead_conv if look_ahead_conv is not None else nn.Identity()
+ )
+ self.input_proj = (
+ nn.Linear(input_dim, config.dim)
+ if input_dim != config.dim
+ else nn.Identity()
+ )
+ self.output_proj = (
+ nn.Linear(config.dim, input_dim)
+ if input_dim != config.dim
+ else nn.Identity()
+ )
+
+ def make_window_limited_mask(
+ self,
+ max_length: int,
+ x_lens: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Make mask to form window limited attention.
+ """
+ if self.causal:
+ mask = torch.tril(torch.ones(max_length, max_length))
+ row_indices = torch.arange(max_length).view(-1, 1)
+ window_size = self.window_size or max_length
+ valid_range = (row_indices - window_size + 1).clamp(min=0)
+ column_indices = torch.arange(max_length)
+ mask = (column_indices >= valid_range) & mask.bool()
+ else:
+ raise NotImplementedError
+ mask = mask.bool()[None, None]
+ return mask
+
+ def make_mask(
+ self,
+ max_length: int,
+ x_lens: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Make ordinary mask if window size is not specified.
+ """
+ if self.causal:
+ mask = torch.tril(torch.ones(max_length, max_length))
+ else:
+ mask = torch.ones(max_length, max_length)
+ mask = mask.bool()[None, None]
+ for i, x_len in enumerate(x_lens):
+ mask[:x_len, i] = 0
+ mask = mask.bool()[None, None]
+ return mask
+
+ def forward(
+ self,
+ x: Tensor,
+ x_lens: Optional[Tensor] = None,
+ ) -> Tensor:
+ if self.channels_first:
+ x = x.transpose(1, 2)
+ x = self.input_proj(x) # (B, T, D)
+ x = self.look_ahead_conv(x)
+ input_pos = torch.arange(x.shape[1], device=x.device)
+ # construct mask to form window limited attention
+ max_length = x.shape[1]
+ if self.window_size is not None:
+ mask = self.make_window_limited_mask(max_length, x_lens)
+ else:
+ mask = self.make_mask(max_length, x_lens)
+ mask = mask.to(x.device)
+ x = super().forward(x, input_pos, mask)
+ x = self.output_proj(x) # (B, T, D)
+ if self.channels_first:
+ x = x.transpose(1, 2)
+ return x
+
+
+def precompute_freqs_cis(
+ seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
+) -> Tensor:
+ freqs = 1.0 / (
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+ )
+ t = torch.arange(seq_len, device=freqs.device)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+ return cache.to(dtype=dtype)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+ ],
+ -1,
+ )
+
+ x_out2 = x_out2.flatten(3)
+ return x_out2.type_as(x)
+
+
+def init_weights(m):
+ if isinstance(m, nn.Conv1d):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+
+def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ assert (padding_left + padding_right) <= x.shape[-1]
+ end = x.shape[-1] - padding_right
+ return x[..., padding_left:end]
+
+
+def get_extra_padding_for_conv1d(
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+) -> int:
+ """See `pad_for_conv1d`."""
+ length = x.shape[-1]
+ n_frames = (length - kernel_size + padding_total) / stride + 1
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+ return ideal_length - length
+
+
+def pad1d(
+ x: torch.Tensor,
+ paddings: tp.Tuple[int, int],
+ mode: str = "zeros",
+ value: float = 0.0,
+):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right
+ before the reflection happen.
+ """
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ if mode == "reflect":
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ x = F.pad(x, (0, extra_pad))
+ padded = F.pad(x, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+ else:
+ return F.pad(x, paddings, mode, value)
+
+
+class CausalConvNet(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ dilation=1,
+ stride=1,
+ groups=1,
+ padding=None,
+ ):
+ super(CausalConvNet, self).__init__()
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ dilation=dilation,
+ groups=groups,
+ )
+ self.stride = stride
+ self.kernel_size = (kernel_size - 1) * dilation + 1
+ self.dilation = dilation
+ self.padding = self.kernel_size - self.stride
+
+ def forward(self, x):
+ pad = self.padding
+ extra_padding = get_extra_padding_for_conv1d(
+ x, self.kernel_size, self.stride, pad
+ )
+ x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
+ return self.conv(x).contiguous()
+
+ def weight_norm(self, name="weight", dim=0):
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
+ return self
+
+ def remove_weight_norm(self):
+ self.conv = remove_parametrizations(self.conv)
+ return self
+
+
+class CausalTransConvNet(nn.Module):
+ def __init__(
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
+ ):
+ super(CausalTransConvNet, self).__init__()
+ self.conv = nn.ConvTranspose1d(
+ in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
+ )
+ self.stride = stride
+ self.kernel_size = kernel_size
+
+ def forward(self, x):
+ x = self.conv(x)
+ pad = self.kernel_size - self.stride
+ padding_right = math.ceil(pad)
+ padding_left = pad - padding_right
+ x = unpad1d(x, (padding_left, padding_right))
+ return x.contiguous()
+
+ def weight_norm(self, name="weight", dim=0):
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
+ return self
+
+ def remove_weight_norm(self):
+ self.conv = remove_parametrizations(self.conv)
+ return self
+
+
+def CausalWNConv1d(*args, **kwargs):
+ return CausalConvNet(*args, **kwargs).weight_norm()
+
+
+def CausalWNConvTranspose1d(*args, **kwargs):
+ return CausalTransConvNet(*args, **kwargs).weight_norm()
+
+
+class ResidualUnit(nn.Module):
+ def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
+ super().__init__()
+ conv_class = CausalWNConv1d if causal else WNConv1d
+ pad = ((7 - 1) * dilation) // 2
+ self.block = nn.Sequential(
+ Snake1d(dim),
+ conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
+ Snake1d(dim),
+ conv_class(dim, dim, kernel_size=1),
+ )
+ self.causal = causal
+
+ def forward(self, x):
+ y = self.block(x)
+ pad = x.shape[-1] - y.shape[-1]
+ if pad > 0:
+ if self.causal:
+ x = x[..., :-pad]
+ else:
+ x = x[..., pad // 2 : -pad // 2]
+ return x + y
+
+
+class EncoderBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int = 16,
+ stride: int = 1,
+ causal: bool = False,
+ n_t_layer: int = 0,
+ transformer_general_config=None,
+ ):
+ super().__init__()
+ conv_class = CausalWNConv1d if causal else WNConv1d
+ transformer_module = (
+ nn.Identity()
+ if n_t_layer == 0
+ else (
+ WindowLimitedTransformer(
+ causal=causal,
+ input_dim=dim,
+ window_size=512,
+ config=transformer_general_config(
+ n_layer=n_t_layer,
+ n_head=dim // 64,
+ dim=dim,
+ intermediate_size=dim * 3,
+ ),
+ )
+ )
+ )
+ self.block = nn.Sequential(
+ ResidualUnit(dim // 2, dilation=1, causal=causal),
+ ResidualUnit(dim // 2, dilation=3, causal=causal),
+ ResidualUnit(dim // 2, dilation=9, causal=causal),
+ Snake1d(dim // 2),
+ conv_class(
+ dim // 2,
+ dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=math.ceil(stride / 2),
+ ),
+ transformer_module,
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ d_model: int = 64,
+ strides: list = [2, 4, 8, 8],
+ d_latent: int = 64,
+ n_transformer_layers: list = [0, 0, 4, 4],
+ transformer_general_config: ModelArgs = None,
+ causal: bool = False,
+ ):
+ super().__init__()
+ conv_class = CausalWNConv1d if causal else WNConv1d
+ # Create first convolution
+ self.block = [conv_class(1, d_model, kernel_size=7, padding=3)]
+
+ # Create EncoderBlocks that double channels as they downsample by `stride`
+ for stride, n_t_layer in zip(strides, n_transformer_layers):
+ d_model *= 2
+ self.block += [
+ EncoderBlock(
+ d_model,
+ stride=stride,
+ causal=causal,
+ n_t_layer=n_t_layer,
+ transformer_general_config=transformer_general_config,
+ )
+ ]
+
+ # Create last convolution
+ self.block += [
+ Snake1d(d_model),
+ conv_class(d_model, d_latent, kernel_size=3, padding=1),
+ ]
+
+ # Wrap black into nn.Sequential
+ self.block = nn.Sequential(*self.block)
+ self.enc_dim = d_model
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class DecoderBlock(nn.Module):
+ def __init__(
+ self,
+ input_dim: int = 16,
+ output_dim: int = 8,
+ stride: int = 1,
+ causal: bool = False,
+ n_t_layer: int = 0,
+ transformer_general_config=None,
+ ):
+ super().__init__()
+ conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
+ transformer_module = (
+ nn.Identity()
+ if n_t_layer == 0
+ else (
+ WindowLimitedTransformer(
+ causal=causal,
+ input_dim=input_dim,
+ window_size=None,
+ config=transformer_general_config(
+ n_layer=n_t_layer,
+ n_head=input_dim // 64,
+ dim=input_dim,
+ intermediate_size=input_dim * 3,
+ ),
+ )
+ )
+ )
+ self.block = nn.Sequential(
+ # transformer_module,
+ Snake1d(input_dim),
+ conv_trans_class(
+ input_dim,
+ output_dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=math.ceil(stride / 2),
+ ),
+ ResidualUnit(output_dim, dilation=1, causal=causal),
+ ResidualUnit(output_dim, dilation=3, causal=causal),
+ ResidualUnit(output_dim, dilation=9, causal=causal),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ input_channel,
+ channels,
+ rates,
+ d_out: int = 1,
+ causal: bool = False,
+ n_transformer_layers: list = [0, 0, 0, 0],
+ transformer_general_config=None,
+ ):
+ super().__init__()
+ conv_class = CausalWNConv1d if causal else WNConv1d
+ # Add first conv layer
+ layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
+
+ # Add upsampling + MRF blocks
+ for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
+ input_dim = channels // 2**i
+ output_dim = channels // 2 ** (i + 1)
+ layers += [
+ DecoderBlock(
+ input_dim,
+ output_dim,
+ stride,
+ causal=causal,
+ n_t_layer=n_t_layer,
+ transformer_general_config=transformer_general_config,
+ )
+ ]
+
+ # Add final conv layer
+ layers += [
+ Snake1d(output_dim),
+ conv_class(output_dim, d_out, kernel_size=7, padding=3),
+ nn.Tanh(),
+ ]
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class DAC(BaseModel, CodecMixin):
+ def __init__(
+ self,
+ encoder_dim: int = 64,
+ encoder_rates: List[int] = [2, 4, 8, 8],
+ latent_dim: int = None,
+ decoder_dim: int = 1536,
+ decoder_rates: List[int] = [8, 8, 4, 2],
+ quantizer: torch.nn.Module = None,
+ sample_rate: int = 44100,
+ causal: bool = True,
+ encoder_transformer_layers: List[int] = [0, 0, 0, 0],
+ decoder_transformer_layers: List[int] = [0, 0, 0, 0],
+ transformer_general_config=None,
+ ):
+ super().__init__()
+
+ self.encoder_dim = encoder_dim
+ self.encoder_rates = encoder_rates
+ self.decoder_dim = decoder_dim
+ self.decoder_rates = decoder_rates
+ self.sample_rate = sample_rate
+
+ if latent_dim is None:
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
+
+ self.latent_dim = latent_dim
+
+ self.hop_length = np.prod(encoder_rates)
+ self.encoder = Encoder(
+ encoder_dim,
+ encoder_rates,
+ latent_dim,
+ causal=causal,
+ n_transformer_layers=encoder_transformer_layers,
+ transformer_general_config=transformer_general_config,
+ )
+
+ self.quantizer = quantizer
+
+ self.decoder = Decoder(
+ latent_dim,
+ decoder_dim,
+ decoder_rates,
+ causal=causal,
+ n_transformer_layers=decoder_transformer_layers,
+ transformer_general_config=transformer_general_config,
+ )
+ self.sample_rate = sample_rate
+ self.apply(init_weights)
+
+ self.delay = self.get_delay()
+
+ self.frame_length = self.hop_length * 4
+
+ def preprocess(self, audio_data, sample_rate):
+ if sample_rate is None:
+ sample_rate = self.sample_rate
+ assert sample_rate == self.sample_rate
+
+ length = audio_data.shape[-1]
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
+
+ return audio_data
+
+ def encode(
+ self,
+ audio_data: torch.Tensor,
+ audio_lengths: torch.Tensor = None,
+ n_quantizers: int = None,
+ **kwargs,
+ ):
+ """Encode given audio data and return quantized latent codes
+
+ Parameters
+ ----------
+ audio_data : Tensor[B x T]
+ Audio data to encode
+ n_quantizers : int, optional
+ Number of quantizers to use, by default None
+ If None, all quantizers are used.
+
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ "length" : int
+ Number of samples in input audio
+ """
+ # pad to multiple of self.frame_length
+ if audio_data.ndim == 2:
+ audio_data = audio_data.unsqueeze(1)
+ # print(audio_data.shape)
+ length = audio_data.shape[-1]
+ right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
+ if audio_lengths is None:
+ audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
+
+ z = self.encoder(audio_data)
+ vq_results = self.quantizer(z, n_quantizers, **kwargs)
+ indices = vq_results.codes
+ indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
+ return indices, indices_lens
+
+ def decode(self, indices: torch.Tensor, feature_lengths):
+ if indices.ndim == 2:
+ indices = indices[None]
+
+ z = self.quantizer.decode(indices)
+ audio_lengths = feature_lengths * self.frame_length
+ return self.decoder(z), audio_lengths
+
+ def forward(
+ self,
+ audio_data: torch.Tensor,
+ template: torch.Tensor = None,
+ mask: torch.Tensor = None,
+ sample_rate: int = None,
+ n_quantizers: int = None,
+ **kwargs,
+ ):
+ """Model forward pass
+
+ Parameters
+ ----------
+ audio_data : Tensor[B x 1 x T]
+ Audio data to encode
+ sample_rate : int, optional
+ Sample rate of audio data in Hz, by default None
+ If None, defaults to `self.sample_rate`
+ n_quantizers : int, optional
+ Number of quantizers to use, by default None.
+ If None, all quantizers are used.
+
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ "length" : int
+ Number of samples in input audio
+ "audio" : Tensor[B x 1 x length]
+ Decoded audio data.
+ """
+ length = audio_data.shape[-1]
+ audio_data = self.preprocess(audio_data, sample_rate)
+ vq_results = self.encode(audio_data, n_quantizers, **kwargs)
+ z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
+ x = self.decode(z)
+ return x[..., :length], vq_results
+
+
+if __name__ == "__main__":
+
+ def filter_state_dict_shapes(params, model):
+ model_state_dict = model.state_dict()
+ filtered_state_dict = {
+ k: v
+ for k, v in params.items()
+ if k in model_state_dict and v.shape == model_state_dict[k].shape
+ }
+ skipped_keys = set(params.keys()) - set(filtered_state_dict.keys())
+ if skipped_keys:
+ print(
+ f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
+ )
+ return filtered_state_dict, skipped_keys
+
+ model = hydra.utils.instantiate(
+ OmegaConf.load("fish_speech/configs/modded_dac_vq.yaml")
+ )
+ sd = torch.load("checkpoints/openaudio-s1-mini/firefly-gan-large.pth")
+ filtered_sd, skipped_keys = filter_state_dict_shapes(sd, model)
+ print(f"Skipped keys: {skipped_keys}")
+ model.load_state_dict(filtered_sd, strict=False)
+ model.eval()
+
+ src_audio_path = "./test.wav"
+ wave_np, _ = librosa.load(src_audio_path, sr=44100, mono=False)
+ if len(wave_np.shape) == 1:
+ wave_np = wave_np[None, :]
+ wave_tensor = torch.from_numpy(wave_np).unsqueeze(1)
+
+ with torch.no_grad():
+ # encode 返回 (indices, indices_lens)
+ indices, indices_lens = model.encode(wave_tensor)
+ print(f"Indices shape: {indices.shape}")
+ print(f"Indices lengths: {indices_lens}")
+
+ # decode 需要 indices 和 feature_lengths 两个参数
+ fake_audio, audio_lengths = model.decode(indices, indices_lens)
+ print(f"Decoded audio shape: {fake_audio.shape}")
+ print(f"Audio lengths: {audio_lengths}")
+
+ # 保存重建的音频
+ sf.write("fake.wav", fake_audio.squeeze(1).cpu().numpy().T, 44100)
diff --git a/fish_speech/models/dac/rvq.py b/fish_speech/models/dac/rvq.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5b4647be5f507b45ce00d57296785a128f5606f
--- /dev/null
+++ b/fish_speech/models/dac/rvq.py
@@ -0,0 +1,403 @@
+import math
+import typing as tp
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from dac.nn.quantize import ResidualVectorQuantize
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+
+
+def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ assert (padding_left + padding_right) <= x.shape[-1]
+ end = x.shape[-1] - padding_right
+ return x[..., padding_left:end]
+
+
+def get_extra_padding_for_conv1d(
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+) -> int:
+ """See `pad_for_conv1d`."""
+ length = x.shape[-1]
+ n_frames = (length - kernel_size + padding_total) / stride + 1
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+ return ideal_length - length
+
+
+def pad1d(
+ x: torch.Tensor,
+ paddings: tp.Tuple[int, int],
+ mode: str = "zeros",
+ value: float = 0.0,
+):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right
+ before the reflection happen.
+ """
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ if mode == "reflect":
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ x = F.pad(x, (0, extra_pad))
+ padded = F.pad(x, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+ else:
+ return F.pad(x, paddings, mode, value)
+
+
+class CausalConvNet(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ dilation=1,
+ stride=1,
+ groups=1,
+ padding=None,
+ ):
+ super(CausalConvNet, self).__init__()
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ dilation=dilation,
+ groups=groups,
+ )
+ self.stride = stride
+ self.kernel_size = (kernel_size - 1) * dilation + 1
+ self.dilation = dilation
+ self.padding = self.kernel_size - self.stride
+
+ def forward(self, x):
+ pad = self.padding
+ extra_padding = get_extra_padding_for_conv1d(
+ x, self.kernel_size, self.stride, pad
+ )
+ x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
+ return self.conv(x).contiguous()
+
+ def weight_norm(self, name="weight", dim=0):
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
+ return self
+
+ def remove_weight_norm(self):
+ self.conv = remove_parametrizations(self.conv)
+ return self
+
+
+class CausalTransConvNet(nn.Module):
+ def __init__(
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
+ ):
+ super(CausalTransConvNet, self).__init__()
+ self.conv = nn.ConvTranspose1d(
+ in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
+ )
+ self.stride = stride
+ self.kernel_size = kernel_size
+
+ def forward(self, x):
+ x = self.conv(x)
+ pad = self.kernel_size - self.stride
+ padding_right = math.ceil(pad)
+ padding_left = pad - padding_right
+ x = unpad1d(x, (padding_left, padding_right))
+ return x.contiguous()
+
+ def weight_norm(self, name="weight", dim=0):
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
+ return self
+
+ def remove_weight_norm(self):
+ self.conv = remove_parametrizations(self.conv)
+ return self
+
+
+# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
+class ConvNeXtBlock(nn.Module):
+ r"""ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
+ dilation (int): Dilation for depthwise conv. Default: 1.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ dim: int,
+ layer_scale_init_value: float = 1e-6,
+ mlp_ratio: float = 4.0,
+ kernel_size: int = 7,
+ dilation: int = 1,
+ ):
+ super().__init__()
+ convnet_type = CausalConvNet
+ self.dwconv = convnet_type(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ # padding=int(dilation * (kernel_size - 1) / 2),
+ groups=dim,
+ dilation=dilation,
+ ) # depthwise conv
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim, int(mlp_ratio * dim)
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+
+ def forward(self, x, apply_residual: bool = True):
+ input = x
+
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+
+ if self.gamma is not None:
+ x = self.gamma * x
+
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
+
+ if apply_residual:
+ x = input + x
+
+ return x
+
+
+@dataclass
+class VQResult:
+ z: torch.Tensor
+ codes: torch.Tensor
+ latents: torch.Tensor
+ codebook_loss: torch.Tensor
+ commitment_loss: torch.Tensor
+ semantic_distill_z: torch.Tensor | None = None
+
+
+class DownsampleResidualVectorQuantize(nn.Module):
+ def __init__(
+ self,
+ input_dim: int = 1024,
+ n_codebooks: int = 9,
+ codebook_dim: int = 8,
+ quantizer_dropout: float = 0.5,
+ codebook_size: int = 1024,
+ semantic_codebook_size: int = 4096,
+ downsample_factor: tuple[int] = (2, 2),
+ downsample_dims: tuple[int] | None = None,
+ pre_module: nn.Module | None = None,
+ post_module: nn.Module | None = None,
+ semantic_predictor_module: nn.Module | None = None,
+ ):
+ super().__init__()
+
+ if downsample_dims is None:
+ downsample_dims = [input_dim for _ in range(len(downsample_factor))]
+
+ all_dims = (input_dim,) + tuple(downsample_dims)
+
+ self.semantic_quantizer = ResidualVectorQuantize(
+ input_dim=input_dim,
+ n_codebooks=1,
+ codebook_size=semantic_codebook_size,
+ codebook_dim=codebook_dim,
+ quantizer_dropout=0.0,
+ )
+
+ self.quantizer = ResidualVectorQuantize(
+ input_dim=input_dim,
+ n_codebooks=n_codebooks,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ quantizer_dropout=quantizer_dropout,
+ )
+
+ self.downsample_factor = downsample_factor
+ self.downsample_dims = downsample_dims
+
+ convnet_type = CausalConvNet
+ transconvnet_type = CausalTransConvNet
+
+ self.downsample = nn.Sequential(
+ *[
+ nn.Sequential(
+ convnet_type(
+ all_dims[idx],
+ all_dims[idx + 1],
+ kernel_size=factor,
+ stride=factor,
+ ),
+ ConvNeXtBlock(dim=all_dims[idx + 1]),
+ )
+ for idx, factor in enumerate(downsample_factor)
+ ]
+ )
+
+ self.upsample = nn.Sequential(
+ *[
+ nn.Sequential(
+ transconvnet_type(
+ all_dims[idx + 1],
+ all_dims[idx],
+ kernel_size=factor,
+ stride=factor,
+ ),
+ ConvNeXtBlock(dim=all_dims[idx]),
+ )
+ for idx, factor in reversed(list(enumerate(downsample_factor)))
+ ]
+ )
+ self.apply(self._init_weights)
+ self.pre_module = (
+ pre_module if pre_module is not None else nn.Identity()
+ ) # leave for transformer, LSTM or Mamba or something else
+ self.post_module = post_module if post_module is not None else nn.Identity()
+ self.semantic_predictor_module = (
+ semantic_predictor_module
+ if semantic_predictor_module is not None
+ else nn.Identity()
+ )
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(
+ self, z, n_quantizers: int = None, semantic_len: torch.Tensor = None, **kwargs
+ ):
+ # z: (B, D, T)
+ original_shape = z.shape
+ if semantic_len is None:
+ semantic_len = torch.LongTensor([z.shape[-1]])
+ z = self.downsample(z)
+ z = self.pre_module(z) # B, T, D
+ (
+ semantic_z,
+ semantic_codes,
+ semantic_latents,
+ semantic_commitment_loss,
+ semantic_codebook_loss,
+ ) = self.semantic_quantizer(z)
+ residual_z = z - semantic_z
+ residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
+ residual_z, n_quantizers=n_quantizers
+ )
+ z = semantic_z + residual_z
+ commitment_loss = commitment_loss + semantic_commitment_loss
+ codebook_loss = codebook_loss + semantic_codebook_loss
+ codes = torch.cat([semantic_codes, codes], dim=1)
+ latents = torch.cat([semantic_latents, latents], dim=1)
+ z = self.post_module(z)
+ z = self.upsample(z)
+ # z: (B, D, T)
+
+ # semantic distillation (disabled here since only used in training)
+ # semantic_distill_z = self.semantic_predictor_module(semantic_z, semantic_len).mT # wav2vec target is B, T, D
+
+ # Pad or crop z to match original shape
+ diff = original_shape[-1] - z.shape[-1]
+ right = 0
+ left = abs(diff) - right
+
+ if diff > 0:
+ z = F.pad(z, (left, right))
+ elif diff < 0:
+ z = z[..., left:]
+
+ results = VQResult(
+ z=z,
+ codes=codes,
+ latents=latents,
+ commitment_loss=commitment_loss,
+ codebook_loss=codebook_loss,
+ )
+
+ return results
+
+ # def encode(self, z):
+ # z = self.downsample(z)
+ # z = self.pre_module(z)
+ # _, indices, _, _, _ = self.quantizer(z.mT)
+ # indices = rearrange(indices, "g b l r -> b (g r) l")
+ # return indices
+ #
+ def decode(self, indices: torch.Tensor):
+ # indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
+
+ # print(f"indices: {indices.shape}, semantic_quantizer.codebook_size: {self.semantic_quantizer.codebook_size}, quantizer.codebook_size: {self.quantizer.codebook_size}, semantic min: {indices[:, 0].min()}, max: {indices[:, 0].max()}, quantizer min: {indices[:, 1:].min()}, max: {indices[:, 1:].max()}")
+
+ new_indices = torch.zeros_like(indices)
+ new_indices[:, 0] = torch.clamp(
+ indices[:, 0], max=self.semantic_quantizer.codebook_size - 1
+ )
+ new_indices[:, 1:] = torch.clamp(
+ indices[:, 1:], max=self.quantizer.codebook_size - 1
+ )
+
+ z_q_semantic = self.semantic_quantizer.from_codes(new_indices[:, :1])[0]
+ z_q_residual = self.quantizer.from_codes(new_indices[:, 1:])[0]
+ z_q = z_q_semantic + z_q_residual
+ z_q = self.post_module(z_q)
+ z_q = self.upsample(z_q)
+ return z_q
+
+ # def from_latents(self, latents: torch.Tensor):
+ # z_q, z_p, codes = super().from_latents(latents)
+ # z_q = self.upsample(z_q)
+ # return z_q, z_p, codes
+
+
+if __name__ == "__main__":
+ rvq = DownsampleResidualVectorQuantize(
+ input_dim=512,
+ n_codebooks=8,
+ codebook_dim=8,
+ codebook_size=1024,
+ quantizer_dropout=0.5,
+ downsample_factor=[2, 2],
+ )
+ rvq.eval()
+ x = torch.randn(2, 512, 442)
+
+ result = rvq(x)
+ print(rvq)
+ print(result.latents.shape, result.codes.shape, result.z.shape)
+
+ # y = rvq.from_codes(result.codes)
+ # print(y[0].shape)
+
+ # y = rvq.from_latents(
+
+ result1 = rvq(x[:, :, :40])
+ print(result1.latents.shape, result1.codes.shape, result1.z.shape)
+
+ assert torch.allclose(result.z[:, :, :40], result1.z, atol=1e-8)
+ print("Success")
diff --git a/fish_speech/models/text2semantic/inference.py b/fish_speech/models/text2semantic/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..9442c7cc731f52db523b46434d4b90c3e2885c4f
--- /dev/null
+++ b/fish_speech/models/text2semantic/inference.py
@@ -0,0 +1,716 @@
+import os
+import queue
+import threading
+import time
+from contextlib import nullcontext
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Literal, Optional, Tuple, Union
+
+import click
+import numpy as np
+import torch
+import torch._dynamo.config
+import torch._inductor.config
+from loguru import logger
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+from fish_speech.content_sequence import (
+ ContentSequence,
+ TextPart,
+ VQPart,
+)
+from fish_speech.models.text2semantic.llama import BaseModelArgs
+from fish_speech.text import clean_text, split_text
+from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.triton.unique_kernel_names = True
+
+if hasattr(torch._inductor.config, "fx_graph_cache"):
+ # Experimental feature to reduce compilation times, will be on by default in future
+ torch._inductor.config.fx_graph_cache = True
+
+
+from torch.nn.attention import SDPBackend, sdpa_kernel
+
+from fish_speech.models.text2semantic.llama import (
+ BaseTransformer,
+ DualARTransformer,
+ NaiveTransformer,
+)
+
+
+def multinomial_sample_one_no_sync(
+ probs_sort,
+): # Does multinomial sampling without a cuda synchronization
+ q = torch.empty_like(probs_sort).exponential_(1)
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def logits_to_probs(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ temperature: torch.Tensor = 1.0,
+ top_p: torch.Tensor = 1.0,
+ repetition_penalty: torch.Tensor = 1.0,
+) -> torch.Tensor:
+ # Apply repetition penalty
+ if previous_tokens is not None:
+ previous_tokens = previous_tokens.long()
+ score = torch.gather(logits, dim=0, index=previous_tokens)
+ score = torch.where(
+ score < 0, score * repetition_penalty, score / repetition_penalty
+ )
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
+
+ # Apply top-p sampling
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cum_probs > top_p
+ sorted_indices_to_remove[0] = False # keep at least one option
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
+ )
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
+ logits = logits / max(temperature, 1e-5)
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ return probs
+
+
+def sample(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ **sampling_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ probs = logits_to_probs(
+ logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
+ )
+ idx_next = multinomial_sample_one_no_sync(probs)
+ return idx_next, probs
+
+
+def decode_one_token_ar(
+ model: DualARTransformer,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ semantic_ids: list,
+ previous_tokens: torch.Tensor = None,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ x = model.forward_generate(x, input_pos)
+
+ sampling_kwargs_main = sampling_kwargs.copy()
+ # sampling_kwargs_main["temperature"] = 0.1
+ # sampling_kwargs_main["top_p"] = 0.1
+ # sampling_kwargs_main["repetition_penalty"] = 1.0
+
+ codebooks = [
+ sample(
+ x.logits,
+ previous_tokens=(
+ previous_tokens[0] if previous_tokens is not None else None
+ ), # Disable repetition penalty for the token codebook
+ **sampling_kwargs_main,
+ )[0]
+ ]
+
+ hidden_states = x.hidden_states
+
+ # Cleanup the cache
+ for layer in model.fast_layers:
+ layer.attention.kv_cache.k_cache.fill_(0)
+ layer.attention.kv_cache.v_cache.fill_(0)
+
+ input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
+ model.forward_generate_fast(hidden_states, input_pos)
+ a = codebooks[0] - model.tokenizer.semantic_begin_id
+ a[a < 0] = 0
+ hidden_states = model.fast_embeddings(a)
+ codebooks.append(a)
+
+ for codebook_idx in range(1, model.config.num_codebooks):
+ input_pos = torch.tensor(
+ [codebook_idx], device=hidden_states.device, dtype=torch.long
+ )
+ logits = model.forward_generate_fast(hidden_states, input_pos)
+ chunked_logits = logits[..., :1024]
+ a = sample(
+ chunked_logits,
+ previous_tokens=(
+ previous_tokens[codebook_idx + 1]
+ if previous_tokens is not None
+ else None
+ ),
+ **sampling_kwargs,
+ )[0]
+ hidden_states = model.fast_embeddings(a)
+ codebooks.append(a)
+
+ codebooks = torch.stack(codebooks, dim=0)
+ # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
+ # codebooks[1:, :] = torch.masked_fill(
+ # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
+ # )
+
+ # print(codebooks)
+ return codebooks
+
+
+def decode_n_tokens(
+ model: NaiveTransformer,
+ cur_token: torch.Tensor,
+ input_pos: torch.Tensor,
+ num_new_tokens: int,
+ semantic_ids: list,
+ decode_one_token=decode_one_token_ar,
+ **sampling_kwargs,
+):
+ previous_tokens = torch.zeros(
+ (model.config.num_codebooks + 1, model.config.max_seq_len),
+ dtype=torch.int,
+ device=cur_token.device,
+ )
+
+ for i in tqdm(range(num_new_tokens)):
+ # We need to get windowed repeat penalty
+ win_size = 16
+ if i < win_size:
+ window = previous_tokens[:, :win_size]
+ else:
+ window = previous_tokens[:, i - win_size : i]
+
+ with (
+ torch.backends.cuda.sdp_kernel(
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
+ )
+ if torch.cuda.is_available()
+ else nullcontext()
+ ): # Actually better for Inductor to codegen attention here
+ next_token = decode_one_token(
+ model=model,
+ x=cur_token,
+ input_pos=input_pos,
+ previous_tokens=window,
+ semantic_ids=semantic_ids,
+ **sampling_kwargs,
+ )
+
+ input_pos += 1
+ cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
+ previous_tokens[:, i : i + 1] = next_token.view(
+ model.config.num_codebooks + 1, -1
+ )
+
+ if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
+ break
+
+ return previous_tokens[:, : i + 1]
+
+
+@torch.no_grad()
+@torch.inference_mode()
+def generate(
+ *,
+ model: NaiveTransformer,
+ prompt: torch.Tensor,
+ max_new_tokens: int,
+ decode_one_token=decode_one_token_ar,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ """
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
+ """
+
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ T = prompt.size(1)
+ # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
+ semantic_ids = [
+ model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
+ ]
+
+ if max_new_tokens:
+ if T + max_new_tokens > model.config.max_seq_len:
+ max_new_tokens = model.config.max_seq_len - T
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
+
+ T_new = T + max_new_tokens
+ else:
+ T_new = model.config.max_seq_len
+ max_new_tokens = T_new - T
+
+ device, dtype = prompt.device, prompt.dtype
+
+ codebook_dim = 1 + model.config.num_codebooks
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ empty = torch.empty(
+ (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
+ )
+ empty[:, :T] = prompt
+ seq = empty
+ input_pos = torch.arange(0, T, device=device)
+
+ # Use non-accelerated version for now, to avoid compilation overhead
+ prefill_decode = decode_one_token_ar
+
+ next_token = prefill_decode(
+ model,
+ prompt.view(1, codebook_dim, -1),
+ input_pos,
+ semantic_ids=semantic_ids,
+ **sampling_kwargs,
+ )
+ seq[:, T : T + 1] = next_token
+
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
+ x = decode_n_tokens(
+ model,
+ next_token.view(1, codebook_dim, -1),
+ input_pos,
+ max_new_tokens - 1,
+ decode_one_token=decode_one_token,
+ semantic_ids=semantic_ids,
+ **sampling_kwargs,
+ )
+ # x = torch.cat(generated_tokens, dim=1)
+ seq = seq[:, : T + 1 + x.size(1)]
+ seq[:, T + 1 :] = x
+
+ return seq
+
+
+def load_model(checkpoint_path, device, precision, compile=False):
+ model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
+
+ model = model.to(device=device, dtype=precision)
+ logger.info(f"Restored model from checkpoint")
+
+ if isinstance(model, DualARTransformer):
+ decode_one_token = decode_one_token_ar
+ logger.info("Using DualARTransformer")
+ else:
+ raise ValueError("Model is not a DualARTransformer")
+
+ if compile:
+ logger.info("Compiling function...")
+ decode_one_token = torch.compile(
+ decode_one_token,
+ fullgraph=True,
+ backend="inductor" if torch.cuda.is_available() else "aot_eager",
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
+ )
+
+ return model.eval(), decode_one_token
+
+
+@dataclass
+class GenerateResponse:
+ action: Literal["sample", "next"]
+ codes: Optional[torch.Tensor] = None
+ text: Optional[str] = None
+
+
+def generate_long(
+ *,
+ model,
+ device: str | torch.device,
+ decode_one_token: callable,
+ text: str,
+ num_samples: int = 1,
+ max_new_tokens: int = 0,
+ top_p: int = 0.8,
+ repetition_penalty: float = 1.1,
+ temperature: float = 0.8,
+ compile: bool = False,
+ iterative_prompt: bool = True,
+ chunk_length: int = 150,
+ prompt_text: Optional[str | list[str]] = None,
+ prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
+):
+ assert 0 < top_p <= 1, "top_p must be in (0, 1]"
+ assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
+ assert 0 < temperature < 2, "temperature must be in (0, 2)"
+
+ use_prompt = prompt_text is not None and prompt_tokens is not None
+ if use_prompt and isinstance(prompt_text, str):
+ prompt_text = [prompt_text]
+ prompt_tokens = [prompt_tokens]
+
+ assert use_prompt is False or len(prompt_text) == len(
+ prompt_tokens
+ ), "Prompt text and tokens must have the same length"
+
+ prompt_tokens = [i.cpu() for i in prompt_tokens]
+
+ model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ tokenizer = model.tokenizer
+ base_content_sequence = ContentSequence(modality="interleave")
+
+ texts = split_text(text, chunk_length) if iterative_prompt else [text]
+ max_length = model.config.max_seq_len
+
+ if use_prompt:
+ for t, c in zip(prompt_text, prompt_tokens):
+ base_content_sequence.append(
+ [
+ TextPart(text=t),
+ VQPart(codes=c),
+ ],
+ add_end=True,
+ )
+
+ encoded_prompts = base_content_sequence.encode_for_inference(
+ tokenizer, num_codebooks=model.config.num_codebooks
+ )
+ if encoded_prompts.size(1) > max_length - 2048:
+ raise ValueError(
+ f"Prompt is too long: {encoded_prompts.size(1)} > {max_length - 2048}"
+ )
+
+ encoded = []
+ for text in texts:
+ content_sequence = ContentSequence(modality=None)
+ content_sequence.append(TextPart(text=text))
+ encoded.append(
+ content_sequence.encode_for_inference(
+ tokenizer, num_codebooks=model.config.num_codebooks
+ )
+ )
+ logger.info(f"Encoded text: {text}")
+
+ # Move temperature, top_p, repetition_penalty to device
+ # This is important so that changing params doesn't trigger recompile
+ temperature = torch.tensor(temperature, device=device, dtype=torch.float)
+ top_p = torch.tensor(top_p, device=device, dtype=torch.float)
+ repetition_penalty = torch.tensor(
+ repetition_penalty, device=device, dtype=torch.float
+ )
+
+ for sample_idx in range(num_samples):
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ global_encoded = []
+ seg_idx = 0
+
+ while seg_idx < len(encoded):
+ logger.info(
+ f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
+ )
+
+ seg = encoded[seg_idx]
+ global_encoded.append(seg)
+
+ # Do not use previous segments to generate current segment for now
+ # lengths = reversed([seg.size(1) for seg in global_encoded])
+
+ # # Pick last 2000 tokens
+ # count = 0
+ # for i, length in enumerate(lengths):
+ # count += length
+ # if count + length > max_length - 2048 - encoded_prompts.size(1):
+ # break
+
+ # if i != 0 and i % 2 == 0:
+ # i -= 1
+
+ # # Rotate the list, always make sure first segment is included to avoid drift
+ # if i < len(global_encoded) - 2:
+ # partial_encoded = global_encoded[:2] + global_encoded[-i:]
+ # else:
+ # partial_encoded = global_encoded
+
+ # cat_encoded = torch.cat([encoded_prompts, *partial_encoded], dim=1)
+ if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
+ cat_encoded = torch.cat(
+ [encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
+ )
+ else:
+ cat_encoded = torch.cat([encoded_prompts, seg], dim=1)
+
+ cat_encoded = cat_encoded.to(device=device)
+ prompt_length = cat_encoded.size(1)
+
+ t0 = time.perf_counter()
+ y = generate(
+ model=model,
+ prompt=cat_encoded,
+ max_new_tokens=max_new_tokens,
+ decode_one_token=decode_one_token,
+ temperature=temperature,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ )
+
+ if sample_idx == 0 and seg_idx == 0 and compile:
+ logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
+
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ t = time.perf_counter() - t0
+
+ tokens_generated = y.size(1) - prompt_length
+ tokens_sec = tokens_generated / t
+ logger.info(
+ f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
+ )
+ logger.info(
+ f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
+ )
+
+ if torch.cuda.is_available():
+ logger.info(
+ f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
+ )
+
+ # Put the generated tokens
+ # since there is
, we remove last token
+ codes = y[1:, prompt_length:-1].clone()
+ assert (codes >= 0).all(), f"Negative code found"
+
+ decoded = y[:, prompt_length:].clone()
+ # But for global encoding, we should keep the token
+
+ global_encoded.append(decoded.cpu())
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
+ yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
+ seg_idx += 1
+
+ # This indicates the end of the current sample
+ yield GenerateResponse(action="next")
+
+
+@dataclass
+class WrappedGenerateResponse:
+ status: Literal["success", "error"]
+ response: Optional[GenerateResponse | Exception] = None
+
+
+@dataclass
+class GenerateRequest:
+ request: dict
+ response_queue: queue.Queue
+
+
+def launch_thread_safe_queue(
+ checkpoint_path,
+ device,
+ precision,
+ compile: bool = False,
+):
+ input_queue = queue.Queue()
+ init_event = threading.Event()
+
+ def worker():
+ model, decode_one_token = load_model(
+ checkpoint_path, device, precision, compile=compile
+ )
+ with torch.device(device):
+ model.setup_caches(
+ max_batch_size=1,
+ max_seq_len=model.config.max_seq_len,
+ dtype=next(model.parameters()).dtype,
+ )
+ init_event.set()
+
+ while True:
+ item: GenerateRequest | None = input_queue.get()
+ if item is None:
+ break
+
+ kwargs = item.request
+ response_queue = item.response_queue
+
+ try:
+ for chunk in generate_long(
+ model=model, decode_one_token=decode_one_token, **kwargs
+ ):
+ response_queue.put(
+ WrappedGenerateResponse(status="success", response=chunk)
+ )
+ except Exception as e:
+ response_queue.put(WrappedGenerateResponse(status="error", response=e))
+
+ threading.Thread(target=worker, daemon=True).start()
+ init_event.wait()
+
+ return input_queue
+
+
+def launch_thread_safe_queue_agent(
+ checkpoint_path,
+ device,
+ precision,
+ compile: bool = False,
+):
+ input_queue = queue.Queue()
+ init_event = threading.Event()
+
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
+ config = BaseModelArgs.from_pretrained(checkpoint_path)
+
+ def worker():
+ model, decode_one_token = load_model(
+ checkpoint_path, device, precision, compile=compile, is_agent=True
+ )
+
+ with torch.device(device):
+ model.setup_caches(
+ max_batch_size=1,
+ max_seq_len=model.config.max_seq_len,
+ dtype=next(model.parameters()).dtype,
+ )
+ init_event.set()
+
+ while True:
+ item: GenerateRequest | None = input_queue.get()
+ if item is None:
+ break
+
+ kwargs = item.request
+ response_queue = item.response_queue
+
+ try:
+ for token in generate_agent(
+ model=model,
+ decode_one_token=decode_one_token,
+ **kwargs,
+ ):
+ response_queue.put(token)
+
+ response_queue.put("stop")
+ except Exception as e:
+ import traceback
+
+ logger.exception(f"Error in worker: {traceback.format_exc()}")
+ response_queue.put("error")
+
+ threading.Thread(target=worker, daemon=True).start()
+ init_event.wait()
+
+ return input_queue, tokenizer, config
+
+
+@click.command()
+@click.option(
+ "--text",
+ type=str,
+ default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+)
+@click.option("--prompt-text", type=str, default=None, multiple=True)
+@click.option(
+ "--prompt-tokens",
+ type=click.Path(path_type=Path, exists=True),
+ default=None,
+ multiple=True,
+)
+@click.option("--num-samples", type=int, default=1)
+@click.option("--max-new-tokens", type=int, default=0)
+@click.option("--top-p", type=float, default=0.8)
+@click.option("--repetition-penalty", type=float, default=1.1)
+@click.option("--temperature", type=float, default=0.8)
+@click.option(
+ "--checkpoint-path",
+ type=click.Path(path_type=Path, exists=True),
+ default="checkpoints/openaudio-s1-mini",
+)
+@click.option("--device", type=str, default="cuda")
+@click.option("--compile/--no-compile", default=False)
+@click.option("--seed", type=int, default=42)
+@click.option("--half/--no-half", default=False)
+@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
+@click.option("--chunk-length", type=int, default=300)
+@click.option("--output-dir", type=Path, default="temp")
+def main(
+ text: str,
+ prompt_text: Optional[list[str]],
+ prompt_tokens: Optional[list[Path]],
+ num_samples: int,
+ max_new_tokens: int,
+ top_p: int,
+ repetition_penalty: float,
+ temperature: float,
+ checkpoint_path: Path,
+ device: str,
+ compile: bool,
+ seed: int,
+ half: bool,
+ iterative_prompt: bool,
+ chunk_length: int,
+ output_dir: Path,
+) -> None:
+ os.makedirs(output_dir, exist_ok=True)
+ precision = torch.half if half else torch.bfloat16
+
+ if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
+ raise ValueError(
+ f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
+ )
+
+ logger.info("Loading model ...")
+ t0 = time.time()
+ model, decode_one_token = load_model(
+ checkpoint_path, device, precision, compile=compile
+ )
+ with torch.device(device):
+ model.setup_caches(
+ max_batch_size=1,
+ max_seq_len=model.config.max_seq_len,
+ dtype=next(model.parameters()).dtype,
+ )
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
+
+ if prompt_tokens is not None:
+ prompt_tokens = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
+
+ torch.manual_seed(seed)
+
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+
+ generator = generate_long(
+ model=model,
+ device=device,
+ decode_one_token=decode_one_token,
+ text=text,
+ num_samples=num_samples,
+ max_new_tokens=max_new_tokens,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ compile=compile,
+ iterative_prompt=iterative_prompt,
+ chunk_length=chunk_length,
+ prompt_text=prompt_text,
+ prompt_tokens=prompt_tokens,
+ )
+
+ idx = 0
+ codes = []
+
+ for response in generator:
+ if response.action == "sample":
+ codes.append(response.codes)
+ logger.info(f"Sampled text: {response.text}")
+ elif response.action == "next":
+ if codes:
+ codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
+ np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
+ logger.info(f"Saved codes to {codes_npy_path}")
+ logger.info(f"Next sample")
+ codes = []
+ idx += 1
+ else:
+ logger.error(f"Error: {response}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fish_speech/models/text2semantic/lit_module.py b/fish_speech/models/text2semantic/lit_module.py
index 7b26793532c9e6b42de189c628aac59a477d0f66..df970400f8a073be4c4166a697245fabdf6b09b0 100644
--- a/fish_speech/models/text2semantic/lit_module.py
+++ b/fish_speech/models/text2semantic/lit_module.py
@@ -1,202 +1,202 @@
-from typing import Any, Optional
-
-import lightning as L
-import torch
-import torch.nn.functional as F
-from lightning.pytorch.utilities.types import OptimizerLRScheduler
-
-import fish_speech.utils as utils
-from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
-from fish_speech.models.text2semantic.llama import NaiveTransformer
-
-log = utils.RankedLogger(__name__, rank_zero_only=True)
-
-
-class TextToSemantic(L.LightningModule):
- def __init__(
- self,
- model: NaiveTransformer,
- optimizer: Any,
- lr_scheduler: Any,
- ):
- super().__init__()
-
- self.model = model
- self.optimizer_builder = optimizer
- self.lr_scheduler_builder = lr_scheduler
-
- def forward(self, x):
- return self.model(x)
-
- def on_save_checkpoint(self, checkpoint):
- # Save only LoRA parameters
- state_dict = checkpoint["state_dict"]
- use_lora = any("lora" in name for name in state_dict.keys())
- if not use_lora:
- return
-
- for name in list(state_dict.keys()):
- if "lora" not in name:
- state_dict.pop(name)
-
- def configure_optimizers(self) -> OptimizerLRScheduler:
- # Get weight decay parameters
- weight_decay_parameters, other_parameters = [], []
- for name, param in self.named_parameters():
- if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
- other_parameters.append(param)
- else:
- weight_decay_parameters.append(param)
-
- optimizer = self.optimizer_builder(
- [
- {"params": weight_decay_parameters},
- {"params": other_parameters, "weight_decay": 0.0},
- ]
- )
-
- # Print the parameters and their weight decay
- for i in optimizer.param_groups:
- log.info(
- f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
- )
-
- lr_scheduler = self.lr_scheduler_builder(optimizer)
-
- return {
- "optimizer": optimizer,
- "lr_scheduler": {
- "scheduler": lr_scheduler,
- "interval": "step",
- },
- }
-
- # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
- def get_batch_logps(
- self,
- logits: torch.FloatTensor,
- labels: torch.LongTensor,
- average_log_prob: bool = False,
- ) -> torch.FloatTensor:
- """Compute the log probabilities of the given labels under the given logits.
-
- Args:
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
- labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
-
- Returns:
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
- """
- assert logits.shape[:-1] == labels.shape
-
- labels = labels.clone()
- loss_mask = labels != -100
-
- # dummy token; we'll ignore the losses on these tokens later
- labels[labels == -100] = 0
-
- per_token_logps = torch.gather(
- logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
- ).squeeze(-1)
-
- if average_log_prob:
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
- else:
- return (per_token_logps * loss_mask).sum(-1)
-
- def _step(self, batch, batch_idx, stage: str):
- is_train = stage == "train"
-
- if is_train:
- # Key part to make lora work
- # Otherwise the parameters are merged, which lead to incorrect gradients
- self.model.train()
-
- # Do positive and negative samples in the same batch to speed up training
- labels = batch["labels"]
- outputs = self.model(
- inp=batch["inputs"],
- key_padding_mask=batch["attention_masks"],
- )
- token_logits = outputs.token_logits
- codebook_logits = outputs.codebook_logits
-
- # Generate labels
- base_loss = F.cross_entropy(
- token_logits.view(-1, token_logits.size(-1)),
- labels[:, 0].reshape(-1),
- ignore_index=-100,
- )
-
- codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
- semantic_loss = F.cross_entropy(
- codebook_logits.view(-1, codebook_logits.size(-1)),
- codebook_labels.reshape(-1),
- ignore_index=-100,
- )
-
- loss = base_loss + semantic_loss
-
- self.log(
- f"{stage}/loss",
- loss,
- on_step=is_train,
- on_epoch=not is_train,
- prog_bar=True,
- logger=True,
- sync_dist=not is_train,
- )
-
- self.log(
- f"{stage}/base_loss",
- base_loss,
- on_step=is_train,
- on_epoch=not is_train,
- prog_bar=False,
- logger=True,
- sync_dist=not is_train,
- )
-
- self.log(
- f"{stage}/semantic_loss",
- semantic_loss,
- on_step=is_train,
- on_epoch=not is_train,
- prog_bar=False,
- logger=True,
- sync_dist=not is_train,
- )
-
- # Top-5 accuracy
- accuracy = self.get_accuracy(codebook_logits, codebook_labels)
- self.log(
- f"{stage}/top_5_accuracy",
- accuracy,
- on_step=is_train,
- on_epoch=not is_train,
- prog_bar=True,
- logger=True,
- sync_dist=not is_train,
- )
-
- return loss
-
- def get_accuracy(self, logits, labels):
- mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
- if mask.sum() == 0:
- return torch.tensor(0.0, device=logits.device)
-
- _, indices = logits.topk(5, dim=-1)
- correct = indices.eq(labels.unsqueeze(-1))
- correct[~mask] = 0
- correct = correct.sum()
- accuracy = correct / mask.sum()
-
- return accuracy
-
- def training_step(self, batch, batch_idx):
- return self._step(batch, batch_idx, "train")
-
- def validation_step(self, batch, batch_idx):
- return self._step(batch, batch_idx, "val")
+from typing import Any, Optional
+
+import lightning as L
+import torch
+import torch.nn.functional as F
+from lightning.pytorch.utilities.types import OptimizerLRScheduler
+
+import fish_speech.utils as utils
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.models.text2semantic.llama import NaiveTransformer
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+class TextToSemantic(L.LightningModule):
+ def __init__(
+ self,
+ model: NaiveTransformer,
+ optimizer: Any,
+ lr_scheduler: Any,
+ ):
+ super().__init__()
+
+ self.model = model
+ self.optimizer_builder = optimizer
+ self.lr_scheduler_builder = lr_scheduler
+
+ def forward(self, x):
+ return self.model(x)
+
+ def on_save_checkpoint(self, checkpoint):
+ # Save only LoRA parameters
+ state_dict = checkpoint["state_dict"]
+ use_lora = any("lora" in name for name in state_dict.keys())
+ if not use_lora:
+ return
+
+ for name in list(state_dict.keys()):
+ if "lora" not in name:
+ state_dict.pop(name)
+
+ def configure_optimizers(self) -> OptimizerLRScheduler:
+ # Get weight decay parameters
+ weight_decay_parameters, other_parameters = [], []
+ for name, param in self.named_parameters():
+ if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
+ other_parameters.append(param)
+ else:
+ weight_decay_parameters.append(param)
+
+ optimizer = self.optimizer_builder(
+ [
+ {"params": weight_decay_parameters},
+ {"params": other_parameters, "weight_decay": 0.0},
+ ]
+ )
+
+ # Print the parameters and their weight decay
+ for i in optimizer.param_groups:
+ log.info(
+ f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
+ )
+
+ lr_scheduler = self.lr_scheduler_builder(optimizer)
+
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": lr_scheduler,
+ "interval": "step",
+ },
+ }
+
+ # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
+ def get_batch_logps(
+ self,
+ logits: torch.FloatTensor,
+ labels: torch.LongTensor,
+ average_log_prob: bool = False,
+ ) -> torch.FloatTensor:
+ """Compute the log probabilities of the given labels under the given logits.
+
+ Args:
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
+
+ Returns:
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
+ """
+ assert logits.shape[:-1] == labels.shape
+
+ labels = labels.clone()
+ loss_mask = labels != -100
+
+ # dummy token; we'll ignore the losses on these tokens later
+ labels[labels == -100] = 0
+
+ per_token_logps = torch.gather(
+ logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
+ ).squeeze(-1)
+
+ if average_log_prob:
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
+ else:
+ return (per_token_logps * loss_mask).sum(-1)
+
+ def _step(self, batch, batch_idx, stage: str):
+ is_train = stage == "train"
+
+ if is_train:
+ # Key part to make lora work
+ # Otherwise the parameters are merged, which lead to incorrect gradients
+ self.model.train()
+
+ # Do positive and negative samples in the same batch to speed up training
+ labels = batch["labels"]
+ outputs = self.model(
+ inp=batch["inputs"],
+ key_padding_mask=batch["attention_masks"],
+ )
+ token_logits = outputs.token_logits
+ codebook_logits = outputs.codebook_logits
+
+ # Generate labels
+ base_loss = F.cross_entropy(
+ token_logits.view(-1, token_logits.size(-1)),
+ labels[:, 0].reshape(-1),
+ ignore_index=-100,
+ )
+
+ codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
+ semantic_loss = F.cross_entropy(
+ codebook_logits.view(-1, codebook_logits.size(-1)),
+ codebook_labels.reshape(-1),
+ ignore_index=-100,
+ )
+
+ loss = base_loss + semantic_loss
+
+ self.log(
+ f"{stage}/loss",
+ loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=True,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ self.log(
+ f"{stage}/base_loss",
+ base_loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ self.log(
+ f"{stage}/semantic_loss",
+ semantic_loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ # Top-5 accuracy
+ accuracy = self.get_accuracy(codebook_logits, codebook_labels)
+ self.log(
+ f"{stage}/top_5_accuracy",
+ accuracy,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=True,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ return loss
+
+ def get_accuracy(self, logits, labels):
+ mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
+ if mask.sum() == 0:
+ return torch.tensor(0.0, device=logits.device)
+
+ _, indices = logits.topk(5, dim=-1)
+ correct = indices.eq(labels.unsqueeze(-1))
+ correct[~mask] = 0
+ correct = correct.sum()
+ accuracy = correct / mask.sum()
+
+ return accuracy
+
+ def training_step(self, batch, batch_idx):
+ return self._step(batch, batch_idx, "train")
+
+ def validation_step(self, batch, batch_idx):
+ return self._step(batch, batch_idx, "val")
diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py
index dcd0f0fa96a3bf43768bbb8087f976068e48a8e0..d34882cdeb554989920a3e7feadc9facdc4bdd95 100644
--- a/fish_speech/models/text2semantic/llama.py
+++ b/fish_speech/models/text2semantic/llama.py
@@ -1,887 +1,903 @@
-import dataclasses
-import json
-import math
-from collections import OrderedDict
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Optional
-
-import torch
-import torch.nn as nn
-from einops import rearrange
-from loguru import logger
-from torch import Tensor
-from torch.nn import functional as F
-from torch.nn.attention import SDPBackend, sdpa_kernel
-from torch.utils.checkpoint import checkpoint
-from transformers import AutoTokenizer
-
-from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
-from fish_speech.utils import RankedLogger
-
-from .lora import LoraConfig, setup_lora
-
-log = RankedLogger(__name__, rank_zero_only=True)
-
-
-def find_multiple(n: int, k: int) -> int:
- if n % k == 0:
- return n
- return n + k - (n % k)
-
-
-@dataclass
-class BaseModelArgs:
- model_type: str = "base"
-
- vocab_size: int = 32000
- n_layer: int = 32
- n_head: int = 32
- dim: int = 4096
- intermediate_size: int = None
- n_local_heads: int = -1
- head_dim: int = 64
- rope_base: float = 10000
- norm_eps: float = 1e-5
- max_seq_len: int = 2048
- dropout: float = 0.0
- tie_word_embeddings: bool = True
- attention_qkv_bias: bool = False
-
- # Codebook configs
- codebook_size: int = 160
- num_codebooks: int = 4
-
- # Gradient checkpointing
- use_gradient_checkpointing: bool = True
-
- # Initialize the model
- initializer_range: float = 0.02
-
- # Dummy vars
- is_reward_model: bool = False
- share_codebook_embeddings: bool = True
- scale_codebook_embeddings: bool = False
-
- def __post_init__(self):
- if self.n_local_heads == -1:
- self.n_local_heads = self.n_head
- if self.intermediate_size is None:
- hidden_dim = 4 * self.dim
- n_hidden = int(2 * hidden_dim / 3)
- self.intermediate_size = find_multiple(n_hidden, 256)
- self.head_dim = self.dim // self.n_head
-
- @staticmethod
- def from_pretrained(path: str):
- path = Path(path)
-
- if path.is_dir():
- path = path / "config.json"
-
- with open(path, "r", encoding="utf-8") as f:
- data = json.load(f)
-
- match data["model_type"]:
- case "naive":
- cls = NaiveModelArgs
- case "dual_ar":
- cls = DualARModelArgs
- case _:
- raise ValueError(f"Unknown model type: {data['model_type']}")
-
- return cls(**data)
-
- def save(self, path: str):
- with open(path, "w") as f:
- json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
-
-
-@dataclass
-class NaiveModelArgs(BaseModelArgs):
- model_type: str = "naive"
-
-
-@dataclass
-class DualARModelArgs(BaseModelArgs):
- model_type: str = "dual_ar"
- n_fast_layer: int = 4
- fast_dim: int | None = None
- fast_n_head: int | None = None
- fast_n_local_heads: int | None = None
- fast_head_dim: int | None = None
- fast_intermediate_size: int | None = None
- fast_attention_qkv_bias: bool | None = None
-
- def __post_init__(self):
- super().__post_init__()
-
- self.fast_dim = self.fast_dim or self.dim
- self.fast_n_head = self.fast_n_head or self.n_head
- self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
- self.fast_head_dim = self.fast_head_dim or self.head_dim
- self.fast_intermediate_size = (
- self.fast_intermediate_size or self.intermediate_size
- )
- self.fast_attention_qkv_bias = (
- self.fast_attention_qkv_bias
- if self.fast_attention_qkv_bias is not None
- else self.attention_qkv_bias
- )
-
-
-class KVCache(nn.Module):
- def __init__(
- self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
- ):
- super().__init__()
- cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
- self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
- self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
-
- def update(self, input_pos, k_val, v_val):
- # input_pos: [S], k_val: [B, H, S, D]
- assert input_pos.shape[0] == k_val.shape[2]
-
- k_out = self.k_cache
- v_out = self.v_cache
- k_out[:, :, input_pos] = k_val
- v_out[:, :, input_pos] = v_val
-
- return k_out, v_out
-
-
-@dataclass
-class TransformerForwardResult:
- token_logits: Tensor
- codebook_logits: Tensor
-
-
-@dataclass
-class BaseTransformerForwardResult:
- logits: Tensor
- hidden_states: Tensor
-
-
-class BaseTransformer(nn.Module):
- def __init__(
- self,
- config: BaseModelArgs,
- tokenizer: FishTokenizer | AutoTokenizer,
- init_weights: bool = True,
- ) -> None:
- super().__init__()
- self.config = config
- self.tokenizer = tokenizer
- self.semantic_token_ids = [
- tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS
- ]
-
- # Slow transformer
- self.embeddings = nn.Embedding(
- config.vocab_size,
- config.dim,
- )
- self.codebook_embeddings = nn.Embedding(
- config.codebook_size * config.num_codebooks,
- config.dim,
- )
- self.layers = nn.ModuleList(
- TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
- )
- self.norm = RMSNorm(config.dim, eps=config.norm_eps)
-
- if self.config.tie_word_embeddings is False:
- self.output = nn.Linear(
- config.dim,
- config.vocab_size,
- bias=False,
- )
-
- self.register_buffer(
- "freqs_cis",
- precompute_freqs_cis(
- config.max_seq_len,
- config.dim // config.n_head,
- config.rope_base,
- ),
- persistent=False,
- )
- self.register_buffer(
- "causal_mask",
- torch.tril(
- torch.ones(
- config.max_seq_len,
- config.max_seq_len,
- dtype=torch.bool,
- )
- ),
- persistent=False,
- )
-
- # For kv cache
- self.max_batch_size = -1
- self.max_seq_len = -1
-
- if init_weights:
- self.apply(self._init_weights)
-
- def setup_caches(
- self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
- ):
- if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
- return
-
- head_dim = self.config.dim // self.config.n_head
- max_seq_len = find_multiple(max_seq_len, 8)
- self.max_seq_len = max_seq_len
- self.max_batch_size = max_batch_size
-
- for b in self.layers:
- b.attention.kv_cache = KVCache(
- max_batch_size,
- max_seq_len,
- self.config.n_local_heads,
- head_dim,
- dtype=dtype,
- )
-
- def embed(self, x: Tensor) -> Tensor:
- vocab_embeds = [self.embeddings(x[:, 0])]
- for i in range(self.config.num_codebooks):
- emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
- semantic_token_ids_tensor = torch.tensor(
- self.semantic_token_ids, device=x.device
- )
- emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0
-
- x = torch.stack(vocab_embeds, dim=3)
- x = x.sum(dim=3)
-
- return x
-
- def forward(
- self,
- inp: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- ) -> BaseTransformerForwardResult:
- seq_len = inp.size(2)
-
- # Here we want to merge the embeddings of the codebooks
- x = self.embed(inp)
-
- freqs_cis = self.freqs_cis[:seq_len]
-
- # Not that the causal mask here follows the definition of scaled_dot_product_attention
- # That is, FALSE means masked out
- # To maintain consistency, key_padding_mask use TRUE to mask out
- mask = None
- if key_padding_mask is not None:
- mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
- mask = mask & key_padding_mask[:, None, None, :].logical_not()
-
- for layer in self.layers:
- if self.config.use_gradient_checkpointing and self.training:
- x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
- else:
- x = layer(x, freqs_cis, mask)
-
- # We got slow_out here
- slow_out = self.norm(x)
-
- if self.config.tie_word_embeddings:
- token_logits = F.linear(slow_out, self.embeddings.weight)
- else:
- token_logits = self.output(slow_out)
-
- return BaseTransformerForwardResult(
- logits=token_logits,
- hidden_states=x,
- )
-
- def forward_generate(
- self,
- inp: Tensor,
- input_pos: Optional[Tensor] = None,
- vq_masks: Optional[Tensor] = None, # this is not used in fact
- return_all: bool = False,
- ) -> BaseTransformerForwardResult:
- # This is used for generation, optimized for torch compile
- # assert (
- # self.max_seq_len != -1 and self.max_batch_size != -1
- # ), "Please call setup_caches before forward_generate"
-
- embeds = []
- for i in range(self.config.num_codebooks):
- if self.config.share_codebook_embeddings:
- _tokens = inp[:, i + 1] + i * self.config.codebook_size
- else:
- _tokens = inp[:, i + 1]
-
- emb = self.codebook_embeddings(_tokens)
- embeds.append(emb)
-
- vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
- # if self.config.use_codebook_mlp:
- # vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks
- # vq_embeds_sum = self.codebook_mlp(vq_embeds_sum)
-
- vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
- inp[:, 0] <= self.tokenizer.semantic_end_id
- )
-
- vq_embeds_sum[~vq_masks] = 0
- x = self.embeddings(inp[:, 0]) + vq_embeds_sum
-
- if input_pos is None:
- input_pos = torch.arange(inp.shape[-1], device=x.device)
- max_seq_len = inp.shape[-1]
- else:
- max_seq_len = self.max_seq_len
-
- mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
- freqs_cis = self.freqs_cis[input_pos]
-
- for layer in self.layers:
- x = layer(x, freqs_cis, mask, input_pos=input_pos)
-
- # If prefill, we only calculate the logits of last token
- if x.size(1) > 1 and not return_all:
- x = x[:, -1:]
-
- # We got slow_out here
- slow_out = self.norm(x)
-
- if self.config.is_reward_model:
- token_logits = self.score_output(slow_out)
- elif self.config.tie_word_embeddings:
- token_logits = F.linear(slow_out, self.embeddings.weight)
- else:
- token_logits = self.output(slow_out)
-
- return BaseTransformerForwardResult(
- logits=token_logits,
- hidden_states=x,
- )
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
- @staticmethod
- def from_pretrained(
- path: str,
- load_weights: bool = False,
- max_length: int | None = None,
- lora_config: LoraConfig | None = None,
- rope_base: int | None = None,
- is_agent: bool = False,
- ) -> "BaseTransformer":
- config = BaseModelArgs.from_pretrained(str(path))
- if max_length is not None:
- config.max_seq_len = max_length
- log.info(f"Override max_seq_len to {max_length}")
-
- if rope_base is not None:
- config.rope_base = rope_base
- log.info(f"Override rope_base to {rope_base}")
-
- match config.model_type:
- case "naive":
- model_cls = NaiveTransformer
- case "dual_ar":
- model_cls = DualARTransformer
- case _:
- raise ValueError(f"Unknown model type: {config.model_type}")
-
- if is_agent:
- tokenizer = AutoTokenizer.from_pretrained(str(path))
- else:
- tokenizer_path = str(path) + "/tokenizer.tiktoken"
- tokenizer = FishTokenizer(tokenizer_path)
-
- log.info(f"Loading model from {path}, config: {config}")
- model = model_cls(config, tokenizer=tokenizer)
-
- if lora_config is not None:
- setup_lora(model, lora_config)
- log.info(f"LoRA setup: {lora_config}")
-
- if load_weights is False:
- log.info("Randomly initialized model")
- else:
-
- if "int8" in str(Path(path)):
- logger.info("Using int8 weight-only quantization!")
- from tools.llama.quantize import WeightOnlyInt8QuantHandler
-
- simple_quantizer = WeightOnlyInt8QuantHandler(model)
- model = simple_quantizer.convert_for_runtime()
-
- if "int4" in str(Path(path)):
- logger.info("Using int4 quantization!")
- path_comps = path.name.split("-")
- assert path_comps[-2].startswith("g")
- groupsize = int(path_comps[-2][1:])
- from tools.llama.quantize import WeightOnlyInt4QuantHandler
-
- simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
- model = simple_quantizer.convert_for_runtime()
-
- weights = torch.load(
- Path(path) / "model.pth",
- map_location="cpu",
- mmap=True,
- weights_only=True,
- )
-
- if "state_dict" in weights:
- logger.warning(
- "Using a TextToSemantic LightningModule checkpoint, "
- "please make sure it is a full model, not a LoRA model."
- )
- weights = weights["state_dict"]
-
- if next(iter(weights.keys())).startswith("model."):
- logger.info(
- f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
- )
- new_weights = OrderedDict()
- for k, v in weights.items():
- new_weights[k.replace("model.", "")] = v
- weights = new_weights
-
- # Verify the name and shape of parameters since strict=False in load_state_dict.
- for k, v in model.named_parameters():
- if k not in weights:
- logger.warning(f"No weight for {k}")
- elif v.shape != weights[k].shape:
- logger.warning(
- f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
- )
-
- err = model.load_state_dict(weights, strict=False, assign=True)
- log.info(f"Loaded weights with error: {err}")
-
- return model
-
- def save_pretrained(self, path: str, drop_lora: bool = False):
- path = Path(path)
- path.mkdir(parents=True, exist_ok=True)
-
- self.config.save(path / "config.json")
- state_dict = self.state_dict()
-
- if drop_lora:
- for key in list(state_dict.keys()):
- if "lora" not in key:
- continue
-
- state_dict.pop(key)
- log.info(f"Drop LoRA parameter: {key}")
-
- torch.save(state_dict, path / "model.pth")
- self.tokenizer.save_pretrained(path)
-
-
-class NaiveTransformer(BaseTransformer):
- def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
- super().__init__(config, init_weights=False, tokenizer=tokenizer)
-
- self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
- self.codebook_output = nn.Linear(
- config.dim,
- config.codebook_size * config.num_codebooks,
- bias=False,
- )
-
- self.apply(self._init_weights)
-
- def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
- token_logits = result.logits
- x = result.hidden_states
-
- # Codebook
- codebook_logits = self.codebook_output(self.codebook_norm(x))
- codebook_logits = rearrange(
- codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
- )
-
- return TransformerForwardResult(
- token_logits=token_logits,
- codebook_logits=codebook_logits,
- )
-
- def forward(
- self,
- inp: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- ) -> TransformerForwardResult:
- result = super().forward(
- inp=inp,
- key_padding_mask=key_padding_mask,
- )
- return self.decode(result)
-
- def forward_generate(
- self, x: Tensor, input_pos: Optional[Tensor] = None
- ) -> TransformerForwardResult:
- result = super().forward_generate(x, input_pos)
- return self.decode(result)
-
-
-class DualARTransformer(BaseTransformer):
- def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
- super().__init__(config, init_weights=False, tokenizer=tokenizer)
-
- # Project to fast dim if needed
- if config.fast_dim is not None and config.fast_dim != config.dim:
- self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
- else:
- self.fast_project_in = nn.Identity()
-
- # Fast transformer
- self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
-
- # The equivalent bs is so large that sdpa doesn't work
- override_config = dataclasses.replace(
- config,
- dim=config.fast_dim,
- n_head=config.fast_n_head,
- n_local_heads=config.fast_n_local_heads,
- head_dim=config.fast_head_dim,
- intermediate_size=config.fast_intermediate_size,
- attention_qkv_bias=config.fast_attention_qkv_bias,
- )
-
- self.fast_layers = nn.ModuleList(
- TransformerBlock(override_config, use_sdpa=False)
- for _ in range(config.n_fast_layer)
- )
- self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
- self.fast_output = nn.Linear(
- config.fast_dim,
- config.codebook_size,
- bias=False,
- )
-
- self.register_buffer(
- "fast_freqs_cis",
- precompute_freqs_cis(
- config.num_codebooks,
- config.fast_dim // config.fast_n_head,
- config.rope_base,
- ),
- persistent=False,
- )
- self.apply(self._init_weights)
-
- def setup_caches(
- self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
- ):
- super().setup_caches(max_batch_size, max_seq_len, dtype)
-
- head_dim = self.config.fast_dim // self.config.fast_n_head
-
- # Fast transformer
- # The max seq len here is the number of codebooks
- for b in self.fast_layers:
- b.attention.kv_cache = KVCache(
- max_batch_size,
- self.config.num_codebooks,
- self.config.fast_n_local_heads,
- head_dim,
- dtype=dtype,
- )
-
- def forward(
- self,
- inp: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- ) -> TransformerForwardResult:
- parent_result = super().forward(inp, key_padding_mask)
- token_logits = parent_result.logits
- x = parent_result.hidden_states
- x = self.fast_project_in(x)
-
- # Fast transformer
- fast_seq_len = self.config.num_codebooks
- fast_mask = self.causal_mask[
- None, None, :fast_seq_len, :fast_seq_len
- ] # (B, N, Q, K)
-
- # Drop the last token and rotate left
- codebooks = inp[:, 1:-1, 1:]
- codebooks = F.pad(codebooks, (0, 1), value=0)
- codebook_embeddings = self.fast_embeddings(codebooks)
- x = torch.cat([x[:, None], codebook_embeddings], dim=1)
- b, s = x.size(0), x.size(2)
- x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
-
- # Remove padded part
- codebooks = rearrange(codebooks, "b n s -> (b s) n")
- codebook_mask = (codebooks == 0).all(dim=-1)
-
- if torch.all(codebook_mask):
- # If all codebooks are padded, we keep first 8 to make sure the model runs
- codebook_mask[:8] = False
-
- x_bs, x_len = x.size(0), x.size(1)
- x = x[~codebook_mask]
-
- for layer in self.fast_layers:
- if self.config.use_gradient_checkpointing and self.training:
- x = checkpoint(
- layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
- )
- else:
- x = layer(x, self.fast_freqs_cis, fast_mask)
-
- # unflatten the batch and num_codebooks
- fast_out = self.fast_norm(x)
- codebook_logits = self.fast_output(fast_out)
-
- # Re-pad the codebook_logits
- buffer = torch.zeros(
- x_bs,
- x_len,
- codebook_logits.size(-1),
- device=codebook_logits.device,
- dtype=codebook_logits.dtype,
- )
- buffer[~codebook_mask] = codebook_logits
- codebook_logits = buffer
-
- assert codebook_logits.shape[1] == self.config.num_codebooks
- codebook_logits = rearrange(
- codebook_logits,
- "(b s) n d -> b s n d",
- b=b,
- s=s,
- n=self.config.num_codebooks,
- )
-
- return TransformerForwardResult(
- token_logits=token_logits,
- codebook_logits=codebook_logits,
- )
-
- def forward_generate_fast(
- self, x: Tensor, input_pos: Optional[Tensor] = None
- ) -> Tensor:
- # Fast transformer
- x = x.view(1, 1, -1)
-
- fast_mask = self.causal_mask[
- None, None, input_pos, : self.config.num_codebooks
- ] # (B, N, Q, K)
- fast_freqs_cis = self.fast_freqs_cis[input_pos]
-
- for layer in self.fast_layers:
- x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
-
- # unflatten the batch and num_codebooks
- fast_out = self.fast_norm(x) # only take the last token
- codebook_logits = self.fast_output(fast_out)
-
- return codebook_logits
-
- def forward_generate(
- self,
- x: Tensor,
- input_pos: Optional[Tensor] = None,
- vq_masks: Optional[Tensor] = None,
- ) -> TransformerForwardResult:
- x = super().forward_generate(x, input_pos, vq_masks)
- x.hidden_states = self.fast_project_in(x.hidden_states)
- return x
-
-
-class TransformerBlock(nn.Module):
- def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
- super().__init__()
- self.attention = Attention(config, use_sdpa=use_sdpa)
- self.feed_forward = FeedForward(config)
- self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
- self.attention_norm = RMSNorm(config.dim, config.norm_eps)
-
- def forward(
- self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
- ) -> Tensor:
- h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
- out = h + self.feed_forward(self.ffn_norm(h))
- return out
-
-
-class Attention(nn.Module):
- def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
- super().__init__()
- assert config.dim % config.n_head == 0
-
- total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
- # key, query, value projections for all heads, but in a batch
- self.wqkv = nn.Linear(
- config.dim, total_head_dim, bias=config.attention_qkv_bias
- )
- self.wo = nn.Linear(config.dim, config.dim, bias=False)
- self.kv_cache = None
-
- self.dropout = config.dropout
- self.n_head = config.n_head
- self.head_dim = config.head_dim
- self.n_local_heads = config.n_local_heads
- self.dim = config.dim
- self.use_sdpa = use_sdpa
- self._register_load_state_dict_pre_hook(self.load_hook)
-
- def load_hook(self, state_dict, prefix, *args):
- if prefix + "wq.weight" in state_dict:
- wq = state_dict.pop(prefix + "wq.weight")
- wk = state_dict.pop(prefix + "wk.weight")
- wv = state_dict.pop(prefix + "wv.weight")
- state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
-
- def forward(
- self,
- x: Tensor,
- freqs_cis: Tensor,
- mask: Tensor,
- input_pos: Optional[Tensor] = None,
- ) -> Tensor:
- bsz, seqlen, _ = x.shape
-
- kv_size = self.n_local_heads * self.head_dim
- q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
-
- q = q.view(bsz, seqlen, self.n_head, self.head_dim)
- k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
- v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
-
- q = apply_rotary_emb(q, freqs_cis)
- k = apply_rotary_emb(k, freqs_cis)
-
- q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
-
- if self.kv_cache is not None:
- k, v = self.kv_cache.update(input_pos, k, v)
-
- k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
- v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
-
- if self.use_sdpa:
- if mask is None:
- with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
- y = F.scaled_dot_product_attention(
- q,
- k,
- v,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=True,
- # No third party attn_mask here to use flash_attention
- )
- else:
- y = F.scaled_dot_product_attention(
- q,
- k,
- v,
- attn_mask=mask,
- dropout_p=self.dropout if self.training else 0.0,
- )
- else:
- y = self.eq_scaled_dot_product_attention(
- q,
- k,
- v,
- attn_mask=mask,
- dropout_p=self.dropout if self.training else 0.0,
- )
-
- y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
-
- return self.wo(y)
-
- def eq_scaled_dot_product_attention(
- self,
- query,
- key,
- value,
- attn_mask=None,
- dropout_p=0.0,
- ) -> torch.Tensor:
- # This is a standard scaled dot product attention
- # It's low efficient, but it doesn't raise cuda error
-
- L, S = query.size(-2), key.size(-2)
- scale_factor = 1 / math.sqrt(query.size(-1))
- attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
-
- if attn_mask is not None:
- if attn_mask.dtype == torch.bool:
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
- else:
- attn_bias += attn_mask
-
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
- attn_weight += attn_bias
- attn_weight = torch.softmax(attn_weight, dim=-1)
- attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
-
- return attn_weight @ value
-
-
-class FeedForward(nn.Module):
- def __init__(self, config: BaseModelArgs) -> None:
- super().__init__()
- self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
- self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
- self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
-
- def forward(self, x: Tensor) -> Tensor:
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
-
-
-class RMSNorm(nn.Module):
- def __init__(self, dim: int, eps: float = 1e-5):
- super().__init__()
- self.eps = eps
- self.weight = nn.Parameter(torch.ones(dim))
-
- def _norm(self, x):
- return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
-
- def forward(self, x: Tensor) -> Tensor:
- output = self._norm(x.float()).type_as(x)
- return output * self.weight
-
-
-def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
- freqs = 1.0 / (
- base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
- )
- t = torch.arange(seq_len, device=freqs.device)
- freqs = torch.outer(t, freqs)
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
- cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
- return cache.to(dtype=torch.bfloat16)
-
-
-def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
- xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
- freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
- x_out2 = torch.stack(
- [
- xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
- xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
- ],
- -1,
- )
-
- x_out2 = x_out2.flatten(3)
- return x_out2.type_as(x)
+import dataclasses
+import json
+import math
+from collections import OrderedDict
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from loguru import logger
+from torch import Tensor
+from torch.nn import functional as F
+from torch.nn.attention import SDPBackend, sdpa_kernel
+from torch.utils.checkpoint import checkpoint
+from transformers import AutoTokenizer
+
+from fish_speech.models.text2semantic.lora import LoraConfig, setup_lora
+from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
+
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+
+@dataclass
+class BaseModelArgs:
+ model_type: str = "base"
+
+ vocab_size: int = 32000
+ n_layer: int = 32
+ n_head: int = 32
+ dim: int = 4096
+ intermediate_size: int = None
+ n_local_heads: int = -1
+ head_dim: int = 64
+ rope_base: float = 10000
+ norm_eps: float = 1e-5
+ max_seq_len: int = 2048
+ dropout: float = 0.0
+ tie_word_embeddings: bool = True
+ attention_qkv_bias: bool = False
+ attention_o_bias: bool = False
+ attention_qk_norm: bool = False
+
+ # Codebook configs
+ codebook_size: int = 160
+ num_codebooks: int = 4
+
+ # Gradient checkpointing
+ use_gradient_checkpointing: bool = True
+
+ # Initialize the model
+ initializer_range: float = 0.02
+
+ # Dummy vars
+ is_reward_model: bool = False
+ scale_codebook_embeddings: bool = False
+
+ def __post_init__(self):
+ if self.n_local_heads == -1:
+ self.n_local_heads = self.n_head
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ if self.head_dim is None:
+ self.head_dim = self.dim // self.n_head
+
+ @staticmethod
+ def from_pretrained(path: str):
+ path = Path(path)
+
+ if path.is_dir():
+ path = path / "config.json"
+
+ with open(path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ match data["model_type"]:
+ case "naive":
+ cls = NaiveModelArgs
+ case "dual_ar":
+ cls = DualARModelArgs
+ case _:
+ raise ValueError(f"Unknown model type: {data['model_type']}")
+
+ return cls(**data)
+
+ def save(self, path: str):
+ with open(path, "w") as f:
+ json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
+
+
+@dataclass
+class NaiveModelArgs(BaseModelArgs):
+ model_type: str = "naive"
+
+
+@dataclass
+class DualARModelArgs(BaseModelArgs):
+ model_type: str = "dual_ar"
+ n_fast_layer: int = 4
+ fast_dim: int | None = None
+ fast_n_head: int | None = None
+ fast_n_local_heads: int | None = None
+ fast_head_dim: int | None = None
+ fast_intermediate_size: int | None = None
+ fast_attention_qkv_bias: bool | None = None
+ fast_attention_qk_norm: bool | None = None
+ fast_attention_o_bias: bool | None = None
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ self.fast_dim = self.fast_dim or self.dim
+ self.fast_n_head = self.fast_n_head or self.n_head
+ self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
+ self.fast_head_dim = self.fast_head_dim or self.head_dim
+ self.fast_intermediate_size = (
+ self.fast_intermediate_size or self.intermediate_size
+ )
+ self.fast_attention_qkv_bias = (
+ self.fast_attention_qkv_bias
+ if self.fast_attention_qkv_bias is not None
+ else self.attention_qkv_bias
+ )
+ self.fast_attention_qk_norm = (
+ self.fast_attention_qk_norm
+ if self.fast_attention_qk_norm is not None
+ else self.attention_qk_norm
+ )
+ self.fast_attention_o_bias = (
+ self.fast_attention_o_bias
+ if self.fast_attention_o_bias is not None
+ else self.attention_o_bias
+ )
+
+
+class KVCache(nn.Module):
+ def __init__(
+ self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
+ ):
+ super().__init__()
+ cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
+
+ def update(self, input_pos, k_val, v_val):
+ # input_pos: [S], k_val: [B, H, S, D]
+ assert input_pos.shape[0] == k_val.shape[2]
+
+ k_out = self.k_cache
+ v_out = self.v_cache
+ k_out[:, :, input_pos] = k_val
+ v_out[:, :, input_pos] = v_val
+
+ return k_out, v_out
+
+
+@dataclass
+class TransformerForwardResult:
+ token_logits: Tensor
+ codebook_logits: Tensor
+
+
+@dataclass
+class BaseTransformerForwardResult:
+ logits: Tensor
+ hidden_states: Tensor
+
+
+class BaseTransformer(nn.Module):
+ def __init__(
+ self,
+ config: BaseModelArgs,
+ tokenizer: FishTokenizer,
+ init_weights: bool = True,
+ ) -> None:
+ super().__init__()
+ self.config = config
+ self.tokenizer = tokenizer
+ self.semantic_token_ids = list(tokenizer.semantic_id_to_token_id.values())
+
+ # Slow transformer
+ self.embeddings = nn.Embedding(
+ config.vocab_size,
+ config.dim,
+ )
+ self.codebook_embeddings = nn.Embedding(
+ config.codebook_size * config.num_codebooks,
+ config.dim,
+ )
+ self.layers = nn.ModuleList(
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
+ )
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
+
+ if self.config.tie_word_embeddings is False:
+ self.output = nn.Linear(
+ config.dim,
+ config.vocab_size,
+ bias=False,
+ )
+
+ self.register_buffer(
+ "freqs_cis",
+ precompute_freqs_cis(
+ config.max_seq_len,
+ config.head_dim,
+ config.rope_base,
+ ),
+ persistent=False,
+ )
+ self.register_buffer(
+ "causal_mask",
+ torch.tril(
+ torch.ones(
+ config.max_seq_len,
+ config.max_seq_len,
+ dtype=torch.bool,
+ )
+ ),
+ persistent=False,
+ )
+
+ # For kv cache
+ self.max_batch_size = -1
+ self.max_seq_len = -1
+
+ if init_weights:
+ self.apply(self._init_weights)
+
+ def setup_caches(
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+ ):
+ if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
+ return
+
+ max_seq_len = find_multiple(max_seq_len, 8)
+ self.max_seq_len = max_seq_len
+ self.max_batch_size = max_batch_size
+
+ for b in self.layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ max_seq_len,
+ self.config.n_local_heads,
+ self.config.head_dim,
+ dtype=dtype,
+ )
+
+ def embed(self, inp: Tensor) -> Tensor:
+ embeds = []
+ semantic_token_ids_tensor = torch.tensor(
+ self.semantic_token_ids, device=inp.device, dtype=inp.dtype
+ )
+
+ for i in range(self.config.num_codebooks):
+ emb = self.codebook_embeddings(
+ inp[:, i + 1] + i * self.config.codebook_size
+ )
+ embeds.append(emb)
+
+ vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
+ vq_embeds_sum[~torch.isin(inp[:, 0], semantic_token_ids_tensor)] = 0
+ x = self.embeddings(inp[:, 0]) + vq_embeds_sum
+
+ return x
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> BaseTransformerForwardResult:
+ seq_len = inp.size(2)
+
+ # Here we want to merge the embeddings of the codebooks
+ x = self.embed(inp)
+
+ freqs_cis = self.freqs_cis[:seq_len]
+
+ # Not that the causal mask here follows the definition of scaled_dot_product_attention
+ # That is, FALSE means masked out
+ # To maintain consistency, key_padding_mask use TRUE to mask out
+ mask = None
+ if key_padding_mask is not None:
+ causal = self.causal_mask[:seq_len, :seq_len]
+ causal = rearrange(causal, "q k -> 1 1 q k")
+
+ atten_mask = rearrange(key_padding_mask, "b s -> b 1 1 s")
+ atten_mask = atten_mask.logical_not()
+ mask = causal & atten_mask
+
+ # return freqs_cis, mask
+
+ for layer in self.layers:
+ if self.config.use_gradient_checkpointing and self.training:
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
+ else:
+ x = layer(x, freqs_cis, mask)
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+
+ if self.config.tie_word_embeddings:
+ token_logits = F.linear(slow_out, self.embeddings.weight)
+ else:
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+ def forward_generate(
+ self,
+ inp: Tensor,
+ input_pos: Optional[Tensor] = None,
+ return_all: bool = False,
+ ) -> BaseTransformerForwardResult:
+ x = self.embed(inp)
+
+ if input_pos is None:
+ input_pos = torch.arange(inp.shape[-1], device=x.device)
+ max_seq_len = inp.shape[-1]
+ else:
+ max_seq_len = self.max_seq_len
+
+ mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
+ freqs_cis = self.freqs_cis[input_pos]
+
+ for layer in self.layers:
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
+
+ # If prefill, we only calculate the logits of last token
+ if x.size(1) > 1 and not return_all:
+ x = x[:, -1:]
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+
+ if self.config.is_reward_model:
+ token_logits = self.score_output(slow_out)
+ elif self.config.tie_word_embeddings:
+ token_logits = F.linear(slow_out, self.embeddings.weight)
+ else:
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ @staticmethod
+ def from_pretrained(
+ path: str,
+ load_weights: bool = False,
+ max_length: int | None = None,
+ lora_config: LoraConfig | None = None,
+ rope_base: int | None = None,
+ ) -> "BaseTransformer":
+ config = BaseModelArgs.from_pretrained(str(path))
+ if max_length is not None:
+ config.max_seq_len = max_length
+ logger.info(f"Override max_seq_len to {max_length}")
+
+ if rope_base is not None:
+ config.rope_base = rope_base
+ logger.info(f"Override rope_base to {rope_base}")
+
+ match config.model_type:
+ case "naive":
+ model_cls = NaiveTransformer
+ case "dual_ar":
+ model_cls = DualARTransformer
+ case _:
+ raise ValueError(f"Unknown model type: {config.model_type}")
+
+ tokenizer = FishTokenizer.from_pretrained(path)
+
+ logger.info(f"Loading model from {path}, config: {config}")
+ model = model_cls(config, tokenizer=tokenizer)
+
+ if lora_config is not None:
+ setup_lora(model, lora_config)
+ logger.info(f"LoRA setup: {lora_config}")
+
+ if load_weights is False:
+ logger.info("Randomly initialized model")
+ else:
+
+ if "int8" in str(Path(path)):
+ logger.info("Using int8 weight-only quantization!")
+ from tools.llama.quantize import WeightOnlyInt8QuantHandler
+
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
+ model = simple_quantizer.convert_for_runtime()
+
+ if "int4" in str(Path(path)):
+ logger.info("Using int4 quantization!")
+ path_comps = path.name.split("-")
+ assert path_comps[-2].startswith("g")
+ groupsize = int(path_comps[-2][1:])
+ from tools.llama.quantize import WeightOnlyInt4QuantHandler
+
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
+ model = simple_quantizer.convert_for_runtime()
+
+ weights = torch.load(
+ Path(path) / "model.pth",
+ map_location="cpu",
+ mmap=True,
+ weights_only=True,
+ )
+
+ if "state_dict" in weights:
+ logger.warning(
+ "Using a TextToSemantic LightningModule checkpoint, "
+ "please make sure it is a full model, not a LoRA model."
+ )
+ weights = weights["state_dict"]
+
+ if next(iter(weights.keys())).startswith("model."):
+ logger.info(
+ f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
+ )
+ new_weights = OrderedDict()
+ for k, v in weights.items():
+ new_weights[k.replace("model.", "")] = v
+ weights = new_weights
+
+ # Remove audio related weights
+ for k in list(weights.keys()):
+ if "audio_" in k:
+ weights.pop(k)
+
+ # Verify the name and shape of parameters since strict=False in load_state_dict.
+ for k, v in model.named_parameters():
+ if k not in weights:
+ logger.warning(f"No weight for {k}")
+ elif v.shape != weights[k].shape:
+ logger.warning(
+ f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
+ )
+
+ err = model.load_state_dict(weights, strict=False, assign=True)
+ logger.info(f"Loaded weights with error: {err}")
+
+ return model
+
+ def save_pretrained(self, path: str, drop_lora: bool = False):
+ path = Path(path)
+ path.mkdir(parents=True, exist_ok=True)
+
+ self.config.save(path / "config.json")
+ state_dict = self.state_dict()
+
+ if drop_lora:
+ for key in list(state_dict.keys()):
+ if "lora" not in key:
+ continue
+
+ state_dict.pop(key)
+ logger.info(f"Drop LoRA parameter: {key}")
+
+ torch.save(state_dict, path / "model.pth")
+ self.tokenizer.save_pretrained(path)
+
+
+class NaiveTransformer(BaseTransformer):
+ def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
+
+ self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.codebook_output = nn.Linear(
+ config.dim,
+ config.codebook_size * config.num_codebooks,
+ bias=False,
+ )
+
+ self.apply(self._init_weights)
+
+ def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
+ token_logits = result.logits
+ x = result.hidden_states
+
+ # Codebook
+ codebook_logits = self.codebook_output(self.codebook_norm(x))
+ codebook_logits = rearrange(
+ codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
+ )
+
+ return TransformerForwardResult(
+ token_logits=token_logits,
+ codebook_logits=codebook_logits,
+ )
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ result = super().forward(
+ inp=inp,
+ key_padding_mask=key_padding_mask,
+ )
+ return self.decode(result)
+
+ def forward_generate(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> TransformerForwardResult:
+ result = super().forward_generate(x, input_pos)
+ return self.decode(result)
+
+
+class DualARTransformer(BaseTransformer):
+ def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
+
+ # Project to fast dim if needed
+ if config.fast_dim is not None and config.fast_dim != config.dim:
+ self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
+ else:
+ self.fast_project_in = nn.Identity()
+
+ # Fast transformer
+ self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
+
+ # The equivalent bs is so large that sdpa doesn't work
+ override_config = dataclasses.replace(
+ config,
+ dim=config.fast_dim,
+ n_head=config.fast_n_head,
+ n_local_heads=config.fast_n_local_heads,
+ head_dim=config.fast_head_dim,
+ intermediate_size=config.fast_intermediate_size,
+ attention_qkv_bias=config.fast_attention_qkv_bias,
+ attention_qk_norm=config.fast_attention_qk_norm,
+ attention_o_bias=config.fast_attention_o_bias,
+ )
+
+ self.fast_layers = nn.ModuleList(
+ TransformerBlock(override_config, use_sdpa=False)
+ for _ in range(config.n_fast_layer)
+ )
+ self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
+ self.fast_output = nn.Linear(
+ config.fast_dim,
+ config.codebook_size,
+ bias=False,
+ )
+
+ self.register_buffer(
+ "fast_freqs_cis",
+ precompute_freqs_cis(
+ config.num_codebooks,
+ config.fast_head_dim,
+ config.rope_base,
+ ),
+ persistent=False,
+ )
+ self.apply(self._init_weights)
+
+ def setup_caches(
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+ ):
+ super().setup_caches(max_batch_size, max_seq_len, dtype)
+
+ # Fast transformer
+ # The max seq len here is the number of codebooks
+ for b in self.fast_layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ self.config.num_codebooks,
+ self.config.fast_n_local_heads,
+ self.config.fast_head_dim,
+ dtype=dtype,
+ )
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ parent_result = super().forward(inp, key_padding_mask)
+ token_logits = parent_result.logits
+ x = parent_result.hidden_states
+ x = self.fast_project_in(x)
+
+ # Fast transformer
+ fast_seq_len = self.config.num_codebooks
+ fast_mask = self.causal_mask[
+ None, None, :fast_seq_len, :fast_seq_len
+ ] # (B, N, Q, K)
+
+ # Drop the last token and rotate left
+ codebooks = inp[:, 1:-1, 1:]
+ codebooks = F.pad(codebooks, (0, 1), value=0)
+ codebook_embeddings = self.fast_embeddings(codebooks)
+ x = torch.cat([x[:, None], codebook_embeddings], dim=1)
+ b, s = x.size(0), x.size(2)
+ x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
+
+ # Remove padded part
+ codebooks = rearrange(codebooks, "b n s -> (b s) n")
+ codebook_mask = (codebooks == 0).all(dim=-1)
+
+ if torch.all(codebook_mask):
+ # If all codebooks are padded, we keep first 8 to make sure the model runs
+ codebook_mask[:8] = False
+
+ x_bs, x_len = x.size(0), x.size(1)
+ x = x[~codebook_mask]
+
+ for layer in self.fast_layers:
+ if self.config.use_gradient_checkpointing and self.training:
+ x = checkpoint(
+ layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
+ )
+ else:
+ x = layer(x, self.fast_freqs_cis, fast_mask)
+
+ # unflatten the batch and num_codebooks
+ fast_out = self.fast_norm(x)
+ codebook_logits = self.fast_output(fast_out)
+
+ # Re-pad the codebook_logits
+ buffer = torch.zeros(
+ x_bs,
+ x_len,
+ codebook_logits.size(-1),
+ device=codebook_logits.device,
+ dtype=codebook_logits.dtype,
+ )
+ buffer[~codebook_mask] = codebook_logits
+ codebook_logits = buffer
+
+ assert codebook_logits.shape[1] == self.config.num_codebooks
+ codebook_logits = rearrange(
+ codebook_logits,
+ "(b s) n d -> b s n d",
+ b=b,
+ s=s,
+ n=self.config.num_codebooks,
+ )
+
+ return TransformerForwardResult(
+ token_logits=token_logits,
+ codebook_logits=codebook_logits,
+ )
+
+ def forward_generate_fast(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> Tensor:
+ # Fast transformer
+ x = x.view(1, 1, -1)
+
+ fast_mask = self.causal_mask[
+ None, None, input_pos, : self.config.num_codebooks
+ ] # (B, N, Q, K)
+ fast_freqs_cis = self.fast_freqs_cis[input_pos]
+
+ for layer in self.fast_layers:
+ x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
+
+ # unflatten the batch and num_codebooks
+ fast_out = self.fast_norm(x) # only take the last token
+ codebook_logits = self.fast_output(fast_out)
+
+ return codebook_logits
+
+ def forward_generate(
+ self,
+ x: Tensor,
+ input_pos: Optional[Tensor] = None,
+ vq_masks: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ x = super().forward_generate(x, input_pos, vq_masks)
+ x.hidden_states = self.fast_project_in(x.hidden_states)
+ return x
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
+ super().__init__()
+ self.attention = Attention(config, use_sdpa=use_sdpa)
+ self.feed_forward = FeedForward(config)
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
+
+ def forward(
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
+ ) -> Tensor:
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
+ out = h + self.feed_forward(self.ffn_norm(h))
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+ # key, query, value projections for all heads, but in a batch
+ self.wqkv = nn.Linear(
+ config.dim, total_head_dim, bias=config.attention_qkv_bias
+ )
+ self.wo = nn.Linear(
+ config.n_head * config.head_dim, config.dim, bias=config.attention_o_bias
+ )
+ self.kv_cache = None
+
+ if config.attention_qk_norm:
+ self.q_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
+ self.k_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
+
+ self.dropout = config.dropout
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.n_local_heads = config.n_local_heads
+ self.dim = config.dim
+ self.use_sdpa = use_sdpa
+ self.attention_qk_norm = config.attention_qk_norm
+ self.config = config
+
+ self._register_load_state_dict_pre_hook(self.load_hook)
+
+ def load_hook(self, state_dict, prefix, *args):
+ if prefix + "wq.weight" in state_dict:
+ wq = state_dict.pop(prefix + "wq.weight")
+ wk = state_dict.pop(prefix + "wk.weight")
+ wv = state_dict.pop(prefix + "wv.weight")
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
+
+ def forward(
+ self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ q_size = self.n_head * self.head_dim
+ kv_size = self.n_local_heads * self.head_dim
+ q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+
+ if self.attention_qk_norm:
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+
+ q = apply_rotary_emb(q, freqs_cis)
+ k = apply_rotary_emb(k, freqs_cis)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ if self.kv_cache is not None:
+ k, v = self.kv_cache.update(input_pos, k, v)
+
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+
+ if self.use_sdpa:
+ if mask is None:
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.dropout if self.training else 0.0,
+ is_causal=True,
+ # No third party attn_mask here to use flash_attention
+ )
+ else:
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+ else:
+ y = self.eq_scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
+
+ return self.wo(y)
+
+ def eq_scaled_dot_product_attention(
+ self,
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ ) -> torch.Tensor:
+ # This is a standard scaled dot product attention
+ # It's low efficient, but it doesn't raise cuda error
+
+ L, S = query.size(-2), key.size(-2)
+ scale_factor = 1 / math.sqrt(query.size(-1))
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ attn_weight += attn_bias
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+
+ return attn_weight @ value
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: BaseModelArgs) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
+ """
+ Precomputes frequency tensors for complex exponentials (cis)
+
+ Args:
+ seq_len: Length of the sequence for which positional embeddings are needed.
+ n_elem: Number of elements in the frequency tensor.
+ base: Base value for the frequency scaling (default: 10000).
+
+ Returns:
+ A tensor containing the precomputed frequencies in real and imaginary parts (bfloat16).
+ """
+ freqs = 1.0 / (
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+ )
+ t = torch.arange(seq_len, device=freqs.device)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+ return cache.to(dtype=torch.bfloat16)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+ ],
+ -1,
+ )
+
+ x_out2 = x_out2.flatten(3)
+ return x_out2.type_as(x)
diff --git a/fish_speech/models/text2semantic/lora.py b/fish_speech/models/text2semantic/lora.py
index bb4a6192c469bce1535b5f93e147f89ce05cca04..647ca6fcccf038e17d2cf91a2874281dff3e0938 100644
--- a/fish_speech/models/text2semantic/lora.py
+++ b/fish_speech/models/text2semantic/lora.py
@@ -1,92 +1,92 @@
-from dataclasses import dataclass
-
-import loralib as lora
-
-
-@dataclass
-class LoraConfig:
- r: int
- lora_alpha: float
- lora_dropout: float = 0.0
-
-
-def setup_lora(model, lora_config):
- # Replace the embedding layer with a LoRA layer
- model.embeddings = lora.Embedding(
- num_embeddings=model.embeddings.num_embeddings,
- embedding_dim=model.embeddings.embedding_dim,
- padding_idx=model.embeddings.padding_idx,
- r=lora_config.r,
- lora_alpha=lora_config.lora_alpha,
- )
-
- model.codebook_embeddings = lora.Embedding(
- num_embeddings=model.codebook_embeddings.num_embeddings,
- embedding_dim=model.codebook_embeddings.embedding_dim,
- padding_idx=model.codebook_embeddings.padding_idx,
- r=lora_config.r,
- lora_alpha=lora_config.lora_alpha,
- )
-
- # Replace output layer with a LoRA layer
- linears = [(model, "output")]
-
- # Replace all linear layers with LoRA layers
- for layer in model.layers:
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
- linears.extend(
- [
- (layer.feed_forward, "w1"),
- (layer.feed_forward, "w2"),
- (layer.feed_forward, "w3"),
- ]
- )
-
- if hasattr(model, "fast_layers"):
- model.fast_embeddings = lora.Embedding(
- num_embeddings=model.fast_embeddings.num_embeddings,
- embedding_dim=model.fast_embeddings.embedding_dim,
- padding_idx=model.fast_embeddings.padding_idx,
- r=lora_config.r,
- lora_alpha=lora_config.lora_alpha,
- )
-
- # Dual-AR model
- linears.append((model, "fast_output"))
-
- for layer in model.fast_layers:
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
- linears.extend(
- [
- (layer.feed_forward, "w1"),
- (layer.feed_forward, "w2"),
- (layer.feed_forward, "w3"),
- ]
- )
-
- for module, layer in linears:
- updated_linear = lora.Linear(
- in_features=getattr(module, layer).in_features,
- out_features=getattr(module, layer).out_features,
- bias=getattr(module, layer).bias,
- r=lora_config.r,
- lora_alpha=lora_config.lora_alpha,
- lora_dropout=lora_config.lora_dropout,
- )
- setattr(module, layer, updated_linear)
-
- # Mark only the LoRA layers as trainable
- lora.mark_only_lora_as_trainable(model, bias="none")
-
-
-def get_merged_state_dict(model):
- # This line will merge the state dict of the model and the LoRA parameters
- model.eval()
-
- # Then we need to remove the LoRA parameters from the state dict
- state_dict = model.state_dict()
- for name in list(state_dict.keys()):
- if "lora" in name:
- state_dict.pop(name)
-
- return state_dict
+from dataclasses import dataclass
+
+import loralib as lora
+
+
+@dataclass
+class LoraConfig:
+ r: int
+ lora_alpha: float
+ lora_dropout: float = 0.0
+
+
+def setup_lora(model, lora_config):
+ # Replace the embedding layer with a LoRA layer
+ model.embeddings = lora.Embedding(
+ num_embeddings=model.embeddings.num_embeddings,
+ embedding_dim=model.embeddings.embedding_dim,
+ padding_idx=model.embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ model.codebook_embeddings = lora.Embedding(
+ num_embeddings=model.codebook_embeddings.num_embeddings,
+ embedding_dim=model.codebook_embeddings.embedding_dim,
+ padding_idx=model.codebook_embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ # Replace output layer with a LoRA layer
+ linears = [(model, "output")]
+
+ # Replace all linear layers with LoRA layers
+ for layer in model.layers:
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+ linears.extend(
+ [
+ (layer.feed_forward, "w1"),
+ (layer.feed_forward, "w2"),
+ (layer.feed_forward, "w3"),
+ ]
+ )
+
+ if hasattr(model, "fast_layers"):
+ model.fast_embeddings = lora.Embedding(
+ num_embeddings=model.fast_embeddings.num_embeddings,
+ embedding_dim=model.fast_embeddings.embedding_dim,
+ padding_idx=model.fast_embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ # Dual-AR model
+ linears.append((model, "fast_output"))
+
+ for layer in model.fast_layers:
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+ linears.extend(
+ [
+ (layer.feed_forward, "w1"),
+ (layer.feed_forward, "w2"),
+ (layer.feed_forward, "w3"),
+ ]
+ )
+
+ for module, layer in linears:
+ updated_linear = lora.Linear(
+ in_features=getattr(module, layer).in_features,
+ out_features=getattr(module, layer).out_features,
+ bias=getattr(module, layer).bias,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ lora_dropout=lora_config.lora_dropout,
+ )
+ setattr(module, layer, updated_linear)
+
+ # Mark only the LoRA layers as trainable
+ lora.mark_only_lora_as_trainable(model, bias="none")
+
+
+def get_merged_state_dict(model):
+ # This line will merge the state dict of the model and the LoRA parameters
+ model.eval()
+
+ # Then we need to remove the LoRA parameters from the state dict
+ state_dict = model.state_dict()
+ for name in list(state_dict.keys()):
+ if "lora" in name:
+ state_dict.pop(name)
+
+ return state_dict
diff --git a/fish_speech/text/__init__.py b/fish_speech/text/__init__.py
index c35b4e0dbd174b36229350f27a21b5acf0e9825b..d740bd8eed447d162e55b165965dec17130377ce 100644
--- a/fish_speech/text/__init__.py
+++ b/fish_speech/text/__init__.py
@@ -1,4 +1,4 @@
-from .clean import clean_text
-from .spliter import split_text
-
-__all__ = ["clean_text", "split_text"]
+from .clean import clean_text
+from .spliter import split_text
+
+__all__ = ["clean_text", "split_text"]
diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py
index 2aba28fc1bc7fc6054e37534ece06c743bff9f6c..68428c406c018a5bb156908b80341429a78c0301 100644
--- a/fish_speech/text/clean.py
+++ b/fish_speech/text/clean.py
@@ -1,37 +1,37 @@
-import re
-
-SYMBOLS_MAPPING = {
- "‘": "'",
- "’": "'",
-}
-
-REPLACE_SYMBOL_REGEX = re.compile(
- "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
-)
-
-
-EMOJI_REGEX = re.compile(
- "["
- "\U0001F600-\U0001F64F" # emoticons
- "\U0001F300-\U0001F5FF" # symbols & pictographs
- "\U0001F680-\U0001F6FF" # transport & map symbols
- "\U0001F1E0-\U0001F1FF" # flags (iOS)
- "]+",
- flags=re.UNICODE,
-)
-
-
-def clean_text(text):
- # Clean the text
- text = text.strip()
-
- # Replace all chinese symbols with their english counterparts
- text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
-
- # Remove emojis
- text = EMOJI_REGEX.sub(r"", text)
-
- # Remove continuous periods (...) and commas (,,,)
- text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
-
- return text
+import re
+
+SYMBOLS_MAPPING = {
+ "‘": "'",
+ "’": "'",
+}
+
+REPLACE_SYMBOL_REGEX = re.compile(
+ "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
+)
+
+
+EMOJI_REGEX = re.compile(
+ "["
+ "\U0001f600-\U0001f64f" # emoticons
+ "\U0001f300-\U0001f5ff" # symbols & pictographs
+ "\U0001f680-\U0001f6ff" # transport & map symbols
+ "\U0001f1e0-\U0001f1ff" # flags (iOS)
+ "]+",
+ flags=re.UNICODE,
+)
+
+
+def clean_text(text):
+ # Clean the text
+ text = text.strip()
+
+ # Replace all chinese symbols with their english counterparts
+ text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
+
+ # Remove emojis
+ text = EMOJI_REGEX.sub(r"", text)
+
+ # Remove continuous periods (...) and commas (,,,)
+ text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
+
+ return text
diff --git a/fish_speech/text/spliter.py b/fish_speech/text/spliter.py
index 30661e4ef3796250e539aa367467bac22ecbbfb8..df079addb81cd91145f0c68f70b0da0d7251f036 100644
--- a/fish_speech/text/spliter.py
+++ b/fish_speech/text/spliter.py
@@ -1,130 +1,130 @@
-import re
-import string
-
-from fish_speech.text.clean import clean_text
-
-
-def utf_8_len(text: str):
- return len(text.encode("utf-8"))
-
-
-def break_text(texts, length, splits: set):
- for text in texts:
- if utf_8_len(text) <= length:
- yield text
- continue
-
- curr = ""
- for char in text:
- curr += char
-
- if char in splits:
- yield curr
- curr = ""
-
- if curr:
- yield curr
-
-
-def break_text_by_length(texts, length):
- for text in texts:
- if utf_8_len(text) <= length:
- yield text
- continue
-
- curr = ""
- for char in text:
- curr += char
-
- if utf_8_len(curr) >= length:
- yield curr
- curr = ""
-
- if curr:
- yield curr
-
-
-def add_cleaned(curr, segments):
- curr = curr.strip()
- if curr and not all(c.isspace() or c in string.punctuation for c in curr):
- segments.append(curr)
-
-
-def protect_float(text):
- # Turns 3.14 into <3_f_14> to prevent splitting
- return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
-
-
-def unprotect_float(text):
- # Turns <3_f_14> into 3.14
- return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
-
-
-def split_text(text, length):
- text = clean_text(text)
-
- # Break the text into pieces with following rules:
- # 1. Split the text at ".", "!", "?" if text is NOT a float
- # 2. If the text is longer than length, split at ","
- # 3. If the text is still longer than length, split at " "
- # 4. If the text is still longer than length, split at any character to length
-
- texts = [text]
- texts = map(protect_float, texts)
- texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
- texts = map(unprotect_float, texts)
- texts = break_text(texts, length, {",", ","})
- texts = break_text(texts, length, {" "})
- texts = list(break_text_by_length(texts, length))
-
- # Then, merge the texts into segments with length <= length
- segments = []
- curr = ""
-
- for text in texts:
- if utf_8_len(curr) + utf_8_len(text) <= length:
- curr += text
- else:
- add_cleaned(curr, segments)
- curr = text
-
- if curr:
- add_cleaned(curr, segments)
-
- return segments
-
-
-if __name__ == "__main__":
- # Test the split_text function
-
- text = "This is a test sentence. This is another test sentence. And a third one."
-
- assert split_text(text, 50) == [
- "This is a test sentence.",
- "This is another test sentence. And a third one.",
- ]
- assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
- assert split_text(" ", 10) == []
- assert split_text("a", 10) == ["a"]
-
- text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
- assert split_text(text, 50) == [
- "This is a test sentence with only commas,",
- "and no dots, and no exclamation marks,",
- "and no question marks, and no newlines.",
- ]
-
- text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
- # First half split at " ", second half split at ","
- assert split_text(text, 50) == [
- "This is a test sentence This is a test sentence",
- "This is a test sentence. This is a test sentence,",
- "This is a test sentence, This is a test sentence.",
- ]
-
- text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
- assert split_text(text, 50) == [
- "这是一段很长的中文文本,",
- "而且没有句号,也没有感叹号,",
- "也没有问号,也没有换行符.",
- ]
+import re
+import string
+
+from fish_speech.text.clean import clean_text
+
+
+def utf_8_len(text: str):
+ return len(text.encode("utf-8"))
+
+
+def break_text(texts, length, splits: set):
+ for text in texts:
+ if utf_8_len(text) <= length:
+ yield text
+ continue
+
+ curr = ""
+ for char in text:
+ curr += char
+
+ if char in splits:
+ yield curr
+ curr = ""
+
+ if curr:
+ yield curr
+
+
+def break_text_by_length(texts, length):
+ for text in texts:
+ if utf_8_len(text) <= length:
+ yield text
+ continue
+
+ curr = ""
+ for char in text:
+ curr += char
+
+ if utf_8_len(curr) >= length:
+ yield curr
+ curr = ""
+
+ if curr:
+ yield curr
+
+
+def add_cleaned(curr, segments):
+ curr = curr.strip()
+ if curr and not all(c.isspace() or c in string.punctuation for c in curr):
+ segments.append(curr)
+
+
+def protect_float(text):
+ # Turns 3.14 into <3_f_14> to prevent splitting
+ return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
+
+
+def unprotect_float(text):
+ # Turns <3_f_14> into 3.14
+ return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
+
+
+def split_text(text, length):
+ text = clean_text(text)
+
+ # Break the text into pieces with following rules:
+ # 1. Split the text at ".", "!", "?" if text is NOT a float
+ # 2. If the text is longer than length, split at ","
+ # 3. If the text is still longer than length, split at " "
+ # 4. If the text is still longer than length, split at any character to length
+
+ texts = [text]
+ texts = map(protect_float, texts)
+ texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
+ texts = map(unprotect_float, texts)
+ texts = break_text(texts, length, {",", ","})
+ texts = break_text(texts, length, {" "})
+ texts = list(break_text_by_length(texts, length))
+
+ # Then, merge the texts into segments with length <= length
+ segments = []
+ curr = ""
+
+ for text in texts:
+ if utf_8_len(curr) + utf_8_len(text) <= length:
+ curr += text
+ else:
+ add_cleaned(curr, segments)
+ curr = text
+
+ if curr:
+ add_cleaned(curr, segments)
+
+ return segments
+
+
+if __name__ == "__main__":
+ # Test the split_text function
+
+ text = "This is a test sentence. This is another test sentence. And a third one."
+
+ assert split_text(text, 50) == [
+ "This is a test sentence.",
+ "This is another test sentence. And a third one.",
+ ]
+ assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
+ assert split_text(" ", 10) == []
+ assert split_text("a", 10) == ["a"]
+
+ text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
+ assert split_text(text, 50) == [
+ "This is a test sentence with only commas,",
+ "and no dots, and no exclamation marks,",
+ "and no question marks, and no newlines.",
+ ]
+
+ text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
+ # First half split at " ", second half split at ","
+ assert split_text(text, 50) == [
+ "This is a test sentence This is a test sentence",
+ "This is a test sentence. This is a test sentence,",
+ "This is a test sentence, This is a test sentence.",
+ ]
+
+ text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
+ assert split_text(text, 50) == [
+ "这是一段很长的中文文本,",
+ "而且没有句号,也没有感叹号,",
+ "也没有问号,也没有换行符.",
+ ]
diff --git a/fish_speech/tokenizer.py b/fish_speech/tokenizer.py
index f4d512d31263dcb2abc95c3a7bf3cd4bde8c4830..9a140f9db98269b9abac8f36d5b613500f2cb881 100644
--- a/fish_speech/tokenizer.py
+++ b/fish_speech/tokenizer.py
@@ -1,152 +1,179 @@
-import base64
-import json
-import logging
-from pathlib import Path
-
-import tiktoken
-
-logger = logging.getLogger(__name__)
-
-# This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
-FISH_TIKTOKEN_PATTERN = "|".join(
- [
- r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
- r"\p{P}",
- r"[^\r\n\p{L}\p{N}]?\p{L}+",
- r"\p{N}",
- r" ?[^\s\p{L}\p{N}]+[\r\n]*",
- r"\s*[\r\n]+",
- r"\s+(\?!\S)",
- r"\s+",
- ]
-)
-TIKTOKEN_MAX_ENCODE_CHARS = 400_000
-
-BOS_TOKEN = "<|begin_of_text|>"
-EOS_TOKEN = "<|end_of_text|>"
-PAD_TOKEN = "<|pad|>"
-IM_START_TOKEN = "<|im_start|>"
-IM_END_TOKEN = "<|im_end|>"
-
-MODALITY_TEXT_TOKEN = "<|text|>"
-MODALITY_VOICE_TOKEN = "<|voice|>"
-MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
-MODALITY_TOKENS = {
- "text": MODALITY_TEXT_TOKEN,
- "voice": MODALITY_VOICE_TOKEN,
- "interleave": MODALITY_INTERLEAVE_TOKEN,
-}
-
-PLACEHOLDER_TOKEN = [""] * 4
-for i in range(4):
- PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>"
-
-SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
-SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
-
-# Warning: when you add a new special token, you should only add it to the end of the list.
-ALL_SPECIAL_TOKENS = [
- BOS_TOKEN,
- EOS_TOKEN,
- PAD_TOKEN,
- IM_START_TOKEN,
- IM_END_TOKEN,
- PLACEHOLDER_TOKEN[0],
- PLACEHOLDER_TOKEN[1],
- PLACEHOLDER_TOKEN[2],
- PLACEHOLDER_TOKEN[3],
- MODALITY_TEXT_TOKEN,
- MODALITY_VOICE_TOKEN,
- MODALITY_INTERLEAVE_TOKEN,
- *SEMANTIC_TOKENS,
-]
-
-
-class FishTokenizer:
- def __init__(self, model_path: str) -> None:
- mergeable_ranks = self.load_tiktoken_bpe(model_path)
- special_token_begin = len(mergeable_ranks)
- self.all_special_tokens_with_ids = {
- token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS)
- }
- self.semantic_id_to_token_id = {
- i: self.all_special_tokens_with_ids[token]
- for i, token in enumerate(SEMANTIC_TOKENS)
- }
- self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]]
- self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]]
-
- self.tkt_model = tiktoken.core.Encoding(
- name=Path(model_path).stem,
- pat_str=FISH_TIKTOKEN_PATTERN,
- mergeable_ranks=mergeable_ranks,
- special_tokens=self.all_special_tokens_with_ids,
- )
-
- @staticmethod
- def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
- data = {}
- for line in open(tiktoken_bpe_file).read().splitlines():
- if not line:
- continue
- token, rank = line.split()
- data[base64.b64decode(token)] = int(rank)
- return data
-
- def get_token_id(self, token: str) -> int:
- return self.all_special_tokens_with_ids[token]
-
- def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
- assert isinstance(s, str)
-
- subs = []
- for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
- subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
-
- if allowed_special is True:
- allowed_special = self.tkt_model.special_tokens_set
- elif allowed_special is False:
- allowed_special = set()
-
- return sum(
- self.tkt_model.encode_batch(
- subs, allowed_special=allowed_special, disallowed_special=set()
- ),
- start=[],
- )
-
- def decode(self, tokens: list[int]) -> str:
- return self.tkt_model.decode(tokens)
-
- def save_pretrained(self, path: str):
- path = Path(path)
- path.mkdir(parents=True, exist_ok=True)
-
- with open(path / "tokenizer.tiktoken", "w") as f:
- for token, rank in self.tkt_model._mergeable_ranks.items():
- f.write(f"{base64.b64encode(token).decode()} {rank}\n")
-
- with open(path / "special_tokens.json", "w") as f:
- json.dump(
- self.all_special_tokens_with_ids,
- f,
- indent=2,
- ensure_ascii=False,
- )
-
- @staticmethod
- def from_pretrained(path: str):
- return FishTokenizer(Path(path) / "tokenizer.tiktoken")
-
-
-if __name__ == "__main__":
- tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken")
- tokenizer.save_pretrained("checkpoints/fish-speech-0.5B")
- tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B")
-
- print(
- [
- tokenizer.decode([i])
- for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}")
- ]
- )
+import base64
+import json
+import logging
+import re
+from pathlib import Path
+
+import tiktoken
+
+logger = logging.getLogger(__name__)
+
+# This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
+FISH_TIKTOKEN_PATTERN = "|".join(
+ [
+ r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
+ r"\p{P}",
+ r"[^\r\n\p{L}\p{N}]?\p{L}+",
+ r"\p{N}",
+ r" ?[^\s\p{L}\p{N}]+[\r\n]*",
+ r"\s*[\r\n]+",
+ r"\s+(\?!\S)",
+ r"\s+",
+ ]
+)
+TIKTOKEN_MAX_ENCODE_CHARS = 400_000
+
+BOS_TOKEN = "<|begin_of_text|>"
+EOS_TOKEN = "<|end_of_text|>"
+PAD_TOKEN = "<|pad|>"
+IM_START_TOKEN = "<|im_start|>"
+IM_END_TOKEN = "<|im_end|>"
+PHONEME_START_TOKEN = "<|phoneme_start|>"
+PHONEME_END_TOKEN = "<|phoneme_end|>"
+TOOL_CALL_START_TOKEN = "<|tool_call_start|>"
+TOOL_CALL_END_TOKEN = "<|tool_call_end|>"
+
+MODALITY_TEXT_TOKEN = "<|text|>"
+MODALITY_VOICE_TOKEN = "<|voice|>"
+MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
+AUDIO_START_TOKEN = "<|audio_start|>"
+AUDIO_END_TOKEN = "<|audio_end|>"
+AUDIO_EMBED_TOKEN = "<|audio|>"
+MODALITY_TOKENS = {
+ "text": MODALITY_TEXT_TOKEN,
+ "voice": MODALITY_VOICE_TOKEN,
+ "interleave": MODALITY_INTERLEAVE_TOKEN,
+}
+
+SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
+SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
+
+# Warning: when you add a new special token, you should only add it to the end of the list.
+ALL_SPECIAL_TOKENS = [
+ BOS_TOKEN,
+ EOS_TOKEN,
+ PAD_TOKEN,
+ IM_START_TOKEN,
+ IM_END_TOKEN,
+ PHONEME_START_TOKEN,
+ PHONEME_END_TOKEN,
+ TOOL_CALL_START_TOKEN,
+ TOOL_CALL_END_TOKEN,
+ MODALITY_TEXT_TOKEN,
+ MODALITY_VOICE_TOKEN,
+ MODALITY_INTERLEAVE_TOKEN,
+ AUDIO_START_TOKEN,
+ AUDIO_END_TOKEN,
+ AUDIO_EMBED_TOKEN,
+ *SEMANTIC_TOKENS,
+]
+
+
+class FishTokenizer:
+ def __init__(
+ self, model_path: str, special_tokens: list[str] = ALL_SPECIAL_TOKENS
+ ) -> None:
+ mergeable_ranks = self.load_tiktoken_bpe(model_path)
+ special_token_begin = len(mergeable_ranks)
+ self.all_special_tokens_with_ids = {
+ token: special_token_begin + i for i, token in enumerate(special_tokens)
+ }
+
+ self.semantic_id_to_token_id = {}
+ end_idx = 0
+ for token in special_tokens:
+ if token.startswith("<|semantic:"):
+ idx = int(re.match(r"<\|semantic:(\d+)\|>", token).group(1))
+ self.semantic_id_to_token_id[idx] = self.all_special_tokens_with_ids[
+ token
+ ]
+
+ if idx > end_idx:
+ end_idx = idx
+
+ self.semantic_begin_id = self.semantic_id_to_token_id[0]
+ self.semantic_end_id = self.semantic_id_to_token_id[end_idx]
+
+ self.tkt_model = tiktoken.core.Encoding(
+ name=Path(model_path).stem,
+ pat_str=FISH_TIKTOKEN_PATTERN,
+ mergeable_ranks=mergeable_ranks,
+ special_tokens=self.all_special_tokens_with_ids,
+ )
+
+ @property
+ def vocab_size(self):
+ return len(self.tkt_model._mergeable_ranks)
+
+ @property
+ def num_special_tokens(self):
+ return len(self.all_special_tokens_with_ids)
+
+ @staticmethod
+ def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
+ data = {}
+ for line in open(tiktoken_bpe_file).read().splitlines():
+ if not line:
+ continue
+ token, rank = line.split()
+ if token == "=":
+ continue
+ data[base64.b64decode(token)] = int(rank)
+ return data
+
+ def get_token_id(self, token: str) -> int:
+ return self.all_special_tokens_with_ids[token]
+
+ def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
+ assert isinstance(s, str)
+
+ subs = []
+ for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
+ subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
+
+ if allowed_special is True:
+ allowed_special = self.tkt_model.special_tokens_set
+ elif allowed_special is False:
+ allowed_special = set()
+
+ return sum(
+ self.tkt_model.encode_batch(
+ subs, allowed_special=allowed_special, disallowed_special=set()
+ ),
+ start=[],
+ )
+
+ def decode(self, tokens: list[int]) -> str:
+ return self.tkt_model.decode(tokens)
+
+ def save_pretrained(self, path: str):
+ path = Path(path)
+ path.mkdir(parents=True, exist_ok=True)
+
+ with open(path / "tokenizer.tiktoken", "w") as f:
+ for token, rank in self.tkt_model._mergeable_ranks.items():
+ a = base64.b64encode(token).decode()
+ if a == "":
+ a = "="
+ f.write(f"{a} {rank}\n")
+
+ with open(path / "special_tokens.json", "w") as f:
+ json.dump(
+ self.all_special_tokens_with_ids,
+ f,
+ indent=2,
+ ensure_ascii=False,
+ )
+
+ @staticmethod
+ def from_pretrained(path: str):
+ special_tokens_path = Path(path) / "special_tokens.json"
+ if special_tokens_path.exists():
+ with open(special_tokens_path) as f:
+ all_special_tokens_with_ids = json.load(f)
+ else:
+ all_special_tokens_with_ids = ALL_SPECIAL_TOKENS
+
+ return FishTokenizer(
+ Path(path) / "tokenizer.tiktoken", all_special_tokens_with_ids
+ )
diff --git a/fish_speech/utils/__init__.py b/fish_speech/utils/__init__.py
index 185517110af6cacfbda3c3d9561e6081d825fca4..53cf2f23174ddac9bf523730aca2f6a9965d134a 100644
--- a/fish_speech/utils/__init__.py
+++ b/fish_speech/utils/__init__.py
@@ -1,24 +1,24 @@
-from .braceexpand import braceexpand
-from .context import autocast_exclude_mps
-from .file import get_latest_checkpoint
-from .instantiators import instantiate_callbacks, instantiate_loggers
-from .logger import RankedLogger
-from .logging_utils import log_hyperparameters
-from .rich_utils import enforce_tags, print_config_tree
-from .utils import extras, get_metric_value, set_seed, task_wrapper
-
-__all__ = [
- "enforce_tags",
- "extras",
- "get_metric_value",
- "RankedLogger",
- "instantiate_callbacks",
- "instantiate_loggers",
- "log_hyperparameters",
- "print_config_tree",
- "task_wrapper",
- "braceexpand",
- "get_latest_checkpoint",
- "autocast_exclude_mps",
- "set_seed",
-]
+from .braceexpand import braceexpand
+from .context import autocast_exclude_mps
+from .file import get_latest_checkpoint
+from .instantiators import instantiate_callbacks, instantiate_loggers
+from .logger import RankedLogger
+from .logging_utils import log_hyperparameters
+from .rich_utils import enforce_tags, print_config_tree
+from .utils import extras, get_metric_value, set_seed, task_wrapper
+
+__all__ = [
+ "enforce_tags",
+ "extras",
+ "get_metric_value",
+ "RankedLogger",
+ "instantiate_callbacks",
+ "instantiate_loggers",
+ "log_hyperparameters",
+ "print_config_tree",
+ "task_wrapper",
+ "braceexpand",
+ "get_latest_checkpoint",
+ "autocast_exclude_mps",
+ "set_seed",
+]
diff --git a/fish_speech/utils/braceexpand.py b/fish_speech/utils/braceexpand.py
index 8888977ce194fc5caa9e85bcf548e3bc42a3c52c..f3ac739f01f7e10e039c68c1157d6c761064f974 100644
--- a/fish_speech/utils/braceexpand.py
+++ b/fish_speech/utils/braceexpand.py
@@ -1,217 +1,217 @@
-"""
-Bash-style brace expansion
-Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
-License: MIT
-"""
-
-import re
-import string
-from itertools import chain, product
-from typing import Iterable, Iterator, Optional
-
-__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
-
-
-class UnbalancedBracesError(ValueError):
- pass
-
-
-alphabet = string.ascii_uppercase + string.ascii_lowercase
-
-int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
-char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
-escape_re = re.compile(r"\\(.)")
-
-
-def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
- """braceexpand(pattern) -> iterator over generated strings
-
- Returns an iterator over the strings resulting from brace expansion
- of pattern. This function implements Brace Expansion as described in
- bash(1), with the following limitations:
-
- * A pattern containing unbalanced braces will raise an
- UnbalancedBracesError exception. In bash, unbalanced braces will either
- be partly expanded or ignored.
-
- * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
- include the characters '[]^_`' between 'Z' and 'a'.
-
- When escape is True (the default), characters in pattern can be
- prefixed with a backslash to cause them not to be interpreted as
- special characters for brace expansion (such as '{', '}', ',').
- To pass through a a literal backslash, double it ('\\\\').
-
- When escape is False, backslashes in pattern have no special
- meaning and will be preserved in the output.
-
- Examples:
-
- >>> from braceexpand import braceexpand
-
- # Integer range
- >>> list(braceexpand('item{1..3}'))
- ['item1', 'item2', 'item3']
-
- # Character range
- >>> list(braceexpand('{a..c}'))
- ['a', 'b', 'c']
-
- # Sequence
- >>> list(braceexpand('index.html{,.backup}'))
- ['index.html', 'index.html.backup']
-
- # Nested patterns
- >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
- ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
-
- # Prefixing an integer with zero causes all numbers to be padded to
- # the same width.
- >>> list(braceexpand('{07..10}'))
- ['07', '08', '09', '10']
-
- # An optional increment can be specified for ranges.
- >>> list(braceexpand('{a..g..2}'))
- ['a', 'c', 'e', 'g']
-
- # Ranges can go in both directions.
- >>> list(braceexpand('{4..1}'))
- ['4', '3', '2', '1']
-
- # Numbers can be negative
- >>> list(braceexpand('{2..-1}'))
- ['2', '1', '0', '-1']
-
- # Unbalanced braces raise an exception.
- >>> list(braceexpand('{1{2,3}'))
- Traceback (most recent call last):
- ...
- UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
-
- # By default, the backslash is the escape character.
- >>> list(braceexpand(r'{1\\{2,3}'))
- ['1{2', '3']
-
- # Setting 'escape' to False disables backslash escaping.
- >>> list(braceexpand(r'\\{1,2}', escape=False))
- ['\\\\1', '\\\\2']
-
- """
- return (
- escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
- )
-
-
-def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
- start = 0
- pos = 0
- bracketdepth = 0
- items: list[Iterable[str]] = []
-
- # print 'pattern:', pattern
- while pos < len(pattern):
- if escape and pattern[pos] == "\\":
- pos += 2
- continue
- elif pattern[pos] == "{":
- if bracketdepth == 0 and pos > start:
- # print 'literal:', pattern[start:pos]
- items.append([pattern[start:pos]])
- start = pos
- bracketdepth += 1
- elif pattern[pos] == "}":
- bracketdepth -= 1
- if bracketdepth == 0:
- # print 'expression:', pattern[start+1:pos]
- expr = pattern[start + 1 : pos]
- item = parse_expression(expr, escape)
- if item is None: # not a range or sequence
- items.extend([["{"], parse_pattern(expr, escape), ["}"]])
- else:
- items.append(item)
- start = pos + 1 # skip the closing brace
- pos += 1
-
- if bracketdepth != 0: # unbalanced braces
- raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
-
- if start < pos:
- items.append([pattern[start:]])
-
- return ("".join(item) for item in product(*items))
-
-
-def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
- int_range_match = int_range_re.match(expr)
- if int_range_match:
- return make_int_range(*int_range_match.groups())
-
- char_range_match = char_range_re.match(expr)
- if char_range_match:
- return make_char_range(*char_range_match.groups())
-
- return parse_sequence(expr, escape)
-
-
-def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
- # sequence -> chain(*sequence_items)
- start = 0
- pos = 0
- bracketdepth = 0
- items: list[Iterable[str]] = []
-
- # print 'sequence:', seq
- while pos < len(seq):
- if escape and seq[pos] == "\\":
- pos += 2
- continue
- elif seq[pos] == "{":
- bracketdepth += 1
- elif seq[pos] == "}":
- bracketdepth -= 1
- elif seq[pos] == "," and bracketdepth == 0:
- items.append(parse_pattern(seq[start:pos], escape))
- start = pos + 1 # skip the comma
- pos += 1
-
- if bracketdepth != 0:
- raise UnbalancedBracesError
- if not items:
- return None
-
- # part after the last comma (may be the empty string)
- items.append(parse_pattern(seq[start:], escape))
- return chain(*items)
-
-
-def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
- if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
- padding = max(len(left), len(right))
- else:
- padding = 0
- step = (int(incr) or 1) if incr else 1
- start = int(left)
- end = int(right)
- r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
- fmt = "%0{}d".format(padding)
- return (fmt % i for i in r)
-
-
-def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
- step = (int(incr) or 1) if incr else 1
- start = alphabet.index(left)
- end = alphabet.index(right)
- if start < end:
- return alphabet[start : end + 1 : step]
- else:
- end = end or -len(alphabet)
- return alphabet[start : end - 1 : -step]
-
-
-if __name__ == "__main__":
- import doctest
- import sys
-
- failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
- if failed:
- sys.exit(1)
+"""
+Bash-style brace expansion
+Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
+License: MIT
+"""
+
+import re
+import string
+from itertools import chain, product
+from typing import Iterable, Iterator, Optional
+
+__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
+
+
+class UnbalancedBracesError(ValueError):
+ pass
+
+
+alphabet = string.ascii_uppercase + string.ascii_lowercase
+
+int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
+char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
+escape_re = re.compile(r"\\(.)")
+
+
+def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
+ """braceexpand(pattern) -> iterator over generated strings
+
+ Returns an iterator over the strings resulting from brace expansion
+ of pattern. This function implements Brace Expansion as described in
+ bash(1), with the following limitations:
+
+ * A pattern containing unbalanced braces will raise an
+ UnbalancedBracesError exception. In bash, unbalanced braces will either
+ be partly expanded or ignored.
+
+ * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
+ include the characters '[]^_`' between 'Z' and 'a'.
+
+ When escape is True (the default), characters in pattern can be
+ prefixed with a backslash to cause them not to be interpreted as
+ special characters for brace expansion (such as '{', '}', ',').
+ To pass through a a literal backslash, double it ('\\\\').
+
+ When escape is False, backslashes in pattern have no special
+ meaning and will be preserved in the output.
+
+ Examples:
+
+ >>> from braceexpand import braceexpand
+
+ # Integer range
+ >>> list(braceexpand('item{1..3}'))
+ ['item1', 'item2', 'item3']
+
+ # Character range
+ >>> list(braceexpand('{a..c}'))
+ ['a', 'b', 'c']
+
+ # Sequence
+ >>> list(braceexpand('index.html{,.backup}'))
+ ['index.html', 'index.html.backup']
+
+ # Nested patterns
+ >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
+ ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
+
+ # Prefixing an integer with zero causes all numbers to be padded to
+ # the same width.
+ >>> list(braceexpand('{07..10}'))
+ ['07', '08', '09', '10']
+
+ # An optional increment can be specified for ranges.
+ >>> list(braceexpand('{a..g..2}'))
+ ['a', 'c', 'e', 'g']
+
+ # Ranges can go in both directions.
+ >>> list(braceexpand('{4..1}'))
+ ['4', '3', '2', '1']
+
+ # Numbers can be negative
+ >>> list(braceexpand('{2..-1}'))
+ ['2', '1', '0', '-1']
+
+ # Unbalanced braces raise an exception.
+ >>> list(braceexpand('{1{2,3}'))
+ Traceback (most recent call last):
+ ...
+ UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
+
+ # By default, the backslash is the escape character.
+ >>> list(braceexpand(r'{1\\{2,3}'))
+ ['1{2', '3']
+
+ # Setting 'escape' to False disables backslash escaping.
+ >>> list(braceexpand(r'\\{1,2}', escape=False))
+ ['\\\\1', '\\\\2']
+
+ """
+ return (
+ escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
+ )
+
+
+def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
+ start = 0
+ pos = 0
+ bracketdepth = 0
+ items: list[Iterable[str]] = []
+
+ # print 'pattern:', pattern
+ while pos < len(pattern):
+ if escape and pattern[pos] == "\\":
+ pos += 2
+ continue
+ elif pattern[pos] == "{":
+ if bracketdepth == 0 and pos > start:
+ # print 'literal:', pattern[start:pos]
+ items.append([pattern[start:pos]])
+ start = pos
+ bracketdepth += 1
+ elif pattern[pos] == "}":
+ bracketdepth -= 1
+ if bracketdepth == 0:
+ # print 'expression:', pattern[start+1:pos]
+ expr = pattern[start + 1 : pos]
+ item = parse_expression(expr, escape)
+ if item is None: # not a range or sequence
+ items.extend([["{"], parse_pattern(expr, escape), ["}"]])
+ else:
+ items.append(item)
+ start = pos + 1 # skip the closing brace
+ pos += 1
+
+ if bracketdepth != 0: # unbalanced braces
+ raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
+
+ if start < pos:
+ items.append([pattern[start:]])
+
+ return ("".join(item) for item in product(*items))
+
+
+def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
+ int_range_match = int_range_re.match(expr)
+ if int_range_match:
+ return make_int_range(*int_range_match.groups())
+
+ char_range_match = char_range_re.match(expr)
+ if char_range_match:
+ return make_char_range(*char_range_match.groups())
+
+ return parse_sequence(expr, escape)
+
+
+def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
+ # sequence -> chain(*sequence_items)
+ start = 0
+ pos = 0
+ bracketdepth = 0
+ items: list[Iterable[str]] = []
+
+ # print 'sequence:', seq
+ while pos < len(seq):
+ if escape and seq[pos] == "\\":
+ pos += 2
+ continue
+ elif seq[pos] == "{":
+ bracketdepth += 1
+ elif seq[pos] == "}":
+ bracketdepth -= 1
+ elif seq[pos] == "," and bracketdepth == 0:
+ items.append(parse_pattern(seq[start:pos], escape))
+ start = pos + 1 # skip the comma
+ pos += 1
+
+ if bracketdepth != 0:
+ raise UnbalancedBracesError
+ if not items:
+ return None
+
+ # part after the last comma (may be the empty string)
+ items.append(parse_pattern(seq[start:], escape))
+ return chain(*items)
+
+
+def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
+ if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
+ padding = max(len(left), len(right))
+ else:
+ padding = 0
+ step = (int(incr) or 1) if incr else 1
+ start = int(left)
+ end = int(right)
+ r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
+ fmt = "%0{}d".format(padding)
+ return (fmt % i for i in r)
+
+
+def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
+ step = (int(incr) or 1) if incr else 1
+ start = alphabet.index(left)
+ end = alphabet.index(right)
+ if start < end:
+ return alphabet[start : end + 1 : step]
+ else:
+ end = end or -len(alphabet)
+ return alphabet[start : end - 1 : -step]
+
+
+if __name__ == "__main__":
+ import doctest
+ import sys
+
+ failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
+ if failed:
+ sys.exit(1)
diff --git a/fish_speech/utils/context.py b/fish_speech/utils/context.py
index 618c4deceaa2578cd9f0672a65d1dd55430c7dcc..f04a99290ab32f7fe5b60656075a2d03af8468d6 100644
--- a/fish_speech/utils/context.py
+++ b/fish_speech/utils/context.py
@@ -1,13 +1,13 @@
-from contextlib import nullcontext
-
-import torch
-
-
-def autocast_exclude_mps(
- device_type: str, dtype: torch.dtype
-) -> nullcontext | torch.autocast:
- return (
- nullcontext()
- if torch.backends.mps.is_available()
- else torch.autocast(device_type, dtype)
- )
+from contextlib import nullcontext
+
+import torch
+
+
+def autocast_exclude_mps(
+ device_type: str, dtype: torch.dtype
+) -> nullcontext | torch.autocast:
+ return (
+ nullcontext()
+ if torch.backends.mps.is_available()
+ else torch.autocast(device_type, dtype)
+ )
diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py
index 7516ad28b31e9836e6c02a991690d3e466d3ea62..a54c22655fc52d55db8932f7c6edabe017b965f4 100644
--- a/fish_speech/utils/file.py
+++ b/fish_speech/utils/file.py
@@ -1,16 +1,139 @@
-import os
-from pathlib import Path
-
-
-def get_latest_checkpoint(path: Path | str) -> Path | None:
- # Find the latest checkpoint
- ckpt_dir = Path(path)
-
- if ckpt_dir.exists() is False:
- return None
-
- ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
- if len(ckpts) == 0:
- return None
-
- return ckpts[-1]
+import os
+from pathlib import Path
+from typing import Union
+
+from loguru import logger
+from natsort import natsorted
+
+AUDIO_EXTENSIONS = {
+ ".mp3",
+ ".wav",
+ ".flac",
+ ".ogg",
+ ".m4a",
+ ".wma",
+ ".aac",
+ ".aiff",
+ ".aif",
+ ".aifc",
+}
+
+VIDEO_EXTENSIONS = {
+ ".mp4",
+ ".avi",
+}
+
+
+def get_latest_checkpoint(path: Path | str) -> Path | None:
+ # Find the latest checkpoint
+ ckpt_dir = Path(path)
+
+ if ckpt_dir.exists() is False:
+ return None
+
+ ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
+ if len(ckpts) == 0:
+ return None
+
+ return ckpts[-1]
+
+
+def audio_to_bytes(file_path):
+ if not file_path or not Path(file_path).exists():
+ return None
+ with open(file_path, "rb") as wav_file:
+ wav = wav_file.read()
+ return wav
+
+
+def read_ref_text(ref_text):
+ path = Path(ref_text)
+ if path.exists() and path.is_file():
+ with path.open("r", encoding="utf-8") as file:
+ return file.read()
+ return ref_text
+
+
+def list_files(
+ path: Union[Path, str],
+ extensions: set[str] = set(),
+ recursive: bool = False,
+ sort: bool = True,
+) -> list[Path]:
+ """List files in a directory.
+
+ Args:
+ path (Path): Path to the directory.
+ extensions (set, optional): Extensions to filter. Defaults to None.
+ recursive (bool, optional): Whether to search recursively. Defaults to False.
+ sort (bool, optional): Whether to sort the files. Defaults to True.
+
+ Returns:
+ list: List of files.
+ """
+
+ if isinstance(path, str):
+ path = Path(path)
+
+ if not path.exists():
+ raise FileNotFoundError(f"Directory {path} does not exist.")
+
+ files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
+
+ if sort:
+ files = natsorted(files)
+
+ return files
+
+
+def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
+ """
+ Load a Bert-VITS2 style filelist.
+ """
+
+ files = set()
+ results = []
+ count_duplicated, count_not_found = 0, 0
+
+ LANGUAGE_TO_LANGUAGES = {
+ "zh": ["zh", "en"],
+ "jp": ["jp", "en"],
+ "en": ["en"],
+ }
+
+ with open(path, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ splits = line.strip().split("|", maxsplit=3)
+ if len(splits) != 4:
+ logger.warning(f"Invalid line: {line}")
+ continue
+
+ filename, speaker, language, text = splits
+ file = Path(filename)
+ language = language.strip().lower()
+
+ if language == "ja":
+ language = "jp"
+
+ assert language in ["zh", "jp", "en"], f"Invalid language {language}"
+ languages = LANGUAGE_TO_LANGUAGES[language]
+
+ if file in files:
+ logger.warning(f"Duplicated file: {file}")
+ count_duplicated += 1
+ continue
+
+ if not file.exists():
+ logger.warning(f"File not found: {file}")
+ count_not_found += 1
+ continue
+
+ results.append((file, speaker, languages, text))
+
+ if count_duplicated > 0:
+ logger.warning(f"Total duplicated files: {count_duplicated}")
+
+ if count_not_found > 0:
+ logger.warning(f"Total files not found: {count_not_found}")
+
+ return results
diff --git a/fish_speech/utils/instantiators.py b/fish_speech/utils/instantiators.py
index d1a08fe2fd76bedc5b4ad40f8dddfa40e6951c58..f6ee463924f588a35477937fbe3c3364043bdf3e 100644
--- a/fish_speech/utils/instantiators.py
+++ b/fish_speech/utils/instantiators.py
@@ -1,50 +1,50 @@
-from typing import List
-
-import hydra
-from omegaconf import DictConfig
-from pytorch_lightning import Callback
-from pytorch_lightning.loggers import Logger
-
-from .logger import RankedLogger
-
-log = RankedLogger(__name__, rank_zero_only=True)
-
-
-def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
- """Instantiates callbacks from config."""
-
- callbacks: List[Callback] = []
-
- if not callbacks_cfg:
- log.warning("No callback configs found! Skipping..")
- return callbacks
-
- if not isinstance(callbacks_cfg, DictConfig):
- raise TypeError("Callbacks config must be a DictConfig!")
-
- for _, cb_conf in callbacks_cfg.items():
- if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
- log.info(f"Instantiating callback <{cb_conf._target_}>")
- callbacks.append(hydra.utils.instantiate(cb_conf))
-
- return callbacks
-
-
-def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
- """Instantiates loggers from config."""
-
- logger: List[Logger] = []
-
- if not logger_cfg:
- log.warning("No logger configs found! Skipping...")
- return logger
-
- if not isinstance(logger_cfg, DictConfig):
- raise TypeError("Logger config must be a DictConfig!")
-
- for _, lg_conf in logger_cfg.items():
- if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
- log.info(f"Instantiating logger <{lg_conf._target_}>")
- logger.append(hydra.utils.instantiate(lg_conf))
-
- return logger
+from typing import List
+
+import hydra
+from omegaconf import DictConfig
+from pytorch_lightning import Callback
+from pytorch_lightning.loggers import Logger
+
+from .logger import RankedLogger
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
+ """Instantiates callbacks from config."""
+
+ callbacks: List[Callback] = []
+
+ if not callbacks_cfg:
+ log.warning("No callback configs found! Skipping..")
+ return callbacks
+
+ if not isinstance(callbacks_cfg, DictConfig):
+ raise TypeError("Callbacks config must be a DictConfig!")
+
+ for _, cb_conf in callbacks_cfg.items():
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
+ callbacks.append(hydra.utils.instantiate(cb_conf))
+
+ return callbacks
+
+
+def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
+ """Instantiates loggers from config."""
+
+ logger: List[Logger] = []
+
+ if not logger_cfg:
+ log.warning("No logger configs found! Skipping...")
+ return logger
+
+ if not isinstance(logger_cfg, DictConfig):
+ raise TypeError("Logger config must be a DictConfig!")
+
+ for _, lg_conf in logger_cfg.items():
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ return logger
diff --git a/fish_speech/utils/logger.py b/fish_speech/utils/logger.py
index 5e909c26380affa14ec2e8e92ce5ecb37dc0777e..94f94f738d1d87404354d086c30ef0ad9ab04cdc 100644
--- a/fish_speech/utils/logger.py
+++ b/fish_speech/utils/logger.py
@@ -1,55 +1,55 @@
-import logging
-from typing import Mapping, Optional
-
-from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
-
-
-class RankedLogger(logging.LoggerAdapter):
- """A multi-GPU-friendly python command line logger."""
-
- def __init__(
- self,
- name: str = __name__,
- rank_zero_only: bool = True,
- extra: Optional[Mapping[str, object]] = None,
- ) -> None:
- """Initializes a multi-GPU-friendly python command line logger that logs on all processes
- with their rank prefixed in the log message.
-
- :param name: The name of the logger. Default is ``__name__``.
- :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
- :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
- """
- logger = logging.getLogger(name)
- super().__init__(logger=logger, extra=extra)
- self.rank_zero_only = rank_zero_only
-
- def log(
- self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
- ) -> None:
- """Delegate a log call to the underlying logger, after prefixing its message with the rank
- of the process it's being logged from. If `'rank'` is provided, then the log will only
- occur on that rank/process.
-
- :param level: The level to log at. Look at `logging.__init__.py` for more information.
- :param msg: The message to log.
- :param rank: The rank to log at.
- :param args: Additional args to pass to the underlying logging function.
- :param kwargs: Any additional keyword args to pass to the underlying logging function.
- """
- if self.isEnabledFor(level):
- msg, kwargs = self.process(msg, kwargs)
- current_rank = getattr(rank_zero_only, "rank", None)
- if current_rank is None:
- raise RuntimeError(
- "The `rank_zero_only.rank` needs to be set before use"
- )
- msg = rank_prefixed_message(msg, current_rank)
- if self.rank_zero_only:
- if current_rank == 0:
- self.logger.log(level, msg, *args, **kwargs)
- else:
- if rank is None:
- self.logger.log(level, msg, *args, **kwargs)
- elif current_rank == rank:
- self.logger.log(level, msg, *args, **kwargs)
+import logging
+from typing import Mapping, Optional
+
+from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
+
+
+class RankedLogger(logging.LoggerAdapter):
+ """A multi-GPU-friendly python command line logger."""
+
+ def __init__(
+ self,
+ name: str = __name__,
+ rank_zero_only: bool = True,
+ extra: Optional[Mapping[str, object]] = None,
+ ) -> None:
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
+ with their rank prefixed in the log message.
+
+ :param name: The name of the logger. Default is ``__name__``.
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
+ """
+ logger = logging.getLogger(name)
+ super().__init__(logger=logger, extra=extra)
+ self.rank_zero_only = rank_zero_only
+
+ def log(
+ self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
+ ) -> None:
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
+ occur on that rank/process.
+
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
+ :param msg: The message to log.
+ :param rank: The rank to log at.
+ :param args: Additional args to pass to the underlying logging function.
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
+ """
+ if self.isEnabledFor(level):
+ msg, kwargs = self.process(msg, kwargs)
+ current_rank = getattr(rank_zero_only, "rank", None)
+ if current_rank is None:
+ raise RuntimeError(
+ "The `rank_zero_only.rank` needs to be set before use"
+ )
+ msg = rank_prefixed_message(msg, current_rank)
+ if self.rank_zero_only:
+ if current_rank == 0:
+ self.logger.log(level, msg, *args, **kwargs)
+ else:
+ if rank is None:
+ self.logger.log(level, msg, *args, **kwargs)
+ elif current_rank == rank:
+ self.logger.log(level, msg, *args, **kwargs)
diff --git a/fish_speech/utils/logging_utils.py b/fish_speech/utils/logging_utils.py
index ead61c20564687585e945bf6e88f13d803851bd2..8e3b0a2519e12845f09e5fbe86dfccbf5b345429 100644
--- a/fish_speech/utils/logging_utils.py
+++ b/fish_speech/utils/logging_utils.py
@@ -1,48 +1,48 @@
-from lightning.pytorch.utilities import rank_zero_only
-
-from fish_speech.utils import logger as log
-
-
-@rank_zero_only
-def log_hyperparameters(object_dict: dict) -> None:
- """Controls which config parts are saved by lightning loggers.
-
- Additionally saves:
- - Number of model parameters
- """
-
- hparams = {}
-
- cfg = object_dict["cfg"]
- model = object_dict["model"]
- trainer = object_dict["trainer"]
-
- if not trainer.logger:
- log.warning("Logger not found! Skipping hyperparameter logging...")
- return
-
- hparams["model"] = cfg["model"]
-
- # save number of model parameters
- hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
- hparams["model/params/trainable"] = sum(
- p.numel() for p in model.parameters() if p.requires_grad
- )
- hparams["model/params/non_trainable"] = sum(
- p.numel() for p in model.parameters() if not p.requires_grad
- )
-
- hparams["data"] = cfg["data"]
- hparams["trainer"] = cfg["trainer"]
-
- hparams["callbacks"] = cfg.get("callbacks")
- hparams["extras"] = cfg.get("extras")
-
- hparams["task_name"] = cfg.get("task_name")
- hparams["tags"] = cfg.get("tags")
- hparams["ckpt_path"] = cfg.get("ckpt_path")
- hparams["seed"] = cfg.get("seed")
-
- # send hparams to all loggers
- for logger in trainer.loggers:
- logger.log_hyperparams(hparams)
+from lightning.pytorch.utilities import rank_zero_only
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def log_hyperparameters(object_dict: dict) -> None:
+ """Controls which config parts are saved by lightning loggers.
+
+ Additionally saves:
+ - Number of model parameters
+ """
+
+ hparams = {}
+
+ cfg = object_dict["cfg"]
+ model = object_dict["model"]
+ trainer = object_dict["trainer"]
+
+ if not trainer.logger:
+ log.warning("Logger not found! Skipping hyperparameter logging...")
+ return
+
+ hparams["model"] = cfg["model"]
+
+ # save number of model parameters
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+ hparams["model/params/trainable"] = sum(
+ p.numel() for p in model.parameters() if p.requires_grad
+ )
+ hparams["model/params/non_trainable"] = sum(
+ p.numel() for p in model.parameters() if not p.requires_grad
+ )
+
+ hparams["data"] = cfg["data"]
+ hparams["trainer"] = cfg["trainer"]
+
+ hparams["callbacks"] = cfg.get("callbacks")
+ hparams["extras"] = cfg.get("extras")
+
+ hparams["task_name"] = cfg.get("task_name")
+ hparams["tags"] = cfg.get("tags")
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
+ hparams["seed"] = cfg.get("seed")
+
+ # send hparams to all loggers
+ for logger in trainer.loggers:
+ logger.log_hyperparams(hparams)
diff --git a/fish_speech/utils/rich_utils.py b/fish_speech/utils/rich_utils.py
index 5a11ba95b6e54461e9f4faba9ca1f98de6e194ab..6a465f54d610779766d51e3d1a020a3b1517fd1f 100644
--- a/fish_speech/utils/rich_utils.py
+++ b/fish_speech/utils/rich_utils.py
@@ -1,100 +1,100 @@
-from pathlib import Path
-from typing import Sequence
-
-import rich
-import rich.syntax
-import rich.tree
-from hydra.core.hydra_config import HydraConfig
-from lightning.pytorch.utilities import rank_zero_only
-from omegaconf import DictConfig, OmegaConf, open_dict
-from rich.prompt import Prompt
-
-from fish_speech.utils import logger as log
-
-
-@rank_zero_only
-def print_config_tree(
- cfg: DictConfig,
- print_order: Sequence[str] = (
- "data",
- "model",
- "callbacks",
- "logger",
- "trainer",
- "paths",
- "extras",
- ),
- resolve: bool = False,
- save_to_file: bool = False,
-) -> None:
- """Prints content of DictConfig using Rich library and its tree structure.
-
- Args:
- cfg (DictConfig): Configuration composed by Hydra.
- print_order (Sequence[str], optional): Determines in what order config components are printed.
- resolve (bool, optional): Whether to resolve reference fields of DictConfig.
- save_to_file (bool, optional): Whether to export config to the hydra output folder.
- """ # noqa: E501
-
- style = "dim"
- tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
-
- queue = []
-
- # add fields from `print_order` to queue
- for field in print_order:
- (
- queue.append(field)
- if field in cfg
- else log.warning(
- f"Field '{field}' not found in config. "
- + f"Skipping '{field}' config printing..."
- )
- )
-
- # add all the other fields to queue (not specified in `print_order`)
- for field in cfg:
- if field not in queue:
- queue.append(field)
-
- # generate config tree from queue
- for field in queue:
- branch = tree.add(field, style=style, guide_style=style)
-
- config_group = cfg[field]
- if isinstance(config_group, DictConfig):
- branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
- else:
- branch_content = str(config_group)
-
- branch.add(rich.syntax.Syntax(branch_content, "yaml"))
-
- # print config tree
- rich.print(tree)
-
- # save config tree to file
- if save_to_file:
- with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
- rich.print(tree, file=file)
-
-
-@rank_zero_only
-def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
- """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
-
- if not cfg.get("tags"):
- if "id" in HydraConfig().cfg.hydra.job:
- raise ValueError("Specify tags before launching a multirun!")
-
- log.warning("No tags provided in config. Prompting user to input tags...")
- tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
- tags = [t.strip() for t in tags.split(",") if t != ""]
-
- with open_dict(cfg):
- cfg.tags = tags
-
- log.info(f"Tags: {cfg.tags}")
-
- if save_to_file:
- with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
- rich.print(cfg.tags, file=file)
+from pathlib import Path
+from typing import Sequence
+
+import rich
+import rich.syntax
+import rich.tree
+from hydra.core.hydra_config import HydraConfig
+from lightning.pytorch.utilities import rank_zero_only
+from omegaconf import DictConfig, OmegaConf, open_dict
+from rich.prompt import Prompt
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def print_config_tree(
+ cfg: DictConfig,
+ print_order: Sequence[str] = (
+ "data",
+ "model",
+ "callbacks",
+ "logger",
+ "trainer",
+ "paths",
+ "extras",
+ ),
+ resolve: bool = False,
+ save_to_file: bool = False,
+) -> None:
+ """Prints content of DictConfig using Rich library and its tree structure.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ print_order (Sequence[str], optional): Determines in what order config components are printed.
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
+ """ # noqa: E501
+
+ style = "dim"
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
+
+ queue = []
+
+ # add fields from `print_order` to queue
+ for field in print_order:
+ (
+ queue.append(field)
+ if field in cfg
+ else log.warning(
+ f"Field '{field}' not found in config. "
+ + f"Skipping '{field}' config printing..."
+ )
+ )
+
+ # add all the other fields to queue (not specified in `print_order`)
+ for field in cfg:
+ if field not in queue:
+ queue.append(field)
+
+ # generate config tree from queue
+ for field in queue:
+ branch = tree.add(field, style=style, guide_style=style)
+
+ config_group = cfg[field]
+ if isinstance(config_group, DictConfig):
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
+ else:
+ branch_content = str(config_group)
+
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
+
+ # print config tree
+ rich.print(tree)
+
+ # save config tree to file
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
+ rich.print(tree, file=file)
+
+
+@rank_zero_only
+def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
+ """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
+
+ if not cfg.get("tags"):
+ if "id" in HydraConfig().cfg.hydra.job:
+ raise ValueError("Specify tags before launching a multirun!")
+
+ log.warning("No tags provided in config. Prompting user to input tags...")
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
+ tags = [t.strip() for t in tags.split(",") if t != ""]
+
+ with open_dict(cfg):
+ cfg.tags = tags
+
+ log.info(f"Tags: {cfg.tags}")
+
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
+ rich.print(cfg.tags, file=file)
diff --git a/fish_speech/utils/schema.py b/fish_speech/utils/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8fe27b6c859cd13a98bc72c6eacf6c48bbe8706
--- /dev/null
+++ b/fish_speech/utils/schema.py
@@ -0,0 +1,146 @@
+import base64
+import os
+import queue
+from dataclasses import dataclass
+from typing import Literal
+
+import torch
+from pydantic import BaseModel, Field, conint, model_validator
+from pydantic.functional_validators import SkipValidation
+from typing_extensions import Annotated
+
+from fish_speech.content_sequence import TextPart, VQPart
+
+
+class ServeVQPart(BaseModel):
+ type: Literal["vq"] = "vq"
+ codes: SkipValidation[list[list[int]]]
+
+
+class ServeTextPart(BaseModel):
+ type: Literal["text"] = "text"
+ text: str
+
+
+class ServeAudioPart(BaseModel):
+ type: Literal["audio"] = "audio"
+ audio: bytes
+
+
+class ServeASRRequest(BaseModel):
+ # The audio should be an uncompressed PCM float16 audio
+ audios: list[bytes]
+ sample_rate: int = 44100
+ language: Literal["zh", "en", "ja", "auto"] = "auto"
+
+
+class ServeASRTranscription(BaseModel):
+ text: str
+ duration: float
+ huge_gap: bool
+
+
+class ServeASRSegment(BaseModel):
+ text: str
+ start: float
+ end: float
+
+
+class ServeTimedASRResponse(BaseModel):
+ text: str
+ segments: list[ServeASRSegment]
+ duration: float
+
+
+class ServeASRResponse(BaseModel):
+ transcriptions: list[ServeASRTranscription]
+
+
+class ServeRequest(BaseModel):
+ # Raw content sequence dict that we can use with ContentSequence(**content)
+ content: dict
+ max_new_tokens: int = 600
+ top_p: float = 0.7
+ repetition_penalty: float = 1.2
+ temperature: float = 0.7
+ streaming: bool = False
+ num_samples: int = 1
+ early_stop_threshold: float = 1.0
+
+
+class ServeVQGANEncodeRequest(BaseModel):
+ # The audio here should be in wav, mp3, etc
+ audios: list[bytes]
+
+
+class ServeVQGANEncodeResponse(BaseModel):
+ tokens: SkipValidation[list[list[list[int]]]]
+
+
+class ServeVQGANDecodeRequest(BaseModel):
+ tokens: SkipValidation[list[list[list[int]]]]
+
+
+class ServeVQGANDecodeResponse(BaseModel):
+ # The audio here should be in PCM float16 format
+ audios: list[bytes]
+
+
+class ServeStreamDelta(BaseModel):
+ role: Literal["system", "assistant", "user"] | None = None
+ part: ServeVQPart | ServeTextPart | None = None
+
+
+class ServeStreamResponse(BaseModel):
+ sample_id: int = 0
+ delta: ServeStreamDelta | None = None
+ finish_reason: Literal["stop", "error"] | None = None
+ stats: dict[str, int | float | str] | None = None
+
+
+class ServeReferenceAudio(BaseModel):
+ audio: bytes
+ text: str
+
+ @model_validator(mode="before")
+ def decode_audio(cls, values):
+ audio = values.get("audio")
+ if (
+ isinstance(audio, str) and len(audio) > 255
+ ): # Check if audio is a string (Base64)
+ try:
+ values["audio"] = base64.b64decode(audio)
+ except Exception as e:
+ # If the audio is not a valid base64 string, we will just ignore it and let the server handle it
+ pass
+ return values
+
+ def __repr__(self) -> str:
+ return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
+
+
+class ServeTTSRequest(BaseModel):
+ text: str
+ chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
+ # Audio format
+ format: Literal["wav", "pcm", "mp3"] = "wav"
+ # References audios for in-context learning
+ references: list[ServeReferenceAudio] = []
+ # Reference id
+ # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
+ # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
+ reference_id: str | None = None
+ seed: int | None = None
+ use_memory_cache: Literal["on", "off"] = "off"
+ # Normalize text for en & zh, this increase stability for numbers
+ normalize: bool = True
+ # not usually used below
+ streaming: bool = False
+ max_new_tokens: int = 1024
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.1
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
+
+ class Config:
+ # Allow arbitrary types for pytorch related types
+ arbitrary_types_allowed = True
diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py
index 81ce022e2e62781cc62016d70de33916c736f85d..19ea435c996f6862da8885c3e8c9e8ca2b291e32 100644
--- a/fish_speech/utils/spectrogram.py
+++ b/fish_speech/utils/spectrogram.py
@@ -1,122 +1,124 @@
-import torch
-import torchaudio.functional as F
-from torch import Tensor, nn
-from torchaudio.transforms import MelScale
-
-
-class LinearSpectrogram(nn.Module):
- def __init__(
- self,
- n_fft=2048,
- win_length=2048,
- hop_length=512,
- center=False,
- mode="pow2_sqrt",
- ):
- super().__init__()
-
- self.n_fft = n_fft
- self.win_length = win_length
- self.hop_length = hop_length
- self.center = center
- self.mode = mode
-
- self.register_buffer("window", torch.hann_window(win_length), persistent=False)
-
- def forward(self, y: Tensor) -> Tensor:
- if y.ndim == 3:
- y = y.squeeze(1)
-
- y = torch.nn.functional.pad(
- y.unsqueeze(1),
- (
- (self.win_length - self.hop_length) // 2,
- (self.win_length - self.hop_length + 1) // 2,
- ),
- mode="reflect",
- ).squeeze(1)
-
- spec = torch.stft(
- y,
- self.n_fft,
- hop_length=self.hop_length,
- win_length=self.win_length,
- window=self.window,
- center=self.center,
- pad_mode="reflect",
- normalized=False,
- onesided=True,
- return_complex=True,
- )
-
- spec = torch.view_as_real(spec)
-
- if self.mode == "pow2_sqrt":
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
-
- return spec
-
-
-class LogMelSpectrogram(nn.Module):
- def __init__(
- self,
- sample_rate=44100,
- n_fft=2048,
- win_length=2048,
- hop_length=512,
- n_mels=128,
- center=False,
- f_min=0.0,
- f_max=None,
- ):
- super().__init__()
-
- self.sample_rate = sample_rate
- self.n_fft = n_fft
- self.win_length = win_length
- self.hop_length = hop_length
- self.center = center
- self.n_mels = n_mels
- self.f_min = f_min
- self.f_max = f_max or float(sample_rate // 2)
-
- self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
-
- fb = F.melscale_fbanks(
- n_freqs=self.n_fft // 2 + 1,
- f_min=self.f_min,
- f_max=self.f_max,
- n_mels=self.n_mels,
- sample_rate=self.sample_rate,
- norm="slaney",
- mel_scale="slaney",
- )
- self.register_buffer(
- "fb",
- fb,
- persistent=False,
- )
-
- def compress(self, x: Tensor) -> Tensor:
- return torch.log(torch.clamp(x, min=1e-5))
-
- def decompress(self, x: Tensor) -> Tensor:
- return torch.exp(x)
-
- def apply_mel_scale(self, x: Tensor) -> Tensor:
- return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
-
- def forward(
- self, x: Tensor, return_linear: bool = False, sample_rate: int = None
- ) -> Tensor:
- if sample_rate is not None and sample_rate != self.sample_rate:
- x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
-
- linear = self.spectrogram(x)
- x = self.apply_mel_scale(linear)
- x = self.compress(x)
-
- if return_linear:
- return x, self.compress(linear)
-
- return x
+import torch
+import torchaudio.functional as F
+from torch import Tensor, nn
+from torchaudio.transforms import MelScale
+
+
+class LinearSpectrogram(nn.Module):
+ def __init__(
+ self,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ center=False,
+ mode="pow2_sqrt",
+ ):
+ super().__init__()
+
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.mode = mode
+ self.return_complex = True
+
+ self.register_buffer("window", torch.hann_window(win_length), persistent=False)
+
+ def forward(self, y: Tensor) -> Tensor:
+ if y.ndim == 3:
+ y = y.squeeze(1)
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (
+ (self.win_length - self.hop_length) // 2,
+ (self.win_length - self.hop_length + 1) // 2,
+ ),
+ mode="reflect",
+ ).squeeze(1)
+
+ spec = torch.stft(
+ y,
+ self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=self.window,
+ center=self.center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=self.return_complex,
+ )
+
+ if self.return_complex:
+ spec = torch.view_as_real(spec)
+
+ if self.mode == "pow2_sqrt":
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+
+ return spec
+
+
+class LogMelSpectrogram(nn.Module):
+ def __init__(
+ self,
+ sample_rate=44100,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ n_mels=128,
+ center=False,
+ f_min=0.0,
+ f_max=None,
+ ):
+ super().__init__()
+
+ self.sample_rate = sample_rate
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.n_mels = n_mels
+ self.f_min = f_min
+ self.f_max = f_max or float(sample_rate // 2)
+
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
+
+ fb = F.melscale_fbanks(
+ n_freqs=self.n_fft // 2 + 1,
+ f_min=self.f_min,
+ f_max=self.f_max,
+ n_mels=self.n_mels,
+ sample_rate=self.sample_rate,
+ norm="slaney",
+ mel_scale="slaney",
+ )
+ self.register_buffer(
+ "fb",
+ fb,
+ persistent=False,
+ )
+
+ def compress(self, x: Tensor) -> Tensor:
+ return torch.log(torch.clamp(x, min=1e-5))
+
+ def decompress(self, x: Tensor) -> Tensor:
+ return torch.exp(x)
+
+ def apply_mel_scale(self, x: Tensor) -> Tensor:
+ return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
+
+ def forward(
+ self, x: Tensor, return_linear: bool = False, sample_rate: int = None
+ ) -> Tensor:
+ if sample_rate is not None and sample_rate != self.sample_rate:
+ x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
+
+ linear = self.spectrogram(x)
+ x = self.apply_mel_scale(linear)
+ x = self.compress(x)
+
+ if return_linear:
+ return x, self.compress(linear)
+
+ return x
diff --git a/fish_speech/utils/utils.py b/fish_speech/utils/utils.py
index f5e02cd0c8f8dec1d002e7d634b00434e70873f9..5a34bdcfedff76c333f50ed8be050d0dd5a8f98a 100644
--- a/fish_speech/utils/utils.py
+++ b/fish_speech/utils/utils.py
@@ -1,136 +1,136 @@
-import random
-import warnings
-from importlib.util import find_spec
-from typing import Callable
-
-import numpy as np
-import torch
-from omegaconf import DictConfig
-
-from .logger import RankedLogger
-from .rich_utils import enforce_tags, print_config_tree
-
-log = RankedLogger(__name__, rank_zero_only=True)
-
-
-def extras(cfg: DictConfig) -> None:
- """Applies optional utilities before the task is started.
-
- Utilities:
- - Ignoring python warnings
- - Setting tags from command line
- - Rich config printing
- """
-
- # return if no `extras` config
- if not cfg.get("extras"):
- log.warning("Extras config not found! ")
- return
-
- # disable python warnings
- if cfg.extras.get("ignore_warnings"):
- log.info("Disabling python warnings! ")
- warnings.filterwarnings("ignore")
-
- # prompt user to input tags from command line if none are provided in the config
- if cfg.extras.get("enforce_tags"):
- log.info("Enforcing tags! ")
- enforce_tags(cfg, save_to_file=True)
-
- # pretty print config tree using Rich library
- if cfg.extras.get("print_config"):
- log.info("Printing config tree with Rich! ")
- print_config_tree(cfg, resolve=True, save_to_file=True)
-
-
-def task_wrapper(task_func: Callable) -> Callable:
- """Optional decorator that controls the failure behavior when executing the task function.
-
- This wrapper can be used to:
- - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- - save the exception to a `.log` file
- - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- - etc. (adjust depending on your needs)
-
- Example:
- ```
- @utils.task_wrapper
- def train(cfg: DictConfig) -> Tuple[dict, dict]:
-
- ...
-
- return metric_dict, object_dict
- ```
- """ # noqa: E501
-
- def wrap(cfg: DictConfig):
- # execute the task
- try:
- metric_dict, object_dict = task_func(cfg=cfg)
-
- # things to do if exception occurs
- except Exception as ex:
- # save exception to `.log` file
- log.exception("")
-
- # some hyperparameter combinations might be invalid or
- # cause out-of-memory errors so when using hparam search
- # plugins like Optuna, you might want to disable
- # raising the below exception to avoid multirun failure
- raise ex
-
- # things to always do after either success or exception
- finally:
- # display output dir path in terminal
- log.info(f"Output dir: {cfg.paths.run_dir}")
-
- # always close wandb run (even if exception occurs so multirun won't fail)
- if find_spec("wandb"): # check if wandb is installed
- import wandb
-
- if wandb.run:
- log.info("Closing wandb!")
- wandb.finish()
-
- return metric_dict, object_dict
-
- return wrap
-
-
-def get_metric_value(metric_dict: dict, metric_name: str) -> float:
- """Safely retrieves value of the metric logged in LightningModule."""
-
- if not metric_name:
- log.info("Metric name is None! Skipping metric value retrieval...")
- return None
-
- if metric_name not in metric_dict:
- raise Exception(
- f"Metric value not found! \n"
- "Make sure metric name logged in LightningModule is correct!\n"
- "Make sure `optimized_metric` name in `hparams_search` config is correct!"
- )
-
- metric_value = metric_dict[metric_name].item()
- log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
-
- return metric_value
-
-
-def set_seed(seed: int):
- if seed < 0:
- seed = -seed
- if seed > (1 << 31):
- seed = 1 << 31
-
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
-
- if torch.cuda.is_available():
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
-
- if torch.backends.cudnn.is_available():
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
+import random
+import warnings
+from importlib.util import find_spec
+from typing import Callable
+
+import numpy as np
+import torch
+from omegaconf import DictConfig
+
+from .logger import RankedLogger
+from .rich_utils import enforce_tags, print_config_tree
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def extras(cfg: DictConfig) -> None:
+ """Applies optional utilities before the task is started.
+
+ Utilities:
+ - Ignoring python warnings
+ - Setting tags from command line
+ - Rich config printing
+ """
+
+ # return if no `extras` config
+ if not cfg.get("extras"):
+ log.warning("Extras config not found! ")
+ return
+
+ # disable python warnings
+ if cfg.extras.get("ignore_warnings"):
+ log.info("Disabling python warnings! ")
+ warnings.filterwarnings("ignore")
+
+ # prompt user to input tags from command line if none are provided in the config
+ if cfg.extras.get("enforce_tags"):
+ log.info("Enforcing tags! ")
+ enforce_tags(cfg, save_to_file=True)
+
+ # pretty print config tree using Rich library
+ if cfg.extras.get("print_config"):
+ log.info("Printing config tree with Rich! ")
+ print_config_tree(cfg, resolve=True, save_to_file=True)
+
+
+def task_wrapper(task_func: Callable) -> Callable:
+ """Optional decorator that controls the failure behavior when executing the task function.
+
+ This wrapper can be used to:
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
+ - save the exception to a `.log` file
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
+ - etc. (adjust depending on your needs)
+
+ Example:
+ ```
+ @utils.task_wrapper
+ def train(cfg: DictConfig) -> Tuple[dict, dict]:
+
+ ...
+
+ return metric_dict, object_dict
+ ```
+ """ # noqa: E501
+
+ def wrap(cfg: DictConfig):
+ # execute the task
+ try:
+ metric_dict, object_dict = task_func(cfg=cfg)
+
+ # things to do if exception occurs
+ except Exception as ex:
+ # save exception to `.log` file
+ log.exception("")
+
+ # some hyperparameter combinations might be invalid or
+ # cause out-of-memory errors so when using hparam search
+ # plugins like Optuna, you might want to disable
+ # raising the below exception to avoid multirun failure
+ raise ex
+
+ # things to always do after either success or exception
+ finally:
+ # display output dir path in terminal
+ log.info(f"Output dir: {cfg.paths.run_dir}")
+
+ # always close wandb run (even if exception occurs so multirun won't fail)
+ if find_spec("wandb"): # check if wandb is installed
+ import wandb
+
+ if wandb.run:
+ log.info("Closing wandb!")
+ wandb.finish()
+
+ return metric_dict, object_dict
+
+ return wrap
+
+
+def get_metric_value(metric_dict: dict, metric_name: str) -> float:
+ """Safely retrieves value of the metric logged in LightningModule."""
+
+ if not metric_name:
+ log.info("Metric name is None! Skipping metric value retrieval...")
+ return None
+
+ if metric_name not in metric_dict:
+ raise Exception(
+ f"Metric value not found! \n"
+ "Make sure metric name logged in LightningModule is correct!\n"
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
+ )
+
+ metric_value = metric_dict[metric_name].item()
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
+
+ return metric_value
+
+
+def set_seed(seed: int):
+ if seed < 0:
+ seed = -seed
+ if seed > (1 << 31):
+ seed = 1 << 31
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ if torch.backends.cudnn.is_available():
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
diff --git a/requirements.txt b/requirements.txt
index d730ec85cdfe0996143386126b3f21307d1db116..b1eebf25ed100cbb9b6ccc30c434ba1fe433a3be 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-torch==2.3.0
+torch
torchaudio
transformers>=4.35.2
datasets>=2.14.5
@@ -18,8 +18,8 @@ loguru>=0.6.0
loralib>=0.1.2
natsort>=8.4.0
pyrootutils>=1.0.4
+descript-audiotools
vector_quantize_pytorch==1.14.24
-samplerate>=0.2.1
resampy>=0.4.3
spaces>=0.26.1
einx[torch]==0.2.2
@@ -31,4 +31,7 @@ soundfile
cachetools
funasr
silero-vad
-tiktoken
\ No newline at end of file
+tiktoken
+numpy
+huggingface_hub
+git+https://github.com/descriptinc/descript-audio-codec
diff --git a/tools/api_client.py b/tools/api_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee69cd28a05ce09a671f877b4a2a5d3c4dcc6d79
--- /dev/null
+++ b/tools/api_client.py
@@ -0,0 +1,225 @@
+import argparse
+import base64
+import wave
+
+import ormsgpack
+import pyaudio
+import requests
+from pydub import AudioSegment
+from pydub.playback import play
+
+from fish_speech.utils.file import audio_to_bytes, read_ref_text
+from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
+
+
+def parse_args():
+
+ parser = argparse.ArgumentParser(
+ description="Send a WAV file and text to a server and receive synthesized audio.",
+ formatter_class=argparse.RawTextHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--url",
+ "-u",
+ type=str,
+ default="http://127.0.0.1:8080/v1/tts",
+ help="URL of the server",
+ )
+ parser.add_argument(
+ "--text", "-t", type=str, required=True, help="Text to be synthesized"
+ )
+ parser.add_argument(
+ "--reference_id",
+ "-id",
+ type=str,
+ default=None,
+ help="ID of the reference model to be used for the speech\n(Local: name of folder containing audios and files)",
+ )
+ parser.add_argument(
+ "--reference_audio",
+ "-ra",
+ type=str,
+ nargs="+",
+ default=None,
+ help="Path to the audio file",
+ )
+ parser.add_argument(
+ "--reference_text",
+ "-rt",
+ type=str,
+ nargs="+",
+ default=None,
+ help="Reference text for voice synthesis",
+ )
+ parser.add_argument(
+ "--output",
+ "-o",
+ type=str,
+ default="generated_audio",
+ help="Output audio file name",
+ )
+ parser.add_argument(
+ "--play",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help="Whether to play audio after receiving data",
+ )
+ parser.add_argument(
+ "--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
+ )
+ parser.add_argument(
+ "--latency",
+ type=str,
+ default="normal",
+ choices=["normal", "balanced"],
+ help="Used in api.fish.audio/v1/tts",
+ )
+ parser.add_argument(
+ "--max_new_tokens",
+ type=int,
+ default=1024,
+ help="Maximum new tokens to generate. \n0 means no limit.",
+ )
+ parser.add_argument(
+ "--chunk_length", type=int, default=300, help="Chunk length for synthesis"
+ )
+ parser.add_argument(
+ "--top_p", type=float, default=0.8, help="Top-p sampling for synthesis"
+ )
+ parser.add_argument(
+ "--repetition_penalty",
+ type=float,
+ default=1.1,
+ help="Repetition penalty for synthesis",
+ )
+ parser.add_argument(
+ "--temperature", type=float, default=0.8, help="Temperature for sampling"
+ )
+
+ parser.add_argument(
+ "--streaming", type=bool, default=False, help="Enable streaming response"
+ )
+ parser.add_argument(
+ "--channels", type=int, default=1, help="Number of audio channels"
+ )
+ parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
+ parser.add_argument(
+ "--use_memory_cache",
+ type=str,
+ default="off",
+ choices=["on", "off"],
+ help="Cache encoded references codes in memory.\n",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="`None` means randomized inference, otherwise deterministic.\n"
+ "It can't be used for fixing a timbre.",
+ )
+ parser.add_argument(
+ "--api_key",
+ type=str,
+ default="YOUR_API_KEY",
+ help="API key for authentication",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+
+ args = parse_args()
+
+ idstr: str | None = args.reference_id
+ # priority: ref_id > [{text, audio},...]
+ if idstr is None:
+ ref_audios = args.reference_audio
+ ref_texts = args.reference_text
+ if ref_audios is None:
+ byte_audios = []
+ else:
+ byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
+ if ref_texts is None:
+ ref_texts = []
+ else:
+ ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
+ else:
+ byte_audios = []
+ ref_texts = []
+ pass # in api.py
+
+ data = {
+ "text": args.text,
+ "references": [
+ ServeReferenceAudio(
+ audio=ref_audio if ref_audio is not None else b"", text=ref_text
+ )
+ for ref_text, ref_audio in zip(ref_texts, byte_audios)
+ ],
+ "reference_id": idstr,
+ "format": args.format,
+ "max_new_tokens": args.max_new_tokens,
+ "chunk_length": args.chunk_length,
+ "top_p": args.top_p,
+ "repetition_penalty": args.repetition_penalty,
+ "temperature": args.temperature,
+ "streaming": args.streaming,
+ "use_memory_cache": args.use_memory_cache,
+ "seed": args.seed,
+ }
+
+ pydantic_data = ServeTTSRequest(**data)
+
+ response = requests.post(
+ args.url,
+ data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
+ stream=args.streaming,
+ headers={
+ "authorization": f"Bearer {args.api_key}",
+ "content-type": "application/msgpack",
+ },
+ )
+
+ if response.status_code == 200:
+ if args.streaming:
+ p = pyaudio.PyAudio()
+ audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
+ stream = p.open(
+ format=audio_format, channels=args.channels, rate=args.rate, output=True
+ )
+
+ wf = wave.open(f"{args.output}.wav", "wb")
+ wf.setnchannels(args.channels)
+ wf.setsampwidth(p.get_sample_size(audio_format))
+ wf.setframerate(args.rate)
+
+ stream_stopped_flag = False
+
+ try:
+ for chunk in response.iter_content(chunk_size=1024):
+ if chunk:
+ stream.write(chunk)
+ wf.writeframesraw(chunk)
+ else:
+ if not stream_stopped_flag:
+ stream.stop_stream()
+ stream_stopped_flag = True
+ finally:
+ stream.close()
+ p.terminate()
+ wf.close()
+ else:
+ audio_content = response.content
+ audio_path = f"{args.output}.{args.format}"
+ with open(audio_path, "wb") as audio_file:
+ audio_file.write(audio_content)
+
+ audio = AudioSegment.from_file(audio_path, format=args.format)
+ if args.play:
+ play(audio)
+ print(f"Audio has been saved to '{audio_path}'.")
+ else:
+ print(f"Request failed with status code {response.status_code}")
+ print(response.json())
diff --git a/tools/api_server.py b/tools/api_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..f918199a39e8982930e5505c11e909b6f96940bc
--- /dev/null
+++ b/tools/api_server.py
@@ -0,0 +1,122 @@
+import re
+from threading import Lock
+
+import pyrootutils
+import uvicorn
+from kui.asgi import (
+ Depends,
+ FactoryClass,
+ HTTPException,
+ HttpRoute,
+ Kui,
+ OpenAPI,
+ Routes,
+)
+from kui.cors import CORSConfig
+from kui.openapi.specification import Info
+from kui.security import bearer_auth
+from loguru import logger
+from typing_extensions import Annotated
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from tools.server.api_utils import MsgPackRequest, parse_args
+from tools.server.exception_handler import ExceptionHandler
+from tools.server.model_manager import ModelManager
+from tools.server.views import routes
+
+
+class API(ExceptionHandler):
+ def __init__(self):
+ self.args = parse_args()
+
+ def api_auth(endpoint):
+ async def verify(token: Annotated[str, Depends(bearer_auth)]):
+ if token != self.args.api_key:
+ raise HTTPException(401, None, "Invalid token")
+ return await endpoint()
+
+ async def passthrough():
+ return await endpoint()
+
+ if self.args.api_key is not None:
+ return verify
+ else:
+ return passthrough
+
+ self.routes = Routes(
+ routes, # keep existing routes
+ http_middlewares=[api_auth], # apply api_auth middleware
+ )
+
+ # OpenAPIの設定
+ self.openapi = OpenAPI(
+ Info(
+ {
+ "title": "Fish Speech API",
+ "version": "1.5.0",
+ }
+ ),
+ ).routes
+
+ # Initialize the app
+ self.app = Kui(
+ routes=self.routes + self.openapi[1:], # Remove the default route
+ exception_handlers={
+ HTTPException: self.http_exception_handler,
+ Exception: self.other_exception_handler,
+ },
+ factory_class=FactoryClass(http=MsgPackRequest),
+ cors_config=CORSConfig(),
+ )
+
+ # Add the state variables
+ self.app.state.lock = Lock()
+ self.app.state.device = self.args.device
+ self.app.state.max_text_length = self.args.max_text_length
+
+ # Associate the app with the model manager
+ self.app.on_startup(self.initialize_app)
+
+ async def initialize_app(self, app: Kui):
+ # Make the ModelManager available to the views
+ app.state.model_manager = ModelManager(
+ mode=self.args.mode,
+ device=self.args.device,
+ half=self.args.half,
+ compile=self.args.compile,
+ asr_enabled=self.args.load_asr_model,
+ llama_checkpoint_path=self.args.llama_checkpoint_path,
+ decoder_checkpoint_path=self.args.decoder_checkpoint_path,
+ decoder_config_name=self.args.decoder_config_name,
+ )
+
+ logger.info(f"Startup done, listening server at http://{self.args.listen}")
+
+
+# Each worker process created by Uvicorn has its own memory space,
+# meaning that models and variables are not shared between processes.
+# Therefore, any variables (like `llama_queue` or `decoder_model`)
+# will not be shared across workers.
+
+# Multi-threading for deep learning can cause issues, such as inconsistent
+# outputs if multiple threads access the same buffers simultaneously.
+# Instead, it's better to use multiprocessing or independent models per thread.
+
+if __name__ == "__main__":
+ api = API()
+
+ # IPv6 address format is [xxxx:xxxx::xxxx]:port
+ match = re.search(r"\[([^\]]+)\]:(\d+)$", api.args.listen)
+ if match:
+ host, port = match.groups() # IPv6
+ else:
+ host, port = api.args.listen.split(":") # IPv4
+
+ uvicorn.run(
+ api.app,
+ host=host,
+ port=int(port),
+ workers=api.args.workers,
+ log_level="info",
+ )
diff --git a/tools/download_models.py b/tools/download_models.py
index fc735d36e5e07645d46faa035cd5cd3ad88ebdb3..79e23b0daaef125de26c654b09fbb9eeb8fe43cb 100644
--- a/tools/download_models.py
+++ b/tools/download_models.py
@@ -1,55 +1,55 @@
-import os
-
-from huggingface_hub import hf_hub_download
-
-
-# Download
-def check_and_download_files(repo_id, file_list, local_dir):
- os.makedirs(local_dir, exist_ok=True)
- for file in file_list:
- file_path = os.path.join(local_dir, file)
- if not os.path.exists(file_path):
- print(f"{file} 不存在,从 Hugging Face 仓库下载...")
- hf_hub_download(
- repo_id=repo_id,
- filename=file,
- resume_download=True,
- local_dir=local_dir,
- local_dir_use_symlinks=False,
- )
- else:
- print(f"{file} 已存在,跳过下载。")
-
-
-# 1st
-repo_id_1 = "fishaudio/fish-speech-1.4"
-local_dir_1 = "./checkpoints/fish-speech-1.4"
-files_1 = [
- "model.pth",
- "README.md",
- "special_tokens_map.json",
- "tokenizer_config.json",
- "tokenizer.json",
- "config.json",
- "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
-]
-
-# 3rd
-repo_id_3 = "fishaudio/fish-speech-1"
-local_dir_3 = "./"
-files_3 = [
- "ffmpeg.exe",
- "ffprobe.exe",
-]
-
-# 4th
-repo_id_4 = "SpicyqSama007/fish-speech-packed"
-local_dir_4 = "./"
-files_4 = [
- "asr-label-win-x64.exe",
-]
-
-check_and_download_files(repo_id_1, files_1, local_dir_1)
-
-check_and_download_files(repo_id_3, files_3, local_dir_3)
-check_and_download_files(repo_id_4, files_4, local_dir_4)
+import os
+
+from huggingface_hub import hf_hub_download
+
+
+# Download
+def check_and_download_files(repo_id, file_list, local_dir):
+ os.makedirs(local_dir, exist_ok=True)
+ for file in file_list:
+ file_path = os.path.join(local_dir, file)
+ if not os.path.exists(file_path):
+ print(f"{file} 不存在,从 Hugging Face 仓库下载...")
+ hf_hub_download(
+ repo_id=repo_id,
+ filename=file,
+ resume_download=True,
+ local_dir=local_dir,
+ local_dir_use_symlinks=False,
+ )
+ else:
+ print(f"{file} 已存在,跳过下载。")
+
+
+# 1st
+repo_id_1 = "fishaudio/fish-speech-1.5"
+local_dir_1 = "./checkpoints/openaudio-s1-mini"
+files_1 = [
+ ".gitattributes",
+ "model.pth",
+ "README.md",
+ "special_tokens.json",
+ "tokenizer.tiktoken",
+ "config.json",
+ "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+]
+
+# 3rd
+repo_id_3 = "fishaudio/fish-speech-1"
+local_dir_3 = "./"
+files_3 = [
+ "ffmpeg.exe",
+ "ffprobe.exe",
+]
+
+# 4th
+repo_id_4 = "SpicyqSama007/fish-speech-packed"
+local_dir_4 = "./"
+files_4 = [
+ "asr-label-win-x64.exe",
+]
+
+check_and_download_files(repo_id_1, files_1, local_dir_1)
+
+check_and_download_files(repo_id_3, files_3, local_dir_3)
+check_and_download_files(repo_id_4, files_4, local_dir_4)
diff --git a/tools/llama/build_dataset.py b/tools/llama/build_dataset.py
index 19f8d992312fd98bfb9e1e8e200ab6b5e8153337..20e2219956adc419aba91cde5d9097fad4288315 100644
--- a/tools/llama/build_dataset.py
+++ b/tools/llama/build_dataset.py
@@ -1,169 +1,169 @@
-import itertools
-import os
-import re
-from collections import defaultdict
-from functools import partial
-from multiprocessing import Pool
-from pathlib import Path
-
-import click
-import numpy as np
-from loguru import logger
-from tqdm import tqdm
-
-from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
-from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
-from tools.file import load_filelist
-
-# To avoid CPU overload
-os.environ["MKL_NUM_THREADS"] = "1"
-os.environ["OMP_NUM_THREADS"] = "1"
-
-
-def task_generator_folder(root: Path, text_extension: str):
- files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
- files = sorted(files)
-
- grouped_files = defaultdict(list)
- for file in tqdm(files, desc=f"Grouping {root}"):
- p = str(file.parent)
- speaker = file.parent.name
-
- try:
- if isinstance(text_extension, str):
- texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
- else:
- texts = [
- file.with_suffix(ext).read_text(encoding="utf-8")
- for ext in text_extension
- ]
- except Exception as e:
- logger.error(f"Failed to read text {file}: {e}")
- continue
-
- grouped_files[p].append((speaker, file, texts))
-
- logger.info(
- f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
- )
-
- for i in grouped_files.values():
- subset = [(f, t) for _, f, t in i]
- yield i[0][0], subset, "folder"
-
-
-def task_generator_filelist(filelist):
- grouped_files = defaultdict(list)
- for filename, speaker, _, text in load_filelist(filelist):
- grouped_files[speaker].append((Path(filename), [text]))
-
- logger.info(f"Found {len(grouped_files)} groups in {filelist}")
- for speaker, values in grouped_files.items():
- yield speaker, values, "filelist"
-
-
-def run_task(task):
- name, subset, source = task
-
- # Parse the files
- sentences = []
- for file, texts in subset:
- np_file = file.with_suffix(".npy")
- if np_file.exists() is False:
- logger.warning(f"Can't find {np_file}")
- continue
-
- new_texts = []
-
- for text in texts:
- # Simple cleaning: replace { xxx } and < xxx > with space
- text = re.sub(r"\{.*?\}", " ", text)
- text = re.sub(r"<.*?>", " ", text)
- text = re.sub(r"\s+", " ", text)
- new_texts.append(text)
-
- try:
- semantics = np.load(np_file)
- except Exception as e:
- logger.error(f"Failed to parse {file}: {e}")
- continue
-
- if isinstance(semantics, np.ndarray):
- semantics = semantics.tolist()
-
- sentences.append(
- Sentence(
- texts=new_texts,
- semantics=[Semantics(values=s) for s in semantics],
- )
- )
-
- # Pack the sentences
- return pack_pb_stream(
- TextData(
- source=source,
- name=name,
- sentences=sentences,
- )
- )
-
-
-@click.command()
-@click.option(
- "--input",
- type=click.Path(path_type=Path),
- required=True,
- help="A folder containing the dataset or a filelist",
- multiple=True,
-)
-@click.option(
- "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
-)
-@click.option("--num-workers", type=int, default=16)
-@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
-@click.option(
- "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
-)
-def main(input, output, num_workers, text_extension, shard_size):
- generator_fns = []
-
- for f in input:
- assert f.exists(), f"{f} not found"
-
- if f.is_dir():
- generator_fn = task_generator_folder(f, text_extension)
- else:
- generator_fn = task_generator_filelist(f)
-
- generator_fns.append(generator_fn)
-
- generator_fn = itertools.chain(*generator_fns)
- output.mkdir(parents=True, exist_ok=True)
-
- dataset_fp = None
- tar_idx = 0
- written_size = 0
-
- with Pool(num_workers) as p:
- for result in tqdm(p.imap_unordered(run_task, generator_fn)):
- if dataset_fp is None:
- dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
-
- dataset_fp.write(result)
- written_size += len(result)
-
- if written_size > shard_size * 1024 * 1024:
- logger.info(f"Finished writing {tar_idx} shards to {output}")
- dataset_fp.close()
- dataset_fp = None
- written_size = 0
- tar_idx += 1
-
- if dataset_fp is not None:
- dataset_fp.close()
-
- logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
-
-
-if __name__ == "__main__":
- main()
+import itertools
+import os
+import re
+from collections import defaultdict
+from functools import partial
+from multiprocessing import Pool
+from pathlib import Path
+
+import click
+import numpy as np
+from loguru import logger
+from tqdm import tqdm
+
+from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
+from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
+from fish_speech.utils.file import load_filelist
+
+# To avoid CPU overload
+os.environ["MKL_NUM_THREADS"] = "1"
+os.environ["OMP_NUM_THREADS"] = "1"
+
+
+def task_generator_folder(root: Path, text_extension: str):
+ files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
+ files = sorted(files)
+
+ grouped_files = defaultdict(list)
+ for file in tqdm(files, desc=f"Grouping {root}"):
+ p = str(file.parent)
+ speaker = file.parent.name
+
+ try:
+ if isinstance(text_extension, str):
+ texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
+ else:
+ texts = [
+ file.with_suffix(ext).read_text(encoding="utf-8")
+ for ext in text_extension
+ ]
+ except Exception as e:
+ logger.error(f"Failed to read text {file}: {e}")
+ continue
+
+ grouped_files[p].append((speaker, file, texts))
+
+ logger.info(
+ f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
+ )
+
+ for i in grouped_files.values():
+ subset = [(f, t) for _, f, t in i]
+ yield i[0][0], subset, "folder"
+
+
+def task_generator_filelist(filelist):
+ grouped_files = defaultdict(list)
+ for filename, speaker, _, text in load_filelist(filelist):
+ grouped_files[speaker].append((Path(filename), [text]))
+
+ logger.info(f"Found {len(grouped_files)} groups in {filelist}")
+ for speaker, values in grouped_files.items():
+ yield speaker, values, "filelist"
+
+
+def run_task(task):
+ name, subset, source = task
+
+ # Parse the files
+ sentences = []
+ for file, texts in subset:
+ np_file = file.with_suffix(".npy")
+ if np_file.exists() is False:
+ logger.warning(f"Can't find {np_file}")
+ continue
+
+ new_texts = []
+
+ for text in texts:
+ # Simple cleaning: replace { xxx } and < xxx > with space
+ text = re.sub(r"\{.*?\}", " ", text)
+ text = re.sub(r"<.*?>", " ", text)
+ text = re.sub(r"\s+", " ", text)
+ new_texts.append(text)
+
+ try:
+ semantics = np.load(np_file)
+ except Exception as e:
+ logger.error(f"Failed to parse {file}: {e}")
+ continue
+
+ if isinstance(semantics, np.ndarray):
+ semantics = semantics.tolist()
+
+ sentences.append(
+ Sentence(
+ texts=new_texts,
+ semantics=[Semantics(values=s) for s in semantics],
+ )
+ )
+
+ # Pack the sentences
+ return pack_pb_stream(
+ TextData(
+ source=source,
+ name=name,
+ sentences=sentences,
+ )
+ )
+
+
+@click.command()
+@click.option(
+ "--input",
+ type=click.Path(path_type=Path),
+ required=True,
+ help="A folder containing the dataset or a filelist",
+ multiple=True,
+)
+@click.option(
+ "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
+)
+@click.option("--num-workers", type=int, default=16)
+@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
+@click.option(
+ "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
+)
+def main(input, output, num_workers, text_extension, shard_size):
+ generator_fns = []
+
+ for f in input:
+ assert f.exists(), f"{f} not found"
+
+ if f.is_dir():
+ generator_fn = task_generator_folder(f, text_extension)
+ else:
+ generator_fn = task_generator_filelist(f)
+
+ generator_fns.append(generator_fn)
+
+ generator_fn = itertools.chain(*generator_fns)
+ output.mkdir(parents=True, exist_ok=True)
+
+ dataset_fp = None
+ tar_idx = 0
+ written_size = 0
+
+ with Pool(num_workers) as p:
+ for result in tqdm(p.imap_unordered(run_task, generator_fn)):
+ if dataset_fp is None:
+ dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
+
+ dataset_fp.write(result)
+ written_size += len(result)
+
+ if written_size > shard_size * 1024 * 1024:
+ logger.info(f"Finished writing {tar_idx} shards to {output}")
+ dataset_fp.close()
+ dataset_fp = None
+ written_size = 0
+ tar_idx += 1
+
+ if dataset_fp is not None:
+ dataset_fp.close()
+
+ logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/eval_in_context.py b/tools/llama/eval_in_context.py
index a62f006ba443e14b2450bf9e15927a41556b0068..41d6397472e712d796a6668aa21e84835b87d899 100644
--- a/tools/llama/eval_in_context.py
+++ b/tools/llama/eval_in_context.py
@@ -1,171 +1,171 @@
-import pyrootutils
-import torch
-import torch.nn.functional as F
-from matplotlib import pyplot as plt
-from transformers import AutoTokenizer
-
-# register eval resolver and root
-pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
-
-from torch.utils.data import DataLoader
-
-from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
-from tools.llama.generate import load_model
-
-
-def smooth(
- scalars: list[float], weight: float
-) -> list[float]: # Weight between 0 and 1
- last = scalars[0] # First value in the plot (first timestep)
- smoothed = list()
- for point in scalars:
- smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
- smoothed.append(smoothed_val) # Save it
- last = smoothed_val # Anchor the last smoothed value
-
- return smoothed
-
-
-@torch.inference_mode()
-def analyze_one_model(loader, config, weight, max_length):
- device = "cuda" if torch.cuda.is_available() else "cpu"
- model = load_model(
- config,
- weight,
- device,
- torch.bfloat16,
- max_length,
- compile=False,
- )[0]
-
- current_step = 0
- model.eval()
-
- semantic_loss_sum = torch.zeros(
- max_length,
- dtype=torch.float32,
- device=device,
- )
- counter = torch.zeros(
- max_length,
- dtype=torch.long,
- device=device,
- )
-
- for batch in loader:
- batch = {k: v.to(device) for k, v in batch.items()}
-
- labels = batch["labels"]
- outputs = model(
- inp=batch["inputs"],
- key_padding_mask=batch["attention_masks"],
- )
-
- token_logits = outputs.token_logits
- codebook_logits = outputs.codebook_logits
-
- # Generate labels
- base_loss = F.cross_entropy(
- token_logits.reshape(-1, token_logits.size(-1)),
- labels[:, 0].reshape(-1),
- ignore_index=-100,
- reduction="none",
- )
-
- codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
- semantic_loss = F.cross_entropy(
- codebook_logits.reshape(-1, codebook_logits.size(-1)),
- codebook_labels.reshape(-1),
- ignore_index=-100,
- reduction="none",
- )
-
- base_loss = base_loss.reshape(labels[:, 0].shape)
- semantic_loss = semantic_loss.reshape(codebook_labels.shape)
-
- semantic_loss_frame = semantic_loss.mean(-1)
- pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
-
- for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
- semantic_loss_sum[~pad] += loss_sample[~pad]
- counter[~pad] += 1
-
- current_step += 1
- if current_step == 10:
- break
-
- semantic_loss = semantic_loss.cpu()
- counter = counter.cpu()
- xs, ys = [], []
-
- for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
- if count > 0:
- xs.append(i)
- ys.append((loss / count).item()) # for better loss visualization
-
- smoothed_ys = smooth(ys, 0.95)
-
- # Unload model
- del model
- torch.cuda.empty_cache()
-
- return xs, ys, smoothed_ys
-
-
-def main():
- tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
- max_length = 4096
-
- ds = AutoAugTextDataset(
- ["data/protos/sft/云天河"],
- tokenizer=tokenizer,
- use_speaker=False,
- interactive_prob=1.0,
- max_length=max_length,
- )
-
- loader = DataLoader(
- ds,
- batch_size=8,
- collate_fn=TextDataCollator(tokenizer, max_length=max_length),
- num_workers=0,
- shuffle=False,
- )
-
- plt.figure(figsize=(10, 5), dpi=200)
-
- plt.xlabel("Frame")
- plt.ylabel("Loss")
- plt.yscale("log")
- plt.title("Semantic Loss")
- plt.grid(which="both", axis="both")
- plt.xlim(0, max_length)
-
- tests = [
- (
- "pertrain-medium",
- "dual_ar_2_codebook_medium",
- "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
- ),
- (
- "sft-medium",
- "dual_ar_2_codebook_medium",
- "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
- ),
- (
- "sft-large",
- "dual_ar_2_codebook_large",
- "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
- ),
- ]
-
- for name, config, weight in tests:
- xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
- plt.plot(xs, smoothed_ys, label=name)
-
- plt.legend()
- plt.savefig("semantic_loss.png")
-
-
-if __name__ == "__main__":
- main()
+import pyrootutils
+import torch
+import torch.nn.functional as F
+from matplotlib import pyplot as plt
+from transformers import AutoTokenizer
+
+# register eval resolver and root
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from torch.utils.data import DataLoader
+
+from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
+from fish_speech.models.text2semantic.inference import load_model
+
+
+def smooth(
+ scalars: list[float], weight: float
+) -> list[float]: # Weight between 0 and 1
+ last = scalars[0] # First value in the plot (first timestep)
+ smoothed = list()
+ for point in scalars:
+ smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
+ smoothed.append(smoothed_val) # Save it
+ last = smoothed_val # Anchor the last smoothed value
+
+ return smoothed
+
+
+@torch.inference_mode()
+def analyze_one_model(loader, config, weight, max_length):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ model = load_model(
+ config,
+ weight,
+ device,
+ torch.bfloat16,
+ max_length,
+ compile=False,
+ )[0]
+
+ current_step = 0
+ model.eval()
+
+ semantic_loss_sum = torch.zeros(
+ max_length,
+ dtype=torch.float32,
+ device=device,
+ )
+ counter = torch.zeros(
+ max_length,
+ dtype=torch.long,
+ device=device,
+ )
+
+ for batch in loader:
+ batch = {k: v.to(device) for k, v in batch.items()}
+
+ labels = batch["labels"]
+ outputs = model(
+ inp=batch["inputs"],
+ key_padding_mask=batch["attention_masks"],
+ )
+
+ token_logits = outputs.token_logits
+ codebook_logits = outputs.codebook_logits
+
+ # Generate labels
+ base_loss = F.cross_entropy(
+ token_logits.reshape(-1, token_logits.size(-1)),
+ labels[:, 0].reshape(-1),
+ ignore_index=-100,
+ reduction="none",
+ )
+
+ codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
+ semantic_loss = F.cross_entropy(
+ codebook_logits.reshape(-1, codebook_logits.size(-1)),
+ codebook_labels.reshape(-1),
+ ignore_index=-100,
+ reduction="none",
+ )
+
+ base_loss = base_loss.reshape(labels[:, 0].shape)
+ semantic_loss = semantic_loss.reshape(codebook_labels.shape)
+
+ semantic_loss_frame = semantic_loss.mean(-1)
+ pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
+
+ for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
+ semantic_loss_sum[~pad] += loss_sample[~pad]
+ counter[~pad] += 1
+
+ current_step += 1
+ if current_step == 10:
+ break
+
+ semantic_loss = semantic_loss.cpu()
+ counter = counter.cpu()
+ xs, ys = [], []
+
+ for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
+ if count > 0:
+ xs.append(i)
+ ys.append((loss / count).item()) # for better loss visualization
+
+ smoothed_ys = smooth(ys, 0.95)
+
+ # Unload model
+ del model
+ torch.cuda.empty_cache()
+
+ return xs, ys, smoothed_ys
+
+
+def main():
+ tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
+ max_length = 4096
+
+ ds = AutoAugTextDataset(
+ ["data/protos/sft/云天河"],
+ tokenizer=tokenizer,
+ use_speaker=False,
+ interactive_prob=1.0,
+ max_length=max_length,
+ )
+
+ loader = DataLoader(
+ ds,
+ batch_size=8,
+ collate_fn=TextDataCollator(tokenizer, max_length=max_length),
+ num_workers=0,
+ shuffle=False,
+ )
+
+ plt.figure(figsize=(10, 5), dpi=200)
+
+ plt.xlabel("Frame")
+ plt.ylabel("Loss")
+ plt.yscale("log")
+ plt.title("Semantic Loss")
+ plt.grid(which="both", axis="both")
+ plt.xlim(0, max_length)
+
+ tests = [
+ (
+ "pertrain-medium",
+ "dual_ar_2_codebook_medium",
+ "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
+ ),
+ (
+ "sft-medium",
+ "dual_ar_2_codebook_medium",
+ "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
+ ),
+ (
+ "sft-large",
+ "dual_ar_2_codebook_large",
+ "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
+ ),
+ ]
+
+ for name, config, weight in tests:
+ xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
+ plt.plot(xs, smoothed_ys, label=name)
+
+ plt.legend()
+ plt.savefig("semantic_loss.png")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/merge_lora.py b/tools/llama/merge_lora.py
index 510dd4fc25ac1a51593cb189ea79d8a7f429548a..1080ff5668f6712a7bd51d28476369c49806775d 100644
--- a/tools/llama/merge_lora.py
+++ b/tools/llama/merge_lora.py
@@ -1,95 +1,96 @@
-import shutil
-from copy import deepcopy
-from pathlib import Path
-
-import click
-import hydra
-import torch
-from hydra import compose, initialize
-from hydra.utils import instantiate
-from loguru import logger
-
-from fish_speech.models.text2semantic.llama import BaseTransformer
-from fish_speech.models.text2semantic.lora import get_merged_state_dict
-
-
-@click.command()
-@click.option("--lora-config", type=str, default="r_8_alpha_16")
-@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
-@click.option("--lora-weight", type=str, required=True)
-@click.option("--output", type=str, required=True)
-def merge(lora_config, base_weight, lora_weight, output):
- output = Path(output)
- logger.info(
- f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
- )
-
- with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
- cfg = compose(config_name=lora_config)
-
- lora_config = instantiate(cfg)
- logger.info(f"Loaded lora model with config {lora_config}")
-
- llama_model = BaseTransformer.from_pretrained(
- path=base_weight,
- load_weights=True,
- lora_config=lora_config,
- )
- logger.info(f"Loaded llama model")
-
- llama_state_dict = llama_model.state_dict()
- llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
- llama_state_dict_copy = deepcopy(llama_state_dict)
- lora_state_dict = torch.load(lora_weight, map_location="cpu")
-
- if "state_dict" in llama_state_dict:
- llama_state_dict = llama_state_dict["state_dict"]
-
- if "state_dict" in lora_state_dict:
- lora_state_dict = lora_state_dict["state_dict"]
-
- # remove prefix model.
- if any(k.startswith("model.") for k in llama_state_dict.keys()):
- llama_state_dict = {
- k.replace("model.", ""): v
- for k, v in llama_state_dict.items()
- if k.startswith("model.")
- }
- if any(k.startswith("model.") for k in lora_state_dict.keys()):
- lora_state_dict = {
- k.replace("model.", ""): v
- for k, v in lora_state_dict.items()
- if k.startswith("model.")
- }
-
- logger.info(f"Found {len(llama_state_dict)} keys in llama model")
- logger.info(f"Found {len(lora_state_dict)} keys in lora model")
-
- merged_state_dict = llama_state_dict | lora_state_dict
- llama_model.load_state_dict(merged_state_dict, strict=True)
- logger.info(f"Merged model loaded")
-
- # Trigger eval mode to merge lora
- llama_model.eval()
- llama_model.save_pretrained(output, drop_lora=True)
- logger.info(f"Saved merged model to {output}, validating")
-
- new_state_dict = torch.load(output / "model.pth", map_location="cpu")
- original_keys = set(llama_state_dict_copy.keys())
- merged_keys = set(new_state_dict.keys())
-
- assert original_keys == merged_keys, "Keys should be same"
-
- for key in original_keys:
- diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
- if diff_l1 != 0:
- break
- else:
- logger.error("Merged model is same as the original model")
- exit(1)
-
- logger.info("Merged model is different from the original model, check passed")
-
-
-if __name__ == "__main__":
- merge()
+import shutil
+from copy import deepcopy
+from pathlib import Path
+
+import click
+import hydra
+import torch
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+
+from fish_speech.models.text2semantic.llama import BaseTransformer
+from fish_speech.models.text2semantic.lora import get_merged_state_dict
+
+
+@click.command()
+@click.option("--lora-config", type=str, default="r_8_alpha_16")
+@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
+@click.option("--lora-weight", type=str, required=True)
+@click.option("--output", type=str, required=True)
+def merge(lora_config, base_weight, lora_weight, output):
+ output = Path(output)
+ logger.info(
+ f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
+ )
+
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
+ cfg = compose(config_name=lora_config)
+
+ lora_config = instantiate(cfg)
+ logger.info(f"Loaded lora model with config {lora_config}")
+
+ llama_model = BaseTransformer.from_pretrained(
+ path=base_weight,
+ load_weights=True,
+ lora_config=lora_config,
+ )
+ logger.info(f"Loaded llama model")
+
+ llama_state_dict = llama_model.state_dict()
+ llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
+ llama_state_dict_copy = deepcopy(llama_state_dict)
+ lora_state_dict = torch.load(lora_weight, map_location="cpu", weights_only=False)
+
+ if "state_dict" in llama_state_dict:
+ llama_state_dict = llama_state_dict["state_dict"]
+
+ if "state_dict" in lora_state_dict:
+ lora_state_dict = lora_state_dict["state_dict"]
+
+ # remove prefix model.
+ if any(k.startswith("model.") for k in llama_state_dict.keys()):
+ llama_state_dict = {
+ k.replace("model.", ""): v
+ for k, v in llama_state_dict.items()
+ if k.startswith("model.")
+ }
+ if any(k.startswith("model.") for k in lora_state_dict.keys()):
+ lora_state_dict = {
+ k.replace("model.", ""): v
+ for k, v in lora_state_dict.items()
+ if k.startswith("model.")
+ }
+
+ logger.info(f"Found {len(llama_state_dict)} keys in llama model")
+ logger.info(f"Found {len(lora_state_dict)} keys in lora model")
+
+ merged_state_dict = llama_state_dict | lora_state_dict
+ llama_model.load_state_dict(merged_state_dict, strict=True)
+ logger.info(f"Merged model loaded")
+
+ # Trigger eval mode to merge lora
+ llama_model.eval()
+ llama_model.save_pretrained(output, drop_lora=True)
+ logger.info(f"Saved merged model to {output}, validating")
+
+ new_state_dict = torch.load(output / "model.pth", map_location="cpu")
+ original_keys = set(llama_state_dict_copy.keys())
+
+ tolerance = 1e-5
+ for key in original_keys:
+ diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
+ if diff_l1 > tolerance:
+ logger.info(f"Significant difference found in key: {key}")
+ break
+
+ if diff_l1 <= tolerance:
+ logger.warning(
+ "Merged model seems identical to the original model. Further validation might be needed."
+ )
+ else:
+ logger.info("Merged model is different from the original model, check passed")
+
+
+if __name__ == "__main__":
+ merge()
diff --git a/tools/llama/quantize.py b/tools/llama/quantize.py
index 2a8124da4d3634bb57e8fa5368228d51bb712f77..7cd29891829432a263fcd6a6d58bd247a7d7d587 100644
--- a/tools/llama/quantize.py
+++ b/tools/llama/quantize.py
@@ -1,497 +1,497 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-import datetime
-import shutil
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-import time
-from pathlib import Path
-
-import click
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from fish_speech.models.text2semantic.llama import find_multiple
-from tools.llama.generate import load_model
-
-##### Quantization Primitives ######
-
-
-def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
- # assumes symmetric quantization
- # assumes axis == 0
- # assumes dense memory format
- # TODO(future): relax ^ as needed
-
- # default setup for affine quantization of activations
- eps = torch.finfo(torch.float32).eps
-
- # get min and max
- min_val, max_val = torch.aminmax(x, dim=1)
-
- # calculate scales and zero_points based on min and max
- # reference: https://fburl.com/code/srbiybme
- min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
- max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
- device = min_val_neg.device
-
- # reference: https://fburl.com/code/4wll53rk
- max_val_pos = torch.max(-min_val_neg, max_val_pos)
- scales = max_val_pos / (float(quant_max - quant_min) / 2)
- # ensure scales is the same dtype as the original tensor
- scales = torch.clamp(scales, min=eps).to(x.dtype)
- zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
-
- # quantize based on qmin/qmax/scales/zp
- # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
- x_div = x / scales.unsqueeze(-1)
- x_round = torch.round(x_div)
- x_zp = x_round + zero_points.unsqueeze(-1)
- quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
-
- return quant, scales, zero_points
-
-
-def get_group_qparams(w, n_bit=4, groupsize=128):
- # needed for GPTQ with padding
- if groupsize > w.shape[-1]:
- groupsize = w.shape[-1]
- assert groupsize > 1
- assert w.shape[-1] % groupsize == 0
- assert w.dim() == 2
-
- to_quant = w.reshape(-1, groupsize)
- assert torch.isnan(to_quant).sum() == 0
-
- max_val = to_quant.amax(dim=1, keepdim=True)
- min_val = to_quant.amin(dim=1, keepdim=True)
- max_int = 2**n_bit - 1
- scales = (max_val - min_val).clamp(min=1e-6) / max_int
- zeros = min_val + scales * (2 ** (n_bit - 1))
- return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
- torch.bfloat16
- ).reshape(w.shape[0], -1)
-
-
-def pack_scales_and_zeros(scales, zeros):
- assert scales.shape == zeros.shape
- assert scales.dtype == torch.bfloat16
- assert zeros.dtype == torch.bfloat16
- return (
- torch.cat(
- [
- scales.reshape(scales.size(0), scales.size(1), 1),
- zeros.reshape(zeros.size(0), zeros.size(1), 1),
- ],
- 2,
- )
- .transpose(0, 1)
- .contiguous()
- )
-
-
-def unpack_scales_and_zeros(scales_and_zeros):
- assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
- assert scales_and_zeros.dtype == torch.float
- return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
-
-
-def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
- assert groupsize > 1
- # needed for GPTQ single column quantize
- if groupsize > w.shape[-1] and scales.shape[-1] == 1:
- groupsize = w.shape[-1]
-
- assert w.shape[-1] % groupsize == 0
- assert w.dim() == 2
-
- to_quant = w.reshape(-1, groupsize)
- assert torch.isnan(to_quant).sum() == 0
-
- scales = scales.reshape(-1, 1)
- zeros = zeros.reshape(-1, 1)
- min_val = zeros - scales * (2 ** (n_bit - 1))
- max_int = 2**n_bit - 1
- min_int = 0
- w_int32 = (
- to_quant.sub(min_val)
- .div(scales)
- .round()
- .clamp_(min_int, max_int)
- .to(torch.int32)
- .reshape_as(w)
- )
-
- return w_int32
-
-
-def group_quantize_tensor(w, n_bit=4, groupsize=128):
- scales, zeros = get_group_qparams(w, n_bit, groupsize)
- w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
- scales_and_zeros = pack_scales_and_zeros(scales, zeros)
- return w_int32, scales_and_zeros
-
-
-def group_dequantize_tensor_from_qparams(
- w_int32, scales, zeros, n_bit=4, groupsize=128
-):
- assert groupsize > 1
- # needed for GPTQ single column dequantize
- if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
- groupsize = w_int32.shape[-1]
- assert w_int32.shape[-1] % groupsize == 0
- assert w_int32.dim() == 2
-
- w_int32_grouped = w_int32.reshape(-1, groupsize)
- scales = scales.reshape(-1, 1)
- zeros = zeros.reshape(-1, 1)
-
- w_dq = (
- w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
- )
- return w_dq
-
-
-def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
- scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
- return group_dequantize_tensor_from_qparams(
- w_int32, scales, zeros, n_bit, groupsize
- )
-
-
-class QuantHandler:
- def __init__(self, mod):
- self.mod = mod
-
- def create_quantized_state_dict(self) -> "StateDict":
- pass
-
- def convert_for_runtime(self) -> "nn.Module":
- pass
-
-
-##### Weight-only int8 per-channel quantized code ######
-
-
-def replace_linear_weight_only_int8_per_channel(module):
- for name, child in module.named_children():
- if isinstance(child, nn.Linear):
- setattr(
- module,
- name,
- WeightOnlyInt8Linear(child.in_features, child.out_features),
- )
- else:
- replace_linear_weight_only_int8_per_channel(child)
-
-
-class WeightOnlyInt8QuantHandler:
- def __init__(self, mod):
- self.mod = mod
-
- @torch.no_grad()
- def create_quantized_state_dict(self):
- cur_state_dict = self.mod.state_dict()
- for fqn, mod in self.mod.named_modules():
- if isinstance(mod, torch.nn.Linear):
- int8_weight, scales, _ = dynamically_quantize_per_channel(
- mod.weight.float(), -128, 127, torch.int8
- )
- cur_state_dict[f"{fqn}.weight"] = int8_weight
- cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
-
- return cur_state_dict
-
- def convert_for_runtime(self):
- replace_linear_weight_only_int8_per_channel(self.mod)
- return self.mod
-
-
-class WeightOnlyInt8Linear(torch.nn.Module):
- __constants__ = ["in_features", "out_features"]
- in_features: int
- out_features: int
- weight: torch.Tensor
-
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- self.in_features = in_features
- self.out_features = out_features
- self.register_buffer(
- "weight", torch.empty((out_features, in_features), dtype=torch.int8)
- )
- self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
-
- def forward(self, input: torch.Tensor) -> torch.Tensor:
- return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
-
-
-##### weight only int4 per channel groupwise quantized code ######
-
-
-def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
- weight_int32, scales_and_zeros = group_quantize_tensor(
- weight_bf16, n_bit=4, groupsize=groupsize
- )
- weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
- weight_int32, inner_k_tiles
- )
- return weight_int4pack, scales_and_zeros
-
-
-def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
- origin_x_size = x.size()
- x = x.reshape(-1, origin_x_size[-1])
- c = torch.ops.aten._weight_int4pack_mm(
- x, weight_int4pack, groupsize, scales_and_zeros
- )
- new_shape = origin_x_size[:-1] + (out_features,)
- c = c.reshape(new_shape)
- return c
-
-
-def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
- return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
-
-
-def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
- for name, child in module.named_children():
- if isinstance(child, nn.Linear):
- if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
- setattr(
- module,
- name,
- WeightOnlyInt4Linear(
- child.in_features,
- child.out_features,
- bias=False,
- groupsize=groupsize,
- inner_k_tiles=inner_k_tiles,
- padding=False,
- ),
- )
- elif padding:
- setattr(
- module,
- name,
- WeightOnlyInt4Linear(
- child.in_features,
- child.out_features,
- bias=False,
- groupsize=groupsize,
- inner_k_tiles=inner_k_tiles,
- padding=True,
- ),
- )
- else:
- replace_linear_int4(child, groupsize, inner_k_tiles, padding)
-
-
-class WeightOnlyInt4QuantHandler:
- def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
- self.mod = mod
- self.groupsize = groupsize
- self.inner_k_tiles = inner_k_tiles
- self.padding = padding
- assert groupsize in [32, 64, 128, 256]
- assert inner_k_tiles in [2, 4, 8]
-
- @torch.no_grad()
- def create_quantized_state_dict(self):
- cur_state_dict = self.mod.state_dict()
- for fqn, mod in self.mod.named_modules():
- if isinstance(mod, torch.nn.Linear):
- assert not mod.bias
- out_features = mod.out_features
- in_features = mod.in_features
- assert out_features % 8 == 0, "require out_features % 8 == 0"
- print(f"linear: {fqn}, in={in_features}, out={out_features}")
-
- weight = mod.weight.data
- if not _check_linear_int4_k(
- in_features, self.groupsize, self.inner_k_tiles
- ):
- if self.padding:
- import torch.nn.functional as F
-
- print(
- f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
- )
- padded_in_features = find_multiple(in_features, 1024)
- weight = F.pad(
- weight, pad=(0, padded_in_features - in_features)
- )
- else:
- print(
- f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
- + "and that groupsize and inner_k_tiles*16 evenly divide into it"
- )
- continue
- (
- weight_int4pack,
- scales_and_zeros,
- ) = prepare_int4_weight_and_scales_and_zeros(
- weight.to(torch.bfloat16).to("cuda"),
- self.groupsize,
- self.inner_k_tiles,
- )
- cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
- cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
-
- return cur_state_dict
-
- def convert_for_runtime(self):
- replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
- return self.mod
-
-
-class WeightOnlyInt4Linear(torch.nn.Module):
- __constants__ = ["in_features", "out_features"]
- in_features: int
- out_features: int
- weight: torch.Tensor
-
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias=True,
- device=None,
- dtype=None,
- groupsize: int = 128,
- inner_k_tiles: int = 8,
- padding: bool = True,
- ) -> None:
- super().__init__()
- self.padding = padding
- if padding:
- self.origin_in_features = in_features
- in_features = find_multiple(in_features, 1024)
-
- self.in_features = in_features
- self.out_features = out_features
- assert not bias, "require bias=False"
- self.groupsize = groupsize
- self.inner_k_tiles = inner_k_tiles
-
- assert out_features % 8 == 0, "require out_features % 8 == 0"
- assert (
- in_features % (inner_k_tiles * 16) == 0
- ), "require in_features % (innerKTiles * 16) == 0"
- self.register_buffer(
- "weight",
- torch.empty(
- (
- out_features // 8,
- in_features // (inner_k_tiles * 16),
- 32,
- inner_k_tiles // 2,
- ),
- dtype=torch.int32,
- ),
- )
- self.register_buffer(
- "scales_and_zeros",
- torch.empty(
- (in_features // groupsize, out_features, 2), dtype=torch.bfloat16
- ),
- )
-
- def forward(self, input: torch.Tensor) -> torch.Tensor:
- input = input.to(torch.bfloat16)
- if self.padding:
- import torch.nn.functional as F
-
- input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
- return linear_forward_int4(
- input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
- )
-
-
-def generate_folder_name():
- now = datetime.datetime.now()
- folder_name = now.strftime("%Y%m%d_%H%M%S")
- return folder_name
-
-
-@click.command()
-@click.option(
- "--checkpoint-path",
- type=click.Path(path_type=Path, exists=True),
- default="checkpoints/fish-speech-1.4",
-)
-@click.option(
- "--mode", type=str, default="int8", help="type of quantization to perform"
-)
-@click.option(
- "--groupsize", type=int, default=128, help="Group size for int4 quantization."
-)
-@click.option("--timestamp", type=str, default="None", help="When to do quantization")
-def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
-
- device = "cpu"
- precision = torch.bfloat16
-
- print("Loading model ...")
- t0 = time.time()
-
- model, _ = load_model(
- checkpoint_path=checkpoint_path,
- device=device,
- precision=precision,
- compile=False,
- )
- vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
- now = timestamp if timestamp != "None" else generate_folder_name()
-
- if mode == "int8":
- print(
- "Quantizing model weights for int8 weight-only symmetric per-channel quantization"
- )
- quant_handler = WeightOnlyInt8QuantHandler(model)
- quantized_state_dict = quant_handler.create_quantized_state_dict()
-
- dir_name = checkpoint_path
- dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
- shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
- if (dst_name / vq_model).exists():
- (dst_name / vq_model).unlink()
- quantize_path = dst_name / "model.pth"
-
- elif mode == "int4":
- print(
- "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
- )
- quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
- quantized_state_dict = quant_handler.create_quantized_state_dict()
-
- dir_name = checkpoint_path
- dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
- shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
- if (dst_name / vq_model).exists():
- (dst_name / vq_model).unlink()
- quantize_path = dst_name / "model.pth"
-
- else:
- raise ValueError(
- f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
- )
-
- print(f"Writing quantized weights to {quantize_path}")
- quantize_path.unlink(missing_ok=True) # remove existing file if one already there
- torch.save(quantized_state_dict, quantize_path)
- print(f"Quantization complete took {time.time() - t0:.02f} seconds")
-
-
-if __name__ == "__main__":
- quantize()
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+import datetime
+import shutil
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import time
+from pathlib import Path
+
+import click
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fish_speech.models.text2semantic.inference import load_model
+from fish_speech.models.text2semantic.llama import find_multiple
+
+##### Quantization Primitives ######
+
+
+def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
+ # assumes symmetric quantization
+ # assumes axis == 0
+ # assumes dense memory format
+ # TODO(future): relax ^ as needed
+
+ # default setup for affine quantization of activations
+ eps = torch.finfo(torch.float32).eps
+
+ # get min and max
+ min_val, max_val = torch.aminmax(x, dim=1)
+
+ # calculate scales and zero_points based on min and max
+ # reference: https://fburl.com/code/srbiybme
+ min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+ max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+ device = min_val_neg.device
+
+ # reference: https://fburl.com/code/4wll53rk
+ max_val_pos = torch.max(-min_val_neg, max_val_pos)
+ scales = max_val_pos / (float(quant_max - quant_min) / 2)
+ # ensure scales is the same dtype as the original tensor
+ scales = torch.clamp(scales, min=eps).to(x.dtype)
+ zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+
+ # quantize based on qmin/qmax/scales/zp
+ # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
+ x_div = x / scales.unsqueeze(-1)
+ x_round = torch.round(x_div)
+ x_zp = x_round + zero_points.unsqueeze(-1)
+ quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
+
+ return quant, scales, zero_points
+
+
+def get_group_qparams(w, n_bit=4, groupsize=128):
+ # needed for GPTQ with padding
+ if groupsize > w.shape[-1]:
+ groupsize = w.shape[-1]
+ assert groupsize > 1
+ assert w.shape[-1] % groupsize == 0
+ assert w.dim() == 2
+
+ to_quant = w.reshape(-1, groupsize)
+ assert torch.isnan(to_quant).sum() == 0
+
+ max_val = to_quant.amax(dim=1, keepdim=True)
+ min_val = to_quant.amin(dim=1, keepdim=True)
+ max_int = 2**n_bit - 1
+ scales = (max_val - min_val).clamp(min=1e-6) / max_int
+ zeros = min_val + scales * (2 ** (n_bit - 1))
+ return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
+ torch.bfloat16
+ ).reshape(w.shape[0], -1)
+
+
+def pack_scales_and_zeros(scales, zeros):
+ assert scales.shape == zeros.shape
+ assert scales.dtype == torch.bfloat16
+ assert zeros.dtype == torch.bfloat16
+ return (
+ torch.cat(
+ [
+ scales.reshape(scales.size(0), scales.size(1), 1),
+ zeros.reshape(zeros.size(0), zeros.size(1), 1),
+ ],
+ 2,
+ )
+ .transpose(0, 1)
+ .contiguous()
+ )
+
+
+def unpack_scales_and_zeros(scales_and_zeros):
+ assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
+ assert scales_and_zeros.dtype == torch.float
+ return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
+
+
+def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
+ assert groupsize > 1
+ # needed for GPTQ single column quantize
+ if groupsize > w.shape[-1] and scales.shape[-1] == 1:
+ groupsize = w.shape[-1]
+
+ assert w.shape[-1] % groupsize == 0
+ assert w.dim() == 2
+
+ to_quant = w.reshape(-1, groupsize)
+ assert torch.isnan(to_quant).sum() == 0
+
+ scales = scales.reshape(-1, 1)
+ zeros = zeros.reshape(-1, 1)
+ min_val = zeros - scales * (2 ** (n_bit - 1))
+ max_int = 2**n_bit - 1
+ min_int = 0
+ w_int32 = (
+ to_quant.sub(min_val)
+ .div(scales)
+ .round()
+ .clamp_(min_int, max_int)
+ .to(torch.int32)
+ .reshape_as(w)
+ )
+
+ return w_int32
+
+
+def group_quantize_tensor(w, n_bit=4, groupsize=128):
+ scales, zeros = get_group_qparams(w, n_bit, groupsize)
+ w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
+ scales_and_zeros = pack_scales_and_zeros(scales, zeros)
+ return w_int32, scales_and_zeros
+
+
+def group_dequantize_tensor_from_qparams(
+ w_int32, scales, zeros, n_bit=4, groupsize=128
+):
+ assert groupsize > 1
+ # needed for GPTQ single column dequantize
+ if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
+ groupsize = w_int32.shape[-1]
+ assert w_int32.shape[-1] % groupsize == 0
+ assert w_int32.dim() == 2
+
+ w_int32_grouped = w_int32.reshape(-1, groupsize)
+ scales = scales.reshape(-1, 1)
+ zeros = zeros.reshape(-1, 1)
+
+ w_dq = (
+ w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
+ )
+ return w_dq
+
+
+def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
+ scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
+ return group_dequantize_tensor_from_qparams(
+ w_int32, scales, zeros, n_bit, groupsize
+ )
+
+
+class QuantHandler:
+ def __init__(self, mod):
+ self.mod = mod
+
+ def create_quantized_state_dict(self) -> "StateDict":
+ pass
+
+ def convert_for_runtime(self) -> "nn.Module":
+ pass
+
+
+##### Weight-only int8 per-channel quantized code ######
+
+
+def replace_linear_weight_only_int8_per_channel(module):
+ for name, child in module.named_children():
+ if isinstance(child, nn.Linear):
+ setattr(
+ module,
+ name,
+ WeightOnlyInt8Linear(child.in_features, child.out_features),
+ )
+ else:
+ replace_linear_weight_only_int8_per_channel(child)
+
+
+class WeightOnlyInt8QuantHandler:
+ def __init__(self, mod):
+ self.mod = mod
+
+ @torch.no_grad()
+ def create_quantized_state_dict(self):
+ cur_state_dict = self.mod.state_dict()
+ for fqn, mod in self.mod.named_modules():
+ if isinstance(mod, torch.nn.Linear):
+ int8_weight, scales, _ = dynamically_quantize_per_channel(
+ mod.weight.float(), -128, 127, torch.int8
+ )
+ cur_state_dict[f"{fqn}.weight"] = int8_weight
+ cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
+
+ return cur_state_dict
+
+ def convert_for_runtime(self):
+ replace_linear_weight_only_int8_per_channel(self.mod)
+ return self.mod
+
+
+class WeightOnlyInt8Linear(torch.nn.Module):
+ __constants__ = ["in_features", "out_features"]
+ in_features: int
+ out_features: int
+ weight: torch.Tensor
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.register_buffer(
+ "weight", torch.empty((out_features, in_features), dtype=torch.int8)
+ )
+ self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
+
+
+##### weight only int4 per channel groupwise quantized code ######
+
+
+def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
+ weight_int32, scales_and_zeros = group_quantize_tensor(
+ weight_bf16, n_bit=4, groupsize=groupsize
+ )
+ weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
+ weight_int32, inner_k_tiles
+ )
+ return weight_int4pack, scales_and_zeros
+
+
+def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
+ origin_x_size = x.size()
+ x = x.reshape(-1, origin_x_size[-1])
+ c = torch.ops.aten._weight_int4pack_mm(
+ x, weight_int4pack, groupsize, scales_and_zeros
+ )
+ new_shape = origin_x_size[:-1] + (out_features,)
+ c = c.reshape(new_shape)
+ return c
+
+
+def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
+ return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
+
+
+def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
+ for name, child in module.named_children():
+ if isinstance(child, nn.Linear):
+ if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
+ setattr(
+ module,
+ name,
+ WeightOnlyInt4Linear(
+ child.in_features,
+ child.out_features,
+ bias=False,
+ groupsize=groupsize,
+ inner_k_tiles=inner_k_tiles,
+ padding=False,
+ ),
+ )
+ elif padding:
+ setattr(
+ module,
+ name,
+ WeightOnlyInt4Linear(
+ child.in_features,
+ child.out_features,
+ bias=False,
+ groupsize=groupsize,
+ inner_k_tiles=inner_k_tiles,
+ padding=True,
+ ),
+ )
+ else:
+ replace_linear_int4(child, groupsize, inner_k_tiles, padding)
+
+
+class WeightOnlyInt4QuantHandler:
+ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
+ self.mod = mod
+ self.groupsize = groupsize
+ self.inner_k_tiles = inner_k_tiles
+ self.padding = padding
+ assert groupsize in [32, 64, 128, 256]
+ assert inner_k_tiles in [2, 4, 8]
+
+ @torch.no_grad()
+ def create_quantized_state_dict(self):
+ cur_state_dict = self.mod.state_dict()
+ for fqn, mod in self.mod.named_modules():
+ if isinstance(mod, torch.nn.Linear):
+ assert not mod.bias
+ out_features = mod.out_features
+ in_features = mod.in_features
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
+ print(f"linear: {fqn}, in={in_features}, out={out_features}")
+
+ weight = mod.weight.data
+ if not _check_linear_int4_k(
+ in_features, self.groupsize, self.inner_k_tiles
+ ):
+ if self.padding:
+ import torch.nn.functional as F
+
+ print(
+ f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
+ )
+ padded_in_features = find_multiple(in_features, 1024)
+ weight = F.pad(
+ weight, pad=(0, padded_in_features - in_features)
+ )
+ else:
+ print(
+ f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+ + "and that groupsize and inner_k_tiles*16 evenly divide into it"
+ )
+ continue
+ (
+ weight_int4pack,
+ scales_and_zeros,
+ ) = prepare_int4_weight_and_scales_and_zeros(
+ weight.to(torch.bfloat16).to("cuda"),
+ self.groupsize,
+ self.inner_k_tiles,
+ )
+ cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
+ cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
+
+ return cur_state_dict
+
+ def convert_for_runtime(self):
+ replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
+ return self.mod
+
+
+class WeightOnlyInt4Linear(torch.nn.Module):
+ __constants__ = ["in_features", "out_features"]
+ in_features: int
+ out_features: int
+ weight: torch.Tensor
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias=True,
+ device=None,
+ dtype=None,
+ groupsize: int = 128,
+ inner_k_tiles: int = 8,
+ padding: bool = True,
+ ) -> None:
+ super().__init__()
+ self.padding = padding
+ if padding:
+ self.origin_in_features = in_features
+ in_features = find_multiple(in_features, 1024)
+
+ self.in_features = in_features
+ self.out_features = out_features
+ assert not bias, "require bias=False"
+ self.groupsize = groupsize
+ self.inner_k_tiles = inner_k_tiles
+
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
+ assert (
+ in_features % (inner_k_tiles * 16) == 0
+ ), "require in_features % (innerKTiles * 16) == 0"
+ self.register_buffer(
+ "weight",
+ torch.empty(
+ (
+ out_features // 8,
+ in_features // (inner_k_tiles * 16),
+ 32,
+ inner_k_tiles // 2,
+ ),
+ dtype=torch.int32,
+ ),
+ )
+ self.register_buffer(
+ "scales_and_zeros",
+ torch.empty(
+ (in_features // groupsize, out_features, 2), dtype=torch.bfloat16
+ ),
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ input = input.to(torch.bfloat16)
+ if self.padding:
+ import torch.nn.functional as F
+
+ input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
+ return linear_forward_int4(
+ input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
+ )
+
+
+def generate_folder_name():
+ now = datetime.datetime.now()
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
+ return folder_name
+
+
+@click.command()
+@click.option(
+ "--checkpoint-path",
+ type=click.Path(path_type=Path, exists=True),
+ default="checkpoints/fish-speech-1.4",
+)
+@click.option(
+ "--mode", type=str, default="int8", help="type of quantization to perform"
+)
+@click.option(
+ "--groupsize", type=int, default=128, help="Group size for int4 quantization."
+)
+@click.option("--timestamp", type=str, default="None", help="When to do quantization")
+def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
+
+ device = "cpu"
+ precision = torch.bfloat16
+
+ print("Loading model ...")
+ t0 = time.time()
+
+ model, _ = load_model(
+ checkpoint_path=checkpoint_path,
+ device=device,
+ precision=precision,
+ compile=False,
+ )
+ vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+ now = timestamp if timestamp != "None" else generate_folder_name()
+
+ if mode == "int8":
+ print(
+ "Quantizing model weights for int8 weight-only symmetric per-channel quantization"
+ )
+ quant_handler = WeightOnlyInt8QuantHandler(model)
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
+
+ dir_name = checkpoint_path
+ dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
+ if (dst_name / vq_model).exists():
+ (dst_name / vq_model).unlink()
+ quantize_path = dst_name / "model.pth"
+
+ elif mode == "int4":
+ print(
+ "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
+ )
+ quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
+
+ dir_name = checkpoint_path
+ dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
+ if (dst_name / vq_model).exists():
+ (dst_name / vq_model).unlink()
+ quantize_path = dst_name / "model.pth"
+
+ else:
+ raise ValueError(
+ f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
+ )
+
+ print(f"Writing quantized weights to {quantize_path}")
+ quantize_path.unlink(missing_ok=True) # remove existing file if one already there
+ torch.save(quantized_state_dict, quantize_path)
+ print(f"Quantization complete took {time.time() - t0:.02f} seconds")
+
+
+if __name__ == "__main__":
+ quantize()
diff --git a/tools/run_webui.py b/tools/run_webui.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bc004ef51c9443cf100796485c0e83b856738be
--- /dev/null
+++ b/tools/run_webui.py
@@ -0,0 +1,104 @@
+import os
+from argparse import ArgumentParser
+from pathlib import Path
+
+import pyrootutils
+import torch
+from loguru import logger
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from fish_speech.inference_engine import TTSInferenceEngine
+from fish_speech.models.dac.inference import load_model as load_decoder_model
+from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
+from fish_speech.utils.schema import ServeTTSRequest
+from tools.webui import build_app
+from tools.webui.inference import get_inference_wrapper
+
+# Make einx happy
+os.environ["EINX_FILTER_TRACEBACK"] = "false"
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--llama-checkpoint-path",
+ type=Path,
+ default="checkpoints/openaudio-s1-mini",
+ )
+ parser.add_argument(
+ "--decoder-checkpoint-path",
+ type=Path,
+ default="checkpoints/openaudio-s1-mini/codec.pth",
+ )
+ parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--half", action="store_true")
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--max-gradio-length", type=int, default=0)
+ parser.add_argument("--theme", type=str, default="light")
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ args.precision = torch.half if args.half else torch.bfloat16
+
+ # Check if MPS or CUDA is available
+ if torch.backends.mps.is_available():
+ args.device = "mps"
+ logger.info("mps is available, running on mps.")
+ elif not torch.cuda.is_available():
+ logger.info("CUDA is not available, running on CPU.")
+ args.device = "cpu"
+
+ logger.info("Loading Llama model...")
+ llama_queue = launch_thread_safe_queue(
+ checkpoint_path=args.llama_checkpoint_path,
+ device=args.device,
+ precision=args.precision,
+ compile=args.compile,
+ )
+
+ logger.info("Loading VQ-GAN model...")
+ decoder_model = load_decoder_model(
+ config_name=args.decoder_config_name,
+ checkpoint_path=args.decoder_checkpoint_path,
+ device=args.device,
+ )
+
+ logger.info("Decoder model loaded, warming up...")
+
+ # Create the inference engine
+ inference_engine = TTSInferenceEngine(
+ llama_queue=llama_queue,
+ decoder_model=decoder_model,
+ compile=args.compile,
+ precision=args.precision,
+ )
+
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
+ list(
+ inference_engine.inference(
+ ServeTTSRequest(
+ text="Hello world.",
+ references=[],
+ reference_id=None,
+ max_new_tokens=1024,
+ chunk_length=200,
+ top_p=0.7,
+ repetition_penalty=1.5,
+ temperature=0.7,
+ format="wav",
+ )
+ )
+ )
+
+ logger.info("Warming up done, launching the web UI...")
+
+ # Get the inference function with the immutable arguments
+ inference_fct = get_inference_wrapper(inference_engine)
+
+ app = build_app(inference_fct, args.theme)
+ app.launch(show_api=True)
diff --git a/tools/server/api_utils.py b/tools/server/api_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfe990041bceac39a6b85d97fc019673af3d9013
--- /dev/null
+++ b/tools/server/api_utils.py
@@ -0,0 +1,76 @@
+from argparse import ArgumentParser
+from http import HTTPStatus
+from typing import Annotated, Any
+
+import ormsgpack
+from baize.datastructures import ContentType
+from kui.asgi import HTTPException, HttpRequest
+
+from fish_speech.inference_engine import TTSInferenceEngine
+from fish_speech.utils.schema import ServeTTSRequest
+from tools.server.inference import inference_wrapper as inference
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
+ parser.add_argument("--load-asr-model", action="store_true")
+ parser.add_argument(
+ "--llama-checkpoint-path",
+ type=str,
+ default="checkpoints/openaudio-s1-mini",
+ )
+ parser.add_argument(
+ "--decoder-checkpoint-path",
+ type=str,
+ default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ )
+ parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--half", action="store_true")
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--max-text-length", type=int, default=0)
+ parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
+ parser.add_argument("--workers", type=int, default=1)
+ parser.add_argument("--api-key", type=str, default=None)
+
+ return parser.parse_args()
+
+
+class MsgPackRequest(HttpRequest):
+ async def data(
+ self,
+ ) -> Annotated[
+ Any, ContentType("application/msgpack"), ContentType("application/json")
+ ]:
+ if self.content_type == "application/msgpack":
+ return ormsgpack.unpackb(await self.body)
+
+ elif self.content_type == "application/json":
+ return await self.json
+
+ raise HTTPException(
+ HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
+ headers={"Accept": "application/msgpack, application/json"},
+ )
+
+
+async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
+ for chunk in inference(req, engine):
+ if isinstance(chunk, bytes):
+ yield chunk
+
+
+async def buffer_to_async_generator(buffer):
+ yield buffer
+
+
+def get_content_type(audio_format):
+ if audio_format == "wav":
+ return "audio/wav"
+ elif audio_format == "flac":
+ return "audio/flac"
+ elif audio_format == "mp3":
+ return "audio/mpeg"
+ else:
+ return "application/octet-stream"
diff --git a/tools/server/exception_handler.py b/tools/server/exception_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..07d595fabb7af4e00a1fb67a78b466fea0c2c0f4
--- /dev/null
+++ b/tools/server/exception_handler.py
@@ -0,0 +1,27 @@
+import traceback
+from http import HTTPStatus
+
+from kui.asgi import HTTPException, JSONResponse
+
+
+class ExceptionHandler:
+
+ async def http_exception_handler(self, exc: HTTPException):
+ return JSONResponse(
+ dict(
+ statusCode=exc.status_code,
+ message=exc.content,
+ error=HTTPStatus(exc.status_code).phrase,
+ ),
+ exc.status_code,
+ exc.headers,
+ )
+
+ async def other_exception_handler(self, exc: Exception):
+ traceback.print_exc()
+
+ status = HTTPStatus.INTERNAL_SERVER_ERROR
+ return JSONResponse(
+ dict(statusCode=status, message=str(exc), error=status.phrase),
+ status,
+ )
diff --git a/tools/server/inference.py b/tools/server/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..060e24b968d7c7b04f8f99f302cff0773b1ecdd8
--- /dev/null
+++ b/tools/server/inference.py
@@ -0,0 +1,45 @@
+from http import HTTPStatus
+
+import numpy as np
+from kui.asgi import HTTPException
+
+from fish_speech.inference_engine import TTSInferenceEngine
+from fish_speech.utils.schema import ServeTTSRequest
+
+AMPLITUDE = 32768 # Needs an explaination
+
+
+def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
+ """
+ Wrapper for the inference function.
+ Used in the API server.
+ """
+ count = 0
+ for result in engine.inference(req):
+ match result.code:
+ case "header":
+ if isinstance(result.audio, tuple):
+ yield result.audio[1]
+
+ case "error":
+ raise HTTPException(
+ HTTPStatus.INTERNAL_SERVER_ERROR,
+ content=str(result.error),
+ )
+
+ case "segment":
+ count += 1
+ if isinstance(result.audio, tuple):
+ yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
+
+ case "final":
+ count += 1
+ if isinstance(result.audio, tuple):
+ yield result.audio[1]
+ return None # Stop the generator
+
+ if count == 0:
+ raise HTTPException(
+ HTTPStatus.INTERNAL_SERVER_ERROR,
+ content="No audio generated, please check the input text.",
+ )
diff --git a/tools/server/model_manager.py b/tools/server/model_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..77540a21e5559d38da0fcc92f68445a118f210af
--- /dev/null
+++ b/tools/server/model_manager.py
@@ -0,0 +1,122 @@
+import torch
+from funasr import AutoModel
+from loguru import logger
+
+from fish_speech.inference_engine import TTSInferenceEngine
+from fish_speech.models.dac.inference import load_model as load_decoder_model
+from fish_speech.models.text2semantic.inference import (
+ launch_thread_safe_queue,
+ launch_thread_safe_queue_agent,
+)
+from fish_speech.utils.schema import ServeTTSRequest
+from tools.server.inference import inference_wrapper as inference
+
+ASR_MODEL_NAME = "iic/SenseVoiceSmall"
+
+
+class ModelManager:
+ def __init__(
+ self,
+ mode: str,
+ device: str,
+ half: bool,
+ compile: bool,
+ asr_enabled: bool,
+ llama_checkpoint_path: str,
+ decoder_checkpoint_path: str,
+ decoder_config_name: str,
+ ) -> None:
+
+ self.mode = mode
+ self.device = device
+ self.half = half
+ self.compile = compile
+
+ self.precision = torch.half if half else torch.bfloat16
+
+ # Check if MPS or CUDA is available
+ if torch.backends.mps.is_available():
+ self.device = "mps"
+ logger.info("mps is available, running on mps.")
+ elif not torch.cuda.is_available():
+ self.device = "cpu"
+ logger.info("CUDA is not available, running on CPU.")
+
+ # Load the ASR model if enabled
+ if asr_enabled:
+ self.load_asr_model(self.device)
+
+ # Load the TTS models
+ self.load_llama_model(
+ llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
+ )
+ self.load_decoder_model(
+ decoder_config_name, decoder_checkpoint_path, self.device
+ )
+ self.tts_inference_engine = TTSInferenceEngine(
+ llama_queue=self.llama_queue,
+ decoder_model=self.decoder_model,
+ precision=self.precision,
+ compile=self.compile,
+ )
+
+ # Warm up the models
+ if self.mode == "tts":
+ self.warm_up(self.tts_inference_engine)
+
+ def load_asr_model(self, device, hub="ms") -> None:
+ self.asr_model = AutoModel(
+ model=ASR_MODEL_NAME,
+ device=device,
+ disable_pbar=True,
+ hub=hub,
+ )
+ logger.info("ASR model loaded.")
+
+ def load_llama_model(
+ self, checkpoint_path, device, precision, compile, mode
+ ) -> None:
+
+ if mode == "tts":
+ self.llama_queue = launch_thread_safe_queue(
+ checkpoint_path=checkpoint_path,
+ device=device,
+ precision=precision,
+ compile=compile,
+ )
+ elif mode == "agent":
+ self.llama_queue, self.tokenizer, self.config = (
+ launch_thread_safe_queue_agent(
+ checkpoint_path=checkpoint_path,
+ device=device,
+ precision=precision,
+ compile=compile,
+ )
+ )
+ else:
+ raise ValueError(f"Invalid mode: {mode}")
+
+ logger.info("LLAMA model loaded.")
+
+ def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
+ self.decoder_model = load_decoder_model(
+ config_name=config_name,
+ checkpoint_path=checkpoint_path,
+ device=device,
+ )
+ logger.info("Decoder model loaded.")
+
+ def warm_up(self, tts_inference_engine) -> None:
+ request = ServeTTSRequest(
+ text="Hello world.",
+ references=[],
+ reference_id=None,
+ max_new_tokens=1024,
+ chunk_length=200,
+ top_p=0.7,
+ repetition_penalty=1.2,
+ temperature=0.7,
+ format="wav",
+ )
+ list(inference(request, tts_inference_engine))
+ logger.info("Models warmed up.")
diff --git a/tools/server/model_utils.py b/tools/server/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..822e2f984b2fd6b6ec8d8ed4d40d5f73a05c1039
--- /dev/null
+++ b/tools/server/model_utils.py
@@ -0,0 +1,129 @@
+import io
+import re
+
+import librosa
+import torch
+import torchaudio
+from cachetools import LRUCache, cached
+
+CACHE_MAXSIZE = 10000
+MICRO_BATCH_SIZE = 8
+ASR_SAMPLE_RATE = 16000
+HUGE_GAP_THRESHOLD = 4000
+
+
+@torch.no_grad()
+@torch.autocast(device_type="cuda", dtype=torch.half)
+def batch_encode(model, audios_list: list[bytes]):
+ audios: list[torch.Tensor] = [
+ (
+ torch.from_numpy(
+ librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
+ )[None]
+ if isinstance(audio, bytes)
+ else audio
+ )
+ for audio in audios_list
+ ]
+
+ lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
+ max_length = lengths.max().item()
+
+ print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
+
+ padded = torch.stack(
+ [
+ torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1])))
+ for audio in audios
+ ]
+ ).to(model.device)
+
+ features, feature_lengths = model.encode(padded, audio_lengths=lengths)
+ features, feature_lengths = features.cpu(), feature_lengths.cpu()
+
+ return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
+
+
+@cached(
+ cache=LRUCache(maxsize=CACHE_MAXSIZE),
+ key=lambda model, audios: (model.device, tuple(audios)),
+)
+def cached_vqgan_batch_encode(model, audios: list[bytes]):
+ return batch_encode(model, audios)
+
+
+@torch.no_grad()
+@torch.autocast(device_type="cuda", dtype=torch.half)
+def batch_vqgan_decode(model, features):
+ lengths = torch.tensor(
+ [feature.shape[-1] for feature in features], device=model.device
+ )
+ max_length = lengths.max().item()
+ padded = torch.stack(
+ [
+ torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
+ for feature in features
+ ]
+ ).to(model.device)
+
+ # If bs too large, we do micro batch decode
+ audios, audio_lengths = [], []
+ for i in range(0, padded.shape[0], MICRO_BATCH_SIZE):
+ audio, audio_length = model.decode(
+ padded[i : i + MICRO_BATCH_SIZE],
+ feature_lengths=lengths[i : i + MICRO_BATCH_SIZE],
+ )
+ audios.append(audio)
+ audio_lengths.append(audio_length)
+ audios = torch.cat(audios, dim=0)
+ audio_lengths = torch.cat(audio_lengths, dim=0)
+ audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
+
+ return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
+
+
+@torch.no_grad()
+def batch_asr(model, lock, audios, sr, language="auto"):
+ resampled_audios = []
+ for audio in audios:
+ audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE)
+ assert audio.ndim == 1
+ resampled_audios.append(audio)
+
+ with lock:
+ res = model.generate(
+ input=resampled_audios,
+ batch_size=len(resampled_audios),
+ language=language,
+ use_itn=True,
+ )
+
+ results = []
+ for r, audio in zip(res, audios):
+ text = r["text"]
+ text = re.sub(r"<\|.*?\|>", "", text)
+ duration = len(audio) / sr * 1000
+ huge_gap = False
+
+ if "timestamp" in r and len(r["timestamp"]) > 2:
+ for timestamp_a, timestamp_b in zip(
+ r["timestamp"][:-1], r["timestamp"][1:]
+ ):
+ # If there is a gap of more than 4 seconds, we consider it as a huge gap
+ if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD:
+ huge_gap = True
+ break
+
+ # Doesn't make sense to have a huge gap at the end
+ if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD:
+ huge_gap = True
+
+ results.append(
+ {
+ "text": text,
+ "duration": duration,
+ "huge_gap": huge_gap,
+ }
+ )
+
+ return results
diff --git a/tools/server/views.py b/tools/server/views.py
new file mode 100644
index 0000000000000000000000000000000000000000..905f6adbe0d96a12fac0170fdda5362bdb688eb8
--- /dev/null
+++ b/tools/server/views.py
@@ -0,0 +1,213 @@
+import io
+import os
+import time
+from http import HTTPStatus
+
+import numpy as np
+import ormsgpack
+import soundfile as sf
+import torch
+from kui.asgi import (
+ Body,
+ HTTPException,
+ HttpView,
+ JSONResponse,
+ Routes,
+ StreamResponse,
+ request,
+)
+from loguru import logger
+from typing_extensions import Annotated
+
+from fish_speech.utils.schema import (
+ ServeASRRequest,
+ ServeASRResponse,
+ ServeChatRequest,
+ ServeTTSRequest,
+ ServeVQGANDecodeRequest,
+ ServeVQGANDecodeResponse,
+ ServeVQGANEncodeRequest,
+ ServeVQGANEncodeResponse,
+)
+from tools.server.agent import get_response_generator
+from tools.server.api_utils import (
+ buffer_to_async_generator,
+ get_content_type,
+ inference_async,
+)
+from tools.server.inference import inference_wrapper as inference
+from tools.server.model_manager import ModelManager
+from tools.server.model_utils import (
+ batch_asr,
+ batch_vqgan_decode,
+ cached_vqgan_batch_encode,
+)
+
+MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
+
+routes = Routes()
+
+
+@routes.http("/v1/health")
+class Health(HttpView):
+ @classmethod
+ async def get(cls):
+ return JSONResponse({"status": "ok"})
+
+ @classmethod
+ async def post(cls):
+ return JSONResponse({"status": "ok"})
+
+
+@routes.http.post("/v1/vqgan/encode")
+async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
+ # Get the model from the app
+ model_manager: ModelManager = request.app.state.model_manager
+ decoder_model = model_manager.decoder_model
+
+ # Encode the audio
+ start_time = time.time()
+ tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
+ logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
+
+ # Return the response
+ return ormsgpack.packb(
+ ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+ )
+
+
+@routes.http.post("/v1/vqgan/decode")
+async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
+ # Get the model from the app
+ model_manager: ModelManager = request.app.state.model_manager
+ decoder_model = model_manager.decoder_model
+
+ # Decode the audio
+ tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
+ start_time = time.time()
+ audios = batch_vqgan_decode(decoder_model, tokens)
+ logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
+ audios = [audio.astype(np.float16).tobytes() for audio in audios]
+
+ # Return the response
+ return ormsgpack.packb(
+ ServeVQGANDecodeResponse(audios=audios),
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+ )
+
+
+@routes.http.post("/v1/asr")
+async def asr(req: Annotated[ServeASRRequest, Body(exclusive=True)]):
+ # Get the model from the app
+ model_manager: ModelManager = request.app.state.model_manager
+ asr_model = model_manager.asr_model
+ lock = request.app.state.lock
+
+ # Perform ASR
+ start_time = time.time()
+ audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
+ audios = [torch.from_numpy(audio).float() for audio in audios]
+
+ if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
+ raise HTTPException(status_code=400, content="Audio length is too long")
+
+ transcriptions = batch_asr(
+ asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
+ )
+ logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
+
+ # Return the response
+ return ormsgpack.packb(
+ ServeASRResponse(transcriptions=transcriptions),
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+ )
+
+
+@routes.http.post("/v1/tts")
+async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
+ # Get the model from the app
+ app_state = request.app.state
+ model_manager: ModelManager = app_state.model_manager
+ engine = model_manager.tts_inference_engine
+ sample_rate = engine.decoder_model.spec_transform.sample_rate
+
+ # Check if the text is too long
+ if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
+ raise HTTPException(
+ HTTPStatus.BAD_REQUEST,
+ content=f"Text is too long, max length is {app_state.max_text_length}",
+ )
+
+ # Check if streaming is enabled
+ if req.streaming and req.format != "wav":
+ raise HTTPException(
+ HTTPStatus.BAD_REQUEST,
+ content="Streaming only supports WAV format",
+ )
+
+ # Perform TTS
+ if req.streaming:
+ return StreamResponse(
+ iterable=inference_async(req, engine),
+ headers={
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
+ },
+ content_type=get_content_type(req.format),
+ )
+ else:
+ fake_audios = next(inference(req, engine))
+ buffer = io.BytesIO()
+ sf.write(
+ buffer,
+ fake_audios,
+ sample_rate,
+ format=req.format,
+ )
+
+ return StreamResponse(
+ iterable=buffer_to_async_generator(buffer.getvalue()),
+ headers={
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
+ },
+ content_type=get_content_type(req.format),
+ )
+
+
+@routes.http.post("/v1/chat")
+async def chat(req: Annotated[ServeChatRequest, Body(exclusive=True)]):
+ # Check that the number of samples requested is correct
+ if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
+ raise HTTPException(
+ HTTPStatus.BAD_REQUEST,
+ content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
+ )
+
+ # Get the type of content provided
+ content_type = request.headers.get("Content-Type", "application/json")
+ json_mode = "application/json" in content_type
+
+ # Get the models from the app
+ model_manager: ModelManager = request.app.state.model_manager
+ llama_queue = model_manager.llama_queue
+ tokenizer = model_manager.tokenizer
+ config = model_manager.config
+
+ device = request.app.state.device
+
+ # Get the response generators
+ response_generator = get_response_generator(
+ llama_queue, tokenizer, config, req, device, json_mode
+ )
+
+ # Return the response in the correct format
+ if req.streaming is False:
+ result = response_generator()
+ if json_mode:
+ return JSONResponse(result.model_dump())
+ else:
+ return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
+
+ return StreamResponse(
+ iterable=response_generator(), content_type="text/event-stream"
+ )
diff --git a/tools/smart_pad.py b/tools/smart_pad.py
index 6ce8c4d8dd0fd63e8039822adb4424a38d8e80fe..2d52323affb189659e6418b2f6e69eb1c4268864 100644
--- a/tools/smart_pad.py
+++ b/tools/smart_pad.py
@@ -1,60 +1,60 @@
-import random
-from multiprocessing import Pool
-from pathlib import Path
-
-import click
-import librosa
-import torch.nn.functional as F
-import torchaudio
-from tqdm import tqdm
-
-from tools.file import AUDIO_EXTENSIONS, list_files
-
-threshold = 10 ** (-50 / 20.0)
-
-
-def process(file):
- waveform, sample_rate = torchaudio.load(str(file), backend="sox")
- if waveform.size(0) > 1:
- waveform = waveform.mean(dim=0, keepdim=True)
-
- loudness = librosa.feature.rms(
- y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
- )[0]
-
- for i in range(len(loudness) - 1, 0, -1):
- if loudness[i] > threshold:
- break
-
- end_silent_time = (len(loudness) - i) * 512 / sample_rate
-
- if end_silent_time <= 0.3:
- random_time = random.uniform(0.3, 0.7) - end_silent_time
- waveform = F.pad(
- waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
- )
-
- for i in range(len(loudness)):
- if loudness[i] > threshold:
- break
-
- start_silent_time = i * 512 / sample_rate
-
- if start_silent_time > 0.02:
- waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
-
- torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
-
-
-@click.command()
-@click.argument("source", type=Path)
-@click.option("--num-workers", type=int, default=12)
-def main(source, num_workers):
- files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
-
- with Pool(num_workers) as p:
- list(tqdm(p.imap_unordered(process, files), total=len(files)))
-
-
-if __name__ == "__main__":
- main()
+import random
+from multiprocessing import Pool
+from pathlib import Path
+
+import click
+import librosa
+import torch.nn.functional as F
+import torchaudio
+from tqdm import tqdm
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+threshold = 10 ** (-50 / 20.0)
+
+
+def process(file):
+ waveform, sample_rate = torchaudio.load(str(file), backend="sox")
+ if waveform.size(0) > 1:
+ waveform = waveform.mean(dim=0, keepdim=True)
+
+ loudness = librosa.feature.rms(
+ y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
+ )[0]
+
+ for i in range(len(loudness) - 1, 0, -1):
+ if loudness[i] > threshold:
+ break
+
+ end_silent_time = (len(loudness) - i) * 512 / sample_rate
+
+ if end_silent_time <= 0.3:
+ random_time = random.uniform(0.3, 0.7) - end_silent_time
+ waveform = F.pad(
+ waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
+ )
+
+ for i in range(len(loudness)):
+ if loudness[i] > threshold:
+ break
+
+ start_silent_time = i * 512 / sample_rate
+
+ if start_silent_time > 0.02:
+ waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
+
+ torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
+
+
+@click.command()
+@click.argument("source", type=Path)
+@click.option("--num-workers", type=int, default=12)
+def main(source, num_workers):
+ files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
+
+ with Pool(num_workers) as p:
+ list(tqdm(p.imap_unordered(process, files), total=len(files)))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py
index bdb914c2a3b13d13a79882be3e7b33027ce5109a..977afdf3260994ef31d2189a5973a2628b26c0c5 100644
--- a/tools/vqgan/create_train_split.py
+++ b/tools/vqgan/create_train_split.py
@@ -1,83 +1,83 @@
-import math
-from pathlib import Path
-from random import Random
-
-import click
-from loguru import logger
-from pydub import AudioSegment
-from tqdm import tqdm
-
-from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
-
-
-@click.command()
-@click.argument("root", type=click.Path(exists=True, path_type=Path))
-@click.option("--val-ratio", type=float, default=None)
-@click.option("--val-count", type=int, default=None)
-@click.option("--filelist", default=None, type=Path)
-@click.option("--min-duration", default=None, type=float)
-@click.option("--max-duration", default=None, type=float)
-def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
- if filelist:
- files = [i[0] for i in load_filelist(filelist)]
- else:
- files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
-
- if min_duration is None and max_duration is None:
- filtered_files = list(map(str, [file.relative_to(root) for file in files]))
- else:
- filtered_files = []
- for file in tqdm(files):
- try:
- audio = AudioSegment.from_file(str(file))
- duration = len(audio) / 1000.0
-
- if min_duration is not None and duration < min_duration:
- logger.info(
- f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
- )
- continue
-
- if max_duration is not None and duration > max_duration:
- logger.info(
- f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
- )
- continue
-
- filtered_files.append(str(file.relative_to(root)))
- except Exception as e:
- logger.info(f"Error processing {file}: {e}")
-
- logger.info(
- f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
- )
-
- Random(42).shuffle(filtered_files)
-
- if val_count is None and val_ratio is None:
- logger.info("Validation ratio and count not specified, using min(20%, 100)")
- val_size = min(100, math.ceil(len(filtered_files) * 0.2))
- elif val_count is not None and val_ratio is not None:
- logger.error("Cannot specify both val_count and val_ratio")
- return
- elif val_count is not None:
- if val_count < 1 or val_count > len(filtered_files):
- logger.error("val_count must be between 1 and number of files")
- return
- val_size = val_count
- else:
- val_size = math.ceil(len(filtered_files) * val_ratio)
-
- logger.info(f"Using {val_size} files for validation")
-
- with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
- f.write("\n".join(filtered_files[val_size:]))
-
- with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
- f.write("\n".join(filtered_files[:val_size]))
-
- logger.info("Done")
-
-
-if __name__ == "__main__":
- main()
+import math
+from pathlib import Path
+from random import Random
+
+import click
+from loguru import logger
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+
+@click.command()
+@click.argument("root", type=click.Path(exists=True, path_type=Path))
+@click.option("--val-ratio", type=float, default=None)
+@click.option("--val-count", type=int, default=None)
+@click.option("--filelist", default=None, type=Path)
+@click.option("--min-duration", default=None, type=float)
+@click.option("--max-duration", default=None, type=float)
+def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
+
+ if min_duration is None and max_duration is None:
+ filtered_files = list(map(str, [file.relative_to(root) for file in files]))
+ else:
+ filtered_files = []
+ for file in tqdm(files):
+ try:
+ audio = AudioSegment.from_file(str(file))
+ duration = len(audio) / 1000.0
+
+ if min_duration is not None and duration < min_duration:
+ logger.info(
+ f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
+ )
+ continue
+
+ if max_duration is not None and duration > max_duration:
+ logger.info(
+ f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
+ )
+ continue
+
+ filtered_files.append(str(file.relative_to(root)))
+ except Exception as e:
+ logger.info(f"Error processing {file}: {e}")
+
+ logger.info(
+ f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
+ )
+
+ Random(42).shuffle(filtered_files)
+
+ if val_count is None and val_ratio is None:
+ logger.info("Validation ratio and count not specified, using min(20%, 100)")
+ val_size = min(100, math.ceil(len(filtered_files) * 0.2))
+ elif val_count is not None and val_ratio is not None:
+ logger.error("Cannot specify both val_count and val_ratio")
+ return
+ elif val_count is not None:
+ if val_count < 1 or val_count > len(filtered_files):
+ logger.error("val_count must be between 1 and number of files")
+ return
+ val_size = val_count
+ else:
+ val_size = math.ceil(len(filtered_files) * val_ratio)
+
+ logger.info(f"Using {val_size} files for validation")
+
+ with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(filtered_files[val_size:]))
+
+ with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(filtered_files[:val_size]))
+
+ logger.info("Done")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py
index bccc721a7f5d5cb68596df62d7b4c629814538c5..e621eaddac7215cb68767d138f748bc40d29d67d 100644
--- a/tools/vqgan/extract_vq.py
+++ b/tools/vqgan/extract_vq.py
@@ -1,233 +1,232 @@
-import os
-import subprocess as sp
-import sys
-import time
-from datetime import timedelta
-from functools import lru_cache
-from pathlib import Path
-from random import Random
-
-import click
-import numpy as np
-import torch
-import torchaudio
-from hydra import compose, initialize
-from hydra.utils import instantiate
-from lightning import LightningModule
-from loguru import logger
-from omegaconf import OmegaConf
-
-from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
-
-# register eval resolver
-OmegaConf.register_new_resolver("eval", eval)
-# This file is used to convert the audio files to text files using the Whisper model.
-# It's mainly used to generate the training data for the VQ model.
-
-backends = torchaudio.list_audio_backends()
-
-if "ffmpeg" in backends:
- backend = "ffmpeg"
-else:
- backend = "soundfile"
-
-RANK = int(os.environ.get("SLURM_PROCID", 0))
-WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
-
-logger_format = (
- "{time:YYYY-MM-DD HH:mm:ss.SSS} | "
- "{level: <8} | "
- "{name}:{function}:{line} | "
- "{extra[rank]} - {message}"
-)
-logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
-logger.remove()
-logger.add(sys.stderr, format=logger_format)
-
-
-@lru_cache(maxsize=1)
-def get_model(
- config_name: str = "firefly_gan_vq",
- checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
- device: str | torch.device = "cuda",
-):
- with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
- cfg = compose(config_name=config_name)
-
- model = instantiate(cfg)
- state_dict = torch.load(
- checkpoint_path,
- map_location=device,
- )
- if "state_dict" in state_dict:
- state_dict = state_dict["state_dict"]
-
- if any("generator" in k for k in state_dict):
- state_dict = {
- k.replace("generator.", ""): v
- for k, v in state_dict.items()
- if "generator." in k
- }
-
- model.load_state_dict(state_dict, strict=False)
- model.eval()
- model.to(device)
-
- logger.info(f"Loaded model")
- return model
-
-
-@torch.inference_mode()
-def process_batch(files: list[Path], model) -> float:
- wavs = []
- audio_lengths = []
- new_files = []
- max_length = total_time = 0
-
- for file in files:
- try:
- wav, sr = torchaudio.load(
- str(file), backend=backend
- ) # Need to install libsox-dev
- except Exception as e:
- logger.error(f"Error reading {file}: {e}")
- continue
-
- if wav.shape[0] > 1:
- wav = wav.mean(dim=0, keepdim=True)
-
- wav = torchaudio.functional.resample(
- wav.cuda(), sr, model.spec_transform.sample_rate
- )[0]
- total_time += len(wav) / model.spec_transform.sample_rate
- max_length = max(max_length, len(wav))
-
- wavs.append(wav)
- audio_lengths.append(len(wav))
- new_files.append(file)
-
- files = new_files
-
- # Pad to max length
- for i, wav in enumerate(wavs):
- wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
-
- audios = torch.stack(wavs, dim=0)[:, None]
- audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
-
- # Calculate lengths
- indices, feature_lengths = model.encode(audios, audio_lengths)
-
- # Save to disk
- outputs = indices.cpu().numpy()
-
- for file, length, feature, audio_length in zip(
- files, feature_lengths, outputs, audio_lengths
- ):
- feature = feature[:, :length]
-
- # (T,)
- with open(file.with_suffix(".npy"), "wb") as f:
- np.save(f, feature)
-
- return total_time
-
-
-@click.command()
-@click.argument("folder")
-@click.option("--num-workers", default=1)
-@click.option("--config-name", default="firefly_gan_vq")
-@click.option(
- "--checkpoint-path",
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
-)
-@click.option("--batch-size", default=64)
-@click.option("--filelist", default=None, type=Path)
-def main(
- folder: str,
- num_workers: int,
- config_name: str,
- checkpoint_path: str,
- batch_size: int,
- filelist: Path,
-):
- if num_workers > 1 and WORLD_SIZE != num_workers:
- assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
-
- logger.info(f"Spawning {num_workers} workers")
-
- if torch.cuda.is_available():
- visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
- if visible_devices is None:
- visible_devices = list(range(torch.cuda.device_count()))
- else:
- visible_devices = visible_devices.split(",")
- else:
- # Set to empty string to avoid using GPU
- visible_devices = [""]
-
- processes = []
- for i in range(num_workers):
- env = os.environ.copy()
- env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
- env["SLURM_PROCID"] = str(i)
- env["SLURM_NTASKS"] = str(num_workers)
-
- processes.append(
- sp.Popen(
- [sys.executable] + sys.argv.copy(),
- env=env,
- )
- )
-
- for p in processes:
- p.wait()
-
- logger.info(f"All workers finished")
- return
-
- # This is a worker
- logger.info(f"Starting worker")
- if filelist:
- files = [i[0] for i in load_filelist(filelist)]
- else:
- files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
-
- print(f"Found {len(files)} files")
- files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
-
- total_files = len(files)
- files = files[RANK::WORLD_SIZE]
- logger.info(f"Processing {len(files)}/{total_files} files")
-
- # Batch processing
- total_time = 0
- begin_time = time.time()
- processed_files = 0
- model = get_model(config_name, checkpoint_path)
-
- for n_batch, idx in enumerate(range(0, len(files), batch_size)):
- batch = files[idx : idx + batch_size]
- batch_time = process_batch(batch, model)
-
- total_time += batch_time
- processed_files += len(batch)
-
- if (n_batch + 1) % 10 == 0:
- eta = (
- (time.time() - begin_time)
- / processed_files
- * (len(files) - processed_files)
- )
- logger.info(
- f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
- + f"ETA: {timedelta(seconds=round(eta))}s"
- )
-
- logger.info(
- f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
- )
-
-
-if __name__ == "__main__":
- main()
+import os
+import subprocess as sp
+import sys
+import time
+from datetime import timedelta
+from functools import lru_cache
+from pathlib import Path
+from random import Random
+
+import click
+import numpy as np
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+from omegaconf import OmegaConf
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+# This file is used to convert the audio files to text files using the Whisper model.
+# It's mainly used to generate the training data for the VQ model.
+
+backends = torchaudio.list_audio_backends()
+
+if "ffmpeg" in backends:
+ backend = "ffmpeg"
+else:
+ backend = "soundfile"
+
+RANK = int(os.environ.get("SLURM_PROCID", 0))
+WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
+
+logger_format = (
+ "{time:YYYY-MM-DD HH:mm:ss.SSS} | "
+ "{level: <8} | "
+ "{name}:{function}:{line} | "
+ "{extra[rank]} - {message}"
+)
+logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
+logger.remove()
+logger.add(sys.stderr, format=logger_format)
+
+
+@lru_cache(maxsize=1)
+def get_model(
+ config_name: str = "modded_dac_vq",
+ checkpoint_path: str = "checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ device: str | torch.device = "cuda",
+):
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+ cfg = compose(config_name=config_name)
+
+ model = instantiate(cfg)
+ state_dict = torch.load(
+ checkpoint_path,
+ map_location=device,
+ )
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator" in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ model.load_state_dict(state_dict, strict=False)
+ model.eval()
+ model.to(device)
+
+ logger.info(f"Loaded model")
+ return model
+
+
+@torch.inference_mode()
+def process_batch(files: list[Path], model) -> float:
+ wavs = []
+ audio_lengths = []
+ new_files = []
+ max_length = total_time = 0
+
+ for file in files:
+ try:
+ wav, sr = torchaudio.load(
+ str(file), backend=backend
+ ) # Need to install libsox-dev
+ except Exception as e:
+ logger.error(f"Error reading {file}: {e}")
+ continue
+
+ if wav.shape[0] > 1:
+ wav = wav.mean(dim=0, keepdim=True)
+
+ wav = torchaudio.functional.resample(
+ wav.cuda(), sr, model.spec_transform.sample_rate
+ )[0]
+ total_time += len(wav) / model.spec_transform.sample_rate
+ max_length = max(max_length, len(wav))
+
+ wavs.append(wav)
+ audio_lengths.append(len(wav))
+ new_files.append(file)
+
+ files = new_files
+
+ # Pad to max length
+ for i, wav in enumerate(wavs):
+ wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
+
+ audios = torch.stack(wavs, dim=0)[:, None]
+ audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
+
+ # Calculate lengths
+ indices, feature_lengths = model.encode(audios, audio_lengths)
+
+ # Save to disk
+ outputs = indices.cpu().numpy()
+
+ for file, length, feature, audio_length in zip(
+ files, feature_lengths, outputs, audio_lengths
+ ):
+ feature = feature[:, :length]
+
+ # (T,)
+ with open(file.with_suffix(".npy"), "wb") as f:
+ np.save(f, feature)
+
+ return total_time
+
+
+@click.command()
+@click.argument("folder")
+@click.option("--num-workers", default=1)
+@click.option("--config-name", default="modded_dac_vq")
+@click.option(
+ "--checkpoint-path",
+ default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+)
+@click.option("--batch-size", default=64)
+@click.option("--filelist", default=None, type=Path)
+def main(
+ folder: str,
+ num_workers: int,
+ config_name: str,
+ checkpoint_path: str,
+ batch_size: int,
+ filelist: Path,
+):
+ if num_workers > 1 and WORLD_SIZE != num_workers:
+ assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
+
+ logger.info(f"Spawning {num_workers} workers")
+
+ if torch.cuda.is_available():
+ visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ if visible_devices is None:
+ visible_devices = list(range(torch.cuda.device_count()))
+ else:
+ visible_devices = visible_devices.split(",")
+ else:
+ # Set to empty string to avoid using GPU
+ visible_devices = [""]
+
+ processes = []
+ for i in range(num_workers):
+ env = os.environ.copy()
+ env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
+ env["SLURM_PROCID"] = str(i)
+ env["SLURM_NTASKS"] = str(num_workers)
+
+ processes.append(
+ sp.Popen(
+ [sys.executable] + sys.argv.copy(),
+ env=env,
+ )
+ )
+
+ for p in processes:
+ p.wait()
+
+ logger.info(f"All workers finished")
+ return
+
+ # This is a worker
+ logger.info(f"Starting worker")
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
+
+ print(f"Found {len(files)} files")
+ files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
+
+ total_files = len(files)
+ files = files[RANK::WORLD_SIZE]
+ logger.info(f"Processing {len(files)}/{total_files} files")
+
+ # Batch processing
+ total_time = 0
+ begin_time = time.time()
+ processed_files = 0
+ model = get_model(config_name, checkpoint_path)
+
+ for n_batch, idx in enumerate(range(0, len(files), batch_size)):
+ batch = files[idx : idx + batch_size]
+ batch_time = process_batch(batch, model)
+
+ total_time += batch_time
+ processed_files += len(batch)
+
+ if (n_batch + 1) % 10 == 0:
+ eta = (
+ (time.time() - begin_time)
+ / processed_files
+ * (len(files) - processed_files)
+ )
+ logger.info(
+ f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
+ + f"ETA: {timedelta(seconds=round(eta))}s"
+ )
+
+ logger.info(
+ f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/webui/__init__.py b/tools/webui/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9b9a0259aaf77778709d81994bf9885e1ec4b2f
--- /dev/null
+++ b/tools/webui/__init__.py
@@ -0,0 +1,155 @@
+from typing import Callable
+
+import gradio as gr
+
+from fish_speech.i18n import i18n
+from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER
+
+
+def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
+ with gr.Blocks(theme=gr.themes.Base()) as app:
+ gr.Markdown(HEADER_MD)
+
+ # Use light theme by default
+ app.load(
+ None,
+ None,
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
+ % theme,
+ )
+
+ # Inference
+ with gr.Row():
+ with gr.Column(scale=3):
+ text = gr.Textbox(
+ label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
+ )
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab(label=i18n("Advanced Config")):
+ with gr.Row():
+ chunk_length = gr.Slider(
+ label=i18n("Iterative Prompt Length, 0 means off"),
+ minimum=100,
+ maximum=400,
+ value=300,
+ step=8,
+ )
+
+ max_new_tokens = gr.Slider(
+ label=i18n(
+ "Maximum tokens per batch, 0 means no limit"
+ ),
+ minimum=0,
+ maximum=2048,
+ value=0,
+ step=8,
+ )
+
+ with gr.Row():
+ top_p = gr.Slider(
+ label="Top-P",
+ minimum=0.7,
+ maximum=0.95,
+ value=0.8,
+ step=0.01,
+ )
+
+ repetition_penalty = gr.Slider(
+ label=i18n("Repetition Penalty"),
+ minimum=1,
+ maximum=1.2,
+ value=1.1,
+ step=0.01,
+ )
+
+ with gr.Row():
+ temperature = gr.Slider(
+ label="Temperature",
+ minimum=0.7,
+ maximum=1.0,
+ value=0.8,
+ step=0.01,
+ )
+ seed = gr.Number(
+ label="Seed",
+ info="0 means randomized inference, otherwise deterministic",
+ value=0,
+ )
+
+ with gr.Tab(label=i18n("Reference Audio")):
+ with gr.Row():
+ gr.Markdown(
+ i18n(
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
+ )
+ )
+ with gr.Row():
+ reference_id = gr.Textbox(
+ label=i18n("Reference ID"),
+ placeholder="Leave empty to use uploaded references",
+ )
+
+ with gr.Row():
+ use_memory_cache = gr.Radio(
+ label=i18n("Use Memory Cache"),
+ choices=["on", "off"],
+ value="on",
+ )
+
+ with gr.Row():
+ reference_audio = gr.Audio(
+ label=i18n("Reference Audio"),
+ type="filepath",
+ )
+ with gr.Row():
+ reference_text = gr.Textbox(
+ label=i18n("Reference Text"),
+ lines=1,
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+ value="",
+ )
+
+ with gr.Column(scale=3):
+ with gr.Row():
+ error = gr.HTML(
+ label=i18n("Error Message"),
+ visible=True,
+ )
+ with gr.Row():
+ audio = gr.Audio(
+ label=i18n("Generated Audio"),
+ type="numpy",
+ interactive=False,
+ visible=True,
+ )
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ generate = gr.Button(
+ value="\U0001f3a7 " + i18n("Generate"),
+ variant="primary",
+ )
+
+ # Submit
+ generate.click(
+ inference_fct,
+ [
+ text,
+ reference_id,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed,
+ use_memory_cache,
+ ],
+ [audio, error],
+ concurrency_limit=1,
+ )
+
+ return app
diff --git a/tools/webui/inference.py b/tools/webui/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6cd1d75599ad38b757b57aaf33ee32b62936338
--- /dev/null
+++ b/tools/webui/inference.py
@@ -0,0 +1,89 @@
+import html
+from functools import partial
+from typing import Any, Callable
+
+from fish_speech.i18n import i18n
+from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
+
+
+def inference_wrapper(
+ text,
+ reference_id,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed,
+ use_memory_cache,
+ engine,
+):
+ """
+ Wrapper for the inference function.
+ Used in the Gradio interface.
+ """
+
+ if reference_audio:
+ references = get_reference_audio(reference_audio, reference_text)
+ else:
+ references = []
+
+ req = ServeTTSRequest(
+ text=text,
+ reference_id=reference_id if reference_id else None,
+ references=references,
+ max_new_tokens=max_new_tokens,
+ chunk_length=chunk_length,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ seed=int(seed) if seed else None,
+ use_memory_cache=use_memory_cache,
+ )
+
+ for result in engine.inference(req):
+ match result.code:
+ case "final":
+ return result.audio, None
+ case "error":
+ return None, build_html_error_message(i18n(result.error))
+ case _:
+ pass
+
+ return None, i18n("No audio generated")
+
+
+def get_reference_audio(reference_audio: str, reference_text: str) -> list:
+ """
+ Get the reference audio bytes.
+ """
+
+ with open(reference_audio, "rb") as audio_file:
+ audio_bytes = audio_file.read()
+
+ return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)]
+
+
+def build_html_error_message(error: Any) -> str:
+
+ error = error if isinstance(error, Exception) else Exception("Unknown error")
+
+ return f"""
+
+ {html.escape(str(error))}
+
+ """
+
+
+def get_inference_wrapper(engine) -> Callable:
+ """
+ Get the inference function with the immutable arguments.
+ """
+
+ return partial(
+ inference_wrapper,
+ engine=engine,
+ )
diff --git a/tools/webui/variables.py b/tools/webui/variables.py
new file mode 100644
index 0000000000000000000000000000000000000000..db42d5d797e821e9a34832bea4344ecb726c97db
--- /dev/null
+++ b/tools/webui/variables.py
@@ -0,0 +1,14 @@
+from fish_speech.i18n import i18n
+
+HEADER_MD = f"""# Fish Speech
+
+{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
+
+{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}
+
+{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
+
+{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
+"""
+
+TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
diff --git a/tools/whisper_asr.py b/tools/whisper_asr.py
index 8ac720d2dad94fca5087852928f7e6ce57d1b92d..1c34efbe44928def767305d77882f1fcb0c7eccb 100644
--- a/tools/whisper_asr.py
+++ b/tools/whisper_asr.py
@@ -1,176 +1,176 @@
-"""
-Used to transcribe all audio files in one folder into another folder.
-e.g.
-Directory structure:
---pre_data_root
-----SP_1
-------01.wav
-------02.wav
-------......
-----SP_2
-------01.wav
-------02.wav
-------......
-Use
-python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
-to transcribe the first speaker.
-
-Use
-python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
-to transcribe the second speaker.
-
-Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
-"""
-
-import re
-from pathlib import Path
-
-import click
-import soundfile as sf
-from faster_whisper import WhisperModel
-from loguru import logger
-from pydub import AudioSegment
-from tqdm import tqdm
-
-from tools.file import AUDIO_EXTENSIONS, list_files
-
-
-@click.command()
-@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
-@click.option(
- "--compute-type",
- default="float16",
- help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
-)
-@click.option("--audio-dir", required=True, help="Directory containing audio files")
-@click.option(
- "--save-dir", required=True, help="Directory to save processed audio files"
-)
-@click.option(
- "--sample-rate",
- default=44100,
- type=int,
- help="Output sample rate, default to input sample rate",
-)
-@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
-@click.option("--language", default="auto", help="Language of the transcription")
-@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
-def main(
- model_size,
- compute_type,
- audio_dir,
- save_dir,
- sample_rate,
- device,
- language,
- initial_prompt,
-):
- logger.info("Loading / Downloading Faster Whisper model...")
-
- model = WhisperModel(
- model_size,
- device=device,
- compute_type=compute_type,
- download_root="faster_whisper",
- )
-
- logger.info("Model loaded.")
-
- save_path = Path(save_dir)
- save_path.mkdir(parents=True, exist_ok=True)
-
- audio_files = list_files(
- path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
- )
-
- for file_path in tqdm(audio_files, desc="Processing audio file"):
- file_stem = file_path.stem
- file_suffix = file_path.suffix
-
- rel_path = Path(file_path).relative_to(audio_dir)
- (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
-
- audio = AudioSegment.from_file(file_path)
-
- segments, info = model.transcribe(
- file_path,
- beam_size=5,
- language=None if language == "auto" else language,
- initial_prompt=initial_prompt,
- )
-
- print(
- "Detected language '%s' with probability %f"
- % (info.language, info.language_probability)
- )
- print("Total len(ms): ", len(audio))
-
- whole_text = None
- for segment in segments:
- id, start, end, text = (
- segment.id,
- segment.start,
- segment.end,
- segment.text,
- )
- print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
- if not whole_text:
- whole_text = text
- else:
- whole_text += ", " + text
-
- whole_text += "."
-
- audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
- audio.export(audio_save_path, format=file_suffix[1:])
- print(f"Exported {audio_save_path}")
-
- transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
- with open(
- transcript_save_path,
- "w",
- encoding="utf-8",
- ) as f:
- f.write(whole_text)
-
-
-if __name__ == "__main__":
- main()
- exit(0)
-
- audio = AudioSegment.from_wav(
- r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
- )
-
- model_size = "large-v3"
-
- model = WhisperModel(
- model_size,
- device="cuda",
- compute_type="float16",
- download_root="faster_whisper",
- )
-
- segments, info = model.transcribe(
- r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
- beam_size=5,
- )
-
- print(
- "Detected language '%s' with probability %f"
- % (info.language, info.language_probability)
- )
- print("Total len(ms): ", len(audio))
-
- for i, segment in enumerate(segments):
- print(
- "Segment %03d [%.2fs -> %.2fs] %s"
- % (i, segment.start, segment.end, segment.text)
- )
- start_ms = int(segment.start * 1000)
- end_ms = int(segment.end * 1000)
- segment_audio = audio[start_ms:end_ms]
- segment_audio.export(f"segment_{i:03d}.wav", format="wav")
- print(f"Exported segment_{i:03d}.wav")
-
- print("All segments have been exported.")
+"""
+Used to transcribe all audio files in one folder into another folder.
+e.g.
+Directory structure:
+--pre_data_root
+----SP_1
+------01.wav
+------02.wav
+------......
+----SP_2
+------01.wav
+------02.wav
+------......
+Use
+python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
+to transcribe the first speaker.
+
+Use
+python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
+to transcribe the second speaker.
+
+Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
+"""
+
+import re
+from pathlib import Path
+
+import click
+import soundfile as sf
+from faster_whisper import WhisperModel
+from loguru import logger
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+
+@click.command()
+@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
+@click.option(
+ "--compute-type",
+ default="float16",
+ help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
+)
+@click.option("--audio-dir", required=True, help="Directory containing audio files")
+@click.option(
+ "--save-dir", required=True, help="Directory to save processed audio files"
+)
+@click.option(
+ "--sample-rate",
+ default=44100,
+ type=int,
+ help="Output sample rate, default to input sample rate",
+)
+@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
+@click.option("--language", default="auto", help="Language of the transcription")
+@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
+def main(
+ model_size,
+ compute_type,
+ audio_dir,
+ save_dir,
+ sample_rate,
+ device,
+ language,
+ initial_prompt,
+):
+ logger.info("Loading / Downloading Faster Whisper model...")
+
+ model = WhisperModel(
+ model_size,
+ device=device,
+ compute_type=compute_type,
+ download_root="faster_whisper",
+ )
+
+ logger.info("Model loaded.")
+
+ save_path = Path(save_dir)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
+ file_stem = file_path.stem
+ file_suffix = file_path.suffix
+
+ rel_path = Path(file_path).relative_to(audio_dir)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+
+ audio = AudioSegment.from_file(file_path)
+
+ segments, info = model.transcribe(
+ file_path,
+ beam_size=5,
+ language=None if language == "auto" else language,
+ initial_prompt=initial_prompt,
+ )
+
+ print(
+ "Detected language '%s' with probability %f"
+ % (info.language, info.language_probability)
+ )
+ print("Total len(ms): ", len(audio))
+
+ whole_text = None
+ for segment in segments:
+ id, start, end, text = (
+ segment.id,
+ segment.start,
+ segment.end,
+ segment.text,
+ )
+ print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
+ if not whole_text:
+ whole_text = text
+ else:
+ whole_text += ", " + text
+
+ whole_text += "."
+
+ audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
+ audio.export(audio_save_path, format=file_suffix[1:])
+ print(f"Exported {audio_save_path}")
+
+ transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
+ with open(
+ transcript_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(whole_text)
+
+
+if __name__ == "__main__":
+ main()
+ exit(0)
+
+ audio = AudioSegment.from_wav(
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
+ )
+
+ model_size = "large-v3"
+
+ model = WhisperModel(
+ model_size,
+ device="cuda",
+ compute_type="float16",
+ download_root="faster_whisper",
+ )
+
+ segments, info = model.transcribe(
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
+ beam_size=5,
+ )
+
+ print(
+ "Detected language '%s' with probability %f"
+ % (info.language, info.language_probability)
+ )
+ print("Total len(ms): ", len(audio))
+
+ for i, segment in enumerate(segments):
+ print(
+ "Segment %03d [%.2fs -> %.2fs] %s"
+ % (i, segment.start, segment.end, segment.text)
+ )
+ start_ms = int(segment.start * 1000)
+ end_ms = int(segment.end * 1000)
+ segment_audio = audio[start_ms:end_ms]
+ segment_audio.export(f"segment_{i:03d}.wav", format="wav")
+ print(f"Exported segment_{i:03d}.wav")
+
+ print("All segments have been exported.")