diff --git a/.gitattributes b/.gitattributes index a56018dcef8a9d2bbdf0bae70eb19e573133741a..d0958581de2691a0397f245b105e854a8660a466 100644 --- a/.gitattributes +++ b/.gitattributes @@ -36,3 +36,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text examples/Japanese.wav filter=lfs diff=lfs merge=lfs -text examples/Korean.wav filter=lfs diff=lfs merge=lfs -text examples/Nice[[:space:]]English[[:space:]]Ref.wav filter=lfs diff=lfs merge=lfs -text +examples/Arabic.wav filter=lfs diff=lfs merge=lfs -text +examples/English.wav filter=lfs diff=lfs merge=lfs -text +examples/French.wav filter=lfs diff=lfs merge=lfs -text +examples/German.wav filter=lfs diff=lfs merge=lfs -text +examples/Spanish.wav filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index e87fb1fa90473c1ffce53ab39f9a33993ee861b6..91c4a63cb1a38b778237397a69aac079c5ac8ab6 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ --- -title: Fish Speech 1 +title: OpenAudio S1 emoji: 🏆 colorFrom: purple colorTo: gray diff --git a/app.py b/app.py index 7919d4335ed4a6daf739c752e117e539c3a987c8..85c9d20a70fe709751b53c43af3316f9b92fb850 100644 --- a/app.py +++ b/app.py @@ -1,86 +1,51 @@ import os import queue from huggingface_hub import snapshot_download -import hydra import numpy as np import wave import io -import pyrootutils import gc +from typing import Callable # Download if not exists os.makedirs("checkpoints", exist_ok=True) -snapshot_download(repo_id="fishaudio/fish-speech-1.5", local_dir="./checkpoints/fish-speech-1.5") +snapshot_download(repo_id="fishaudio/openaudio-s1-mini", local_dir="./checkpoints/openaudio-s1-mini") print("All checkpoints downloaded") import html import os -import threading from argparse import ArgumentParser from pathlib import Path -from functools import partial import gradio as gr -import librosa import torch import torchaudio torchaudio.set_audio_backend("soundfile") from loguru import logger -from transformers import AutoTokenizer - from fish_speech.i18n import i18n -from fish_speech.text.chn_text_norm.text import Text as ChnNormedText -from fish_speech.utils import autocast_exclude_mps, set_seed -from tools.api import decode_vq_tokens, encode_reference -from tools.file import AUDIO_EXTENSIONS, list_files -from tools.llama.generate import ( - GenerateRequest, - GenerateResponse, - WrappedGenerateResponse, - launch_thread_safe_queue, -) -from tools.vqgan.inference import load_model as load_decoder_model - -from tools.schema import ( - GLOBAL_NUM_SAMPLES, - ASRPackRequest, - ServeASRRequest, - ServeASRResponse, - ServeASRSegment, - ServeAudioPart, - ServeForwardMessage, - ServeMessage, - ServeRequest, - ServeResponse, - ServeStreamDelta, - ServeStreamResponse, - ServeTextPart, - ServeTimedASRResponse, - ServeTTSRequest, - ServeVQGANDecodeRequest, - ServeVQGANDecodeResponse, - ServeVQGANEncodeRequest, - ServeVQGANEncodeResponse, - ServeVQPart, - ServeReferenceAudio -) +from fish_speech.inference_engine import TTSInferenceEngine +from fish_speech.models.dac.inference import load_model as load_decoder_model +from fish_speech.models.text2semantic.inference import launch_thread_safe_queue +from tools.webui.inference import get_inference_wrapper +from fish_speech.utils.schema import ServeTTSRequest + # Make einx happy os.environ["EINX_FILTER_TRACEBACK"] = "false" -HEADER_MD = """# Fish Speech +HEADER_MD = """# OpenAudio S1 -## The demo in this space is version 1.5, Please check [Fish Audio](https://fish.audio) for the best model. -## 该 Demo 为 Fish Speech 1.5 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO. +## The demo in this space is OpenAudio S1, Please check [Fish Audio](https://fish.audio) for the best model. +## 该 Demo 为 OpenAudio S1 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO. -A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio). -由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成. +A text-to-speech model based on DAC and Qwen3 developed by [Fish Audio](https://fish.audio). +由 [Fish Audio](https://fish.audio) 研发的基于 DAC 和 Qwen3 的多语种语音合成. -You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5). -你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1.5) 找到模型. +You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/openaudio-s1-mini). +你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/openaudio-s1-mini) 找到模型. Related code and weights are released under CC BY-NC-SA 4.0 License. 相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布. @@ -88,8 +53,8 @@ Related code and weights are released under CC BY-NC-SA 4.0 License. We are not responsible for any misuse of the model, please consider your local laws and regulations before using it. 我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规. -The model running in this WebUI is Fish Speech V1.5 Medium. -在此 WebUI 中运行的模型是 Fish Speech V1.5 Medium. +The model running in this WebUI is OpenAudio S1 Mini. +在此 WebUI 中运行的模型是 OpenAudio S1 Mini. """ TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本.""" @@ -106,7 +71,6 @@ except ImportError: return wrapper - def build_html_error_message(error): return f"""
0, - chunk_length=req.chunk_length, - max_length=4096, - prompt_tokens=prompt_tokens, - prompt_text=prompt_texts, - ) - - response_queue = queue.Queue() - llama_queue.put( - GenerateRequest( - request=request, - response_queue=response_queue, - ) - ) - - segments = [] - - while True: - result: WrappedGenerateResponse = response_queue.get() - if result.status == "error": - yield None, None, build_html_error_message(result.response) - break - - result: GenerateResponse = result.response - if result.action == "next": - break - - with autocast_exclude_mps( - device_type=decoder_model.device.type, dtype=args.precision - ): - fake_audios = decode_vq_tokens( - decoder_model=decoder_model, - codes=result.codes, - ) - - fake_audios = fake_audios.float().cpu().numpy() - segments.append(fake_audios) - - if len(segments) == 0: - return ( - None, - None, - build_html_error_message( - i18n("No audio generated, please check the input text.") - ), - ) - - # No matter streaming or not, we need to return the final audio - audio = np.concatenate(segments, axis=0) - yield None, (decoder_model.spec_transform.sample_rate, audio), None - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - except Exception as e: - er = "CUDA error: device-side assert triggered" - if er in str(e): - app.close() - else: - raise Exception(e) - -n_audios = 4 - -global_audio_list = [] -global_error_list = [] - - def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): buffer = io.BytesIO() @@ -230,13 +91,8 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): buffer.close() return wav_header_bytes -def normalize_text(user_input, use_normalization): - if use_normalization: - return ChnNormedText(raw_text=user_input).normalize() - else: - return user_input -def build_app(): +def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks: with gr.Blocks(theme=gr.themes.Base()) as app: gr.Markdown(HEADER_MD) @@ -245,7 +101,7 @@ def build_app(): None, None, js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}" - % args.theme, + % theme, ) # Inference @@ -254,20 +110,6 @@ def build_app(): text = gr.Textbox( label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10 ) - refined_text = gr.Textbox( - label=i18n("Realtime Transform Text"), - placeholder=i18n( - "Normalization Result Preview (Currently Only Chinese)" - ), - lines=5, - interactive=False, - ) - - with gr.Row(): - normalize = gr.Checkbox( - label=i18n("Text Normalization"), - value=False, - ) with gr.Row(): with gr.Column(): @@ -275,45 +117,45 @@ def build_app(): with gr.Row(): chunk_length = gr.Slider( label=i18n("Iterative Prompt Length, 0 means off"), - minimum=0, - maximum=300, - value=200, + minimum=100, + maximum=400, + value=300, step=8, ) max_new_tokens = gr.Slider( label=i18n( - "Maximum tokens per batch" + "Maximum tokens per batch, 0 means no limit" ), - minimum=512, + minimum=0, maximum=2048, - value=1024, - step=64, + value=0, + step=8, ) with gr.Row(): top_p = gr.Slider( label="Top-P", - minimum=0.6, - maximum=0.9, - value=0.7, + minimum=0.7, + maximum=0.95, + value=0.8, step=0.01, ) repetition_penalty = gr.Slider( label=i18n("Repetition Penalty"), minimum=1, - maximum=1.5, - value=1.2, + maximum=1.2, + value=1.1, step=0.01, ) with gr.Row(): temperature = gr.Slider( label="Temperature", - minimum=0.6, - maximum=0.9, - value=0.7, + minimum=0.7, + maximum=1.0, + value=0.8, step=0.01, ) seed = gr.Number( @@ -326,24 +168,20 @@ def build_app(): with gr.Row(): gr.Markdown( i18n( - "15 to 60 seconds of reference audio, useful for specifying speaker." + "5 to 10 seconds of reference audio, useful for specifying speaker." ) ) - with gr.Row(): - # Add dropdown for selecting example audio files - example_audio_files = [f for f in os.listdir("examples") if f.endswith(".wav")] - example_audio_dropdown = gr.Dropdown( - label="Select Example Audio", - choices=[""] + example_audio_files, - value="" + reference_id = gr.Textbox( + label=i18n("Reference ID"), + placeholder="Leave empty to use uploaded references", ) with gr.Row(): use_memory_cache = gr.Radio( label=i18n("Use Memory Cache"), - choices=["never"], - value="never", + choices=["on", "off"], + value="on", ) with gr.Row(): @@ -351,7 +189,6 @@ def build_app(): label=i18n("Reference Audio"), type="filepath", ) - with gr.Row(): reference_text = gr.Textbox( label=i18n("Reference Text"), @@ -377,101 +214,16 @@ def build_app(): with gr.Row(): with gr.Column(scale=3): generate = gr.Button( - value="\U0001F3A7 " + i18n("Generate"), variant="primary" + value="\U0001f3a7 " + i18n("Generate"), + variant="primary", ) - text.input( - fn=normalize_text, inputs=[text, normalize], outputs=[refined_text] - ) - - def inference_wrapper( - text, - normalize, - reference_audio, - reference_text, - max_new_tokens, - chunk_length, - top_p, - repetition_penalty, - temperature, - seed, - use_memory_cache, - ): - print( - "call inference wrapper", - text, - normalize, - reference_audio, - reference_text, - max_new_tokens, - chunk_length, - top_p, - repetition_penalty, - temperature, - seed, - use_memory_cache - ) - - references = [] - if reference_audio: - # 将文件路径转换为字节 - with open(reference_audio, 'rb') as audio_file: - audio_bytes = audio_file.read() - - references = [ - ServeReferenceAudio(audio=audio_bytes, text=reference_text) - ] - - req = ServeTTSRequest( - text=text, - normalize=normalize, - reference_id=None, - references=references, - max_new_tokens=max_new_tokens, - chunk_length=chunk_length, - top_p=top_p, - repetition_penalty=repetition_penalty, - temperature=temperature, - seed=int(seed) if seed else None, - use_memory_cache=use_memory_cache, - ) - - for result in inference(req): - if result[2]: # Error message - return None, result[2] - elif result[1]: # Audio data - return result[1], None - - return None, i18n("No audio generated") - - def select_example_audio(audio_file): - if audio_file: - audio_path = os.path.join("examples", audio_file) - lab_file = os.path.splitext(audio_file)[0] + ".lab" - lab_path = os.path.join("examples", lab_file) - - if os.path.exists(lab_path): - with open(lab_path, "r", encoding="utf-8") as f: - lab_content = f.read().strip() - else: - lab_content = "" - - return audio_path, lab_content - return None, "" - - # Connect the dropdown to update reference audio and text - example_audio_dropdown.change( - fn=select_example_audio, - inputs=[example_audio_dropdown], - outputs=[reference_audio, reference_text] - ) - # Submit generate.click( - inference_wrapper, + inference_fct, [ - refined_text, - normalize, + text, + reference_id, reference_audio, reference_text, max_new_tokens, @@ -488,26 +240,24 @@ def build_app(): return app - - def parse_args(): parser = ArgumentParser() parser.add_argument( "--llama-checkpoint-path", type=Path, - default="checkpoints/fish-speech-1.5", + default="checkpoints/openaudio-s1-mini", ) parser.add_argument( "--decoder-checkpoint-path", type=Path, - default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + default="checkpoints/openaudio-s1-mini/codec.pth", ) - parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") + parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--half", action="store_true") parser.add_argument("--compile", action="store_true",default=True) parser.add_argument("--max-gradio-length", type=int, default=0) - parser.add_argument("--theme", type=str, default="light") + parser.add_argument("--theme", type=str, default="dark") return parser.parse_args() @@ -533,25 +283,34 @@ if __name__ == "__main__": logger.info("Decoder model loaded, warming up...") + # Create the inference engine + inference_engine = TTSInferenceEngine( + llama_queue=llama_queue, + decoder_model=decoder_model, + compile=args.compile, + precision=args.precision, + ) + # Dry run to check if the model is loaded correctly and avoid the first-time latency list( - inference( - ServeTTSRequest( - text="Hello world.", - references=[], - reference_id=None, - max_new_tokens=0, - chunk_length=200, - top_p=0.7, - repetition_penalty=1.5, - temperature=0.7, - emotion=None, - format="wav", - ) + inference_engine.inference( + ServeTTSRequest( + text="Hello world.", + references=[], + reference_id=None, + max_new_tokens=1024, + chunk_length=200, + top_p=0.7, + repetition_penalty=1.5, + temperature=0.7, + format="wav", ) + ) ) logger.info("Warming up done, launching the web UI...") - app = build_app() - app.queue(api_open=True).launch(show_error=True, show_api=True) + inference_fct = get_inference_wrapper(inference_engine) + + app = build_app(inference_fct, args.theme) + app.queue(api_open=True).launch(show_error=True, show_api=True, server_name="0.0.0.0", server_port=18888) diff --git a/examples/Arabic.wav b/examples/Arabic.wav index 0e5cecd8112b805841714e01550b63b98a8b92be..2e2acdf221d68c1ff00d4c50e86959669dcd6b94 100644 Binary files a/examples/Arabic.wav and b/examples/Arabic.wav differ diff --git a/examples/English.wav b/examples/English.wav index c41d923cb56b33ac51f63ab1ceab1643976bf6d0..5061c0e2658b9238b883c054d6c3b08f90501e99 100644 Binary files a/examples/English.wav and b/examples/English.wav differ diff --git a/examples/French.wav b/examples/French.wav index a36ca2193b9b391701178e378423375897191e91..5292499755d86e57263c4ee00160864b5988197a 100644 Binary files a/examples/French.wav and b/examples/French.wav differ diff --git a/examples/German.wav b/examples/German.wav index 13ef8df7a782da5a6f863f57f72abaecaf7a9180..1c3a10598b0c6345dfdd7e3989aaa155f72ea114 100644 Binary files a/examples/German.wav and b/examples/German.wav differ diff --git a/examples/Japanese.wav b/examples/Japanese.wav index d601bc9d57b518e6b4646de0c0c9d063fb7df23c..8e41057b85229d69a196945d146109edd1eea2b2 100644 --- a/examples/Japanese.wav +++ b/examples/Japanese.wav @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a23cffeac70f42e1cc69e2a0505e4c1fda50884dd34c509128d432aaf44565e5 -size 1148682 +oid sha256:3034a38260884be854cb4a3f6cb648db85ebdeeb8cab74cfae2a578dc7aaedc2 +size 132 diff --git a/examples/Korean.wav b/examples/Korean.wav index 66c9661830a819c99d88eef921a351969150651e..78b1faf981694a3353b842c2d66a912c17eb504f 100644 --- a/examples/Korean.wav +++ b/examples/Korean.wav @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c4234f119c741782e2c9c0ede4b5b864a560a355c28a23b2332e79420b69961a -size 1632522 +oid sha256:5767663f0c26f4dc94f45227f385c2be568aac065272466915d65eaa64fdda0f +size 132 diff --git a/examples/Nice English Ref.wav b/examples/Nice English Ref.wav index 018def3da6421b1f1ec4a57870bb3a2d235efeb4..828534a0c3042d87cfeeb3a81a1adf367a8b45c0 100644 --- a/examples/Nice English Ref.wav +++ b/examples/Nice English Ref.wav @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d00ad9768c62f9821fc01ecab3e02669581ca75c18af6549690e19ce90a09f53 -size 5254482 +oid sha256:4b707de0cfc5d2eee59dcc3fea495603fe28d95ca64d8202bcdb31537d588782 +size 132 diff --git a/examples/Spanish.wav b/examples/Spanish.wav index 31f7fe5875c9dcc81db9c87d98c8c68a4df22382..98c26d0c4173f2f067af00e882915a2a9123e9ed 100644 Binary files a/examples/Spanish.wav and b/examples/Spanish.wav differ diff --git a/fish_speech/configs/base.yaml b/fish_speech/configs/base.yaml index b6bf1c8cf5a79e7ddb58ebc725f2c002c26c9486..99e6dab54d3f57bce4f6d29a9129a19a523cad75 100644 --- a/fish_speech/configs/base.yaml +++ b/fish_speech/configs/base.yaml @@ -1,87 +1,87 @@ -# Base configuration for training a model -paths: - run_dir: results/${project} - ckpt_dir: ${paths.run_dir}/checkpoints - -hydra: - run: - dir: ${paths.run_dir} - -# Lightning Trainer -trainer: - _target_: lightning.pytorch.trainer.Trainer - - default_root_dir: ${paths.run_dir} - accelerator: gpu - num_nodes: 1 - devices: auto - strategy: - _target_: lightning.pytorch.strategies.DDPStrategy - process_group_backend: nccl # This should be override when training on windows - - precision: bf16-mixed - - # disable validation by epoch end - check_val_every_n_epoch: null - val_check_interval: 5000 - max_steps: 100_000 - - # Use torch.backends.cudnn.benchmark to speed up training - benchmark: true - -# Callbacks -callbacks: - model_checkpoint: - _target_: lightning.pytorch.callbacks.ModelCheckpoint - dirpath: ${paths.ckpt_dir} - filename: "step_{step:09d}" - save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt - save_top_k: 5 # save 5 latest checkpoints - monitor: step # use step to monitor checkpoints - mode: max # save the latest checkpoint with the highest global_step - every_n_epochs: null # don't save checkpoints by epoch end - every_n_train_steps: 5000 # save checkpoints every 5000 steps - auto_insert_metric_name: false - - model_summary: - _target_: lightning.pytorch.callbacks.ModelSummary - max_depth: 2 # the maximum depth of layer nesting that the summary will include - - learning_rate_monitor: - _target_: lightning.pytorch.callbacks.LearningRateMonitor - logging_interval: step - log_momentum: false - - grad_norm_monitor: - _target_: fish_speech.callbacks.GradNormMonitor - norm_type: 2 - logging_interval: step - -# Logger -logger: - tensorboard: - _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger - save_dir: "${paths.run_dir}/tensorboard/" - name: null - log_graph: false - default_hp_metric: true - prefix: "" - - # wandb: - # _target_: lightning.pytorch.loggers.wandb.WandbLogger - # # name: "" # name of the run (normally generated by wandb) - # save_dir: "${paths.run_dir}" - # offline: False - # id: null # pass correct id to resume experiment! - # anonymous: null # enable anonymous logging - # project: "fish-speech" - # log_model: False # upload lightning ckpts - # prefix: "" # a string to put at the beginning of metric keys - # # entity: "" # set to name of your wandb team - # group: "" - # tags: ["vq", "hq", "finetune"] - # job_type: "" - -# Loop -train: true -test: false +# Base configuration for training a model +paths: + run_dir: results/${project} + ckpt_dir: ${paths.run_dir}/checkpoints + +hydra: + run: + dir: ${paths.run_dir} + +# Lightning Trainer +trainer: + _target_: lightning.pytorch.trainer.Trainer + + default_root_dir: ${paths.run_dir} + accelerator: gpu + num_nodes: 1 + devices: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + process_group_backend: nccl # This should be override when training on windows + + precision: bf16-mixed + + # disable validation by epoch end + check_val_every_n_epoch: null + val_check_interval: 5000 + max_steps: 100_000 + + # Use torch.backends.cudnn.benchmark to speed up training + benchmark: true + +# Callbacks +callbacks: + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${paths.ckpt_dir} + filename: "step_{step:09d}" + save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 5 # save 5 latest checkpoints + monitor: step # use step to monitor checkpoints + mode: max # save the latest checkpoint with the highest global_step + every_n_epochs: null # don't save checkpoints by epoch end + every_n_train_steps: 5000 # save checkpoints every 5000 steps + auto_insert_metric_name: false + + model_summary: + _target_: lightning.pytorch.callbacks.ModelSummary + max_depth: 2 # the maximum depth of layer nesting that the summary will include + + learning_rate_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: step + log_momentum: false + + grad_norm_monitor: + _target_: fish_speech.callbacks.GradNormMonitor + norm_type: 2 + logging_interval: step + +# Logger +logger: + tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.run_dir}/tensorboard/" + name: null + log_graph: false + default_hp_metric: true + prefix: "" + + # wandb: + # _target_: lightning.pytorch.loggers.wandb.WandbLogger + # # name: "" # name of the run (normally generated by wandb) + # save_dir: "${paths.run_dir}" + # offline: False + # id: null # pass correct id to resume experiment! + # anonymous: null # enable anonymous logging + # project: "fish-speech" + # log_model: False # upload lightning ckpts + # prefix: "" # a string to put at the beginning of metric keys + # # entity: "" # set to name of your wandb team + # group: "" + # tags: ["vq", "hq", "finetune"] + # job_type: "" + +# Loop +train: true +test: false diff --git a/fish_speech/configs/lora/r_8_alpha_16.yaml b/fish_speech/configs/lora/r_8_alpha_16.yaml index 0bb13622fc9dcf3ea59fd9db215536946fb346fa..aecc4d9766a18fe31c55941e01b1f590c95e77c9 100644 --- a/fish_speech/configs/lora/r_8_alpha_16.yaml +++ b/fish_speech/configs/lora/r_8_alpha_16.yaml @@ -1,4 +1,4 @@ -_target_: fish_speech.models.text2semantic.lora.LoraConfig -r: 8 -lora_alpha: 16 -lora_dropout: 0.01 +_target_: fish_speech.models.text2semantic.lora.LoraConfig +r: 8 +lora_alpha: 16 +lora_dropout: 0.01 diff --git a/fish_speech/configs/modded_dac_vq.yaml b/fish_speech/configs/modded_dac_vq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..68a926deff44228fe757a22bc27f4f5b76f34408 --- /dev/null +++ b/fish_speech/configs/modded_dac_vq.yaml @@ -0,0 +1,50 @@ +_target_: fish_speech.models.dac.modded_dac.DAC +# Model setup +sample_rate: 44100 +encoder_dim: 64 +encoder_rates: [2, 4, 8, 8] +decoder_dim: 1536 +decoder_rates: [8, 8, 4, 2] +encoder_transformer_layers: [0, 0, 0, 4] +decoder_transformer_layers: [4, 0, 0, 0] +transformer_general_config: + _target_: fish_speech.models.dac.modded_dac.ModelArgs + _partial_: true + block_size: 16384 + n_local_heads: -1 + head_dim: 64 + rope_base: 10000 + norm_eps: 1e-5 + dropout_rate: 0.1 + attn_dropout_rate: 0.1 + channels_first: true +# Quantization +quantizer: + _target_: fish_speech.models.dac.rvq.DownsampleResidualVectorQuantize + input_dim: 1024 + n_codebooks: 9 + codebook_size: 1024 + codebook_dim: 8 + quantizer_dropout: 0.5 + downsample_factor: [2, 2] + post_module: &transformer_module + _target_: fish_speech.models.dac.modded_dac.WindowLimitedTransformer + causal: true + window_size: 128 # empirically this does not seem to matter + input_dim: 1024 + config: &transformer_config + _target_: fish_speech.models.dac.modded_dac.ModelArgs + block_size: 4096 + n_layer: 8 + n_head: 16 + dim: 1024 + intermediate_size: 3072 + n_local_heads: -1 + head_dim: 64 + rope_base: 10000 + norm_eps: 1e-5 + dropout_rate: 0.1 + attn_dropout_rate: 0.1 + channels_first: true + pre_module: *transformer_module + semantic_codebook_size: 4096 diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml index 8c999411703e85c7ad5972134f7a33f42b279571..00f69051c05cbe428c8ef51fa8c467f7fc708bef 100644 --- a/fish_speech/configs/text2semantic_finetune.yaml +++ b/fish_speech/configs/text2semantic_finetune.yaml @@ -1,83 +1,86 @@ -defaults: - - base - - _self_ - -project: text2semantic_finetune_dual_ar -max_length: 4096 -pretrained_ckpt_path: checkpoints/fish-speech-1.4 - -# Lightning Trainer -trainer: - accumulate_grad_batches: 1 - gradient_clip_val: 1.0 - gradient_clip_algorithm: "norm" - max_steps: 1000 - precision: bf16-true - limit_val_batches: 10 - val_check_interval: 100 - -# Dataset Configuration -tokenizer: - _target_: transformers.AutoTokenizer.from_pretrained - pretrained_model_name_or_path: ${pretrained_ckpt_path} - -# Dataset Configuration -train_dataset: - _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset - proto_files: - - data/protos - tokenizer: ${tokenizer} - causal: true - max_length: ${max_length} - use_speaker: false - interactive_prob: 0.7 - -val_dataset: - _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset - proto_files: - - data/protos - tokenizer: ${tokenizer} - causal: true - max_length: ${max_length} - use_speaker: false - interactive_prob: 0.7 - -data: - _target_: fish_speech.datasets.semantic.SemanticDataModule - train_dataset: ${train_dataset} - val_dataset: ${val_dataset} - num_workers: 4 - batch_size: 8 - tokenizer: ${tokenizer} - max_length: ${max_length} - -# Model Configuration -model: - _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic - model: - _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained - path: ${pretrained_ckpt_path} - load_weights: true - max_length: ${max_length} - lora_config: null - - optimizer: - _target_: torch.optim.AdamW - _partial_: true - lr: 1e-4 - weight_decay: 0 - betas: [0.9, 0.95] - eps: 1e-5 - - lr_scheduler: - _target_: torch.optim.lr_scheduler.LambdaLR - _partial_: true - lr_lambda: - _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda - _partial_: true - num_warmup_steps: 10 - -# Callbacks -callbacks: - model_checkpoint: - every_n_train_steps: ${trainer.val_check_interval} +defaults: + - base + - _self_ + +project: text2semantic_finetune_dual_ar +max_length: 4096 +pretrained_ckpt_path: checkpoints/openaudio-s1-mini + +# Lightning Trainer +trainer: + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + gradient_clip_algorithm: "norm" + max_steps: 10000 + precision: bf16-true + limit_val_batches: 10 + val_check_interval: 100 + # strategy: + # find_unused_parameters: true + # static_graph: true + +# Dataset Configuration +tokenizer: + _target_: fish_speech.tokenizer.FishTokenizer + model_path: ${pretrained_ckpt_path}/tokenizer.tiktoken + +# Dataset Configuration +train_dataset: + _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset + proto_files: + - data/protos + tokenizer: ${tokenizer} + causal: true + max_length: ${max_length} + use_speaker: false + interactive_prob: 0.7 + +val_dataset: + _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset + proto_files: + - data/protos + tokenizer: ${tokenizer} + causal: true + max_length: ${max_length} + use_speaker: false + interactive_prob: 0.7 + +data: + _target_: fish_speech.datasets.semantic.SemanticDataModule + train_dataset: ${train_dataset} + val_dataset: ${val_dataset} + num_workers: 4 + batch_size: 4 + tokenizer: ${tokenizer} + max_length: ${max_length} + +# Model Configuration +model: + _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic + model: + _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained + path: ${pretrained_ckpt_path} + load_weights: true + max_length: ${max_length} + lora_config: null + + optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 1e-4 + weight_decay: 0 + betas: [0.9, 0.95] + eps: 1e-5 + + lr_scheduler: + _target_: torch.optim.lr_scheduler.LambdaLR + _partial_: true + lr_lambda: + _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda + _partial_: true + num_warmup_steps: 10 + +# Callbacks +callbacks: + model_checkpoint: + every_n_train_steps: ${trainer.val_check_interval} diff --git a/fish_speech/content_sequence.py b/fish_speech/content_sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..f5885c65e5e103d2df0f5b9c27adefa5cdb637c2 --- /dev/null +++ b/fish_speech/content_sequence.py @@ -0,0 +1,367 @@ +from dataclasses import dataclass, field +from typing import List, Literal, Union + +import numpy as np +import torch + +from fish_speech.tokenizer import ( + IM_END_TOKEN, + MODALITY_TOKENS, + FishTokenizer, +) + + +def restore_ndarray(obj, to_tensor: bool = False): + if isinstance(obj, dict) and "__ndarray__" in obj: + obj = np.frombuffer(obj["data"], dtype=obj["dtype"]).reshape(obj["shape"]) + + if to_tensor and isinstance(obj, np.ndarray): + obj = torch.from_numpy(obj.copy()) + + return obj + + +@dataclass +class BasePart: + type: Literal["text", "vq", "audio"] | None = None + cal_loss: bool = False + + +@dataclass(kw_only=True) +class VQPart(BasePart): + type = "vq" + codes: torch.Tensor + + def __post_init__(self: "VQPart"): + self.type = "vq" + self.codes = restore_ndarray(self.codes, to_tensor=True) + + +@dataclass(kw_only=True) +class TextPart(BasePart): + type = "text" + text: str | None = None + tokens: list[int] | None = None + + def __post_init__(self: "TextPart"): + self.type = "text" + if self.text is None and self.tokens is None: + raise ValueError("Either text or tokens must be provided") + + +@dataclass(kw_only=True) +class AudioPart(BasePart): + type = "audio" + features: torch.Tensor + + def __post_init__(self: "AudioPart"): + self.type = "audio" + self.features = restore_ndarray(self.features, to_tensor=True) + + +@dataclass(kw_only=True) +class EncodedMessage: + tokens: torch.Tensor + labels: torch.Tensor + vq_mask_tokens: torch.Tensor | None = None + vq_mask_labels: torch.Tensor | None = None + vq_parts: list[torch.Tensor] + vq_require_losses: torch.Tensor | None = None + audio_parts: list[torch.Tensor] + audio_masks: torch.Tensor | None = None + metadata: dict | None = None + + +@dataclass +class ContentSequence: + """ + Flexible sequence of content parts that supports interleaved multimodal format. + Example format: <|interleave|><|speaker:1|> TEXT AUDIO <|im_end|><|speaker:2|> TEXT AUDIO <|im_end|> + """ + + parts: list[BasePart] = field(default_factory=list) + modality: Literal["text", "voice", "interleave"] | None = None + metadata: dict | None = None + + def __init__( + self: "ContentSequence", + parts: list[BasePart | dict] | None = None, + modality: Literal["text", "voice", "interleave"] | None = None, + metadata: dict | None = None, + ): + self.modality = modality + self.metadata = metadata or {} + + fixed_parts = [] + for part in parts or []: + if isinstance(part, dict): + if part["type"] == "vq": + part = VQPart(**part) + elif part["type"] == "audio": + part = AudioPart(**part) + elif part["type"] == "text": + part = TextPart(**part) + else: + raise ValueError(f"Unsupported part type: {part['type']}") + fixed_parts.append(part) + + self.parts = fixed_parts + + # If modality is specified, add it at the beginning if it's not already there + if self.modality and not ( + len(self.parts) > 0 + and isinstance(self.parts[0], dict) is False + and isinstance(self.parts[0], TextPart) + and self.parts[0].text is not None + and self.parts[0].text.startswith(MODALITY_TOKENS[self.modality]) + ): + modality_token = MODALITY_TOKENS[self.modality] + self.parts.insert(0, TextPart(text=modality_token)) + + def append( + self: "ContentSequence", + part_or_parts: Union[BasePart, List[BasePart]], + add_end: bool = False, + speaker: Union[str, int] | None = None, + ): + """ + Append a part or list of parts to the sequence. + + Args: + part_or_parts: A single part or list of parts to add + add_end: Whether to add the IM_END_TOKEN after these parts + speaker: Optional speaker identifier (name or ID) to add before the parts + """ + # Convert single part to list + parts_to_add = ( + [part_or_parts] if not isinstance(part_or_parts, list) else part_or_parts + ) + + # Add speaker token if specified + if speaker is not None: + speaker_token = f"<|speaker:{speaker}|>" + self.parts.append(TextPart(text=speaker_token)) + + # Add all the parts + self.parts.extend(parts_to_add) + + # Add end token if requested + if add_end: + self.parts.append( + TextPart(text=IM_END_TOKEN, cal_loss=self.parts[-1].cal_loss) + ) + + def encode( + self: "ContentSequence", + tokenizer: FishTokenizer, + add_shift: bool = True, + ignore_loss_tokens: list[str] = [], + ) -> EncodedMessage: + """ + Encode the sequence parts into tokens for the model. + + Args: + tokenizer: The tokenizer to use + add_shift: Whether to shift tokens for next-token prediction + ignore_loss_tokens: List of token strings to ignore when calculating loss + + Returns: + EncodedMessage with tensors ready for the model + """ + all_tokens = [] + all_labels = [] + + # Multi-modal elements + vq_parts = [] + vq_masks = [] + vq_require_losses = [] + + audio_parts = [] + audio_masks = [] + + ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens] + + for part in self.parts: + if isinstance(part, TextPart): + if part.tokens is None: + assert part.text is not None + tokens = tokenizer.encode(part.text) + else: + tokens = part.tokens + + tokens = torch.tensor(tokens, dtype=torch.int) + elif isinstance(part, VQPart): + curr_codes = part.codes.clone().to(torch.int) + tokens = torch.tensor( + [ + tokenizer.semantic_id_to_token_id[int(i.item())] + for i in curr_codes[0].int() + ], + dtype=torch.int, + ) + vq_parts.append(curr_codes) + vq_require_losses.append(part.cal_loss) + else: + raise ValueError(f"Unsupported part type: {type(part)}") + + all_tokens.append(tokens) + + # Set masks for different part types + if isinstance(part, VQPart): + vq_masks.append(torch.ones_like(tokens, dtype=torch.bool)) + audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool)) + elif isinstance(part, AudioPart): + vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool)) + audio_mask = torch.ones_like(tokens, dtype=torch.bool) + audio_mask[0] = False # Skip start token + audio_mask[-1] = False # Skip end token + audio_masks.append(audio_mask) + else: + vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool)) + audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool)) + + # Set labels based on whether we want to calculate loss for this part + if part.cal_loss and not isinstance(part, AudioPart): + all_labels.append(tokens.clone()) + else: + all_labels.append(torch.full_like(tokens, -100)) + + # Concatenate all tensors + tokens = torch.cat(all_tokens, dim=0) + labels = torch.cat(all_labels, dim=0) + vq_masks = torch.cat(vq_masks, dim=0) + audio_masks = torch.cat(audio_masks, dim=0) + vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool) + + # Apply shift if needed for next-token prediction + vq_mask_tokens = vq_masks + vq_mask_labels = vq_masks + + if add_shift: + tokens = tokens[:-1] + labels = labels[1:] + vq_masks = vq_masks[:-1] + vq_mask_tokens = vq_mask_tokens[:-1] + vq_mask_labels = vq_mask_labels[1:] + audio_masks = audio_masks[:-1] + + # Ignore specified tokens + for i in ignore_loss_token_ids: + assert i != -100 and i is not None + labels[labels == i] = -100 + + assert tokens.dtype in [ + torch.int, + torch.long, + ], f"Invalid dtype: {tokens.dtype}" + + return EncodedMessage( + tokens=tokens, + labels=labels, + vq_parts=vq_parts, + vq_mask_tokens=vq_mask_tokens, + vq_mask_labels=vq_mask_labels, + vq_require_losses=vq_require_losses, + audio_parts=audio_parts, + audio_masks=audio_masks, + metadata=self.metadata, + ) + + def encode_for_inference( + self: "ContentSequence", + tokenizer: FishTokenizer, + num_codebooks: int, + ) -> torch.Tensor: + encoded = self.encode(tokenizer, add_shift=False) + tokens = encoded.tokens + values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int) + values[0] = tokens + + if (encoded.vq_parts is None or len(encoded.vq_parts) == 0) and ( + encoded.audio_parts is None or len(encoded.audio_parts) == 0 + ): + return values + + if encoded.vq_parts is not None and len(encoded.vq_parts) > 0: + vq_parts = encoded.vq_parts + vq_parts = torch.cat(vq_parts, dim=1) + values[0, encoded.vq_mask_tokens] = ( + vq_parts[0] + tokenizer.semantic_begin_id + ) + values[1:, encoded.vq_mask_tokens] = vq_parts + + return values + + def visualize( + self: "ContentSequence", + tokenizer: FishTokenizer, + ignore_loss_tokens: list[str] = [], + merge_semantic_tokens: bool = False, + ): + """ + Visualize the encoded sequence with color-coded tokens. + Blue/cyan tokens contribute to loss, green tokens do not. + """ + encoded = self.encode( + tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens + ) + + # Colors for alternating tokens + colors = { + "blue": "\033[94m", # Light blue + "cyan": "\033[96m", # Cyan + "green": "\033[92m", # Light green + "dark_green": "\033[32m", # Dark green + } + blue_idx = 0 + green_idx = 0 + + def print_in_blue(x): + nonlocal blue_idx + color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"] + print(f"{color}{x}\033[0m", end="") + blue_idx += 1 + + def print_in_green(x): + nonlocal green_idx + color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"] + print(f"{color}{x}\033[0m", end="") + green_idx += 1 + + def print_semantic_token(x, count): + val = f"[<|semantic|>x{count}]" + if x == -100: + print_in_green(val) + else: + print_in_blue(val) + + count_semantic_tokens = 0 + semantic_label = None + + for tok, lab in zip(encoded.tokens, encoded.labels): + token_id = int(tok.item()) + + if merge_semantic_tokens: + if ( + tokenizer.semantic_begin_id <= token_id <= tokenizer.semantic_end_id + and (semantic_label is None or semantic_label == lab) + ): + count_semantic_tokens += 1 + semantic_label = lab + continue + elif count_semantic_tokens > 0: + print_semantic_token(semantic_label, count_semantic_tokens) + count_semantic_tokens = 0 + semantic_label = None + + val = tokenizer.decode([int(tok.item())]) + + if lab == -100: + print_in_green(val) + else: + print_in_blue(val) + + if merge_semantic_tokens and count_semantic_tokens > 0: + print_semantic_token(semantic_label, count_semantic_tokens) + + print() diff --git a/fish_speech/i18n/README.md b/fish_speech/i18n/README.md index 7d4612883379a16fd2d0945c431d9fb8b04b249a..700902b09db20911ef1ad678cbdce5644b84aea2 100644 --- a/fish_speech/i18n/README.md +++ b/fish_speech/i18n/README.md @@ -1,27 +1,27 @@ -## i18n Folder Attribution - -The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below: - -### fish_speech/i18n/core.py - -**Related code from RVC:** -[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py) - -**Initial commit:** -add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35) - -**Initial author:** -[@L4Ph](https://github.com/L4Ph) - -### fish_speech/i18n/scan.py - -**Related code from RVC:** -[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py) - -**Initial commit:** -File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058) - -**Initial author:** -[@towzeur](https://github.com/towzeur) - -We appreciate the contributions of the RVC project and its authors. +## i18n Folder Attribution + +The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below: + +### fish_speech/i18n/core.py + +**Related code from RVC:** +[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py) + +**Initial commit:** +add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35) + +**Initial author:** +[@L4Ph](https://github.com/L4Ph) + +### fish_speech/i18n/scan.py + +**Related code from RVC:** +[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py) + +**Initial commit:** +File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058) + +**Initial author:** +[@towzeur](https://github.com/towzeur) + +We appreciate the contributions of the RVC project and its authors. diff --git a/fish_speech/i18n/__init__.py b/fish_speech/i18n/__init__.py index 0ac9702a707223257997e283ebc259b78996ad5c..981dbb3b3ecf28043ec9ff5757f947182821a246 100644 --- a/fish_speech/i18n/__init__.py +++ b/fish_speech/i18n/__init__.py @@ -1,3 +1,3 @@ -from .core import i18n - -__all__ = ["i18n"] +from .core import i18n + +__all__ = ["i18n"] diff --git a/fish_speech/i18n/core.py b/fish_speech/i18n/core.py index 8375d9ddb4e3c2b3ec25c426d2786f2a7506a0ab..9f793ec95669228f7f4e8f9a7a5fe38da85c74bd 100644 --- a/fish_speech/i18n/core.py +++ b/fish_speech/i18n/core.py @@ -1,40 +1,40 @@ -import json -import locale -from pathlib import Path - -I18N_FILE_PATH = Path(__file__).parent / "locale" -DEFAULT_LANGUAGE = "en_US" - - -def load_language_list(language): - with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f: - language_list = json.load(f) - - return language_list - - -class I18nAuto: - def __init__(self): - i18n_file = Path(".locale") - - if i18n_file.exists(): - with open(i18n_file, "r", encoding="utf-8") as f: - language = f.read().strip() - else: - # getlocale can't identify the system's language ((None, None)) - language = locale.getdefaultlocale()[0] - - if (I18N_FILE_PATH / f"{language}.json").exists() is False: - language = DEFAULT_LANGUAGE - - self.language = language - self.language_map = load_language_list(language) - - def __call__(self, key): - return self.language_map.get(key, key) - - def __repr__(self): - return "Use Language: " + self.language - - -i18n = I18nAuto() +import json +import locale +from pathlib import Path + +I18N_FILE_PATH = Path(__file__).parent / "locale" +DEFAULT_LANGUAGE = "en_US" + + +def load_language_list(language): + with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f: + language_list = json.load(f) + + return language_list + + +class I18nAuto: + def __init__(self): + i18n_file = Path(".locale") + + if i18n_file.exists(): + with open(i18n_file, "r", encoding="utf-8") as f: + language = f.read().strip() + else: + # getlocale can't identify the system's language ((None, None)) + language = locale.getdefaultlocale()[0] + + if (I18N_FILE_PATH / f"{language}.json").exists() is False: + language = DEFAULT_LANGUAGE + + self.language = language + self.language_map = load_language_list(language) + + def __call__(self, key): + return self.language_map.get(key, key) + + def __repr__(self): + return "Use Language: " + self.language + + +i18n = I18nAuto() diff --git a/fish_speech/i18n/locale/en_US.json b/fish_speech/i18n/locale/en_US.json index 32d58983a5df52d032411fd50ef1a9e3ecdeb859..d36c774313628fe9d4ee60e816f404c09935e655 100644 --- a/fish_speech/i18n/locale/en_US.json +++ b/fish_speech/i18n/locale/en_US.json @@ -1,123 +1,123 @@ -{ - "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU", - "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).", - "Accumulate Gradient Batches": "Accumulate Gradient Batches", - "Add to Processing Area": "Add to Processing Area", - "Added path successfully!": "Added path successfully!", - "Advanced Config": "Advanced Config", - "Base LLAMA Model": "Base LLAMA Model", - "Batch Inference": "Batch Inference", - "Batch Size": "Batch Size", - "Changing with the Model Path": "Changing with the Model Path", - "Chinese": "Chinese", - "Compile Model": "Compile Model", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time", - "Copy": "Copy", - "Data Preprocessing": "Data Preprocessing", - "Data Preprocessing Path": "Data Preprocessing Path", - "Data Source": "Data Source", - "Decoder Model Config": "Decoder Model Config", - "Decoder Model Path": "Decoder Model Path", - "Disabled": "Disabled", - "Enable Reference Audio": "Enable Reference Audio", - "English": "English", - "Error Message": "Error Message", - "File Preprocessing": "File Preprocessing", - "Generate": "Generate", - "Generated Audio": "Generated Audio", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format", - "Infer interface is closed": "Infer interface is closed", - "Inference Configuration": "Inference Configuration", - "Inference Server Configuration": "Inference Server Configuration", - "Inference Server Error": "Inference Server Error", - "Inferring interface is launched at {}": "Inferring interface is launched at {}", - "Initial Learning Rate": "Initial Learning Rate", - "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription", - "Input Text": "Input Text", - "Invalid path: {}": "Invalid path: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU", - "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off", - "Japanese": "Japanese", - "LLAMA Configuration": "LLAMA Configuration", - "LLAMA Model Config": "LLAMA Model Config", - "LLAMA Model Path": "LLAMA Model Path", - "Labeling Device": "Labeling Device", - "LoRA Model to be merged": "LoRA Model to be merged", - "Maximum Audio Duration": "Maximum Audio Duration", - "Maximum Length per Sample": "Maximum Length per Sample", - "Maximum Training Steps": "Maximum Training Steps", - "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit", - "Merge": "Merge", - "Merge LoRA": "Merge LoRA", - "Merge successfully": "Merge successfully", - "Minimum Audio Duration": "Minimum Audio Duration", - "Model Output Path": "Model Output Path", - "Model Size": "Model Size", - "Move": "Move", - "Move files successfully": "Move files successfully", - "No audio generated, please check the input text.": "No audio generated, please check the input text.", - "No selected options": "No selected options", - "Number of Workers": "Number of Workers", - "Open Inference Server": "Open Inference Server", - "Open Labeler WebUI": "Open Labeler WebUI", - "Open Tensorboard": "Open Tensorboard", - "Opened labeler in browser": "Opened labeler in browser", - "Optional Label Language": "Optional Label Language", - "Optional online ver": "Optional online ver", - "Output Path": "Output Path", - "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path", - "Precision": "Precision", - "Probability of applying Speaker Condition": "Probability of applying Speaker Condition", - "Put your text here.": "Put your text here.", - "Reference Audio": "Reference Audio", - "Reference Text": "Reference Text", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.", - "Remove Selected Data": "Remove Selected Data", - "Removed path successfully!": "Removed path successfully!", - "Repetition Penalty": "Repetition Penalty", - "Save model every n steps": "Save model every n steps", - "Select LLAMA ckpt": "Select LLAMA ckpt", - "Select VITS ckpt": "Select VITS ckpt", - "Select VQGAN ckpt": "Select VQGAN ckpt", - "Select source file processing method": "Select source file processing method", - "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)", - "Selected: {}": "Selected: {}", - "Speaker": "Speaker", - "Speaker is identified by the folder name": "Speaker is identified by the folder name", - "Start Training": "Start Training", - "Streaming Audio": "Streaming Audio", - "Streaming Generate": "Streaming Generate", - "Tensorboard Host": "Tensorboard Host", - "Tensorboard Log Path": "Tensorboard Log Path", - "Tensorboard Port": "Tensorboard Port", - "Tensorboard interface is closed": "Tensorboard interface is closed", - "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}", - "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.", - "Training Configuration": "Training Configuration", - "Training Error": "Training Error", - "Training stopped": "Training stopped", - "Type name of the speaker": "Type name of the speaker", - "Type the path or select from the dropdown": "Type the path or select from the dropdown", - "Use LoRA": "Use LoRA", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model", - "Use filelist": "Use filelist", - "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G", - "VITS Configuration": "VITS Configuration", - "VQGAN Configuration": "VQGAN Configuration", - "Validation Batch Size": "Validation Batch Size", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.", - "WebUI Host": "WebUI Host", - "WebUI Port": "WebUI Port", - "Whisper Model": "Whisper Model", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU", - "latest": "latest", - "new": "new", - "Realtime Transform Text": "Realtime Transform Text", - "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)", - "Text Normalization": "Text Normalization", - "Select Example Audio": "Select Example Audio" -} +{ + "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Accumulate Gradient Batches", + "Add to Processing Area": "Add to Processing Area", + "Added path successfully!": "Added path successfully!", + "Advanced Config": "Advanced Config", + "Base LLAMA Model": "Base LLAMA Model", + "Batch Inference": "Batch Inference", + "Batch Size": "Batch Size", + "Changing with the Model Path": "Changing with the Model Path", + "Chinese": "Chinese", + "Compile Model": "Compile Model", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time", + "Copy": "Copy", + "Data Preprocessing": "Data Preprocessing", + "Data Preprocessing Path": "Data Preprocessing Path", + "Data Source": "Data Source", + "Decoder Model Config": "Decoder Model Config", + "Decoder Model Path": "Decoder Model Path", + "Disabled": "Disabled", + "Enable Reference Audio": "Enable Reference Audio", + "English": "English", + "Error Message": "Error Message", + "File Preprocessing": "File Preprocessing", + "Generate": "Generate", + "Generated Audio": "Generated Audio", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format", + "Infer interface is closed": "Infer interface is closed", + "Inference Configuration": "Inference Configuration", + "Inference Server Configuration": "Inference Server Configuration", + "Inference Server Error": "Inference Server Error", + "Inferring interface is launched at {}": "Inferring interface is launched at {}", + "Initial Learning Rate": "Initial Learning Rate", + "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription", + "Input Text": "Input Text", + "Invalid path: {}": "Invalid path: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU", + "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off", + "Japanese": "Japanese", + "LLAMA Configuration": "LLAMA Configuration", + "LLAMA Model Config": "LLAMA Model Config", + "LLAMA Model Path": "LLAMA Model Path", + "Labeling Device": "Labeling Device", + "LoRA Model to be merged": "LoRA Model to be merged", + "Maximum Audio Duration": "Maximum Audio Duration", + "Maximum Length per Sample": "Maximum Length per Sample", + "Maximum Training Steps": "Maximum Training Steps", + "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit", + "Merge": "Merge", + "Merge LoRA": "Merge LoRA", + "Merge successfully": "Merge successfully", + "Minimum Audio Duration": "Minimum Audio Duration", + "Model Output Path": "Model Output Path", + "Model Size": "Model Size", + "Move": "Move", + "Move files successfully": "Move files successfully", + "No audio generated, please check the input text.": "No audio generated, please check the input text.", + "No selected options": "No selected options", + "Number of Workers": "Number of Workers", + "Open Inference Server": "Open Inference Server", + "Open Labeler WebUI": "Open Labeler WebUI", + "Open Tensorboard": "Open Tensorboard", + "Opened labeler in browser": "Opened labeler in browser", + "Optional Label Language": "Optional Label Language", + "Optional online ver": "Optional online ver", + "Output Path": "Output Path", + "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path", + "Precision": "Precision", + "Probability of applying Speaker Condition": "Probability of applying Speaker Condition", + "Put your text here.": "Put your text here.", + "Reference Audio": "Reference Audio", + "Reference Text": "Reference Text", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.", + "Remove Selected Data": "Remove Selected Data", + "Removed path successfully!": "Removed path successfully!", + "Repetition Penalty": "Repetition Penalty", + "Save model every n steps": "Save model every n steps", + "Select LLAMA ckpt": "Select LLAMA ckpt", + "Select VITS ckpt": "Select VITS ckpt", + "Select VQGAN ckpt": "Select VQGAN ckpt", + "Select source file processing method": "Select source file processing method", + "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)", + "Selected: {}": "Selected: {}", + "Speaker": "Speaker", + "Speaker is identified by the folder name": "Speaker is identified by the folder name", + "Start Training": "Start Training", + "Streaming Audio": "Streaming Audio", + "Streaming Generate": "Streaming Generate", + "Tensorboard Host": "Tensorboard Host", + "Tensorboard Log Path": "Tensorboard Log Path", + "Tensorboard Port": "Tensorboard Port", + "Tensorboard interface is closed": "Tensorboard interface is closed", + "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}", + "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.", + "Training Configuration": "Training Configuration", + "Training Error": "Training Error", + "Training stopped": "Training stopped", + "Type name of the speaker": "Type name of the speaker", + "Type the path or select from the dropdown": "Type the path or select from the dropdown", + "Use LoRA": "Use LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model", + "Use filelist": "Use filelist", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G", + "VITS Configuration": "VITS Configuration", + "VQGAN Configuration": "VQGAN Configuration", + "Validation Batch Size": "Validation Batch Size", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.", + "WebUI Host": "WebUI Host", + "WebUI Port": "WebUI Port", + "Whisper Model": "Whisper Model", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU", + "latest": "latest", + "new": "new", + "Realtime Transform Text": "Realtime Transform Text", + "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)", + "Text Normalization": "Text Normalization", + "Select Example Audio": "Select Example Audio" +} diff --git a/fish_speech/i18n/locale/es_ES.json b/fish_speech/i18n/locale/es_ES.json index 0bde8404cdaff94e7c49be580cbba99b8f41ce29..7a4757967dd0fe3807ba4d354e75ad7a88eb510e 100644 --- a/fish_speech/i18n/locale/es_ES.json +++ b/fish_speech/i18n/locale/es_ES.json @@ -1,123 +1,123 @@ -{ - "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+", - "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).", - "Accumulate Gradient Batches": "Acumular lotes de gradientes", - "Add to Processing Area": "Agregar al Área de Procesamiento", - "Added path successfully!": "¡Ruta agregada exitosamente!", - "Advanced Config": "Configuración Avanzada", - "Base LLAMA Model": "Modelo Base LLAMA", - "Batch Inference": "Inferencia por Lote", - "Batch Size": "Tamaño del Lote", - "Changing with the Model Path": "Cambiando con la Ruta del Modelo", - "Chinese": "Chino", - "Compile Model": "Compilar Modelo", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío", - "Copy": "Copiar", - "Data Preprocessing": "Preprocesamiento de Datos", - "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos", - "Data Source": "Fuente de Datos", - "Decoder Model Config": "Configuración del modelo decodificador", - "Decoder Model Path": "Ruta del modelo decodificador", - "Disabled": "Desactivado", - "Enable Reference Audio": "Habilitar Audio de Referencia", - "English": "Inglés", - "Error Message": "Mensaje de Error", - "File Preprocessing": "Preprocesamiento de Archivos", - "Generate": "Generar", - "Generated Audio": "Audio Generado", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab", - "Infer interface is closed": "La interfaz de inferencia está cerrada", - "Inference Configuration": "Configuración de Inferencia", - "Inference Server Configuration": "Configuración del Servidor de Inferencia", - "Inference Server Error": "Error del Servidor de Inferencia", - "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}", - "Initial Learning Rate": "Tasa de Aprendizaje Inicial", - "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción", - "Input Text": "Texto de Entrada", - "Invalid path: {}": "Ruta inválida: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU", - "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado", - "Japanese": "Japonés", - "LLAMA Configuration": "Configuración de LLAMA", - "LLAMA Model Config": "Configuración del Modelo LLAMA", - "LLAMA Model Path": "Ruta del Modelo LLAMA", - "Labeling Device": "Dispositivo de Etiquetado", - "LoRA Model to be merged": "Modelo LoRA a fusionar", - "Maximum Audio Duration": "Duración máxima de audio", - "Maximum Length per Sample": "Longitud Máxima por Muestra", - "Maximum Training Steps": "Pasos Máximos de Entrenamiento", - "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite", - "Merge": "Fusionar", - "Merge LoRA": "Fusionar LoRA", - "Merge successfully": "Fusionado exitosamente", - "Minimum Audio Duration": "Duración mínima de audio", - "Model Output Path": "Ruta de Salida del Modelo", - "Model Size": "Tamaño del Modelo", - "Move": "Mover", - "Move files successfully": "Archivos movidos exitosamente", - "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.", - "No selected options": "No hay opciones seleccionadas", - "Number of Workers": "Número de Trabajadores", - "Open Inference Server": "Abrir Servidor de Inferencia", - "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador", - "Open Tensorboard": "Abrir Tensorboard", - "Opened labeler in browser": "Se abrió el etiquetador en el navegador", - "Optional Label Language": "Idioma de Etiquetado Opcional", - "Optional online ver": "Ver en línea opcional", - "Output Path": "Ruta de Salida", - "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente", - "Precision": "Precisión", - "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante", - "Put your text here.": "Ponga su texto aquí.", - "Reference Audio": "Audio de Referencia", - "Reference Text": "Texto de Referencia", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.", - "Remove Selected Data": "Eliminar Datos Seleccionados", - "Removed path successfully!": "¡Ruta eliminada exitosamente!", - "Repetition Penalty": "Penalización por Repetición", - "Save model every n steps": "Guardar modelo cada n pasos", - "Select LLAMA ckpt": "Seleccionar punto de control LLAMA", - "Select VITS ckpt": "Seleccionar punto de control VITS", - "Select VQGAN ckpt": "Seleccionar punto de control VQGAN", - "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente", - "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)", - "Selected: {}": "Seleccionado: {}", - "Speaker": "Hablante", - "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta", - "Start Training": "Iniciar Entrenamiento", - "Streaming Audio": "transmisión de audio", - "Streaming Generate": "síntesis en flujo", - "Tensorboard Host": "Host de Tensorboard", - "Tensorboard Log Path": "Ruta de Registro de Tensorboard", - "Tensorboard Port": "Puerto de Tensorboard", - "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada", - "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}", - "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.", - "Training Configuration": "Configuración de Entrenamiento", - "Training Error": "Error de Entrenamiento", - "Training stopped": "Entrenamiento detenido", - "Type name of the speaker": "Escriba el nombre del hablante", - "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable", - "Use LoRA": "Usar LoRA", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo", - "Use filelist": "Usar lista de archivos", - "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G", - "VITS Configuration": "Configuración de VITS", - "VQGAN Configuration": "Configuración de VQGAN", - "Validation Batch Size": "Tamaño del Lote de Validación", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.", - "WebUI Host": "Host de WebUI", - "WebUI Port": "Puerto de WebUI", - "Whisper Model": "Modelo Whisper", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+", - "latest": "más reciente", - "new": "nuevo", - "Realtime Transform Text": "Transformación de Texto en Tiempo Real", - "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)", - "Text Normalization": "Normalización de Texto", - "Select Example Audio": "Selecionar áudio de exemplo" -} +{ + "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Acumular lotes de gradientes", + "Add to Processing Area": "Agregar al Área de Procesamiento", + "Added path successfully!": "¡Ruta agregada exitosamente!", + "Advanced Config": "Configuración Avanzada", + "Base LLAMA Model": "Modelo Base LLAMA", + "Batch Inference": "Inferencia por Lote", + "Batch Size": "Tamaño del Lote", + "Changing with the Model Path": "Cambiando con la Ruta del Modelo", + "Chinese": "Chino", + "Compile Model": "Compilar Modelo", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío", + "Copy": "Copiar", + "Data Preprocessing": "Preprocesamiento de Datos", + "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos", + "Data Source": "Fuente de Datos", + "Decoder Model Config": "Configuración del modelo decodificador", + "Decoder Model Path": "Ruta del modelo decodificador", + "Disabled": "Desactivado", + "Enable Reference Audio": "Habilitar Audio de Referencia", + "English": "Inglés", + "Error Message": "Mensaje de Error", + "File Preprocessing": "Preprocesamiento de Archivos", + "Generate": "Generar", + "Generated Audio": "Audio Generado", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab", + "Infer interface is closed": "La interfaz de inferencia está cerrada", + "Inference Configuration": "Configuración de Inferencia", + "Inference Server Configuration": "Configuración del Servidor de Inferencia", + "Inference Server Error": "Error del Servidor de Inferencia", + "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}", + "Initial Learning Rate": "Tasa de Aprendizaje Inicial", + "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción", + "Input Text": "Texto de Entrada", + "Invalid path: {}": "Ruta inválida: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU", + "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado", + "Japanese": "Japonés", + "LLAMA Configuration": "Configuración de LLAMA", + "LLAMA Model Config": "Configuración del Modelo LLAMA", + "LLAMA Model Path": "Ruta del Modelo LLAMA", + "Labeling Device": "Dispositivo de Etiquetado", + "LoRA Model to be merged": "Modelo LoRA a fusionar", + "Maximum Audio Duration": "Duración máxima de audio", + "Maximum Length per Sample": "Longitud Máxima por Muestra", + "Maximum Training Steps": "Pasos Máximos de Entrenamiento", + "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite", + "Merge": "Fusionar", + "Merge LoRA": "Fusionar LoRA", + "Merge successfully": "Fusionado exitosamente", + "Minimum Audio Duration": "Duración mínima de audio", + "Model Output Path": "Ruta de Salida del Modelo", + "Model Size": "Tamaño del Modelo", + "Move": "Mover", + "Move files successfully": "Archivos movidos exitosamente", + "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.", + "No selected options": "No hay opciones seleccionadas", + "Number of Workers": "Número de Trabajadores", + "Open Inference Server": "Abrir Servidor de Inferencia", + "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador", + "Open Tensorboard": "Abrir Tensorboard", + "Opened labeler in browser": "Se abrió el etiquetador en el navegador", + "Optional Label Language": "Idioma de Etiquetado Opcional", + "Optional online ver": "Ver en línea opcional", + "Output Path": "Ruta de Salida", + "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente", + "Precision": "Precisión", + "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante", + "Put your text here.": "Ponga su texto aquí.", + "Reference Audio": "Audio de Referencia", + "Reference Text": "Texto de Referencia", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.", + "Remove Selected Data": "Eliminar Datos Seleccionados", + "Removed path successfully!": "¡Ruta eliminada exitosamente!", + "Repetition Penalty": "Penalización por Repetición", + "Save model every n steps": "Guardar modelo cada n pasos", + "Select LLAMA ckpt": "Seleccionar punto de control LLAMA", + "Select VITS ckpt": "Seleccionar punto de control VITS", + "Select VQGAN ckpt": "Seleccionar punto de control VQGAN", + "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente", + "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)", + "Selected: {}": "Seleccionado: {}", + "Speaker": "Hablante", + "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta", + "Start Training": "Iniciar Entrenamiento", + "Streaming Audio": "transmisión de audio", + "Streaming Generate": "síntesis en flujo", + "Tensorboard Host": "Host de Tensorboard", + "Tensorboard Log Path": "Ruta de Registro de Tensorboard", + "Tensorboard Port": "Puerto de Tensorboard", + "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada", + "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}", + "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.", + "Training Configuration": "Configuración de Entrenamiento", + "Training Error": "Error de Entrenamiento", + "Training stopped": "Entrenamiento detenido", + "Type name of the speaker": "Escriba el nombre del hablante", + "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable", + "Use LoRA": "Usar LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo", + "Use filelist": "Usar lista de archivos", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G", + "VITS Configuration": "Configuración de VITS", + "VQGAN Configuration": "Configuración de VQGAN", + "Validation Batch Size": "Tamaño del Lote de Validación", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.", + "WebUI Host": "Host de WebUI", + "WebUI Port": "Puerto de WebUI", + "Whisper Model": "Modelo Whisper", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+", + "latest": "más reciente", + "new": "nuevo", + "Realtime Transform Text": "Transformación de Texto en Tiempo Real", + "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)", + "Text Normalization": "Normalización de Texto", + "Select Example Audio": "Selecionar áudio de exemplo" +} diff --git a/fish_speech/i18n/locale/ja_JP.json b/fish_speech/i18n/locale/ja_JP.json index 9d0baeb73ffd9cc1af7570ef0ac7e6018ce9527b..863b8b0b41da7e504ac0dcc4abf707f1f71a53fa 100644 --- a/fish_speech/i18n/locale/ja_JP.json +++ b/fish_speech/i18n/locale/ja_JP.json @@ -1,123 +1,123 @@ -{ - "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします", - "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。", - "Accumulate Gradient Batches": "勾配バッチの累積", - "Add to Processing Area": "処理エリアに追加", - "Added path successfully!": "パスの追加に成功しました!", - "Advanced Config": "詳細設定", - "Base LLAMA Model": "基本LLAMAモデル", - "Batch Inference": "バッチ推論", - "Batch Size": "バッチサイズ", - "Changing with the Model Path": "モデルのパスに伴って変化する", - "Chinese": "中国語", - "Compile Model": "モデルのコンパイル", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります", - "Copy": "コピー", - "Data Preprocessing": "データ前処理", - "Data Preprocessing Path": "データ前処理パス", - "Data Source": "データソース", - "Decoder Model Config": "デコーダーモデルの構成", - "Decoder Model Path": "デコーダーモデルのパス", - "Disabled": "無効", - "Enable Reference Audio": "リファレンスオーディオを有効にする", - "English": "英語", - "Error Message": "エラーメッセージ", - "File Preprocessing": "文書前处理", - "Generate": "生成", - "Generated Audio": "生成されたオーディオ", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています", - "Infer interface is closed": "推論インターフェースが閉じられています", - "Inference Configuration": "推論設定", - "Inference Server Configuration": "推論サーバー設定", - "Inference Server Error": "推論サーバーエラー", - "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました", - "Initial Learning Rate": "初期学習率", - "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス", - "Input Text": "入力テキスト", - "Invalid path: {}": "無効なパス: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください", - "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します", - "Japanese": "日本語", - "LLAMA Configuration": "LLAMA設定", - "LLAMA Model Config": "LLAMAモデル設定", - "LLAMA Model Path": "LLAMAモデルパス", - "Labeling Device": "ラベリングデバイス", - "LoRA Model to be merged": "マージするLoRAモデル", - "Maximum Audio Duration": "最大オーディオの長さ", - "Maximum Length per Sample": "サンプルあたりの最大長", - "Maximum Training Steps": "最大トレーニングステップ数", - "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します", - "Merge": "マージ", - "Merge LoRA": "LoRAのマージ", - "Merge successfully": "マージに成功しました", - "Minimum Audio Duration": "最小オーディオの長さ", - "Model Output Path": "モデル出力パス", - "Model Size": "モデルサイズ", - "Move": "移動", - "Move files successfully": "ファイルの移動に成功しました", - "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。", - "No selected options": "選択されたオプションはありません", - "Number of Workers": "ワーカー数", - "Open Inference Server": "推論サーバーを開く", - "Open Labeler WebUI": "ラベラーWebUIを開く", - "Open Tensorboard": "Tensorboardを開く", - "Opened labeler in browser": "ブラウザでラベラーを開きました", - "Optional Label Language": "オプションのラベル言語", - "Optional online ver": "オプションのオンラインバージョン", - "Output Path": "出力パス", - "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください", - "Precision": "精度", - "Probability of applying Speaker Condition": "話者条件を適用する確率", - "Put your text here.": "ここにテキストを入力してください。", - "Reference Audio": "リファレンスオーディオ", - "Reference Text": "リファレンステキスト", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。", - "Remove Selected Data": "選択したデータを削除", - "Removed path successfully!": "パスの削除に成功しました!", - "Repetition Penalty": "反復ペナルティ", - "Save model every n steps": "nステップごとにモデルを保存", - "Select LLAMA ckpt": " LLAMA チェックポイントを選択", - "Select VITS ckpt": "VITS チェックポイントを選択", - "Select VQGAN ckpt": "VQGAN チェックポイントを選択", - "Select source file processing method": "ソースファイルの処理方法を選択", - "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください", - "Selected: {}": "選択済み: {}", - "Speaker": "話者", - "Speaker is identified by the folder name": "話者はフォルダ名で識別されます", - "Start Training": "トレーニング開始", - "Streaming Audio": "ストリーミングオーディオ", - "Streaming Generate": "ストリーミング合成", - "Tensorboard Host": "Tensorboardホスト", - "Tensorboard Log Path": "Tensorboardログパス", - "Tensorboard Port": "Tensorboardポート", - "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています", - "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました", - "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。", - "Training Configuration": "トレーニング設定", - "Training Error": "トレーニングエラー", - "Training stopped": "トレーニングが停止しました", - "Type name of the speaker": "話者の名前を入力", - "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください", - "Use LoRA": "LoRAを使用", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります", - "Use filelist": "ファイルリストを使用", - "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください", - "VITS Configuration": "VITS の構成", - "VQGAN Configuration": "VQGAN の構成", - "Validation Batch Size": "検証バッチサイズ", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。", - "WebUI Host": "WebUIホスト", - "WebUI Port": "WebUIポート", - "Whisper Model": "Whisperモデル", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします", - "latest": "最新", - "new": "新規", - "Realtime Transform Text": "リアルタイム変換テキスト", - "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)", - "Text Normalization": "テキスト正規化", - "Select Example Audio": "サンプル音声を選択" -} +{ + "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。", + "Accumulate Gradient Batches": "勾配バッチの累積", + "Add to Processing Area": "処理エリアに追加", + "Added path successfully!": "パスの追加に成功しました!", + "Advanced Config": "詳細設定", + "Base LLAMA Model": "基本LLAMAモデル", + "Batch Inference": "バッチ推論", + "Batch Size": "バッチサイズ", + "Changing with the Model Path": "モデルのパスに伴って変化する", + "Chinese": "中国語", + "Compile Model": "モデルのコンパイル", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります", + "Copy": "コピー", + "Data Preprocessing": "データ前処理", + "Data Preprocessing Path": "データ前処理パス", + "Data Source": "データソース", + "Decoder Model Config": "デコーダーモデルの構成", + "Decoder Model Path": "デコーダーモデルのパス", + "Disabled": "無効", + "Enable Reference Audio": "リファレンスオーディオを有効にする", + "English": "英語", + "Error Message": "エラーメッセージ", + "File Preprocessing": "文書前处理", + "Generate": "生成", + "Generated Audio": "生成されたオーディオ", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています", + "Infer interface is closed": "推論インターフェースが閉じられています", + "Inference Configuration": "推論設定", + "Inference Server Configuration": "推論サーバー設定", + "Inference Server Error": "推論サーバーエラー", + "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました", + "Initial Learning Rate": "初期学習率", + "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス", + "Input Text": "入力テキスト", + "Invalid path: {}": "無効なパス: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください", + "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します", + "Japanese": "日本語", + "LLAMA Configuration": "LLAMA設定", + "LLAMA Model Config": "LLAMAモデル設定", + "LLAMA Model Path": "LLAMAモデルパス", + "Labeling Device": "ラベリングデバイス", + "LoRA Model to be merged": "マージするLoRAモデル", + "Maximum Audio Duration": "最大オーディオの長さ", + "Maximum Length per Sample": "サンプルあたりの最大長", + "Maximum Training Steps": "最大トレーニングステップ数", + "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します", + "Merge": "マージ", + "Merge LoRA": "LoRAのマージ", + "Merge successfully": "マージに成功しました", + "Minimum Audio Duration": "最小オーディオの長さ", + "Model Output Path": "モデル出力パス", + "Model Size": "モデルサイズ", + "Move": "移動", + "Move files successfully": "ファイルの移動に成功しました", + "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。", + "No selected options": "選択されたオプションはありません", + "Number of Workers": "ワーカー数", + "Open Inference Server": "推論サーバーを開く", + "Open Labeler WebUI": "ラベラーWebUIを開く", + "Open Tensorboard": "Tensorboardを開く", + "Opened labeler in browser": "ブラウザでラベラーを開きました", + "Optional Label Language": "オプションのラベル言語", + "Optional online ver": "オプションのオンラインバージョン", + "Output Path": "出力パス", + "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください", + "Precision": "精度", + "Probability of applying Speaker Condition": "話者条件を適用する確率", + "Put your text here.": "ここにテキストを入力してください。", + "Reference Audio": "リファレンスオーディオ", + "Reference Text": "リファレンステキスト", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。", + "Remove Selected Data": "選択したデータを削除", + "Removed path successfully!": "パスの削除に成功しました!", + "Repetition Penalty": "反復ペナルティ", + "Save model every n steps": "nステップごとにモデルを保存", + "Select LLAMA ckpt": " LLAMA チェックポイントを選択", + "Select VITS ckpt": "VITS チェックポイントを選択", + "Select VQGAN ckpt": "VQGAN チェックポイントを選択", + "Select source file processing method": "ソースファイルの処理方法を選択", + "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください", + "Selected: {}": "選択済み: {}", + "Speaker": "話者", + "Speaker is identified by the folder name": "話者はフォルダ名で識別されます", + "Start Training": "トレーニング開始", + "Streaming Audio": "ストリーミングオーディオ", + "Streaming Generate": "ストリーミング合成", + "Tensorboard Host": "Tensorboardホスト", + "Tensorboard Log Path": "Tensorboardログパス", + "Tensorboard Port": "Tensorboardポート", + "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています", + "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました", + "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。", + "Training Configuration": "トレーニング設定", + "Training Error": "トレーニングエラー", + "Training stopped": "トレーニングが停止しました", + "Type name of the speaker": "話者の名前を入力", + "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください", + "Use LoRA": "LoRAを使用", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります", + "Use filelist": "ファイルリストを使用", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください", + "VITS Configuration": "VITS の構成", + "VQGAN Configuration": "VQGAN の構成", + "Validation Batch Size": "検証バッチサイズ", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。", + "WebUI Host": "WebUIホスト", + "WebUI Port": "WebUIポート", + "Whisper Model": "Whisperモデル", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします", + "latest": "最新", + "new": "新規", + "Realtime Transform Text": "リアルタイム変換テキスト", + "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)", + "Text Normalization": "テキスト正規化", + "Select Example Audio": "サンプル音声を選択" +} diff --git a/fish_speech/i18n/locale/ko_KR.json b/fish_speech/i18n/locale/ko_KR.json index f4bf1841b7c847993707ec3b8e32f5174de77214..180263874b476059870035d4c2b74ce5fa553a8a 100644 --- a/fish_speech/i18n/locale/ko_KR.json +++ b/fish_speech/i18n/locale/ko_KR.json @@ -1,123 +1,123 @@ -{ - "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.", - "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.", - "Accumulate Gradient Batches": "그라디언트 배치 누적", - "Add to Processing Area": "처리 영역에 추가", - "Added path successfully!": "경로가 성공적으로 추가되었습니다!", - "Advanced Config": "고급 설정", - "Base LLAMA Model": "기본 LLAMA 모델", - "Batch Inference": "배치 추론", - "Batch Size": "배치 크기", - "Changing with the Model Path": "모델 경로에 따라 변경 중", - "Chinese": "중국어", - "Compile Model": "모델 컴파일", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.", - "Copy": "복사", - "Data Preprocessing": "데이터 전처리", - "Data Preprocessing Path": "데이터 전처리 경로", - "Data Source": "데이터 소스", - "Decoder Model Config": "디코더 모델 설정", - "Decoder Model Path": "디코더 모델 경로", - "Disabled": "비활성화 됨", - "Enable Reference Audio": "참고 음성 활성화", - "English": "영어", - "Error Message": "오류 메시지", - "File Preprocessing": "파일 전처리", - "Generate": "생성", - "Generated Audio": "생성된 오디오", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.", - "Infer interface is closed": "추론 인터페이스가 닫혔습니다.", - "Inference Configuration": "추론 설정", - "Inference Server Configuration": "추론 서버 설정", - "Inference Server Error": "추론 서버 오류", - "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.", - "Initial Learning Rate": "초기 학습률", - "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로", - "Input Text": "입력 텍스트", - "Invalid path: {}": "유효하지 않은 경로: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.", - "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)", - "Japanese": "일본어", - "LLAMA Configuration": "LLAMA 설정", - "LLAMA Model Config": "LLAMA 모델 설정", - "LLAMA Model Path": "LLAMA 모델 경로", - "Labeling Device": "라벨링 장치", - "LoRA Model to be merged": "병합할 LoRA 모델", - "Maximum Audio Duration": "최대 오디오 길이", - "Maximum Length per Sample": "샘플당 최대 길이", - "Maximum Training Steps": "최대 학습 단계", - "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)", - "Merge": "병합", - "Merge LoRA": "LoRA 병합", - "Merge successfully": "성공적으로 병합 되었습니다.", - "Minimum Audio Duration": "최소 오디오 길이", - "Model Output Path": "모델 출력 경로", - "Model Size": "모델 크기", - "Move": "이동", - "Move files successfully": "파일이 성공적으로 이동되었습니다.", - "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.", - "No selected options": "옵션이 선택되지 않았습니다.", - "Number of Workers": "작업자 수", - "Open Inference Server": "추론 서버 열기", - "Open Labeler WebUI": "라벨러 WebUI 열기", - "Open Tensorboard": "Tensorboard 열기", - "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.", - "Optional Label Language": "선택적 라벨 언어", - "Optional online ver": "온라인 버전 선택", - "Output Path": "출력 경로", - "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.", - "Precision": "정밀도", - "Probability of applying Speaker Condition": "화자 조건 적용 확률", - "Put your text here.": "여기에 텍스트를 입력하세요.", - "Reference Audio": "참고 오디오", - "Reference Text": "참고 텍스트", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.", - "Remove Selected Data": "선택한 데이터 제거", - "Removed path successfully!": "경로가 성공적으로 제거되었습니다!", - "Repetition Penalty": "반복 패널티", - "Save model every n steps": "n 단계마다 모델 저장", - "Select LLAMA ckpt": "LLAMA ckpt 선택", - "Select VITS ckpt": "VITS ckpt 선택", - "Select VQGAN ckpt": "VQGAN ckpt 선택", - "Select source file processing method": "소스 파일 처리 방법 선택", - "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)", - "Selected: {}": "선택됨: {}", - "Speaker": "화자", - "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다", - "Start Training": "학습 시작", - "Streaming Audio": "스트리밍 오디오", - "Streaming Generate": "스트리밍 생성", - "Tensorboard Host": "Tensorboard 호스트", - "Tensorboard Log Path": "Tensorboard 로그 경로", - "Tensorboard Port": "Tensorboard 포트", - "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다", - "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.", - "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.", - "Training Configuration": "학습 설정", - "Training Error": "학습 오류", - "Training stopped": "학습이 중지되었습니다.", - "Type name of the speaker": "화자의 이름을 입력하세요.", - "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.", - "Use LoRA": "LoRA 사용", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.", - "Use filelist": "파일 목록 사용", - "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.", - "VITS Configuration": "VITS 설정", - "VQGAN Configuration": "VQGAN 설정", - "Validation Batch Size": "검증 배치 크기", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.", - "WebUI Host": "WebUI 호스트", - "WebUI Port": "WebUI 포트", - "Whisper Model": "Whisper 모델", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다", - "latest": "최신", - "new": "새로운", - "Realtime Transform Text": "실시간 텍스트 변환", - "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)", - "Text Normalization": "텍스트 정규화", - "Select Example Audio": "예시 오디오 선택" -} +{ + "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.", + "Accumulate Gradient Batches": "그라디언트 배치 누적", + "Add to Processing Area": "처리 영역에 추가", + "Added path successfully!": "경로가 성공적으로 추가되었습니다!", + "Advanced Config": "고급 설정", + "Base LLAMA Model": "기본 LLAMA 모델", + "Batch Inference": "배치 추론", + "Batch Size": "배치 크기", + "Changing with the Model Path": "모델 경로에 따라 변경 중", + "Chinese": "중국어", + "Compile Model": "모델 컴파일", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.", + "Copy": "복사", + "Data Preprocessing": "데이터 전처리", + "Data Preprocessing Path": "데이터 전처리 경로", + "Data Source": "데이터 소스", + "Decoder Model Config": "디코더 모델 설정", + "Decoder Model Path": "디코더 모델 경로", + "Disabled": "비활성화 됨", + "Enable Reference Audio": "참고 음성 활성화", + "English": "영어", + "Error Message": "오류 메시지", + "File Preprocessing": "파일 전처리", + "Generate": "생성", + "Generated Audio": "생성된 오디오", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.", + "Infer interface is closed": "추론 인터페이스가 닫혔습니다.", + "Inference Configuration": "추론 설정", + "Inference Server Configuration": "추론 서버 설정", + "Inference Server Error": "추론 서버 오류", + "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.", + "Initial Learning Rate": "초기 학습률", + "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로", + "Input Text": "입력 텍스트", + "Invalid path: {}": "유효하지 않은 경로: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.", + "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)", + "Japanese": "일본어", + "LLAMA Configuration": "LLAMA 설정", + "LLAMA Model Config": "LLAMA 모델 설정", + "LLAMA Model Path": "LLAMA 모델 경로", + "Labeling Device": "라벨링 장치", + "LoRA Model to be merged": "병합할 LoRA 모델", + "Maximum Audio Duration": "최대 오디오 길이", + "Maximum Length per Sample": "샘플당 최대 길이", + "Maximum Training Steps": "최대 학습 단계", + "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)", + "Merge": "병합", + "Merge LoRA": "LoRA 병합", + "Merge successfully": "성공적으로 병합 되었습니다.", + "Minimum Audio Duration": "최소 오디오 길이", + "Model Output Path": "모델 출력 경로", + "Model Size": "모델 크기", + "Move": "이동", + "Move files successfully": "파일이 성공적으로 이동되었습니다.", + "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.", + "No selected options": "옵션이 선택되지 않았습니다.", + "Number of Workers": "작업자 수", + "Open Inference Server": "추론 서버 열기", + "Open Labeler WebUI": "라벨러 WebUI 열기", + "Open Tensorboard": "Tensorboard 열기", + "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.", + "Optional Label Language": "선택적 라벨 언어", + "Optional online ver": "온라인 버전 선택", + "Output Path": "출력 경로", + "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.", + "Precision": "정밀도", + "Probability of applying Speaker Condition": "화자 조건 적용 확률", + "Put your text here.": "여기에 텍스트를 입력하세요.", + "Reference Audio": "참고 오디오", + "Reference Text": "참고 텍스트", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.", + "Remove Selected Data": "선택한 데이터 제거", + "Removed path successfully!": "경로가 성공적으로 제거되었습니다!", + "Repetition Penalty": "반복 패널티", + "Save model every n steps": "n 단계마다 모델 저장", + "Select LLAMA ckpt": "LLAMA ckpt 선택", + "Select VITS ckpt": "VITS ckpt 선택", + "Select VQGAN ckpt": "VQGAN ckpt 선택", + "Select source file processing method": "소스 파일 처리 방법 선택", + "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)", + "Selected: {}": "선택됨: {}", + "Speaker": "화자", + "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다", + "Start Training": "학습 시작", + "Streaming Audio": "스트리밍 오디오", + "Streaming Generate": "스트리밍 생성", + "Tensorboard Host": "Tensorboard 호스트", + "Tensorboard Log Path": "Tensorboard 로그 경로", + "Tensorboard Port": "Tensorboard 포트", + "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다", + "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.", + "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.", + "Training Configuration": "학습 설정", + "Training Error": "학습 오류", + "Training stopped": "학습이 중지되었습니다.", + "Type name of the speaker": "화자의 이름을 입력하세요.", + "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.", + "Use LoRA": "LoRA 사용", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.", + "Use filelist": "파일 목록 사용", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.", + "VITS Configuration": "VITS 설정", + "VQGAN Configuration": "VQGAN 설정", + "Validation Batch Size": "검증 배치 크기", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.", + "WebUI Host": "WebUI 호스트", + "WebUI Port": "WebUI 포트", + "Whisper Model": "Whisper 모델", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다", + "latest": "최신", + "new": "새로운", + "Realtime Transform Text": "실시간 텍스트 변환", + "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)", + "Text Normalization": "텍스트 정규화", + "Select Example Audio": "예시 오디오 선택" +} diff --git a/fish_speech/i18n/locale/pt_BR.json b/fish_speech/i18n/locale/pt_BR.json index a5278e29fac737d9fd3da3f8e82e49ac22a96ac3..385f20272e19053ab9b6cf6463a84c8ece768c68 100644 --- a/fish_speech/i18n/locale/pt_BR.json +++ b/fish_speech/i18n/locale/pt_BR.json @@ -1,133 +1,133 @@ -{ - "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).", - "Accumulate Gradient Batches": "Acumular Lotes de Gradiente", - "Add to Processing Area": "Adicionar à Área de Processamento", - "Added path successfully!": "Caminho adicionado com sucesso!", - "Advanced Config": "Configuração Avançada", - "Base LLAMA Model": "Modelo LLAMA Base", - "Batch Inference": "Inferência em Lote", - "Batch Size": "Tamanho do Lote", - "Changing with the Model Path": "Alterando com o Caminho do Modelo", - - "Compile Model": "Compilar Modelo", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial", - "Copy": "Copiar", - "Data Preprocessing": "Pré-processamento de Dados", - "Data Preprocessing Path": "Caminho de Pré-processamento de Dados", - "Data Source": "Fonte de Dados", - "Decoder Model Config": "Configuração do Modelo Decodificador", - "Decoder Model Path": "Caminho do Modelo Decodificador", - "Disabled": "Desativado", - "Enable Initial Prompt": "Habilitar Prompt Inicial", - "Enable Reference Audio": "Habilitar Áudio de Referência", - "English": "Inglês", - "Japanese": "Japonês", - "Chinese": "Chinês", - "Portuguese": "Português", - "Spanish": "Espanhol", - "Error Message": "Mensagem de Erro", - "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)", - "File Preprocessing": "Pré-processamento de Arquivos", - "Generate": "Gerar", - "Generated Audio": "Áudio Gerado", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)", - "Infer interface is closed": "A interface de inferência foi fechada", - "Inference Configuration": "Configuração de Inferência", - "Inference Server Configuration": "Configuração do Servidor de Inferência", - "Inference Server Error": "Erro do Servidor de Inferência", - "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}", - "Initial Learning Rate": "Taxa de Aprendizagem Inicial", - "Initial Prompt": "Prompt Inicial", - "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.", - "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição", - "Input Text": "Texto de Entrada", - "Invalid path: {}": "Caminho inválido: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU", - "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)", - "LLAMA Configuration": "Configuração do LLAMA", - "LLAMA Model Config": "Configuração do Modelo LLAMA", - "LLAMA Model Path": "Caminho do Modelo LLAMA", - "Labeling Device": "Dispositivo de Rotulagem", - "LoRA Model to be merged": "Modelo LoRA para mesclagem", - "Maximum Length per Sample": "Comprimento Máximo por Amostra", - "Maximum Training Steps": "Etapas Máximas de Treinamento", - "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite", - "Merge": "Mesclar", - "Merge LoRA": "Mesclar LoRA", - "Merge successfully": "Mesclado com sucesso", - "Model Output Path": "Caminho de Saída do Modelo", - "Model Quantization": "Quantização do Modelo", - "Model Size": "Tamanho do Modelo", - "Move": "Mover", - "Move files successfully": "Arquivos movidos com sucesso", - "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.", - "No selected options": "Nenhuma opção selecionada", - "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)", - "Number of Workers": "Número de Processos", - "Open Inference Server": "Abrir Servidor de Inferência", - "Open Labeler WebUI": "Abrir WebUI de Rotulagem", - "Open Tensorboard": "Abrir Tensorboard", - "Opened labeler in browser": "WebUI de rotulagem aberta no navegador", - "Optional Label Language": "Idioma do Rótulo (Opcional)", - "Optional online ver": "Versão online (opcional)", - "Output Path": "Caminho de Saída", - "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente", - "Post-quantification Precision": "Precisão Pós-quantização", - "Precision": "Precisão", - "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador", - "Put your text here.": "Insira seu texto aqui.", - "Quantify": "Quantizar", - "Quantify successfully": "Quantizado com sucesso", - "Realtime Transform Text": "Transformar Texto em Tempo Real", - "Reference Audio": "Áudio de Referência", - "Reference Text": "Texto de Referência", - "warning": "Aviso", - "Pre-processing begins...": "O pré-processamento começou!", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.", - "Remove Selected Data": "Remover Dados Selecionados", - "Removed path successfully!": "Caminho removido com sucesso!", - "Repetition Penalty": "Penalidade de Repetição", - "Save model every n steps": "Salvar modelo a cada n etapas", - "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA", - "Select source file processing method": "Escolha como processar o arquivo de origem", - "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)", - "Selected: {}": "Selecionado: {}", - "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta", - "Start Training": "Iniciar Treinamento", - "Streaming Audio": "Áudio em Streaming", - "Streaming Generate": "Geração em Streaming", - "Tensorboard Host": "Host do Tensorboard", - "Tensorboard Log Path": "Caminho de Log do Tensorboard", - "Tensorboard Port": "Porta do Tensorboard", - "Tensorboard interface is closed": "A interface do Tensorboard está fechada", - "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}", - "Text Normalization": "Normalização de Texto", - "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.", - "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.", - "Training Configuration": "Configuração de Treinamento", - "Training Error": "Erro de Treinamento", - "Training stopped": "Treinamento interrompido!", - "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso", - "Use LoRA": "Usar LoRA", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade", - "Use filelist": "Usar lista de arquivos", - "VQGAN Configuration": "Configuração do VQGAN", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.", - "WebUI Host": "Host da WebUI", - "WebUI Port": "Porta da WebUI", - "Whisper Model": "Modelo Whisper", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).", - "auto": "automático", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+", - "latest": "mais recente", - "new": "novo", - "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.", - "You don't need to train this model!": "Não é necessário treinar este modelo!", - "Yes": "Sim", - "No": "Não", - "version:": "versão:", - "author:": "autor:" -} +{ + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Acumular Lotes de Gradiente", + "Add to Processing Area": "Adicionar à Área de Processamento", + "Added path successfully!": "Caminho adicionado com sucesso!", + "Advanced Config": "Configuração Avançada", + "Base LLAMA Model": "Modelo LLAMA Base", + "Batch Inference": "Inferência em Lote", + "Batch Size": "Tamanho do Lote", + "Changing with the Model Path": "Alterando com o Caminho do Modelo", + + "Compile Model": "Compilar Modelo", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial", + "Copy": "Copiar", + "Data Preprocessing": "Pré-processamento de Dados", + "Data Preprocessing Path": "Caminho de Pré-processamento de Dados", + "Data Source": "Fonte de Dados", + "Decoder Model Config": "Configuração do Modelo Decodificador", + "Decoder Model Path": "Caminho do Modelo Decodificador", + "Disabled": "Desativado", + "Enable Initial Prompt": "Habilitar Prompt Inicial", + "Enable Reference Audio": "Habilitar Áudio de Referência", + "English": "Inglês", + "Japanese": "Japonês", + "Chinese": "Chinês", + "Portuguese": "Português", + "Spanish": "Espanhol", + "Error Message": "Mensagem de Erro", + "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)", + "File Preprocessing": "Pré-processamento de Arquivos", + "Generate": "Gerar", + "Generated Audio": "Áudio Gerado", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)", + "Infer interface is closed": "A interface de inferência foi fechada", + "Inference Configuration": "Configuração de Inferência", + "Inference Server Configuration": "Configuração do Servidor de Inferência", + "Inference Server Error": "Erro do Servidor de Inferência", + "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}", + "Initial Learning Rate": "Taxa de Aprendizagem Inicial", + "Initial Prompt": "Prompt Inicial", + "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.", + "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição", + "Input Text": "Texto de Entrada", + "Invalid path: {}": "Caminho inválido: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU", + "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)", + "LLAMA Configuration": "Configuração do LLAMA", + "LLAMA Model Config": "Configuração do Modelo LLAMA", + "LLAMA Model Path": "Caminho do Modelo LLAMA", + "Labeling Device": "Dispositivo de Rotulagem", + "LoRA Model to be merged": "Modelo LoRA para mesclagem", + "Maximum Length per Sample": "Comprimento Máximo por Amostra", + "Maximum Training Steps": "Etapas Máximas de Treinamento", + "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite", + "Merge": "Mesclar", + "Merge LoRA": "Mesclar LoRA", + "Merge successfully": "Mesclado com sucesso", + "Model Output Path": "Caminho de Saída do Modelo", + "Model Quantization": "Quantização do Modelo", + "Model Size": "Tamanho do Modelo", + "Move": "Mover", + "Move files successfully": "Arquivos movidos com sucesso", + "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.", + "No selected options": "Nenhuma opção selecionada", + "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)", + "Number of Workers": "Número de Processos", + "Open Inference Server": "Abrir Servidor de Inferência", + "Open Labeler WebUI": "Abrir WebUI de Rotulagem", + "Open Tensorboard": "Abrir Tensorboard", + "Opened labeler in browser": "WebUI de rotulagem aberta no navegador", + "Optional Label Language": "Idioma do Rótulo (Opcional)", + "Optional online ver": "Versão online (opcional)", + "Output Path": "Caminho de Saída", + "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente", + "Post-quantification Precision": "Precisão Pós-quantização", + "Precision": "Precisão", + "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador", + "Put your text here.": "Insira seu texto aqui.", + "Quantify": "Quantizar", + "Quantify successfully": "Quantizado com sucesso", + "Realtime Transform Text": "Transformar Texto em Tempo Real", + "Reference Audio": "Áudio de Referência", + "Reference Text": "Texto de Referência", + "warning": "Aviso", + "Pre-processing begins...": "O pré-processamento começou!", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.", + "Remove Selected Data": "Remover Dados Selecionados", + "Removed path successfully!": "Caminho removido com sucesso!", + "Repetition Penalty": "Penalidade de Repetição", + "Save model every n steps": "Salvar modelo a cada n etapas", + "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA", + "Select source file processing method": "Escolha como processar o arquivo de origem", + "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)", + "Selected: {}": "Selecionado: {}", + "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta", + "Start Training": "Iniciar Treinamento", + "Streaming Audio": "Áudio em Streaming", + "Streaming Generate": "Geração em Streaming", + "Tensorboard Host": "Host do Tensorboard", + "Tensorboard Log Path": "Caminho de Log do Tensorboard", + "Tensorboard Port": "Porta do Tensorboard", + "Tensorboard interface is closed": "A interface do Tensorboard está fechada", + "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}", + "Text Normalization": "Normalização de Texto", + "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.", + "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.", + "Training Configuration": "Configuração de Treinamento", + "Training Error": "Erro de Treinamento", + "Training stopped": "Treinamento interrompido!", + "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso", + "Use LoRA": "Usar LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade", + "Use filelist": "Usar lista de arquivos", + "VQGAN Configuration": "Configuração do VQGAN", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.", + "WebUI Host": "Host da WebUI", + "WebUI Port": "Porta da WebUI", + "Whisper Model": "Modelo Whisper", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).", + "auto": "automático", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+", + "latest": "mais recente", + "new": "novo", + "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.", + "You don't need to train this model!": "Não é necessário treinar este modelo!", + "Yes": "Sim", + "No": "Não", + "version:": "versão:", + "author:": "autor:" +} diff --git a/fish_speech/i18n/locale/zh_CN.json b/fish_speech/i18n/locale/zh_CN.json index df7cd5477bfa035ea66ae7322dad46f5d054d9b0..9068ef0b9a41b9941b37644c6a4c96ec6a5d836e 100644 --- a/fish_speech/i18n/locale/zh_CN.json +++ b/fish_speech/i18n/locale/zh_CN.json @@ -1,123 +1,123 @@ -{ - "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed", - "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。", - "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.", - "Accumulate Gradient Batches": "梯度累积批次", - "Add to Processing Area": "加入处理区", - "Added path successfully!": "添加路径成功!", - "Advanced Config": "高级参数", - "Base LLAMA Model": "基础 LLAMA 模型", - "Batch Inference": "批量推理", - "Batch Size": "批次大小", - "Changing with the Model Path": "随模型路径变化", - "Chinese": "中文", - "Compile Model": "编译模型", - "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间", - "Copy": "复制", - "Data Preprocessing": "数据预处理", - "Data Preprocessing Path": "数据预处理路径", - "Data Source": "数据源", - "Decoder Model Config": "解码器模型配置", - "Decoder Model Path": "解码器模型路径", - "Disabled": "禁用", - "Enable Reference Audio": "启用参考音频", - "English": "英文", - "Error Message": "错误信息", - "File Preprocessing": "文件预处理", - "Generate": "生成", - "Generated Audio": "音频", - "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式", - "Infer interface is closed": "推理界面已关闭", - "Inference Configuration": "推理配置", - "Inference Server Configuration": "推理服务器配置", - "Inference Server Error": "推理服务器错误", - "Inferring interface is launched at {}": "推理界面已在 {} 上启动", - "Initial Learning Rate": "初始学习率", - "Input Audio & Source Path for Transcription": "输入音频和转录源路径", - "Input Text": "输入文本", - "Invalid path: {}": "无效路径: {}", - "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU", - "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭", - "Japanese": "日文", - "LLAMA Configuration": "LLAMA 配置", - "LLAMA Model Config": "LLAMA 模型配置", - "LLAMA Model Path": "LLAMA 模型路径", - "Labeling Device": "标注加速设备", - "LoRA Model to be merged": "要合并的 LoRA 模型", - "Maximum Audio Duration": "最大音频时长", - "Maximum Length per Sample": "每个样本的最大长度", - "Maximum Training Steps": "最大训练步数", - "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制", - "Merge": "合并", - "Merge LoRA": "合并 LoRA", - "Merge successfully": "合并成功", - "Minimum Audio Duration": "最小音频时长", - "Model Output Path": "模型输出路径", - "Model Size": "模型规模", - "Move": "移动", - "Move files successfully": "移动文件成功", - "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.", - "No selected options": "没有选择的选项", - "Number of Workers": "数据加载进程数", - "Open Inference Server": "打开推理服务器", - "Open Labeler WebUI": "打开标注工具", - "Open Tensorboard": "打开 Tensorboard", - "Opened labeler in browser": "在浏览器中打开标注工具", - "Optional Label Language": "[可选] 标注语言", - "Optional online ver": "[可选] 使用在线版", - "Output Path": "输出路径", - "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径", - "Precision": "精度", - "Probability of applying Speaker Condition": "应用说话人条件的概率", - "Put your text here.": "在此处输入文本.", - "Reference Audio": "参考音频", - "Reference Text": "参考文本", - "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.", - "Remove Selected Data": "移除选中数据", - "Removed path successfully!": "移除路径成功!", - "Repetition Penalty": "重复惩罚", - "Save model every n steps": "每 n 步保存模型", - "Select LLAMA ckpt": "选择 LLAMA 检查点", - "Select VITS ckpt": "选择 VITS 检查点", - "Select VQGAN ckpt": "选择 VQGAN 检查点", - "Select source file processing method": "选择源文件处理方法", - "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型", - "Selected: {}": "已选择: {}", - "Speaker": "说话人", - "Speaker is identified by the folder name": "自动根据父目录名称识别说话人", - "Start Training": "开始训练", - "Streaming Audio": "流式音频", - "Streaming Generate": "流式合成", - "Tensorboard Host": "Tensorboard 监听地址", - "Tensorboard Log Path": "Tensorboard 日志路径", - "Tensorboard Port": "Tensorboard 端口", - "Tensorboard interface is closed": "Tensorboard 界面已关闭", - "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动", - "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.", - "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.", - "Training Configuration": "训练配置", - "Training Error": "训练错误", - "Training stopped": "训练已停止", - "Type name of the speaker": "输入说话人的名称", - "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择", - "Use LoRA": "使用 LoRA", - "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量", - "Use filelist": "使用文件列表", - "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small", - "VITS Configuration": "VITS 配置", - "VQGAN Configuration": "VQGAN 配置", - "Validation Batch Size": "验证批次大小", - "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)", - "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.", - "WebUI Host": "WebUI 监听地址", - "WebUI Port": "WebUI 端口", - "Whisper Model": "Whisper 模型", - "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.", - "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed", - "latest": "最近的检查点", - "new": "创建新的检查点", - "Realtime Transform Text": "实时规范化文本", - "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览", - "Text Normalization": "文本规范化", - "Select Example Audio": "选择参考音频" -} +{ + "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.", + "Accumulate Gradient Batches": "梯度累积批次", + "Add to Processing Area": "加入处理区", + "Added path successfully!": "添加路径成功!", + "Advanced Config": "高级参数", + "Base LLAMA Model": "基础 LLAMA 模型", + "Batch Inference": "批量推理", + "Batch Size": "批次大小", + "Changing with the Model Path": "随模型路径变化", + "Chinese": "中文", + "Compile Model": "编译模型", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间", + "Copy": "复制", + "Data Preprocessing": "数据预处理", + "Data Preprocessing Path": "数据预处理路径", + "Data Source": "数据源", + "Decoder Model Config": "解码器模型配置", + "Decoder Model Path": "解码器模型路径", + "Disabled": "禁用", + "Enable Reference Audio": "启用参考音频", + "English": "英文", + "Error Message": "错误信息", + "File Preprocessing": "文件预处理", + "Generate": "生成", + "Generated Audio": "音频", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式", + "Infer interface is closed": "推理界面已关闭", + "Inference Configuration": "推理配置", + "Inference Server Configuration": "推理服务器配置", + "Inference Server Error": "推理服务器错误", + "Inferring interface is launched at {}": "推理界面已在 {} 上启动", + "Initial Learning Rate": "初始学习率", + "Input Audio & Source Path for Transcription": "输入音频和转录源路径", + "Input Text": "输入文本", + "Invalid path: {}": "无效路径: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU", + "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭", + "Japanese": "日文", + "LLAMA Configuration": "LLAMA 配置", + "LLAMA Model Config": "LLAMA 模型配置", + "LLAMA Model Path": "LLAMA 模型路径", + "Labeling Device": "标注加速设备", + "LoRA Model to be merged": "要合并的 LoRA 模型", + "Maximum Audio Duration": "最大音频时长", + "Maximum Length per Sample": "每个样本的最大长度", + "Maximum Training Steps": "最大训练步数", + "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制", + "Merge": "合并", + "Merge LoRA": "合并 LoRA", + "Merge successfully": "合并成功", + "Minimum Audio Duration": "最小音频时长", + "Model Output Path": "模型输出路径", + "Model Size": "模型规模", + "Move": "移动", + "Move files successfully": "移动文件成功", + "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.", + "No selected options": "没有选择的选项", + "Number of Workers": "数据加载进程数", + "Open Inference Server": "打开推理服务器", + "Open Labeler WebUI": "打开标注工具", + "Open Tensorboard": "打开 Tensorboard", + "Opened labeler in browser": "在浏览器中打开标注工具", + "Optional Label Language": "[可选] 标注语言", + "Optional online ver": "[可选] 使用在线版", + "Output Path": "输出路径", + "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径", + "Precision": "精度", + "Probability of applying Speaker Condition": "应用说话人条件的概率", + "Put your text here.": "在此处输入文本.", + "Reference Audio": "参考音频", + "Reference Text": "参考文本", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.", + "Remove Selected Data": "移除选中数据", + "Removed path successfully!": "移除路径成功!", + "Repetition Penalty": "重复惩罚", + "Save model every n steps": "每 n 步保存模型", + "Select LLAMA ckpt": "选择 LLAMA 检查点", + "Select VITS ckpt": "选择 VITS 检查点", + "Select VQGAN ckpt": "选择 VQGAN 检查点", + "Select source file processing method": "选择源文件处理方法", + "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型", + "Selected: {}": "已选择: {}", + "Speaker": "说话人", + "Speaker is identified by the folder name": "自动根据父目录名称识别说话人", + "Start Training": "开始训练", + "Streaming Audio": "流式音频", + "Streaming Generate": "流式合成", + "Tensorboard Host": "Tensorboard 监听地址", + "Tensorboard Log Path": "Tensorboard 日志路径", + "Tensorboard Port": "Tensorboard 端口", + "Tensorboard interface is closed": "Tensorboard 界面已关闭", + "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动", + "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.", + "Training Configuration": "训练配置", + "Training Error": "训练错误", + "Training stopped": "训练已停止", + "Type name of the speaker": "输入说话人的名称", + "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择", + "Use LoRA": "使用 LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量", + "Use filelist": "使用文件列表", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small", + "VITS Configuration": "VITS 配置", + "VQGAN Configuration": "VQGAN 配置", + "Validation Batch Size": "验证批次大小", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.", + "WebUI Host": "WebUI 监听地址", + "WebUI Port": "WebUI 端口", + "Whisper Model": "Whisper 模型", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed", + "latest": "最近的检查点", + "new": "创建新的检查点", + "Realtime Transform Text": "实时规范化文本", + "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览", + "Text Normalization": "文本规范化", + "Select Example Audio": "选择参考音频" +} diff --git a/fish_speech/i18n/scan.py b/fish_speech/i18n/scan.py index 00a39a4d08b8a19c91a8518e90bafe6ceea2231c..d0194c0f1a31dc95309c64626d13f04751a44ba1 100644 --- a/fish_speech/i18n/scan.py +++ b/fish_speech/i18n/scan.py @@ -1,122 +1,122 @@ -import ast -import glob -import json -from collections import OrderedDict -from pathlib import Path - -from loguru import logger - -from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH - - -def extract_i18n_strings(node): - i18n_strings = [] - - if ( - isinstance(node, ast.Call) - and isinstance(node.func, ast.Name) - and node.func.id == "i18n" - ): - for arg in node.args: - if isinstance(arg, ast.Str): - i18n_strings.append(arg.s) - - for child_node in ast.iter_child_nodes(node): - i18n_strings.extend(extract_i18n_strings(child_node)) - - return i18n_strings - - -# scan the directory for all .py files (recursively) -# for each file, parse the code into an AST -# for each AST, extract the i18n strings - -strings = [] -folders = ["fish_speech", "tools"] -# for filename in glob.iglob("**/*.py", recursive=True): -for folder in folders: - for f in Path(folder).rglob("*.py"): - code = f.read_text(encoding="utf-8") - if "i18n(" in code: - tree = ast.parse(code) - i18n_strings = extract_i18n_strings(tree) - logger.info(f"Found {len(i18n_strings)} i18n strings in {f}") - strings.extend(i18n_strings) - -code_keys = set(strings) -logger.info(f"Total unique: {len(code_keys)}") - - -standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" -with open(standard_file, "r", encoding="utf-8") as f: - standard_data = json.load(f, object_pairs_hook=OrderedDict) -standard_keys = set(standard_data.keys()) - -# Define the standard file name -unused_keys = standard_keys - code_keys -logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}") -for unused_key in unused_keys: - logger.info(f"\t{unused_key}") - -missing_keys = code_keys - standard_keys -logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}") -for missing_key in missing_keys: - logger.info(f"\t{missing_key}") - -code_keys_dict = OrderedDict() -for s in strings: - code_keys_dict[s] = s - -# write back -with open(standard_file, "w", encoding="utf-8") as f: - json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True) - f.write("\n") - -logger.info(f"Updated {standard_file}") - - -# Define the standard file name -standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" - -# Find all JSON files in the directory -dir_path = I18N_FILE_PATH -languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE] - -# Load the standard file -with open(standard_file, "r", encoding="utf-8") as f: - standard_data = json.load(f, object_pairs_hook=OrderedDict) - -# Loop through each language file -for lang_file in languages: - # Load the language file - with open(lang_file, "r", encoding="utf-8") as f: - lang_data = json.load(f, object_pairs_hook=OrderedDict) - - # Find the difference between the language file and the standard file - diff = set(standard_data.keys()) - set(lang_data.keys()) - - miss = set(lang_data.keys()) - set(standard_data.keys()) - - # Add any missing keys to the language file - for key in diff: - lang_data[key] = "#!" + key - logger.info(f"Added missing key: {key} to {lang_file}") - - # Del any extra keys to the language file - for key in miss: - del lang_data[key] - logger.info(f"Del extra key: {key} from {lang_file}") - - # Sort the keys of the language file to match the order of the standard file - lang_data = OrderedDict( - sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0])) - ) - - # Save the updated language file - with open(lang_file, "w", encoding="utf-8") as f: - json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True) - f.write("\n") - - logger.info(f"Updated {lang_file}") - -logger.info("Done") +import ast +import glob +import json +from collections import OrderedDict +from pathlib import Path + +from loguru import logger + +from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH + + +def extract_i18n_strings(node): + i18n_strings = [] + + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id == "i18n" + ): + for arg in node.args: + if isinstance(arg, ast.Str): + i18n_strings.append(arg.s) + + for child_node in ast.iter_child_nodes(node): + i18n_strings.extend(extract_i18n_strings(child_node)) + + return i18n_strings + + +# scan the directory for all .py files (recursively) +# for each file, parse the code into an AST +# for each AST, extract the i18n strings + +strings = [] +folders = ["fish_speech", "tools"] +# for filename in glob.iglob("**/*.py", recursive=True): +for folder in folders: + for f in Path(folder).rglob("*.py"): + code = f.read_text(encoding="utf-8") + if "i18n(" in code: + tree = ast.parse(code) + i18n_strings = extract_i18n_strings(tree) + logger.info(f"Found {len(i18n_strings)} i18n strings in {f}") + strings.extend(i18n_strings) + +code_keys = set(strings) +logger.info(f"Total unique: {len(code_keys)}") + + +standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" +with open(standard_file, "r", encoding="utf-8") as f: + standard_data = json.load(f, object_pairs_hook=OrderedDict) +standard_keys = set(standard_data.keys()) + +# Define the standard file name +unused_keys = standard_keys - code_keys +logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}") +for unused_key in unused_keys: + logger.info(f"\t{unused_key}") + +missing_keys = code_keys - standard_keys +logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}") +for missing_key in missing_keys: + logger.info(f"\t{missing_key}") + +code_keys_dict = OrderedDict() +for s in strings: + code_keys_dict[s] = s + +# write back +with open(standard_file, "w", encoding="utf-8") as f: + json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True) + f.write("\n") + +logger.info(f"Updated {standard_file}") + + +# Define the standard file name +standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" + +# Find all JSON files in the directory +dir_path = I18N_FILE_PATH +languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE] + +# Load the standard file +with open(standard_file, "r", encoding="utf-8") as f: + standard_data = json.load(f, object_pairs_hook=OrderedDict) + +# Loop through each language file +for lang_file in languages: + # Load the language file + with open(lang_file, "r", encoding="utf-8") as f: + lang_data = json.load(f, object_pairs_hook=OrderedDict) + + # Find the difference between the language file and the standard file + diff = set(standard_data.keys()) - set(lang_data.keys()) + + miss = set(lang_data.keys()) - set(standard_data.keys()) + + # Add any missing keys to the language file + for key in diff: + lang_data[key] = "#!" + key + logger.info(f"Added missing key: {key} to {lang_file}") + + # Del any extra keys to the language file + for key in miss: + del lang_data[key] + logger.info(f"Del extra key: {key} from {lang_file}") + + # Sort the keys of the language file to match the order of the standard file + lang_data = OrderedDict( + sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0])) + ) + + # Save the updated language file + with open(lang_file, "w", encoding="utf-8") as f: + json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True) + f.write("\n") + + logger.info(f"Updated {lang_file}") + +logger.info("Done") diff --git a/fish_speech/inference_engine/__init__.py b/fish_speech/inference_engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f54de5bf7f4d70dbfcfaff10423ac75d4c7d07b0 --- /dev/null +++ b/fish_speech/inference_engine/__init__.py @@ -0,0 +1,192 @@ +import gc +import queue +from typing import Generator + +import numpy as np +import torch +from loguru import logger + +from fish_speech.inference_engine.reference_loader import ReferenceLoader +from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header +from fish_speech.inference_engine.vq_manager import VQManager +from fish_speech.models.dac.modded_dac import DAC +from fish_speech.models.text2semantic.inference import ( + GenerateRequest, + GenerateResponse, + WrappedGenerateResponse, +) +from fish_speech.utils import autocast_exclude_mps, set_seed +from fish_speech.utils.schema import ServeTTSRequest + + +class TTSInferenceEngine(ReferenceLoader, VQManager): + + def __init__( + self, + llama_queue: queue.Queue, + decoder_model: DAC, + precision: torch.dtype, + compile: bool, + ) -> None: + + super().__init__() + + self.llama_queue = llama_queue + self.decoder_model = decoder_model + self.precision = precision + self.compile = compile + + @torch.inference_mode() + def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]: + """ + Main inference function: + - Loads the reference audio and text. + - Calls the LLAMA model for inference. + - Decodes the VQ tokens to audio. + """ + + ref_id: str | None = req.reference_id + prompt_tokens, prompt_texts = [], [] + # Load the reference audio and text based on id or hash + if ref_id is not None: + prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache) + + elif req.references: + prompt_tokens, prompt_texts = self.load_by_hash( + req.references, req.use_memory_cache + ) + + # Set the random seed if provided + if req.seed is not None: + set_seed(req.seed) + logger.warning(f"set seed: {req.seed}") + + # Get the symbolic tokens from the LLAMA model + response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts) + + # Get the sample rate from the decoder model + if hasattr(self.decoder_model, "spec_transform"): + sample_rate = self.decoder_model.spec_transform.sample_rate + else: + sample_rate = self.decoder_model.sample_rate + + # If streaming, send the header + if req.streaming: + yield InferenceResult( + code="header", + audio=( + sample_rate, + np.array(wav_chunk_header(sample_rate=sample_rate)), + ), + error=None, + ) + + segments = [] + + while True: + # Get the response from the LLAMA model + wrapped_result: WrappedGenerateResponse = response_queue.get() + if wrapped_result.status == "error": + yield InferenceResult( + code="error", + audio=None, + error=( + wrapped_result.response + if isinstance(wrapped_result.response, Exception) + else Exception("Unknown error") + ), + ) + break + + # Check the response type + if not isinstance(wrapped_result.response, GenerateResponse): + raise TypeError( + "Expected GenerateResponse, got {type(wrapped_result.response).__name__}" + ) + + result: GenerateResponse = wrapped_result.response + if result.action != "next": + segment = self.get_audio_segment(result) + + if req.streaming: # Used only by the API server + yield InferenceResult( + code="segment", + audio=(sample_rate, segment), + error=None, + ) + segments.append(segment) + else: + break + + # Clean up the memory + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # Edge case: no audio generated + if len(segments) == 0: + yield InferenceResult( + code="error", + audio=None, + error=RuntimeError("No audio generated, please check the input text."), + ) + else: + # Streaming or not, return the final audio + audio = np.concatenate(segments, axis=0) + yield InferenceResult( + code="final", + audio=(sample_rate, audio), + error=None, + ) + + return None + + def send_Llama_request( + self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list + ) -> queue.Queue: + """ + Send a request to the LLAMA model to generate the symbolic tokens. + """ + + # Prepare the request + request = dict( + device=self.decoder_model.device, + max_new_tokens=req.max_new_tokens, + text=req.text, + top_p=req.top_p, + repetition_penalty=req.repetition_penalty, + temperature=req.temperature, + compile=self.compile, + iterative_prompt=req.chunk_length > 0, + chunk_length=req.chunk_length, + prompt_tokens=prompt_tokens, + prompt_text=prompt_texts, + ) + + # Create a queue to get the response + response_queue = queue.Queue() + + # Send the request to the LLAMA model + self.llama_queue.put( + GenerateRequest( + request=request, + response_queue=response_queue, + ) + ) + + return response_queue + + def get_audio_segment(self, result: GenerateResponse) -> np.ndarray: + """ + Decode the VQ tokens to audio. + """ + + # Don't use autocast on MPS devices + with autocast_exclude_mps( + device_type=self.decoder_model.device.type, dtype=self.precision + ): + # Decode the symbolic tokens to audio + segment = self.decode_vq_tokens(codes=result.codes) + + # Convert the audio to numpy + return segment.float().cpu().numpy() diff --git a/fish_speech/inference_engine/reference_loader.py b/fish_speech/inference_engine/reference_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc3a89bb4010efb659cbf6b09e301bb0aac45d1 --- /dev/null +++ b/fish_speech/inference_engine/reference_loader.py @@ -0,0 +1,130 @@ +import io +from hashlib import sha256 +from pathlib import Path +from typing import Callable, Literal, Tuple + +import torch +import torchaudio +from loguru import logger + +from fish_speech.models.dac.modded_dac import DAC +from fish_speech.utils.file import ( + AUDIO_EXTENSIONS, + audio_to_bytes, + list_files, + read_ref_text, +) +from fish_speech.utils.schema import ServeReferenceAudio + + +class ReferenceLoader: + + def __init__(self) -> None: + """ + Component of the TTSInferenceEngine class. + Loads and manages the cache for the reference audio and text. + """ + self.ref_by_id: dict = {} + self.ref_by_hash: dict = {} + + # Make Pylance happy (attribut/method not defined...) + self.decoder_model: DAC + self.encode_reference: Callable + + # Define the torchaudio backend + backends = torchaudio.list_audio_backends() + if "ffmpeg" in backends: + self.backend = "ffmpeg" + else: + self.backend = "soundfile" + + def load_by_id( + self, + id: str, + use_cache: Literal["on", "off"], + ) -> Tuple: + + # Load the references audio and text by id + ref_folder = Path("references") / id + ref_folder.mkdir(parents=True, exist_ok=True) + ref_audios = list_files( + ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False + ) + + if use_cache == "off" or id not in self.ref_by_id: + # If the references are not already loaded, encode them + prompt_tokens = [ + self.encode_reference( + # decoder_model=self.decoder_model, + reference_audio=audio_to_bytes(str(ref_audio)), + enable_reference_audio=True, + ) + for ref_audio in ref_audios + ] + prompt_texts = [ + read_ref_text(str(ref_audio.with_suffix(".lab"))) + for ref_audio in ref_audios + ] + self.ref_by_id[id] = (prompt_tokens, prompt_texts) + + else: + # Reuse already encoded references + logger.info("Use same references") + prompt_tokens, prompt_texts = self.ref_by_id[id] + + return prompt_tokens, prompt_texts + + def load_by_hash( + self, + references: list[ServeReferenceAudio], + use_cache: Literal["on", "off"], + ) -> Tuple: + + # Load the references audio and text by hash + audio_hashes = [sha256(ref.audio).hexdigest() for ref in references] + + cache_used = False + prompt_tokens, prompt_texts = [], [] + for i, ref in enumerate(references): + if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash: + # If the references are not already loaded, encode them + prompt_tokens.append( + self.encode_reference( + reference_audio=ref.audio, + enable_reference_audio=True, + ) + ) + prompt_texts.append(ref.text) + self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts) + + else: + # Reuse already encoded references + prompt_tokens, prompt_texts = self.ref_by_hash[audio_hashes[i]] + cache_used = True + + if cache_used: + logger.info("Use same references") + + return prompt_tokens, prompt_texts + + def load_audio(self, reference_audio, sr): + """ + Load the audio data from a file or bytes. + """ + if len(reference_audio) > 255 or not Path(reference_audio).exists(): + audio_data = reference_audio + reference_audio = io.BytesIO(audio_data) + + waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend) + + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + if original_sr != sr: + resampler = torchaudio.transforms.Resample( + orig_freq=original_sr, new_freq=sr + ) + waveform = resampler(waveform) + + audio = waveform.squeeze().numpy() + return audio diff --git a/fish_speech/inference_engine/utils.py b/fish_speech/inference_engine/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c11af150dc4a0985b11ba7675d93cf165490f35 --- /dev/null +++ b/fish_speech/inference_engine/utils.py @@ -0,0 +1,29 @@ +import io +import wave +from dataclasses import dataclass +from typing import Literal, Optional, Tuple + +import numpy as np + + +@dataclass +class InferenceResult: + code: Literal["header", "segment", "error", "final"] + audio: Optional[Tuple[int, np.ndarray]] + error: Optional[Exception] + + +def wav_chunk_header( + sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1 +) -> bytes: + buffer = io.BytesIO() + + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(bit_depth // 8) + wav_file.setframerate(sample_rate) + + wav_header_bytes = buffer.getvalue() + buffer.close() + + return wav_header_bytes diff --git a/fish_speech/inference_engine/vq_manager.py b/fish_speech/inference_engine/vq_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef2b55271016e87f7f3d3f0b6d79177fdb27dd4 --- /dev/null +++ b/fish_speech/inference_engine/vq_manager.py @@ -0,0 +1,59 @@ +from typing import Callable + +import torch +from loguru import logger + +from fish_speech.models.dac.modded_dac import DAC + + +class VQManager: + + def __init__(self): + # Make Pylance happy (attribut/method not defined...) + self.decoder_model: DAC + self.load_audio: Callable + + def decode_vq_tokens(self, codes): + feature_lengths = torch.tensor( + [codes.shape[1]], device=self.decoder_model.device + ) + logger.info(f"VQ features: {codes.shape}") + + if isinstance(self.decoder_model, DAC): + return self.decoder_model.decode( + indices=codes[None], + feature_lengths=feature_lengths, + )[0].squeeze() + + raise ValueError(f"Unknown model type: {type(self.decoder_model)}") + + def encode_reference(self, reference_audio, enable_reference_audio): + if enable_reference_audio and reference_audio is not None: + # Load audios, and prepare basic info here + if hasattr(self.decoder_model, "spec_transform"): + sample_rate = self.decoder_model.spec_transform.sample_rate + else: + sample_rate = self.decoder_model.sample_rate + reference_audio_content = self.load_audio(reference_audio, sample_rate) + + audios = torch.from_numpy(reference_audio_content).to( + self.decoder_model.device + )[None, None, :] + audio_lengths = torch.tensor( + [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long + ) + logger.info( + f"Loaded audio with {audios.shape[2] / sample_rate:.2f} seconds" + ) + + # VQ Encoder + if isinstance(self.decoder_model, DAC): + prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0] + logger.info(f"Encoded prompt: {prompt_tokens.shape}") + else: + raise ValueError(f"Unknown model type: {type(self.decoder_model)}") + else: + prompt_tokens = None + logger.info("No reference audio provided") + + return prompt_tokens diff --git a/fish_speech/models/dac/__init__.py b/fish_speech/models/dac/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fish_speech/models/dac/inference.py b/fish_speech/models/dac/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..2cd0452757ec50c8a89934f24ea9cc196a45de31 --- /dev/null +++ b/fish_speech/models/dac/inference.py @@ -0,0 +1,123 @@ +from pathlib import Path + +import click +import hydra +import numpy as np +import pyrootutils +import soundfile as sf +import torch +import torchaudio +from hydra import compose, initialize +from hydra.utils import instantiate +from loguru import logger +from omegaconf import OmegaConf + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from fish_speech.utils.file import AUDIO_EXTENSIONS + +# register eval resolver +OmegaConf.register_new_resolver("eval", eval) + + +def load_model(config_name, checkpoint_path, device="cuda"): + hydra.core.global_hydra.GlobalHydra.instance().clear() + with initialize(version_base="1.3", config_path="../../configs"): + cfg = compose(config_name=config_name) + + model = instantiate(cfg) + state_dict = torch.load( + checkpoint_path, map_location=device, mmap=True, weights_only=True + ) + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + + if any("generator" in k for k in state_dict): + state_dict = { + k.replace("generator.", ""): v + for k, v in state_dict.items() + if "generator." in k + } + + result = model.load_state_dict(state_dict, strict=False, assign=True) + model.eval() + model.to(device) + + logger.info(f"Loaded model: {result}") + return model + + +@torch.no_grad() +@click.command() +@click.option( + "--input-path", + "-i", + default="test.wav", + type=click.Path(exists=True, path_type=Path), +) +@click.option( + "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path) +) +@click.option("--config-name", default="modded_dac_vq") +@click.option( + "--checkpoint-path", + default="checkpoints/openaudio-s1-mini/codec.pth", +) +@click.option( + "--device", + "-d", + default="cuda", +) +def main(input_path, output_path, config_name, checkpoint_path, device): + model = load_model(config_name, checkpoint_path, device=device) + + if input_path.suffix in AUDIO_EXTENSIONS: + logger.info(f"Processing in-place reconstruction of {input_path}") + + # Load audio + audio, sr = torchaudio.load(str(input_path)) + if audio.shape[0] > 1: + audio = audio.mean(0, keepdim=True) + audio = torchaudio.functional.resample(audio, sr, model.sample_rate) + + audios = audio[None].to(device) + logger.info( + f"Loaded audio with {audios.shape[2] / model.sample_rate:.2f} seconds" + ) + + # VQ Encoder + audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long) + indices, indices_lens = model.encode(audios, audio_lengths) + + if indices.ndim == 3: + indices = indices[0] + + logger.info(f"Generated indices of shape {indices.shape}") + + # Save indices + np.save(output_path.with_suffix(".npy"), indices.cpu().numpy()) + elif input_path.suffix == ".npy": + logger.info(f"Processing precomputed indices from {input_path}") + indices = np.load(input_path) + indices = torch.from_numpy(indices).to(device).long() + assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}" + indices_lens = torch.tensor([indices.shape[1]], device=device, dtype=torch.long) + else: + raise ValueError(f"Unknown input type: {input_path}") + + # Restore + fake_audios, audio_lengths = model.decode(indices, indices_lens) + audio_time = fake_audios.shape[-1] / model.sample_rate + + logger.info( + f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}" + ) + + # Save audio + fake_audio = fake_audios[0, 0].float().cpu().numpy() + sf.write(output_path, fake_audio, model.sample_rate) + logger.info(f"Saved audio to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/fish_speech/models/dac/modded_dac.py b/fish_speech/models/dac/modded_dac.py new file mode 100644 index 0000000000000000000000000000000000000000..db0a7a0ccece45c2b175a6ef1d06e7193bd15a07 --- /dev/null +++ b/fish_speech/models/dac/modded_dac.py @@ -0,0 +1,1024 @@ +import math +import typing as tp +from dataclasses import dataclass +from typing import List, Optional, Union + +import hydra +import librosa +import numpy as np +import soundfile as sf +import torch +from audiotools import AudioSignal +from audiotools.ml import BaseModel +from dac.model.base import CodecMixin +from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d +from omegaconf import OmegaConf +from torch import Tensor, nn +from torch.nn import functional as F +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations + + +@dataclass +class VQResult: + z: torch.Tensor + codes: torch.Tensor + latents: torch.Tensor + codebook_loss: torch.Tensor + commitment_loss: torch.Tensor + semantic_distill_z: torch.Tensor | None = None + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class ModelArgs: + block_size: int = 2048 + n_layer: int = 8 + n_head: int = 8 + dim: int = 512 + intermediate_size: int = 1536 + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + dropout_rate: float = 0.1 + attn_dropout_rate: float = 0.1 + channels_first: bool = True # to be compatible with conv1d input/output + pos_embed_type: str = "rope" # can be "rope" or "conformer" + max_relative_position: int = 128 # for conformer-style relative position embedding + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + assert self.pos_embed_type in [ + "rope", + "conformer", + ], "pos_embed_type must be either 'rope' or 'conformer'" + + +class KVCache(nn.Module): + def __init__( + self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16 + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return ( + k_out[:, :, : input_pos.max() + 1, :], + v_out[:, :, : input_pos.max() + 1, :], + ) + + def clear_cache(self, prompt_len): + self.k_cache[:, :, prompt_len:, :].fill_(0) + self.v_cache[:, :, prompt_len:, :].fill_(0) + + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.layers = nn.ModuleList( + TransformerBlock(config) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + + # Only compute RoPE frequencies if using RoPE + if config.pos_embed_type == "rope": + freqs_cis = precompute_freqs_cis( + self.config.block_size, self.config.head_dim, self.config.rope_base + ) + self.register_buffer("freqs_cis", freqs_cis) + else: + self.register_buffer("freqs_cis", None) + + causal_mask = torch.tril( + torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool) + ) + self.register_buffer("causal_mask", causal_mask) + + self.max_batch_size = -1 + self.max_seq_length = -1 + self.use_kv_cache = False + + def setup_caches(self, max_batch_size, max_seq_length): + """ + This method will only be called during inference when using KV cache. + """ + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + dtype = self.norm.weight.dtype + device = self.norm.weight.device + + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, + max_seq_length, + self.config.n_local_heads, + head_dim, + dtype, + ).to(device) + + self.use_kv_cache = True + + def forward( + self, + x: Tensor, + input_pos: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + ) -> Tensor: + if self.config.pos_embed_type == "rope": + assert ( + self.freqs_cis is not None + ), "RoPE frequencies must be initialized for RoPE positional embedding" + freqs_cis = self.freqs_cis[input_pos] + else: + freqs_cis = None + + if mask is None: # in case of non-causal model + if not self.training and self.use_kv_cache: + mask = self.causal_mask[None, None, input_pos] + mask = mask[..., : input_pos.max() + 1] + else: + mask = self.causal_mask[None, None, input_pos] + mask = mask[..., input_pos] + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + return x + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.attention_layer_scale = LayerScale(config.dim, inplace=True) + self.ffn_layer_scale = LayerScale(config.dim, inplace=True) + + def forward( + self, + x: Tensor, + input_pos: Tensor, + freqs_cis: Tensor, + mask: Tensor, + ) -> Tensor: + h = x + self.attention_layer_scale( + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + ) + out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h))) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self.attn_dropout_rate = config.attn_dropout_rate + self.pos_embed_type = config.pos_embed_type + + # Add relative position embedding for conformer-style + if self.pos_embed_type == "conformer": + self.max_relative_position = config.max_relative_position + num_pos_embeddings = 2 * config.max_relative_position + 1 + self.rel_pos_embeddings = nn.Parameter( + torch.zeros(num_pos_embeddings, self.head_dim) + ) + nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02) + + def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor: + # q: [B, H, S, D] + # Returns: [B, H, S, S] + positions = torch.arange(seqlen, device=q.device) + relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S] + relative_positions = torch.clamp( + relative_positions + self.max_relative_position, + 0, + 2 * self.max_relative_position, + ) + rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D] + + # Compute attention scores with relative position embeddings + q = q.transpose(1, 2) # [B, S, H, D] + rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S] + rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S] + return rel_logits + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) + context_seqlen = seqlen + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + + if self.pos_embed_type == "rope": + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + + if self.pos_embed_type == "conformer": + # Compute attention scores + scale = 1.0 / math.sqrt(self.head_dim) + scores = torch.matmul(q, k.transpose(-2, -1)) * scale + + # Add relative position embeddings for conformer-style + rel_scores = self._compute_conformer_pos_scores(q, seqlen) + scores = scores + rel_scores + + # Apply attention + if mask is not None: + scores = scores.masked_fill(~mask, float("-inf")) + + attn = F.softmax(scores, dim=-1) + if self.attn_dropout_rate > 0 and self.training: + attn = F.dropout(attn, p=self.attn_dropout_rate) + + y = torch.matmul(attn, v) + else: + y = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_dropout_rate if self.training else 0.0, + attn_mask=mask, + ) + # is_causal=True) + y = ( + y.transpose(1, 2) + .contiguous() + .view(bsz, seqlen, self.head_dim * self.n_head) + ) + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x))) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-2, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class WindowLimitedTransformer(Transformer): + """ + Transformer with window limited attention, causal. + """ + + def __init__( + self, + config: ModelArgs, + input_dim: int = 512, + window_size: Optional[int] = None, + causal: bool = True, + look_ahead_conv: nn.Module = None, + ): + super().__init__(config) + self.window_size = window_size + self.causal = causal + self.channels_first = config.channels_first + self.look_ahead_conv = ( + look_ahead_conv if look_ahead_conv is not None else nn.Identity() + ) + self.input_proj = ( + nn.Linear(input_dim, config.dim) + if input_dim != config.dim + else nn.Identity() + ) + self.output_proj = ( + nn.Linear(config.dim, input_dim) + if input_dim != config.dim + else nn.Identity() + ) + + def make_window_limited_mask( + self, + max_length: int, + x_lens: Optional[Tensor] = None, + ) -> Tensor: + """ + Make mask to form window limited attention. + """ + if self.causal: + mask = torch.tril(torch.ones(max_length, max_length)) + row_indices = torch.arange(max_length).view(-1, 1) + window_size = self.window_size or max_length + valid_range = (row_indices - window_size + 1).clamp(min=0) + column_indices = torch.arange(max_length) + mask = (column_indices >= valid_range) & mask.bool() + else: + raise NotImplementedError + mask = mask.bool()[None, None] + return mask + + def make_mask( + self, + max_length: int, + x_lens: Optional[Tensor] = None, + ) -> Tensor: + """ + Make ordinary mask if window size is not specified. + """ + if self.causal: + mask = torch.tril(torch.ones(max_length, max_length)) + else: + mask = torch.ones(max_length, max_length) + mask = mask.bool()[None, None] + for i, x_len in enumerate(x_lens): + mask[:x_len, i] = 0 + mask = mask.bool()[None, None] + return mask + + def forward( + self, + x: Tensor, + x_lens: Optional[Tensor] = None, + ) -> Tensor: + if self.channels_first: + x = x.transpose(1, 2) + x = self.input_proj(x) # (B, T, D) + x = self.look_ahead_conv(x) + input_pos = torch.arange(x.shape[1], device=x.device) + # construct mask to form window limited attention + max_length = x.shape[1] + if self.window_size is not None: + mask = self.make_window_limited_mask(max_length, x_lens) + else: + mask = self.make_mask(max_length, x_lens) + mask = mask.to(x.device) + x = super().forward(x, input_pos, mask) + x = self.output_proj(x) # (B, T, D) + if self.channels_first: + x = x.transpose(1, 2) + return x + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16 +) -> Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left:end] + + +def get_extra_padding_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad1d( + x: torch.Tensor, + paddings: tp.Tuple[int, int], + mode: str = "zeros", + value: float = 0.0, +): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right + before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +class CausalConvNet(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation=1, + stride=1, + groups=1, + padding=None, + ): + super(CausalConvNet, self).__init__() + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + ) + self.stride = stride + self.kernel_size = (kernel_size - 1) * dilation + 1 + self.dilation = dilation + self.padding = self.kernel_size - self.stride + + def forward(self, x): + pad = self.padding + extra_padding = get_extra_padding_for_conv1d( + x, self.kernel_size, self.stride, pad + ) + x = pad1d(x, (pad, extra_padding), mode="constant", value=0) + return self.conv(x).contiguous() + + def weight_norm(self, name="weight", dim=0): + self.conv = weight_norm(self.conv, name=name, dim=dim) + return self + + def remove_weight_norm(self): + self.conv = remove_parametrizations(self.conv) + return self + + +class CausalTransConvNet(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None + ): + super(CausalTransConvNet, self).__init__() + self.conv = nn.ConvTranspose1d( + in_channels, out_channels, kernel_size, stride=stride, dilation=dilation + ) + self.stride = stride + self.kernel_size = kernel_size + + def forward(self, x): + x = self.conv(x) + pad = self.kernel_size - self.stride + padding_right = math.ceil(pad) + padding_left = pad - padding_right + x = unpad1d(x, (padding_left, padding_right)) + return x.contiguous() + + def weight_norm(self, name="weight", dim=0): + self.conv = weight_norm(self.conv, name=name, dim=dim) + return self + + def remove_weight_norm(self): + self.conv = remove_parametrizations(self.conv) + return self + + +def CausalWNConv1d(*args, **kwargs): + return CausalConvNet(*args, **kwargs).weight_norm() + + +def CausalWNConvTranspose1d(*args, **kwargs): + return CausalTransConvNet(*args, **kwargs).weight_norm() + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False): + super().__init__() + conv_class = CausalWNConv1d if causal else WNConv1d + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + conv_class(dim, dim, kernel_size=1), + ) + self.causal = causal + + def forward(self, x): + y = self.block(x) + pad = x.shape[-1] - y.shape[-1] + if pad > 0: + if self.causal: + x = x[..., :-pad] + else: + x = x[..., pad // 2 : -pad // 2] + return x + y + + +class EncoderBlock(nn.Module): + def __init__( + self, + dim: int = 16, + stride: int = 1, + causal: bool = False, + n_t_layer: int = 0, + transformer_general_config=None, + ): + super().__init__() + conv_class = CausalWNConv1d if causal else WNConv1d + transformer_module = ( + nn.Identity() + if n_t_layer == 0 + else ( + WindowLimitedTransformer( + causal=causal, + input_dim=dim, + window_size=512, + config=transformer_general_config( + n_layer=n_t_layer, + n_head=dim // 64, + dim=dim, + intermediate_size=dim * 3, + ), + ) + ) + ) + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1, causal=causal), + ResidualUnit(dim // 2, dilation=3, causal=causal), + ResidualUnit(dim // 2, dilation=9, causal=causal), + Snake1d(dim // 2), + conv_class( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + transformer_module, + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 64, + n_transformer_layers: list = [0, 0, 4, 4], + transformer_general_config: ModelArgs = None, + causal: bool = False, + ): + super().__init__() + conv_class = CausalWNConv1d if causal else WNConv1d + # Create first convolution + self.block = [conv_class(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride, n_t_layer in zip(strides, n_transformer_layers): + d_model *= 2 + self.block += [ + EncoderBlock( + d_model, + stride=stride, + causal=causal, + n_t_layer=n_t_layer, + transformer_general_config=transformer_general_config, + ) + ] + + # Create last convolution + self.block += [ + Snake1d(d_model), + conv_class(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__( + self, + input_dim: int = 16, + output_dim: int = 8, + stride: int = 1, + causal: bool = False, + n_t_layer: int = 0, + transformer_general_config=None, + ): + super().__init__() + conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d + transformer_module = ( + nn.Identity() + if n_t_layer == 0 + else ( + WindowLimitedTransformer( + causal=causal, + input_dim=input_dim, + window_size=None, + config=transformer_general_config( + n_layer=n_t_layer, + n_head=input_dim // 64, + dim=input_dim, + intermediate_size=input_dim * 3, + ), + ) + ) + ) + self.block = nn.Sequential( + # transformer_module, + Snake1d(input_dim), + conv_trans_class( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ResidualUnit(output_dim, dilation=1, causal=causal), + ResidualUnit(output_dim, dilation=3, causal=causal), + ResidualUnit(output_dim, dilation=9, causal=causal), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + causal: bool = False, + n_transformer_layers: list = [0, 0, 0, 0], + transformer_general_config=None, + ): + super().__init__() + conv_class = CausalWNConv1d if causal else WNConv1d + # Add first conv layer + layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [ + DecoderBlock( + input_dim, + output_dim, + stride, + causal=causal, + n_t_layer=n_t_layer, + transformer_general_config=transformer_general_config, + ) + ] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + conv_class(output_dim, d_out, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DAC(BaseModel, CodecMixin): + def __init__( + self, + encoder_dim: int = 64, + encoder_rates: List[int] = [2, 4, 8, 8], + latent_dim: int = None, + decoder_dim: int = 1536, + decoder_rates: List[int] = [8, 8, 4, 2], + quantizer: torch.nn.Module = None, + sample_rate: int = 44100, + causal: bool = True, + encoder_transformer_layers: List[int] = [0, 0, 0, 0], + decoder_transformer_layers: List[int] = [0, 0, 0, 0], + transformer_general_config=None, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder( + encoder_dim, + encoder_rates, + latent_dim, + causal=causal, + n_transformer_layers=encoder_transformer_layers, + transformer_general_config=transformer_general_config, + ) + + self.quantizer = quantizer + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + causal=causal, + n_transformer_layers=decoder_transformer_layers, + transformer_general_config=transformer_general_config, + ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + self.frame_length = self.hop_length * 4 + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + audio_lengths: torch.Tensor = None, + n_quantizers: int = None, + **kwargs, + ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + # pad to multiple of self.frame_length + if audio_data.ndim == 2: + audio_data = audio_data.unsqueeze(1) + # print(audio_data.shape) + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.frame_length) * self.frame_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + if audio_lengths is None: + audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device) + + z = self.encoder(audio_data) + vq_results = self.quantizer(z, n_quantizers, **kwargs) + indices = vq_results.codes + indices_lens = torch.ceil(audio_lengths / self.frame_length).long() + return indices, indices_lens + + def decode(self, indices: torch.Tensor, feature_lengths): + if indices.ndim == 2: + indices = indices[None] + + z = self.quantizer.decode(indices) + audio_lengths = feature_lengths * self.frame_length + return self.decoder(z), audio_lengths + + def forward( + self, + audio_data: torch.Tensor, + template: torch.Tensor = None, + mask: torch.Tensor = None, + sample_rate: int = None, + n_quantizers: int = None, + **kwargs, + ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + vq_results = self.encode(audio_data, n_quantizers, **kwargs) + z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z + x = self.decode(z) + return x[..., :length], vq_results + + +if __name__ == "__main__": + + def filter_state_dict_shapes(params, model): + model_state_dict = model.state_dict() + filtered_state_dict = { + k: v + for k, v in params.items() + if k in model_state_dict and v.shape == model_state_dict[k].shape + } + skipped_keys = set(params.keys()) - set(filtered_state_dict.keys()) + if skipped_keys: + print( + f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}" + ) + return filtered_state_dict, skipped_keys + + model = hydra.utils.instantiate( + OmegaConf.load("fish_speech/configs/modded_dac_vq.yaml") + ) + sd = torch.load("checkpoints/openaudio-s1-mini/firefly-gan-large.pth") + filtered_sd, skipped_keys = filter_state_dict_shapes(sd, model) + print(f"Skipped keys: {skipped_keys}") + model.load_state_dict(filtered_sd, strict=False) + model.eval() + + src_audio_path = "./test.wav" + wave_np, _ = librosa.load(src_audio_path, sr=44100, mono=False) + if len(wave_np.shape) == 1: + wave_np = wave_np[None, :] + wave_tensor = torch.from_numpy(wave_np).unsqueeze(1) + + with torch.no_grad(): + # encode 返回 (indices, indices_lens) + indices, indices_lens = model.encode(wave_tensor) + print(f"Indices shape: {indices.shape}") + print(f"Indices lengths: {indices_lens}") + + # decode 需要 indices 和 feature_lengths 两个参数 + fake_audio, audio_lengths = model.decode(indices, indices_lens) + print(f"Decoded audio shape: {fake_audio.shape}") + print(f"Audio lengths: {audio_lengths}") + + # 保存重建的音频 + sf.write("fake.wav", fake_audio.squeeze(1).cpu().numpy().T, 44100) diff --git a/fish_speech/models/dac/rvq.py b/fish_speech/models/dac/rvq.py new file mode 100644 index 0000000000000000000000000000000000000000..e5b4647be5f507b45ce00d57296785a128f5606f --- /dev/null +++ b/fish_speech/models/dac/rvq.py @@ -0,0 +1,403 @@ +import math +import typing as tp +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from dac.nn.quantize import ResidualVectorQuantize +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left:end] + + +def get_extra_padding_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad1d( + x: torch.Tensor, + paddings: tp.Tuple[int, int], + mode: str = "zeros", + value: float = 0.0, +): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right + before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +class CausalConvNet(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation=1, + stride=1, + groups=1, + padding=None, + ): + super(CausalConvNet, self).__init__() + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + ) + self.stride = stride + self.kernel_size = (kernel_size - 1) * dilation + 1 + self.dilation = dilation + self.padding = self.kernel_size - self.stride + + def forward(self, x): + pad = self.padding + extra_padding = get_extra_padding_for_conv1d( + x, self.kernel_size, self.stride, pad + ) + x = pad1d(x, (pad, extra_padding), mode="constant", value=0) + return self.conv(x).contiguous() + + def weight_norm(self, name="weight", dim=0): + self.conv = weight_norm(self.conv, name=name, dim=dim) + return self + + def remove_weight_norm(self): + self.conv = remove_parametrizations(self.conv) + return self + + +class CausalTransConvNet(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None + ): + super(CausalTransConvNet, self).__init__() + self.conv = nn.ConvTranspose1d( + in_channels, out_channels, kernel_size, stride=stride, dilation=dilation + ) + self.stride = stride + self.kernel_size = kernel_size + + def forward(self, x): + x = self.conv(x) + pad = self.kernel_size - self.stride + padding_right = math.ceil(pad) + padding_left = pad - padding_right + x = unpad1d(x, (padding_left, padding_right)) + return x.contiguous() + + def weight_norm(self, name="weight", dim=0): + self.conv = weight_norm(self.conv, name=name, dim=dim) + return self + + def remove_weight_norm(self): + self.conv = remove_parametrizations(self.conv) + return self + + +# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py +class ConvNeXtBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + kernel_size (int): Kernel size for depthwise conv. Default: 7. + dilation (int): Dilation for depthwise conv. Default: 1. + """ # noqa: E501 + + def __init__( + self, + dim: int, + layer_scale_init_value: float = 1e-6, + mlp_ratio: float = 4.0, + kernel_size: int = 7, + dilation: int = 1, + ): + super().__init__() + convnet_type = CausalConvNet + self.dwconv = convnet_type( + dim, + dim, + kernel_size=kernel_size, + # padding=int(dilation * (kernel_size - 1) / 2), + groups=dim, + dilation=dilation, + ) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, int(mlp_ratio * dim) + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward(self, x, apply_residual: bool = True): + input = x + + x = self.dwconv(x) + x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + if self.gamma is not None: + x = self.gamma * x + + x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) + + if apply_residual: + x = input + x + + return x + + +@dataclass +class VQResult: + z: torch.Tensor + codes: torch.Tensor + latents: torch.Tensor + codebook_loss: torch.Tensor + commitment_loss: torch.Tensor + semantic_distill_z: torch.Tensor | None = None + + +class DownsampleResidualVectorQuantize(nn.Module): + def __init__( + self, + input_dim: int = 1024, + n_codebooks: int = 9, + codebook_dim: int = 8, + quantizer_dropout: float = 0.5, + codebook_size: int = 1024, + semantic_codebook_size: int = 4096, + downsample_factor: tuple[int] = (2, 2), + downsample_dims: tuple[int] | None = None, + pre_module: nn.Module | None = None, + post_module: nn.Module | None = None, + semantic_predictor_module: nn.Module | None = None, + ): + super().__init__() + + if downsample_dims is None: + downsample_dims = [input_dim for _ in range(len(downsample_factor))] + + all_dims = (input_dim,) + tuple(downsample_dims) + + self.semantic_quantizer = ResidualVectorQuantize( + input_dim=input_dim, + n_codebooks=1, + codebook_size=semantic_codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=0.0, + ) + + self.quantizer = ResidualVectorQuantize( + input_dim=input_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.downsample_factor = downsample_factor + self.downsample_dims = downsample_dims + + convnet_type = CausalConvNet + transconvnet_type = CausalTransConvNet + + self.downsample = nn.Sequential( + *[ + nn.Sequential( + convnet_type( + all_dims[idx], + all_dims[idx + 1], + kernel_size=factor, + stride=factor, + ), + ConvNeXtBlock(dim=all_dims[idx + 1]), + ) + for idx, factor in enumerate(downsample_factor) + ] + ) + + self.upsample = nn.Sequential( + *[ + nn.Sequential( + transconvnet_type( + all_dims[idx + 1], + all_dims[idx], + kernel_size=factor, + stride=factor, + ), + ConvNeXtBlock(dim=all_dims[idx]), + ) + for idx, factor in reversed(list(enumerate(downsample_factor))) + ] + ) + self.apply(self._init_weights) + self.pre_module = ( + pre_module if pre_module is not None else nn.Identity() + ) # leave for transformer, LSTM or Mamba or something else + self.post_module = post_module if post_module is not None else nn.Identity() + self.semantic_predictor_module = ( + semantic_predictor_module + if semantic_predictor_module is not None + else nn.Identity() + ) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward( + self, z, n_quantizers: int = None, semantic_len: torch.Tensor = None, **kwargs + ): + # z: (B, D, T) + original_shape = z.shape + if semantic_len is None: + semantic_len = torch.LongTensor([z.shape[-1]]) + z = self.downsample(z) + z = self.pre_module(z) # B, T, D + ( + semantic_z, + semantic_codes, + semantic_latents, + semantic_commitment_loss, + semantic_codebook_loss, + ) = self.semantic_quantizer(z) + residual_z = z - semantic_z + residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer( + residual_z, n_quantizers=n_quantizers + ) + z = semantic_z + residual_z + commitment_loss = commitment_loss + semantic_commitment_loss + codebook_loss = codebook_loss + semantic_codebook_loss + codes = torch.cat([semantic_codes, codes], dim=1) + latents = torch.cat([semantic_latents, latents], dim=1) + z = self.post_module(z) + z = self.upsample(z) + # z: (B, D, T) + + # semantic distillation (disabled here since only used in training) + # semantic_distill_z = self.semantic_predictor_module(semantic_z, semantic_len).mT # wav2vec target is B, T, D + + # Pad or crop z to match original shape + diff = original_shape[-1] - z.shape[-1] + right = 0 + left = abs(diff) - right + + if diff > 0: + z = F.pad(z, (left, right)) + elif diff < 0: + z = z[..., left:] + + results = VQResult( + z=z, + codes=codes, + latents=latents, + commitment_loss=commitment_loss, + codebook_loss=codebook_loss, + ) + + return results + + # def encode(self, z): + # z = self.downsample(z) + # z = self.pre_module(z) + # _, indices, _, _, _ = self.quantizer(z.mT) + # indices = rearrange(indices, "g b l r -> b (g r) l") + # return indices + # + def decode(self, indices: torch.Tensor): + # indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) + + # print(f"indices: {indices.shape}, semantic_quantizer.codebook_size: {self.semantic_quantizer.codebook_size}, quantizer.codebook_size: {self.quantizer.codebook_size}, semantic min: {indices[:, 0].min()}, max: {indices[:, 0].max()}, quantizer min: {indices[:, 1:].min()}, max: {indices[:, 1:].max()}") + + new_indices = torch.zeros_like(indices) + new_indices[:, 0] = torch.clamp( + indices[:, 0], max=self.semantic_quantizer.codebook_size - 1 + ) + new_indices[:, 1:] = torch.clamp( + indices[:, 1:], max=self.quantizer.codebook_size - 1 + ) + + z_q_semantic = self.semantic_quantizer.from_codes(new_indices[:, :1])[0] + z_q_residual = self.quantizer.from_codes(new_indices[:, 1:])[0] + z_q = z_q_semantic + z_q_residual + z_q = self.post_module(z_q) + z_q = self.upsample(z_q) + return z_q + + # def from_latents(self, latents: torch.Tensor): + # z_q, z_p, codes = super().from_latents(latents) + # z_q = self.upsample(z_q) + # return z_q, z_p, codes + + +if __name__ == "__main__": + rvq = DownsampleResidualVectorQuantize( + input_dim=512, + n_codebooks=8, + codebook_dim=8, + codebook_size=1024, + quantizer_dropout=0.5, + downsample_factor=[2, 2], + ) + rvq.eval() + x = torch.randn(2, 512, 442) + + result = rvq(x) + print(rvq) + print(result.latents.shape, result.codes.shape, result.z.shape) + + # y = rvq.from_codes(result.codes) + # print(y[0].shape) + + # y = rvq.from_latents( + + result1 = rvq(x[:, :, :40]) + print(result1.latents.shape, result1.codes.shape, result1.z.shape) + + assert torch.allclose(result.z[:, :, :40], result1.z, atol=1e-8) + print("Success") diff --git a/fish_speech/models/text2semantic/inference.py b/fish_speech/models/text2semantic/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..9442c7cc731f52db523b46434d4b90c3e2885c4f --- /dev/null +++ b/fish_speech/models/text2semantic/inference.py @@ -0,0 +1,716 @@ +import os +import queue +import threading +import time +from contextlib import nullcontext +from dataclasses import dataclass +from pathlib import Path +from typing import Literal, Optional, Tuple, Union + +import click +import numpy as np +import torch +import torch._dynamo.config +import torch._inductor.config +from loguru import logger +from tqdm import tqdm +from transformers import AutoTokenizer + +from fish_speech.content_sequence import ( + ContentSequence, + TextPart, + VQPart, +) +from fish_speech.models.text2semantic.llama import BaseModelArgs +from fish_speech.text import clean_text, split_text +from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True + +if hasattr(torch._inductor.config, "fx_graph_cache"): + # Experimental feature to reduce compilation times, will be on by default in future + torch._inductor.config.fx_graph_cache = True + + +from torch.nn.attention import SDPBackend, sdpa_kernel + +from fish_speech.models.text2semantic.llama import ( + BaseTransformer, + DualARTransformer, + NaiveTransformer, +) + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs( + logits, + previous_tokens: Optional[torch.Tensor] = None, + temperature: torch.Tensor = 1.0, + top_p: torch.Tensor = 1.0, + repetition_penalty: torch.Tensor = 1.0, +) -> torch.Tensor: + # Apply repetition penalty + if previous_tokens is not None: + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=0, index=previous_tokens) + score = torch.where( + score < 0, score * repetition_penalty, score / repetition_penalty + ) + logits.scatter_(dim=0, index=previous_tokens, src=score) + + # Apply top-p sampling + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[0] = False # keep at least one option + indices_to_remove = sorted_indices_to_remove.scatter( + dim=0, index=sorted_indices, src=sorted_indices_to_remove + ) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) + + logits = logits / max(temperature, 1e-5) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample( + logits, + previous_tokens: Optional[torch.Tensor] = None, + **sampling_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + probs = logits_to_probs( + logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs + ) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def decode_one_token_ar( + model: DualARTransformer, + x: torch.Tensor, + input_pos: torch.Tensor, + semantic_ids: list, + previous_tokens: torch.Tensor = None, + **sampling_kwargs, +) -> torch.Tensor: + x = model.forward_generate(x, input_pos) + + sampling_kwargs_main = sampling_kwargs.copy() + # sampling_kwargs_main["temperature"] = 0.1 + # sampling_kwargs_main["top_p"] = 0.1 + # sampling_kwargs_main["repetition_penalty"] = 1.0 + + codebooks = [ + sample( + x.logits, + previous_tokens=( + previous_tokens[0] if previous_tokens is not None else None + ), # Disable repetition penalty for the token codebook + **sampling_kwargs_main, + )[0] + ] + + hidden_states = x.hidden_states + + # Cleanup the cache + for layer in model.fast_layers: + layer.attention.kv_cache.k_cache.fill_(0) + layer.attention.kv_cache.v_cache.fill_(0) + + input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long) + model.forward_generate_fast(hidden_states, input_pos) + a = codebooks[0] - model.tokenizer.semantic_begin_id + a[a < 0] = 0 + hidden_states = model.fast_embeddings(a) + codebooks.append(a) + + for codebook_idx in range(1, model.config.num_codebooks): + input_pos = torch.tensor( + [codebook_idx], device=hidden_states.device, dtype=torch.long + ) + logits = model.forward_generate_fast(hidden_states, input_pos) + chunked_logits = logits[..., :1024] + a = sample( + chunked_logits, + previous_tokens=( + previous_tokens[codebook_idx + 1] + if previous_tokens is not None + else None + ), + **sampling_kwargs, + )[0] + hidden_states = model.fast_embeddings(a) + codebooks.append(a) + + codebooks = torch.stack(codebooks, dim=0) + # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device) + # codebooks[1:, :] = torch.masked_fill( + # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID + # ) + + # print(codebooks) + return codebooks + + +def decode_n_tokens( + model: NaiveTransformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + semantic_ids: list, + decode_one_token=decode_one_token_ar, + **sampling_kwargs, +): + previous_tokens = torch.zeros( + (model.config.num_codebooks + 1, model.config.max_seq_len), + dtype=torch.int, + device=cur_token.device, + ) + + for i in tqdm(range(num_new_tokens)): + # We need to get windowed repeat penalty + win_size = 16 + if i < win_size: + window = previous_tokens[:, :win_size] + else: + window = previous_tokens[:, i - win_size : i] + + with ( + torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ) + if torch.cuda.is_available() + else nullcontext() + ): # Actually better for Inductor to codegen attention here + next_token = decode_one_token( + model=model, + x=cur_token, + input_pos=input_pos, + previous_tokens=window, + semantic_ids=semantic_ids, + **sampling_kwargs, + ) + + input_pos += 1 + cur_token = next_token.view(1, model.config.num_codebooks + 1, -1) + previous_tokens[:, i : i + 1] = next_token.view( + model.config.num_codebooks + 1, -1 + ) + + if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN): + break + + return previous_tokens[:, : i + 1] + + +@torch.no_grad() +@torch.inference_mode() +def generate( + *, + model: NaiveTransformer, + prompt: torch.Tensor, + max_new_tokens: int, + decode_one_token=decode_one_token_ar, + **sampling_kwargs, +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + + # create an empty tensor of the expected final shape and fill in the current tokens + T = prompt.size(1) + # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>") + semantic_ids = [ + model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024) + ] + + if max_new_tokens: + if T + max_new_tokens > model.config.max_seq_len: + max_new_tokens = model.config.max_seq_len - T + logger.info(f"Truncating max_new_tokens to {max_new_tokens}") + + T_new = T + max_new_tokens + else: + T_new = model.config.max_seq_len + max_new_tokens = T_new - T + + device, dtype = prompt.device, prompt.dtype + + codebook_dim = 1 + model.config.num_codebooks + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty( + (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device + ) + empty[:, :T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + # Use non-accelerated version for now, to avoid compilation overhead + prefill_decode = decode_one_token_ar + + next_token = prefill_decode( + model, + prompt.view(1, codebook_dim, -1), + input_pos, + semantic_ids=semantic_ids, + **sampling_kwargs, + ) + seq[:, T : T + 1] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + x = decode_n_tokens( + model, + next_token.view(1, codebook_dim, -1), + input_pos, + max_new_tokens - 1, + decode_one_token=decode_one_token, + semantic_ids=semantic_ids, + **sampling_kwargs, + ) + # x = torch.cat(generated_tokens, dim=1) + seq = seq[:, : T + 1 + x.size(1)] + seq[:, T + 1 :] = x + + return seq + + +def load_model(checkpoint_path, device, precision, compile=False): + model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True) + + model = model.to(device=device, dtype=precision) + logger.info(f"Restored model from checkpoint") + + if isinstance(model, DualARTransformer): + decode_one_token = decode_one_token_ar + logger.info("Using DualARTransformer") + else: + raise ValueError("Model is not a DualARTransformer") + + if compile: + logger.info("Compiling function...") + decode_one_token = torch.compile( + decode_one_token, + fullgraph=True, + backend="inductor" if torch.cuda.is_available() else "aot_eager", + mode="reduce-overhead" if torch.cuda.is_available() else None, + ) + + return model.eval(), decode_one_token + + +@dataclass +class GenerateResponse: + action: Literal["sample", "next"] + codes: Optional[torch.Tensor] = None + text: Optional[str] = None + + +def generate_long( + *, + model, + device: str | torch.device, + decode_one_token: callable, + text: str, + num_samples: int = 1, + max_new_tokens: int = 0, + top_p: int = 0.8, + repetition_penalty: float = 1.1, + temperature: float = 0.8, + compile: bool = False, + iterative_prompt: bool = True, + chunk_length: int = 150, + prompt_text: Optional[str | list[str]] = None, + prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None, +): + assert 0 < top_p <= 1, "top_p must be in (0, 1]" + assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)" + assert 0 < temperature < 2, "temperature must be in (0, 2)" + + use_prompt = prompt_text is not None and prompt_tokens is not None + if use_prompt and isinstance(prompt_text, str): + prompt_text = [prompt_text] + prompt_tokens = [prompt_tokens] + + assert use_prompt is False or len(prompt_text) == len( + prompt_tokens + ), "Prompt text and tokens must have the same length" + + prompt_tokens = [i.cpu() for i in prompt_tokens] + + model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) + tokenizer = model.tokenizer + base_content_sequence = ContentSequence(modality="interleave") + + texts = split_text(text, chunk_length) if iterative_prompt else [text] + max_length = model.config.max_seq_len + + if use_prompt: + for t, c in zip(prompt_text, prompt_tokens): + base_content_sequence.append( + [ + TextPart(text=t), + VQPart(codes=c), + ], + add_end=True, + ) + + encoded_prompts = base_content_sequence.encode_for_inference( + tokenizer, num_codebooks=model.config.num_codebooks + ) + if encoded_prompts.size(1) > max_length - 2048: + raise ValueError( + f"Prompt is too long: {encoded_prompts.size(1)} > {max_length - 2048}" + ) + + encoded = [] + for text in texts: + content_sequence = ContentSequence(modality=None) + content_sequence.append(TextPart(text=text)) + encoded.append( + content_sequence.encode_for_inference( + tokenizer, num_codebooks=model.config.num_codebooks + ) + ) + logger.info(f"Encoded text: {text}") + + # Move temperature, top_p, repetition_penalty to device + # This is important so that changing params doesn't trigger recompile + temperature = torch.tensor(temperature, device=device, dtype=torch.float) + top_p = torch.tensor(top_p, device=device, dtype=torch.float) + repetition_penalty = torch.tensor( + repetition_penalty, device=device, dtype=torch.float + ) + + for sample_idx in range(num_samples): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + global_encoded = [] + seg_idx = 0 + + while seg_idx < len(encoded): + logger.info( + f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}" + ) + + seg = encoded[seg_idx] + global_encoded.append(seg) + + # Do not use previous segments to generate current segment for now + # lengths = reversed([seg.size(1) for seg in global_encoded]) + + # # Pick last 2000 tokens + # count = 0 + # for i, length in enumerate(lengths): + # count += length + # if count + length > max_length - 2048 - encoded_prompts.size(1): + # break + + # if i != 0 and i % 2 == 0: + # i -= 1 + + # # Rotate the list, always make sure first segment is included to avoid drift + # if i < len(global_encoded) - 2: + # partial_encoded = global_encoded[:2] + global_encoded[-i:] + # else: + # partial_encoded = global_encoded + + # cat_encoded = torch.cat([encoded_prompts, *partial_encoded], dim=1) + if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2: + cat_encoded = torch.cat( + [encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1 + ) + else: + cat_encoded = torch.cat([encoded_prompts, seg], dim=1) + + cat_encoded = cat_encoded.to(device=device) + prompt_length = cat_encoded.size(1) + + t0 = time.perf_counter() + y = generate( + model=model, + prompt=cat_encoded, + max_new_tokens=max_new_tokens, + decode_one_token=decode_one_token, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + ) + + if sample_idx == 0 and seg_idx == 0 and compile: + logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + t = time.perf_counter() - t0 + + tokens_generated = y.size(1) - prompt_length + tokens_sec = tokens_generated / t + logger.info( + f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec" + ) + logger.info( + f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s" + ) + + if torch.cuda.is_available(): + logger.info( + f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB" + ) + + # Put the generated tokens + # since there is , we remove last token + codes = y[1:, prompt_length:-1].clone() + assert (codes >= 0).all(), f"Negative code found" + + decoded = y[:, prompt_length:].clone() + # But for global encoding, we should keep the token + + global_encoded.append(decoded.cpu()) + assert (codes >= 0).all(), f"Negative code found: {codes}" + yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx]) + seg_idx += 1 + + # This indicates the end of the current sample + yield GenerateResponse(action="next") + + +@dataclass +class WrappedGenerateResponse: + status: Literal["success", "error"] + response: Optional[GenerateResponse | Exception] = None + + +@dataclass +class GenerateRequest: + request: dict + response_queue: queue.Queue + + +def launch_thread_safe_queue( + checkpoint_path, + device, + precision, + compile: bool = False, +): + input_queue = queue.Queue() + init_event = threading.Event() + + def worker(): + model, decode_one_token = load_model( + checkpoint_path, device, precision, compile=compile + ) + with torch.device(device): + model.setup_caches( + max_batch_size=1, + max_seq_len=model.config.max_seq_len, + dtype=next(model.parameters()).dtype, + ) + init_event.set() + + while True: + item: GenerateRequest | None = input_queue.get() + if item is None: + break + + kwargs = item.request + response_queue = item.response_queue + + try: + for chunk in generate_long( + model=model, decode_one_token=decode_one_token, **kwargs + ): + response_queue.put( + WrappedGenerateResponse(status="success", response=chunk) + ) + except Exception as e: + response_queue.put(WrappedGenerateResponse(status="error", response=e)) + + threading.Thread(target=worker, daemon=True).start() + init_event.wait() + + return input_queue + + +def launch_thread_safe_queue_agent( + checkpoint_path, + device, + precision, + compile: bool = False, +): + input_queue = queue.Queue() + init_event = threading.Event() + + tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) + config = BaseModelArgs.from_pretrained(checkpoint_path) + + def worker(): + model, decode_one_token = load_model( + checkpoint_path, device, precision, compile=compile, is_agent=True + ) + + with torch.device(device): + model.setup_caches( + max_batch_size=1, + max_seq_len=model.config.max_seq_len, + dtype=next(model.parameters()).dtype, + ) + init_event.set() + + while True: + item: GenerateRequest | None = input_queue.get() + if item is None: + break + + kwargs = item.request + response_queue = item.response_queue + + try: + for token in generate_agent( + model=model, + decode_one_token=decode_one_token, + **kwargs, + ): + response_queue.put(token) + + response_queue.put("stop") + except Exception as e: + import traceback + + logger.exception(f"Error in worker: {traceback.format_exc()}") + response_queue.put("error") + + threading.Thread(target=worker, daemon=True).start() + init_event.wait() + + return input_queue, tokenizer, config + + +@click.command() +@click.option( + "--text", + type=str, + default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.", +) +@click.option("--prompt-text", type=str, default=None, multiple=True) +@click.option( + "--prompt-tokens", + type=click.Path(path_type=Path, exists=True), + default=None, + multiple=True, +) +@click.option("--num-samples", type=int, default=1) +@click.option("--max-new-tokens", type=int, default=0) +@click.option("--top-p", type=float, default=0.8) +@click.option("--repetition-penalty", type=float, default=1.1) +@click.option("--temperature", type=float, default=0.8) +@click.option( + "--checkpoint-path", + type=click.Path(path_type=Path, exists=True), + default="checkpoints/openaudio-s1-mini", +) +@click.option("--device", type=str, default="cuda") +@click.option("--compile/--no-compile", default=False) +@click.option("--seed", type=int, default=42) +@click.option("--half/--no-half", default=False) +@click.option("--iterative-prompt/--no-iterative-prompt", default=True) +@click.option("--chunk-length", type=int, default=300) +@click.option("--output-dir", type=Path, default="temp") +def main( + text: str, + prompt_text: Optional[list[str]], + prompt_tokens: Optional[list[Path]], + num_samples: int, + max_new_tokens: int, + top_p: int, + repetition_penalty: float, + temperature: float, + checkpoint_path: Path, + device: str, + compile: bool, + seed: int, + half: bool, + iterative_prompt: bool, + chunk_length: int, + output_dir: Path, +) -> None: + os.makedirs(output_dir, exist_ok=True) + precision = torch.half if half else torch.bfloat16 + + if prompt_text is not None and len(prompt_text) != len(prompt_tokens): + raise ValueError( + f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same" + ) + + logger.info("Loading model ...") + t0 = time.time() + model, decode_one_token = load_model( + checkpoint_path, device, precision, compile=compile + ) + with torch.device(device): + model.setup_caches( + max_batch_size=1, + max_seq_len=model.config.max_seq_len, + dtype=next(model.parameters()).dtype, + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + + logger.info(f"Time to load model: {time.time() - t0:.02f} seconds") + + if prompt_tokens is not None: + prompt_tokens = [torch.from_numpy(np.load(p)) for p in prompt_tokens] + + torch.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + generator = generate_long( + model=model, + device=device, + decode_one_token=decode_one_token, + text=text, + num_samples=num_samples, + max_new_tokens=max_new_tokens, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + compile=compile, + iterative_prompt=iterative_prompt, + chunk_length=chunk_length, + prompt_text=prompt_text, + prompt_tokens=prompt_tokens, + ) + + idx = 0 + codes = [] + + for response in generator: + if response.action == "sample": + codes.append(response.codes) + logger.info(f"Sampled text: {response.text}") + elif response.action == "next": + if codes: + codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy") + np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy()) + logger.info(f"Saved codes to {codes_npy_path}") + logger.info(f"Next sample") + codes = [] + idx += 1 + else: + logger.error(f"Error: {response}") + + +if __name__ == "__main__": + main() diff --git a/fish_speech/models/text2semantic/lit_module.py b/fish_speech/models/text2semantic/lit_module.py index 7b26793532c9e6b42de189c628aac59a477d0f66..df970400f8a073be4c4166a697245fabdf6b09b0 100644 --- a/fish_speech/models/text2semantic/lit_module.py +++ b/fish_speech/models/text2semantic/lit_module.py @@ -1,202 +1,202 @@ -from typing import Any, Optional - -import lightning as L -import torch -import torch.nn.functional as F -from lightning.pytorch.utilities.types import OptimizerLRScheduler - -import fish_speech.utils as utils -from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID -from fish_speech.models.text2semantic.llama import NaiveTransformer - -log = utils.RankedLogger(__name__, rank_zero_only=True) - - -class TextToSemantic(L.LightningModule): - def __init__( - self, - model: NaiveTransformer, - optimizer: Any, - lr_scheduler: Any, - ): - super().__init__() - - self.model = model - self.optimizer_builder = optimizer - self.lr_scheduler_builder = lr_scheduler - - def forward(self, x): - return self.model(x) - - def on_save_checkpoint(self, checkpoint): - # Save only LoRA parameters - state_dict = checkpoint["state_dict"] - use_lora = any("lora" in name for name in state_dict.keys()) - if not use_lora: - return - - for name in list(state_dict.keys()): - if "lora" not in name: - state_dict.pop(name) - - def configure_optimizers(self) -> OptimizerLRScheduler: - # Get weight decay parameters - weight_decay_parameters, other_parameters = [], [] - for name, param in self.named_parameters(): - if ".bias" in name or "norm.weight" in name or ".embeddings." in name: - other_parameters.append(param) - else: - weight_decay_parameters.append(param) - - optimizer = self.optimizer_builder( - [ - {"params": weight_decay_parameters}, - {"params": other_parameters, "weight_decay": 0.0}, - ] - ) - - # Print the parameters and their weight decay - for i in optimizer.param_groups: - log.info( - f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters" - ) - - lr_scheduler = self.lr_scheduler_builder(optimizer) - - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": lr_scheduler, - "interval": "step", - }, - } - - # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 - def get_batch_logps( - self, - logits: torch.FloatTensor, - labels: torch.LongTensor, - average_log_prob: bool = False, - ) -> torch.FloatTensor: - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size) - labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size) - average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. - """ - assert logits.shape[:-1] == labels.shape - - labels = labels.clone() - loss_mask = labels != -100 - - # dummy token; we'll ignore the losses on these tokens later - labels[labels == -100] = 0 - - per_token_logps = torch.gather( - logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1) - ).squeeze(-1) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - - def _step(self, batch, batch_idx, stage: str): - is_train = stage == "train" - - if is_train: - # Key part to make lora work - # Otherwise the parameters are merged, which lead to incorrect gradients - self.model.train() - - # Do positive and negative samples in the same batch to speed up training - labels = batch["labels"] - outputs = self.model( - inp=batch["inputs"], - key_padding_mask=batch["attention_masks"], - ) - token_logits = outputs.token_logits - codebook_logits = outputs.codebook_logits - - # Generate labels - base_loss = F.cross_entropy( - token_logits.view(-1, token_logits.size(-1)), - labels[:, 0].reshape(-1), - ignore_index=-100, - ) - - codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT - semantic_loss = F.cross_entropy( - codebook_logits.view(-1, codebook_logits.size(-1)), - codebook_labels.reshape(-1), - ignore_index=-100, - ) - - loss = base_loss + semantic_loss - - self.log( - f"{stage}/loss", - loss, - on_step=is_train, - on_epoch=not is_train, - prog_bar=True, - logger=True, - sync_dist=not is_train, - ) - - self.log( - f"{stage}/base_loss", - base_loss, - on_step=is_train, - on_epoch=not is_train, - prog_bar=False, - logger=True, - sync_dist=not is_train, - ) - - self.log( - f"{stage}/semantic_loss", - semantic_loss, - on_step=is_train, - on_epoch=not is_train, - prog_bar=False, - logger=True, - sync_dist=not is_train, - ) - - # Top-5 accuracy - accuracy = self.get_accuracy(codebook_logits, codebook_labels) - self.log( - f"{stage}/top_5_accuracy", - accuracy, - on_step=is_train, - on_epoch=not is_train, - prog_bar=True, - logger=True, - sync_dist=not is_train, - ) - - return loss - - def get_accuracy(self, logits, labels): - mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID) - if mask.sum() == 0: - return torch.tensor(0.0, device=logits.device) - - _, indices = logits.topk(5, dim=-1) - correct = indices.eq(labels.unsqueeze(-1)) - correct[~mask] = 0 - correct = correct.sum() - accuracy = correct / mask.sum() - - return accuracy - - def training_step(self, batch, batch_idx): - return self._step(batch, batch_idx, "train") - - def validation_step(self, batch, batch_idx): - return self._step(batch, batch_idx, "val") +from typing import Any, Optional + +import lightning as L +import torch +import torch.nn.functional as F +from lightning.pytorch.utilities.types import OptimizerLRScheduler + +import fish_speech.utils as utils +from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from fish_speech.models.text2semantic.llama import NaiveTransformer + +log = utils.RankedLogger(__name__, rank_zero_only=True) + + +class TextToSemantic(L.LightningModule): + def __init__( + self, + model: NaiveTransformer, + optimizer: Any, + lr_scheduler: Any, + ): + super().__init__() + + self.model = model + self.optimizer_builder = optimizer + self.lr_scheduler_builder = lr_scheduler + + def forward(self, x): + return self.model(x) + + def on_save_checkpoint(self, checkpoint): + # Save only LoRA parameters + state_dict = checkpoint["state_dict"] + use_lora = any("lora" in name for name in state_dict.keys()) + if not use_lora: + return + + for name in list(state_dict.keys()): + if "lora" not in name: + state_dict.pop(name) + + def configure_optimizers(self) -> OptimizerLRScheduler: + # Get weight decay parameters + weight_decay_parameters, other_parameters = [], [] + for name, param in self.named_parameters(): + if ".bias" in name or "norm.weight" in name or ".embeddings." in name: + other_parameters.append(param) + else: + weight_decay_parameters.append(param) + + optimizer = self.optimizer_builder( + [ + {"params": weight_decay_parameters}, + {"params": other_parameters, "weight_decay": 0.0}, + ] + ) + + # Print the parameters and their weight decay + for i in optimizer.param_groups: + log.info( + f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters" + ) + + lr_scheduler = self.lr_scheduler_builder(optimizer) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": "step", + }, + } + + # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + assert logits.shape[:-1] == labels.shape + + labels = labels.clone() + loss_mask = labels != -100 + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == -100] = 0 + + per_token_logps = torch.gather( + logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1) + ).squeeze(-1) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def _step(self, batch, batch_idx, stage: str): + is_train = stage == "train" + + if is_train: + # Key part to make lora work + # Otherwise the parameters are merged, which lead to incorrect gradients + self.model.train() + + # Do positive and negative samples in the same batch to speed up training + labels = batch["labels"] + outputs = self.model( + inp=batch["inputs"], + key_padding_mask=batch["attention_masks"], + ) + token_logits = outputs.token_logits + codebook_logits = outputs.codebook_logits + + # Generate labels + base_loss = F.cross_entropy( + token_logits.view(-1, token_logits.size(-1)), + labels[:, 0].reshape(-1), + ignore_index=-100, + ) + + codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT + semantic_loss = F.cross_entropy( + codebook_logits.view(-1, codebook_logits.size(-1)), + codebook_labels.reshape(-1), + ignore_index=-100, + ) + + loss = base_loss + semantic_loss + + self.log( + f"{stage}/loss", + loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=True, + logger=True, + sync_dist=not is_train, + ) + + self.log( + f"{stage}/base_loss", + base_loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=False, + logger=True, + sync_dist=not is_train, + ) + + self.log( + f"{stage}/semantic_loss", + semantic_loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=False, + logger=True, + sync_dist=not is_train, + ) + + # Top-5 accuracy + accuracy = self.get_accuracy(codebook_logits, codebook_labels) + self.log( + f"{stage}/top_5_accuracy", + accuracy, + on_step=is_train, + on_epoch=not is_train, + prog_bar=True, + logger=True, + sync_dist=not is_train, + ) + + return loss + + def get_accuracy(self, logits, labels): + mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID) + if mask.sum() == 0: + return torch.tensor(0.0, device=logits.device) + + _, indices = logits.topk(5, dim=-1) + correct = indices.eq(labels.unsqueeze(-1)) + correct[~mask] = 0 + correct = correct.sum() + accuracy = correct / mask.sum() + + return accuracy + + def training_step(self, batch, batch_idx): + return self._step(batch, batch_idx, "train") + + def validation_step(self, batch, batch_idx): + return self._step(batch, batch_idx, "val") diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py index dcd0f0fa96a3bf43768bbb8087f976068e48a8e0..d34882cdeb554989920a3e7feadc9facdc4bdd95 100644 --- a/fish_speech/models/text2semantic/llama.py +++ b/fish_speech/models/text2semantic/llama.py @@ -1,887 +1,903 @@ -import dataclasses -import json -import math -from collections import OrderedDict -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - -import torch -import torch.nn as nn -from einops import rearrange -from loguru import logger -from torch import Tensor -from torch.nn import functional as F -from torch.nn.attention import SDPBackend, sdpa_kernel -from torch.utils.checkpoint import checkpoint -from transformers import AutoTokenizer - -from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer -from fish_speech.utils import RankedLogger - -from .lora import LoraConfig, setup_lora - -log = RankedLogger(__name__, rank_zero_only=True) - - -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) - - -@dataclass -class BaseModelArgs: - model_type: str = "base" - - vocab_size: int = 32000 - n_layer: int = 32 - n_head: int = 32 - dim: int = 4096 - intermediate_size: int = None - n_local_heads: int = -1 - head_dim: int = 64 - rope_base: float = 10000 - norm_eps: float = 1e-5 - max_seq_len: int = 2048 - dropout: float = 0.0 - tie_word_embeddings: bool = True - attention_qkv_bias: bool = False - - # Codebook configs - codebook_size: int = 160 - num_codebooks: int = 4 - - # Gradient checkpointing - use_gradient_checkpointing: bool = True - - # Initialize the model - initializer_range: float = 0.02 - - # Dummy vars - is_reward_model: bool = False - share_codebook_embeddings: bool = True - scale_codebook_embeddings: bool = False - - def __post_init__(self): - if self.n_local_heads == -1: - self.n_local_heads = self.n_head - if self.intermediate_size is None: - hidden_dim = 4 * self.dim - n_hidden = int(2 * hidden_dim / 3) - self.intermediate_size = find_multiple(n_hidden, 256) - self.head_dim = self.dim // self.n_head - - @staticmethod - def from_pretrained(path: str): - path = Path(path) - - if path.is_dir(): - path = path / "config.json" - - with open(path, "r", encoding="utf-8") as f: - data = json.load(f) - - match data["model_type"]: - case "naive": - cls = NaiveModelArgs - case "dual_ar": - cls = DualARModelArgs - case _: - raise ValueError(f"Unknown model type: {data['model_type']}") - - return cls(**data) - - def save(self, path: str): - with open(path, "w") as f: - json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False) - - -@dataclass -class NaiveModelArgs(BaseModelArgs): - model_type: str = "naive" - - -@dataclass -class DualARModelArgs(BaseModelArgs): - model_type: str = "dual_ar" - n_fast_layer: int = 4 - fast_dim: int | None = None - fast_n_head: int | None = None - fast_n_local_heads: int | None = None - fast_head_dim: int | None = None - fast_intermediate_size: int | None = None - fast_attention_qkv_bias: bool | None = None - - def __post_init__(self): - super().__post_init__() - - self.fast_dim = self.fast_dim or self.dim - self.fast_n_head = self.fast_n_head or self.n_head - self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads - self.fast_head_dim = self.fast_head_dim or self.head_dim - self.fast_intermediate_size = ( - self.fast_intermediate_size or self.intermediate_size - ) - self.fast_attention_qkv_bias = ( - self.fast_attention_qkv_bias - if self.fast_attention_qkv_bias is not None - else self.attention_qkv_bias - ) - - -class KVCache(nn.Module): - def __init__( - self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16 - ): - super().__init__() - cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim) - self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) - self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) - - def update(self, input_pos, k_val, v_val): - # input_pos: [S], k_val: [B, H, S, D] - assert input_pos.shape[0] == k_val.shape[2] - - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - - return k_out, v_out - - -@dataclass -class TransformerForwardResult: - token_logits: Tensor - codebook_logits: Tensor - - -@dataclass -class BaseTransformerForwardResult: - logits: Tensor - hidden_states: Tensor - - -class BaseTransformer(nn.Module): - def __init__( - self, - config: BaseModelArgs, - tokenizer: FishTokenizer | AutoTokenizer, - init_weights: bool = True, - ) -> None: - super().__init__() - self.config = config - self.tokenizer = tokenizer - self.semantic_token_ids = [ - tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS - ] - - # Slow transformer - self.embeddings = nn.Embedding( - config.vocab_size, - config.dim, - ) - self.codebook_embeddings = nn.Embedding( - config.codebook_size * config.num_codebooks, - config.dim, - ) - self.layers = nn.ModuleList( - TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer) - ) - self.norm = RMSNorm(config.dim, eps=config.norm_eps) - - if self.config.tie_word_embeddings is False: - self.output = nn.Linear( - config.dim, - config.vocab_size, - bias=False, - ) - - self.register_buffer( - "freqs_cis", - precompute_freqs_cis( - config.max_seq_len, - config.dim // config.n_head, - config.rope_base, - ), - persistent=False, - ) - self.register_buffer( - "causal_mask", - torch.tril( - torch.ones( - config.max_seq_len, - config.max_seq_len, - dtype=torch.bool, - ) - ), - persistent=False, - ) - - # For kv cache - self.max_batch_size = -1 - self.max_seq_len = -1 - - if init_weights: - self.apply(self._init_weights) - - def setup_caches( - self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 - ): - if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size: - return - - head_dim = self.config.dim // self.config.n_head - max_seq_len = find_multiple(max_seq_len, 8) - self.max_seq_len = max_seq_len - self.max_batch_size = max_batch_size - - for b in self.layers: - b.attention.kv_cache = KVCache( - max_batch_size, - max_seq_len, - self.config.n_local_heads, - head_dim, - dtype=dtype, - ) - - def embed(self, x: Tensor) -> Tensor: - vocab_embeds = [self.embeddings(x[:, 0])] - for i in range(self.config.num_codebooks): - emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size) - semantic_token_ids_tensor = torch.tensor( - self.semantic_token_ids, device=x.device - ) - emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0 - - x = torch.stack(vocab_embeds, dim=3) - x = x.sum(dim=3) - - return x - - def forward( - self, - inp: Tensor, - key_padding_mask: Optional[Tensor] = None, - ) -> BaseTransformerForwardResult: - seq_len = inp.size(2) - - # Here we want to merge the embeddings of the codebooks - x = self.embed(inp) - - freqs_cis = self.freqs_cis[:seq_len] - - # Not that the causal mask here follows the definition of scaled_dot_product_attention - # That is, FALSE means masked out - # To maintain consistency, key_padding_mask use TRUE to mask out - mask = None - if key_padding_mask is not None: - mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K) - mask = mask & key_padding_mask[:, None, None, :].logical_not() - - for layer in self.layers: - if self.config.use_gradient_checkpointing and self.training: - x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True) - else: - x = layer(x, freqs_cis, mask) - - # We got slow_out here - slow_out = self.norm(x) - - if self.config.tie_word_embeddings: - token_logits = F.linear(slow_out, self.embeddings.weight) - else: - token_logits = self.output(slow_out) - - return BaseTransformerForwardResult( - logits=token_logits, - hidden_states=x, - ) - - def forward_generate( - self, - inp: Tensor, - input_pos: Optional[Tensor] = None, - vq_masks: Optional[Tensor] = None, # this is not used in fact - return_all: bool = False, - ) -> BaseTransformerForwardResult: - # This is used for generation, optimized for torch compile - # assert ( - # self.max_seq_len != -1 and self.max_batch_size != -1 - # ), "Please call setup_caches before forward_generate" - - embeds = [] - for i in range(self.config.num_codebooks): - if self.config.share_codebook_embeddings: - _tokens = inp[:, i + 1] + i * self.config.codebook_size - else: - _tokens = inp[:, i + 1] - - emb = self.codebook_embeddings(_tokens) - embeds.append(emb) - - vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1) - # if self.config.use_codebook_mlp: - # vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks - # vq_embeds_sum = self.codebook_mlp(vq_embeds_sum) - - vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & ( - inp[:, 0] <= self.tokenizer.semantic_end_id - ) - - vq_embeds_sum[~vq_masks] = 0 - x = self.embeddings(inp[:, 0]) + vq_embeds_sum - - if input_pos is None: - input_pos = torch.arange(inp.shape[-1], device=x.device) - max_seq_len = inp.shape[-1] - else: - max_seq_len = self.max_seq_len - - mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K) - freqs_cis = self.freqs_cis[input_pos] - - for layer in self.layers: - x = layer(x, freqs_cis, mask, input_pos=input_pos) - - # If prefill, we only calculate the logits of last token - if x.size(1) > 1 and not return_all: - x = x[:, -1:] - - # We got slow_out here - slow_out = self.norm(x) - - if self.config.is_reward_model: - token_logits = self.score_output(slow_out) - elif self.config.tie_word_embeddings: - token_logits = F.linear(slow_out, self.embeddings.weight) - else: - token_logits = self.output(slow_out) - - return BaseTransformerForwardResult( - logits=token_logits, - hidden_states=x, - ) - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - @staticmethod - def from_pretrained( - path: str, - load_weights: bool = False, - max_length: int | None = None, - lora_config: LoraConfig | None = None, - rope_base: int | None = None, - is_agent: bool = False, - ) -> "BaseTransformer": - config = BaseModelArgs.from_pretrained(str(path)) - if max_length is not None: - config.max_seq_len = max_length - log.info(f"Override max_seq_len to {max_length}") - - if rope_base is not None: - config.rope_base = rope_base - log.info(f"Override rope_base to {rope_base}") - - match config.model_type: - case "naive": - model_cls = NaiveTransformer - case "dual_ar": - model_cls = DualARTransformer - case _: - raise ValueError(f"Unknown model type: {config.model_type}") - - if is_agent: - tokenizer = AutoTokenizer.from_pretrained(str(path)) - else: - tokenizer_path = str(path) + "/tokenizer.tiktoken" - tokenizer = FishTokenizer(tokenizer_path) - - log.info(f"Loading model from {path}, config: {config}") - model = model_cls(config, tokenizer=tokenizer) - - if lora_config is not None: - setup_lora(model, lora_config) - log.info(f"LoRA setup: {lora_config}") - - if load_weights is False: - log.info("Randomly initialized model") - else: - - if "int8" in str(Path(path)): - logger.info("Using int8 weight-only quantization!") - from tools.llama.quantize import WeightOnlyInt8QuantHandler - - simple_quantizer = WeightOnlyInt8QuantHandler(model) - model = simple_quantizer.convert_for_runtime() - - if "int4" in str(Path(path)): - logger.info("Using int4 quantization!") - path_comps = path.name.split("-") - assert path_comps[-2].startswith("g") - groupsize = int(path_comps[-2][1:]) - from tools.llama.quantize import WeightOnlyInt4QuantHandler - - simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) - model = simple_quantizer.convert_for_runtime() - - weights = torch.load( - Path(path) / "model.pth", - map_location="cpu", - mmap=True, - weights_only=True, - ) - - if "state_dict" in weights: - logger.warning( - "Using a TextToSemantic LightningModule checkpoint, " - "please make sure it is a full model, not a LoRA model." - ) - weights = weights["state_dict"] - - if next(iter(weights.keys())).startswith("model."): - logger.info( - f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys" - ) - new_weights = OrderedDict() - for k, v in weights.items(): - new_weights[k.replace("model.", "")] = v - weights = new_weights - - # Verify the name and shape of parameters since strict=False in load_state_dict. - for k, v in model.named_parameters(): - if k not in weights: - logger.warning(f"No weight for {k}") - elif v.shape != weights[k].shape: - logger.warning( - f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}" - ) - - err = model.load_state_dict(weights, strict=False, assign=True) - log.info(f"Loaded weights with error: {err}") - - return model - - def save_pretrained(self, path: str, drop_lora: bool = False): - path = Path(path) - path.mkdir(parents=True, exist_ok=True) - - self.config.save(path / "config.json") - state_dict = self.state_dict() - - if drop_lora: - for key in list(state_dict.keys()): - if "lora" not in key: - continue - - state_dict.pop(key) - log.info(f"Drop LoRA parameter: {key}") - - torch.save(state_dict, path / "model.pth") - self.tokenizer.save_pretrained(path) - - -class NaiveTransformer(BaseTransformer): - def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None: - super().__init__(config, init_weights=False, tokenizer=tokenizer) - - self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps) - self.codebook_output = nn.Linear( - config.dim, - config.codebook_size * config.num_codebooks, - bias=False, - ) - - self.apply(self._init_weights) - - def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult: - token_logits = result.logits - x = result.hidden_states - - # Codebook - codebook_logits = self.codebook_output(self.codebook_norm(x)) - codebook_logits = rearrange( - codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks - ) - - return TransformerForwardResult( - token_logits=token_logits, - codebook_logits=codebook_logits, - ) - - def forward( - self, - inp: Tensor, - key_padding_mask: Optional[Tensor] = None, - ) -> TransformerForwardResult: - result = super().forward( - inp=inp, - key_padding_mask=key_padding_mask, - ) - return self.decode(result) - - def forward_generate( - self, x: Tensor, input_pos: Optional[Tensor] = None - ) -> TransformerForwardResult: - result = super().forward_generate(x, input_pos) - return self.decode(result) - - -class DualARTransformer(BaseTransformer): - def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None: - super().__init__(config, init_weights=False, tokenizer=tokenizer) - - # Project to fast dim if needed - if config.fast_dim is not None and config.fast_dim != config.dim: - self.fast_project_in = nn.Linear(config.dim, config.fast_dim) - else: - self.fast_project_in = nn.Identity() - - # Fast transformer - self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim) - - # The equivalent bs is so large that sdpa doesn't work - override_config = dataclasses.replace( - config, - dim=config.fast_dim, - n_head=config.fast_n_head, - n_local_heads=config.fast_n_local_heads, - head_dim=config.fast_head_dim, - intermediate_size=config.fast_intermediate_size, - attention_qkv_bias=config.fast_attention_qkv_bias, - ) - - self.fast_layers = nn.ModuleList( - TransformerBlock(override_config, use_sdpa=False) - for _ in range(config.n_fast_layer) - ) - self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps) - self.fast_output = nn.Linear( - config.fast_dim, - config.codebook_size, - bias=False, - ) - - self.register_buffer( - "fast_freqs_cis", - precompute_freqs_cis( - config.num_codebooks, - config.fast_dim // config.fast_n_head, - config.rope_base, - ), - persistent=False, - ) - self.apply(self._init_weights) - - def setup_caches( - self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 - ): - super().setup_caches(max_batch_size, max_seq_len, dtype) - - head_dim = self.config.fast_dim // self.config.fast_n_head - - # Fast transformer - # The max seq len here is the number of codebooks - for b in self.fast_layers: - b.attention.kv_cache = KVCache( - max_batch_size, - self.config.num_codebooks, - self.config.fast_n_local_heads, - head_dim, - dtype=dtype, - ) - - def forward( - self, - inp: Tensor, - key_padding_mask: Optional[Tensor] = None, - ) -> TransformerForwardResult: - parent_result = super().forward(inp, key_padding_mask) - token_logits = parent_result.logits - x = parent_result.hidden_states - x = self.fast_project_in(x) - - # Fast transformer - fast_seq_len = self.config.num_codebooks - fast_mask = self.causal_mask[ - None, None, :fast_seq_len, :fast_seq_len - ] # (B, N, Q, K) - - # Drop the last token and rotate left - codebooks = inp[:, 1:-1, 1:] - codebooks = F.pad(codebooks, (0, 1), value=0) - codebook_embeddings = self.fast_embeddings(codebooks) - x = torch.cat([x[:, None], codebook_embeddings], dim=1) - b, s = x.size(0), x.size(2) - x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len - - # Remove padded part - codebooks = rearrange(codebooks, "b n s -> (b s) n") - codebook_mask = (codebooks == 0).all(dim=-1) - - if torch.all(codebook_mask): - # If all codebooks are padded, we keep first 8 to make sure the model runs - codebook_mask[:8] = False - - x_bs, x_len = x.size(0), x.size(1) - x = x[~codebook_mask] - - for layer in self.fast_layers: - if self.config.use_gradient_checkpointing and self.training: - x = checkpoint( - layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True - ) - else: - x = layer(x, self.fast_freqs_cis, fast_mask) - - # unflatten the batch and num_codebooks - fast_out = self.fast_norm(x) - codebook_logits = self.fast_output(fast_out) - - # Re-pad the codebook_logits - buffer = torch.zeros( - x_bs, - x_len, - codebook_logits.size(-1), - device=codebook_logits.device, - dtype=codebook_logits.dtype, - ) - buffer[~codebook_mask] = codebook_logits - codebook_logits = buffer - - assert codebook_logits.shape[1] == self.config.num_codebooks - codebook_logits = rearrange( - codebook_logits, - "(b s) n d -> b s n d", - b=b, - s=s, - n=self.config.num_codebooks, - ) - - return TransformerForwardResult( - token_logits=token_logits, - codebook_logits=codebook_logits, - ) - - def forward_generate_fast( - self, x: Tensor, input_pos: Optional[Tensor] = None - ) -> Tensor: - # Fast transformer - x = x.view(1, 1, -1) - - fast_mask = self.causal_mask[ - None, None, input_pos, : self.config.num_codebooks - ] # (B, N, Q, K) - fast_freqs_cis = self.fast_freqs_cis[input_pos] - - for layer in self.fast_layers: - x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos) - - # unflatten the batch and num_codebooks - fast_out = self.fast_norm(x) # only take the last token - codebook_logits = self.fast_output(fast_out) - - return codebook_logits - - def forward_generate( - self, - x: Tensor, - input_pos: Optional[Tensor] = None, - vq_masks: Optional[Tensor] = None, - ) -> TransformerForwardResult: - x = super().forward_generate(x, input_pos, vq_masks) - x.hidden_states = self.fast_project_in(x.hidden_states) - return x - - -class TransformerBlock(nn.Module): - def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None: - super().__init__() - self.attention = Attention(config, use_sdpa=use_sdpa) - self.feed_forward = FeedForward(config) - self.ffn_norm = RMSNorm(config.dim, config.norm_eps) - self.attention_norm = RMSNorm(config.dim, config.norm_eps) - - def forward( - self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None - ) -> Tensor: - h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) - out = h + self.feed_forward(self.ffn_norm(h)) - return out - - -class Attention(nn.Module): - def __init__(self, config: BaseModelArgs, use_sdpa: bool = True): - super().__init__() - assert config.dim % config.n_head == 0 - - total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim - # key, query, value projections for all heads, but in a batch - self.wqkv = nn.Linear( - config.dim, total_head_dim, bias=config.attention_qkv_bias - ) - self.wo = nn.Linear(config.dim, config.dim, bias=False) - self.kv_cache = None - - self.dropout = config.dropout - self.n_head = config.n_head - self.head_dim = config.head_dim - self.n_local_heads = config.n_local_heads - self.dim = config.dim - self.use_sdpa = use_sdpa - self._register_load_state_dict_pre_hook(self.load_hook) - - def load_hook(self, state_dict, prefix, *args): - if prefix + "wq.weight" in state_dict: - wq = state_dict.pop(prefix + "wq.weight") - wk = state_dict.pop(prefix + "wk.weight") - wv = state_dict.pop(prefix + "wv.weight") - state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) - - def forward( - self, - x: Tensor, - freqs_cis: Tensor, - mask: Tensor, - input_pos: Optional[Tensor] = None, - ) -> Tensor: - bsz, seqlen, _ = x.shape - - kv_size = self.n_local_heads * self.head_dim - q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) - - q = q.view(bsz, seqlen, self.n_head, self.head_dim) - k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) - - q = apply_rotary_emb(q, freqs_cis) - k = apply_rotary_emb(k, freqs_cis) - - q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) - - if self.kv_cache is not None: - k, v = self.kv_cache.update(input_pos, k, v) - - k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - - if self.use_sdpa: - if mask is None: - with sdpa_kernel(SDPBackend.FLASH_ATTENTION): - y = F.scaled_dot_product_attention( - q, - k, - v, - dropout_p=self.dropout if self.training else 0.0, - is_causal=True, - # No third party attn_mask here to use flash_attention - ) - else: - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=mask, - dropout_p=self.dropout if self.training else 0.0, - ) - else: - y = self.eq_scaled_dot_product_attention( - q, - k, - v, - attn_mask=mask, - dropout_p=self.dropout if self.training else 0.0, - ) - - y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) - - return self.wo(y) - - def eq_scaled_dot_product_attention( - self, - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - ) -> torch.Tensor: - # This is a standard scaled dot product attention - # It's low efficient, but it doesn't raise cuda error - - L, S = query.size(-2), key.size(-2) - scale_factor = 1 / math.sqrt(query.size(-1)) - attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) - else: - attn_bias += attn_mask - - attn_weight = query @ key.transpose(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - - return attn_weight @ value - - -class FeedForward(nn.Module): - def __init__(self, config: BaseModelArgs) -> None: - super().__init__() - self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) - - def forward(self, x: Tensor) -> Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - -class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) - - def forward(self, x: Tensor) -> Tensor: - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: - freqs = 1.0 / ( - base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) - ) - t = torch.arange(seq_len, device=freqs.device) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) - return cache.to(dtype=torch.bfloat16) - - -def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: - xshaped = x.float().reshape(*x.shape[:-1], -1, 2) - freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], - xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], - ], - -1, - ) - - x_out2 = x_out2.flatten(3) - return x_out2.type_as(x) +import dataclasses +import json +import math +from collections import OrderedDict +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange +from loguru import logger +from torch import Tensor +from torch.nn import functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.utils.checkpoint import checkpoint +from transformers import AutoTokenizer + +from fish_speech.models.text2semantic.lora import LoraConfig, setup_lora +from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class BaseModelArgs: + model_type: str = "base" + + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + max_seq_len: int = 2048 + dropout: float = 0.0 + tie_word_embeddings: bool = True + attention_qkv_bias: bool = False + attention_o_bias: bool = False + attention_qk_norm: bool = False + + # Codebook configs + codebook_size: int = 160 + num_codebooks: int = 4 + + # Gradient checkpointing + use_gradient_checkpointing: bool = True + + # Initialize the model + initializer_range: float = 0.02 + + # Dummy vars + is_reward_model: bool = False + scale_codebook_embeddings: bool = False + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + if self.head_dim is None: + self.head_dim = self.dim // self.n_head + + @staticmethod + def from_pretrained(path: str): + path = Path(path) + + if path.is_dir(): + path = path / "config.json" + + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + match data["model_type"]: + case "naive": + cls = NaiveModelArgs + case "dual_ar": + cls = DualARModelArgs + case _: + raise ValueError(f"Unknown model type: {data['model_type']}") + + return cls(**data) + + def save(self, path: str): + with open(path, "w") as f: + json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False) + + +@dataclass +class NaiveModelArgs(BaseModelArgs): + model_type: str = "naive" + + +@dataclass +class DualARModelArgs(BaseModelArgs): + model_type: str = "dual_ar" + n_fast_layer: int = 4 + fast_dim: int | None = None + fast_n_head: int | None = None + fast_n_local_heads: int | None = None + fast_head_dim: int | None = None + fast_intermediate_size: int | None = None + fast_attention_qkv_bias: bool | None = None + fast_attention_qk_norm: bool | None = None + fast_attention_o_bias: bool | None = None + + def __post_init__(self): + super().__post_init__() + + self.fast_dim = self.fast_dim or self.dim + self.fast_n_head = self.fast_n_head or self.n_head + self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads + self.fast_head_dim = self.fast_head_dim or self.head_dim + self.fast_intermediate_size = ( + self.fast_intermediate_size or self.intermediate_size + ) + self.fast_attention_qkv_bias = ( + self.fast_attention_qkv_bias + if self.fast_attention_qkv_bias is not None + else self.attention_qkv_bias + ) + self.fast_attention_qk_norm = ( + self.fast_attention_qk_norm + if self.fast_attention_qk_norm is not None + else self.attention_qk_norm + ) + self.fast_attention_o_bias = ( + self.fast_attention_o_bias + if self.fast_attention_o_bias is not None + else self.attention_o_bias + ) + + +class KVCache(nn.Module): + def __init__( + self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16 + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +@dataclass +class TransformerForwardResult: + token_logits: Tensor + codebook_logits: Tensor + + +@dataclass +class BaseTransformerForwardResult: + logits: Tensor + hidden_states: Tensor + + +class BaseTransformer(nn.Module): + def __init__( + self, + config: BaseModelArgs, + tokenizer: FishTokenizer, + init_weights: bool = True, + ) -> None: + super().__init__() + self.config = config + self.tokenizer = tokenizer + self.semantic_token_ids = list(tokenizer.semantic_id_to_token_id.values()) + + # Slow transformer + self.embeddings = nn.Embedding( + config.vocab_size, + config.dim, + ) + self.codebook_embeddings = nn.Embedding( + config.codebook_size * config.num_codebooks, + config.dim, + ) + self.layers = nn.ModuleList( + TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + + if self.config.tie_word_embeddings is False: + self.output = nn.Linear( + config.dim, + config.vocab_size, + bias=False, + ) + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + config.max_seq_len, + config.head_dim, + config.rope_base, + ), + persistent=False, + ) + self.register_buffer( + "causal_mask", + torch.tril( + torch.ones( + config.max_seq_len, + config.max_seq_len, + dtype=torch.bool, + ) + ), + persistent=False, + ) + + # For kv cache + self.max_batch_size = -1 + self.max_seq_len = -1 + + if init_weights: + self.apply(self._init_weights) + + def setup_caches( + self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 + ): + if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size: + return + + max_seq_len = find_multiple(max_seq_len, 8) + self.max_seq_len = max_seq_len + self.max_batch_size = max_batch_size + + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, + max_seq_len, + self.config.n_local_heads, + self.config.head_dim, + dtype=dtype, + ) + + def embed(self, inp: Tensor) -> Tensor: + embeds = [] + semantic_token_ids_tensor = torch.tensor( + self.semantic_token_ids, device=inp.device, dtype=inp.dtype + ) + + for i in range(self.config.num_codebooks): + emb = self.codebook_embeddings( + inp[:, i + 1] + i * self.config.codebook_size + ) + embeds.append(emb) + + vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1) + vq_embeds_sum[~torch.isin(inp[:, 0], semantic_token_ids_tensor)] = 0 + x = self.embeddings(inp[:, 0]) + vq_embeds_sum + + return x + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> BaseTransformerForwardResult: + seq_len = inp.size(2) + + # Here we want to merge the embeddings of the codebooks + x = self.embed(inp) + + freqs_cis = self.freqs_cis[:seq_len] + + # Not that the causal mask here follows the definition of scaled_dot_product_attention + # That is, FALSE means masked out + # To maintain consistency, key_padding_mask use TRUE to mask out + mask = None + if key_padding_mask is not None: + causal = self.causal_mask[:seq_len, :seq_len] + causal = rearrange(causal, "q k -> 1 1 q k") + + atten_mask = rearrange(key_padding_mask, "b s -> b 1 1 s") + atten_mask = atten_mask.logical_not() + mask = causal & atten_mask + + # return freqs_cis, mask + + for layer in self.layers: + if self.config.use_gradient_checkpointing and self.training: + x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True) + else: + x = layer(x, freqs_cis, mask) + + # We got slow_out here + slow_out = self.norm(x) + + if self.config.tie_word_embeddings: + token_logits = F.linear(slow_out, self.embeddings.weight) + else: + token_logits = self.output(slow_out) + + return BaseTransformerForwardResult( + logits=token_logits, + hidden_states=x, + ) + + def forward_generate( + self, + inp: Tensor, + input_pos: Optional[Tensor] = None, + return_all: bool = False, + ) -> BaseTransformerForwardResult: + x = self.embed(inp) + + if input_pos is None: + input_pos = torch.arange(inp.shape[-1], device=x.device) + max_seq_len = inp.shape[-1] + else: + max_seq_len = self.max_seq_len + + mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K) + freqs_cis = self.freqs_cis[input_pos] + + for layer in self.layers: + x = layer(x, freqs_cis, mask, input_pos=input_pos) + + # If prefill, we only calculate the logits of last token + if x.size(1) > 1 and not return_all: + x = x[:, -1:] + + # We got slow_out here + slow_out = self.norm(x) + + if self.config.is_reward_model: + token_logits = self.score_output(slow_out) + elif self.config.tie_word_embeddings: + token_logits = F.linear(slow_out, self.embeddings.weight) + else: + token_logits = self.output(slow_out) + + return BaseTransformerForwardResult( + logits=token_logits, + hidden_states=x, + ) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @staticmethod + def from_pretrained( + path: str, + load_weights: bool = False, + max_length: int | None = None, + lora_config: LoraConfig | None = None, + rope_base: int | None = None, + ) -> "BaseTransformer": + config = BaseModelArgs.from_pretrained(str(path)) + if max_length is not None: + config.max_seq_len = max_length + logger.info(f"Override max_seq_len to {max_length}") + + if rope_base is not None: + config.rope_base = rope_base + logger.info(f"Override rope_base to {rope_base}") + + match config.model_type: + case "naive": + model_cls = NaiveTransformer + case "dual_ar": + model_cls = DualARTransformer + case _: + raise ValueError(f"Unknown model type: {config.model_type}") + + tokenizer = FishTokenizer.from_pretrained(path) + + logger.info(f"Loading model from {path}, config: {config}") + model = model_cls(config, tokenizer=tokenizer) + + if lora_config is not None: + setup_lora(model, lora_config) + logger.info(f"LoRA setup: {lora_config}") + + if load_weights is False: + logger.info("Randomly initialized model") + else: + + if "int8" in str(Path(path)): + logger.info("Using int8 weight-only quantization!") + from tools.llama.quantize import WeightOnlyInt8QuantHandler + + simple_quantizer = WeightOnlyInt8QuantHandler(model) + model = simple_quantizer.convert_for_runtime() + + if "int4" in str(Path(path)): + logger.info("Using int4 quantization!") + path_comps = path.name.split("-") + assert path_comps[-2].startswith("g") + groupsize = int(path_comps[-2][1:]) + from tools.llama.quantize import WeightOnlyInt4QuantHandler + + simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) + model = simple_quantizer.convert_for_runtime() + + weights = torch.load( + Path(path) / "model.pth", + map_location="cpu", + mmap=True, + weights_only=True, + ) + + if "state_dict" in weights: + logger.warning( + "Using a TextToSemantic LightningModule checkpoint, " + "please make sure it is a full model, not a LoRA model." + ) + weights = weights["state_dict"] + + if next(iter(weights.keys())).startswith("model."): + logger.info( + f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys" + ) + new_weights = OrderedDict() + for k, v in weights.items(): + new_weights[k.replace("model.", "")] = v + weights = new_weights + + # Remove audio related weights + for k in list(weights.keys()): + if "audio_" in k: + weights.pop(k) + + # Verify the name and shape of parameters since strict=False in load_state_dict. + for k, v in model.named_parameters(): + if k not in weights: + logger.warning(f"No weight for {k}") + elif v.shape != weights[k].shape: + logger.warning( + f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}" + ) + + err = model.load_state_dict(weights, strict=False, assign=True) + logger.info(f"Loaded weights with error: {err}") + + return model + + def save_pretrained(self, path: str, drop_lora: bool = False): + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + self.config.save(path / "config.json") + state_dict = self.state_dict() + + if drop_lora: + for key in list(state_dict.keys()): + if "lora" not in key: + continue + + state_dict.pop(key) + logger.info(f"Drop LoRA parameter: {key}") + + torch.save(state_dict, path / "model.pth") + self.tokenizer.save_pretrained(path) + + +class NaiveTransformer(BaseTransformer): + def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None: + super().__init__(config, init_weights=False, tokenizer=tokenizer) + + self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.codebook_output = nn.Linear( + config.dim, + config.codebook_size * config.num_codebooks, + bias=False, + ) + + self.apply(self._init_weights) + + def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult: + token_logits = result.logits + x = result.hidden_states + + # Codebook + codebook_logits = self.codebook_output(self.codebook_norm(x)) + codebook_logits = rearrange( + codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks + ) + + return TransformerForwardResult( + token_logits=token_logits, + codebook_logits=codebook_logits, + ) + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> TransformerForwardResult: + result = super().forward( + inp=inp, + key_padding_mask=key_padding_mask, + ) + return self.decode(result) + + def forward_generate( + self, x: Tensor, input_pos: Optional[Tensor] = None + ) -> TransformerForwardResult: + result = super().forward_generate(x, input_pos) + return self.decode(result) + + +class DualARTransformer(BaseTransformer): + def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None: + super().__init__(config, init_weights=False, tokenizer=tokenizer) + + # Project to fast dim if needed + if config.fast_dim is not None and config.fast_dim != config.dim: + self.fast_project_in = nn.Linear(config.dim, config.fast_dim) + else: + self.fast_project_in = nn.Identity() + + # Fast transformer + self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim) + + # The equivalent bs is so large that sdpa doesn't work + override_config = dataclasses.replace( + config, + dim=config.fast_dim, + n_head=config.fast_n_head, + n_local_heads=config.fast_n_local_heads, + head_dim=config.fast_head_dim, + intermediate_size=config.fast_intermediate_size, + attention_qkv_bias=config.fast_attention_qkv_bias, + attention_qk_norm=config.fast_attention_qk_norm, + attention_o_bias=config.fast_attention_o_bias, + ) + + self.fast_layers = nn.ModuleList( + TransformerBlock(override_config, use_sdpa=False) + for _ in range(config.n_fast_layer) + ) + self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps) + self.fast_output = nn.Linear( + config.fast_dim, + config.codebook_size, + bias=False, + ) + + self.register_buffer( + "fast_freqs_cis", + precompute_freqs_cis( + config.num_codebooks, + config.fast_head_dim, + config.rope_base, + ), + persistent=False, + ) + self.apply(self._init_weights) + + def setup_caches( + self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 + ): + super().setup_caches(max_batch_size, max_seq_len, dtype) + + # Fast transformer + # The max seq len here is the number of codebooks + for b in self.fast_layers: + b.attention.kv_cache = KVCache( + max_batch_size, + self.config.num_codebooks, + self.config.fast_n_local_heads, + self.config.fast_head_dim, + dtype=dtype, + ) + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> TransformerForwardResult: + parent_result = super().forward(inp, key_padding_mask) + token_logits = parent_result.logits + x = parent_result.hidden_states + x = self.fast_project_in(x) + + # Fast transformer + fast_seq_len = self.config.num_codebooks + fast_mask = self.causal_mask[ + None, None, :fast_seq_len, :fast_seq_len + ] # (B, N, Q, K) + + # Drop the last token and rotate left + codebooks = inp[:, 1:-1, 1:] + codebooks = F.pad(codebooks, (0, 1), value=0) + codebook_embeddings = self.fast_embeddings(codebooks) + x = torch.cat([x[:, None], codebook_embeddings], dim=1) + b, s = x.size(0), x.size(2) + x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len + + # Remove padded part + codebooks = rearrange(codebooks, "b n s -> (b s) n") + codebook_mask = (codebooks == 0).all(dim=-1) + + if torch.all(codebook_mask): + # If all codebooks are padded, we keep first 8 to make sure the model runs + codebook_mask[:8] = False + + x_bs, x_len = x.size(0), x.size(1) + x = x[~codebook_mask] + + for layer in self.fast_layers: + if self.config.use_gradient_checkpointing and self.training: + x = checkpoint( + layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True + ) + else: + x = layer(x, self.fast_freqs_cis, fast_mask) + + # unflatten the batch and num_codebooks + fast_out = self.fast_norm(x) + codebook_logits = self.fast_output(fast_out) + + # Re-pad the codebook_logits + buffer = torch.zeros( + x_bs, + x_len, + codebook_logits.size(-1), + device=codebook_logits.device, + dtype=codebook_logits.dtype, + ) + buffer[~codebook_mask] = codebook_logits + codebook_logits = buffer + + assert codebook_logits.shape[1] == self.config.num_codebooks + codebook_logits = rearrange( + codebook_logits, + "(b s) n d -> b s n d", + b=b, + s=s, + n=self.config.num_codebooks, + ) + + return TransformerForwardResult( + token_logits=token_logits, + codebook_logits=codebook_logits, + ) + + def forward_generate_fast( + self, x: Tensor, input_pos: Optional[Tensor] = None + ) -> Tensor: + # Fast transformer + x = x.view(1, 1, -1) + + fast_mask = self.causal_mask[ + None, None, input_pos, : self.config.num_codebooks + ] # (B, N, Q, K) + fast_freqs_cis = self.fast_freqs_cis[input_pos] + + for layer in self.fast_layers: + x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos) + + # unflatten the batch and num_codebooks + fast_out = self.fast_norm(x) # only take the last token + codebook_logits = self.fast_output(fast_out) + + return codebook_logits + + def forward_generate( + self, + x: Tensor, + input_pos: Optional[Tensor] = None, + vq_masks: Optional[Tensor] = None, + ) -> TransformerForwardResult: + x = super().forward_generate(x, input_pos, vq_masks) + x.hidden_states = self.fast_project_in(x.hidden_states) + return x + + +class TransformerBlock(nn.Module): + def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None: + super().__init__() + self.attention = Attention(config, use_sdpa=use_sdpa) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None + ) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: BaseModelArgs, use_sdpa: bool = True): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear( + config.dim, total_head_dim, bias=config.attention_qkv_bias + ) + self.wo = nn.Linear( + config.n_head * config.head_dim, config.dim, bias=config.attention_o_bias + ) + self.kv_cache = None + + if config.attention_qk_norm: + self.q_norm = nn.RMSNorm(config.head_dim, config.norm_eps) + self.k_norm = nn.RMSNorm(config.head_dim, config.norm_eps) + + self.dropout = config.dropout + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self.use_sdpa = use_sdpa + self.attention_qk_norm = config.attention_qk_norm + self.config = config + + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + q_size = self.n_head * self.head_dim + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + if self.attention_qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + + if self.use_sdpa: + if mask is None: + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + y = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.dropout if self.training else 0.0, + is_causal=True, + # No third party attn_mask here to use flash_attention + ) + else: + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + ) + else: + y = self.eq_scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + ) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size) + + return self.wo(y) + + def eq_scaled_dot_product_attention( + self, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + ) -> torch.Tensor: + # This is a standard scaled dot product attention + # It's low efficient, but it doesn't raise cuda error + + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) + attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + + return attn_weight @ value + + +class FeedForward(nn.Module): + def __init__(self, config: BaseModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: + """ + Precomputes frequency tensors for complex exponentials (cis) + + Args: + seq_len: Length of the sequence for which positional embeddings are needed. + n_elem: Number of elements in the frequency tensor. + base: Base value for the frequency scaling (default: 10000). + + Returns: + A tensor containing the precomputed frequencies in real and imaginary parts (bfloat16). + """ + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/fish_speech/models/text2semantic/lora.py b/fish_speech/models/text2semantic/lora.py index bb4a6192c469bce1535b5f93e147f89ce05cca04..647ca6fcccf038e17d2cf91a2874281dff3e0938 100644 --- a/fish_speech/models/text2semantic/lora.py +++ b/fish_speech/models/text2semantic/lora.py @@ -1,92 +1,92 @@ -from dataclasses import dataclass - -import loralib as lora - - -@dataclass -class LoraConfig: - r: int - lora_alpha: float - lora_dropout: float = 0.0 - - -def setup_lora(model, lora_config): - # Replace the embedding layer with a LoRA layer - model.embeddings = lora.Embedding( - num_embeddings=model.embeddings.num_embeddings, - embedding_dim=model.embeddings.embedding_dim, - padding_idx=model.embeddings.padding_idx, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - ) - - model.codebook_embeddings = lora.Embedding( - num_embeddings=model.codebook_embeddings.num_embeddings, - embedding_dim=model.codebook_embeddings.embedding_dim, - padding_idx=model.codebook_embeddings.padding_idx, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - ) - - # Replace output layer with a LoRA layer - linears = [(model, "output")] - - # Replace all linear layers with LoRA layers - for layer in model.layers: - linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) - linears.extend( - [ - (layer.feed_forward, "w1"), - (layer.feed_forward, "w2"), - (layer.feed_forward, "w3"), - ] - ) - - if hasattr(model, "fast_layers"): - model.fast_embeddings = lora.Embedding( - num_embeddings=model.fast_embeddings.num_embeddings, - embedding_dim=model.fast_embeddings.embedding_dim, - padding_idx=model.fast_embeddings.padding_idx, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - ) - - # Dual-AR model - linears.append((model, "fast_output")) - - for layer in model.fast_layers: - linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) - linears.extend( - [ - (layer.feed_forward, "w1"), - (layer.feed_forward, "w2"), - (layer.feed_forward, "w3"), - ] - ) - - for module, layer in linears: - updated_linear = lora.Linear( - in_features=getattr(module, layer).in_features, - out_features=getattr(module, layer).out_features, - bias=getattr(module, layer).bias, - r=lora_config.r, - lora_alpha=lora_config.lora_alpha, - lora_dropout=lora_config.lora_dropout, - ) - setattr(module, layer, updated_linear) - - # Mark only the LoRA layers as trainable - lora.mark_only_lora_as_trainable(model, bias="none") - - -def get_merged_state_dict(model): - # This line will merge the state dict of the model and the LoRA parameters - model.eval() - - # Then we need to remove the LoRA parameters from the state dict - state_dict = model.state_dict() - for name in list(state_dict.keys()): - if "lora" in name: - state_dict.pop(name) - - return state_dict +from dataclasses import dataclass + +import loralib as lora + + +@dataclass +class LoraConfig: + r: int + lora_alpha: float + lora_dropout: float = 0.0 + + +def setup_lora(model, lora_config): + # Replace the embedding layer with a LoRA layer + model.embeddings = lora.Embedding( + num_embeddings=model.embeddings.num_embeddings, + embedding_dim=model.embeddings.embedding_dim, + padding_idx=model.embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + model.codebook_embeddings = lora.Embedding( + num_embeddings=model.codebook_embeddings.num_embeddings, + embedding_dim=model.codebook_embeddings.embedding_dim, + padding_idx=model.codebook_embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + # Replace output layer with a LoRA layer + linears = [(model, "output")] + + # Replace all linear layers with LoRA layers + for layer in model.layers: + linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) + linears.extend( + [ + (layer.feed_forward, "w1"), + (layer.feed_forward, "w2"), + (layer.feed_forward, "w3"), + ] + ) + + if hasattr(model, "fast_layers"): + model.fast_embeddings = lora.Embedding( + num_embeddings=model.fast_embeddings.num_embeddings, + embedding_dim=model.fast_embeddings.embedding_dim, + padding_idx=model.fast_embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + # Dual-AR model + linears.append((model, "fast_output")) + + for layer in model.fast_layers: + linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) + linears.extend( + [ + (layer.feed_forward, "w1"), + (layer.feed_forward, "w2"), + (layer.feed_forward, "w3"), + ] + ) + + for module, layer in linears: + updated_linear = lora.Linear( + in_features=getattr(module, layer).in_features, + out_features=getattr(module, layer).out_features, + bias=getattr(module, layer).bias, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + ) + setattr(module, layer, updated_linear) + + # Mark only the LoRA layers as trainable + lora.mark_only_lora_as_trainable(model, bias="none") + + +def get_merged_state_dict(model): + # This line will merge the state dict of the model and the LoRA parameters + model.eval() + + # Then we need to remove the LoRA parameters from the state dict + state_dict = model.state_dict() + for name in list(state_dict.keys()): + if "lora" in name: + state_dict.pop(name) + + return state_dict diff --git a/fish_speech/text/__init__.py b/fish_speech/text/__init__.py index c35b4e0dbd174b36229350f27a21b5acf0e9825b..d740bd8eed447d162e55b165965dec17130377ce 100644 --- a/fish_speech/text/__init__.py +++ b/fish_speech/text/__init__.py @@ -1,4 +1,4 @@ -from .clean import clean_text -from .spliter import split_text - -__all__ = ["clean_text", "split_text"] +from .clean import clean_text +from .spliter import split_text + +__all__ = ["clean_text", "split_text"] diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py index 2aba28fc1bc7fc6054e37534ece06c743bff9f6c..68428c406c018a5bb156908b80341429a78c0301 100644 --- a/fish_speech/text/clean.py +++ b/fish_speech/text/clean.py @@ -1,37 +1,37 @@ -import re - -SYMBOLS_MAPPING = { - "‘": "'", - "’": "'", -} - -REPLACE_SYMBOL_REGEX = re.compile( - "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys()) -) - - -EMOJI_REGEX = re.compile( - "[" - "\U0001F600-\U0001F64F" # emoticons - "\U0001F300-\U0001F5FF" # symbols & pictographs - "\U0001F680-\U0001F6FF" # transport & map symbols - "\U0001F1E0-\U0001F1FF" # flags (iOS) - "]+", - flags=re.UNICODE, -) - - -def clean_text(text): - # Clean the text - text = text.strip() - - # Replace all chinese symbols with their english counterparts - text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text) - - # Remove emojis - text = EMOJI_REGEX.sub(r"", text) - - # Remove continuous periods (...) and commas (,,,) - text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text) - - return text +import re + +SYMBOLS_MAPPING = { + "‘": "'", + "’": "'", +} + +REPLACE_SYMBOL_REGEX = re.compile( + "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys()) +) + + +EMOJI_REGEX = re.compile( + "[" + "\U0001f600-\U0001f64f" # emoticons + "\U0001f300-\U0001f5ff" # symbols & pictographs + "\U0001f680-\U0001f6ff" # transport & map symbols + "\U0001f1e0-\U0001f1ff" # flags (iOS) + "]+", + flags=re.UNICODE, +) + + +def clean_text(text): + # Clean the text + text = text.strip() + + # Replace all chinese symbols with their english counterparts + text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text) + + # Remove emojis + text = EMOJI_REGEX.sub(r"", text) + + # Remove continuous periods (...) and commas (,,,) + text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text) + + return text diff --git a/fish_speech/text/spliter.py b/fish_speech/text/spliter.py index 30661e4ef3796250e539aa367467bac22ecbbfb8..df079addb81cd91145f0c68f70b0da0d7251f036 100644 --- a/fish_speech/text/spliter.py +++ b/fish_speech/text/spliter.py @@ -1,130 +1,130 @@ -import re -import string - -from fish_speech.text.clean import clean_text - - -def utf_8_len(text: str): - return len(text.encode("utf-8")) - - -def break_text(texts, length, splits: set): - for text in texts: - if utf_8_len(text) <= length: - yield text - continue - - curr = "" - for char in text: - curr += char - - if char in splits: - yield curr - curr = "" - - if curr: - yield curr - - -def break_text_by_length(texts, length): - for text in texts: - if utf_8_len(text) <= length: - yield text - continue - - curr = "" - for char in text: - curr += char - - if utf_8_len(curr) >= length: - yield curr - curr = "" - - if curr: - yield curr - - -def add_cleaned(curr, segments): - curr = curr.strip() - if curr and not all(c.isspace() or c in string.punctuation for c in curr): - segments.append(curr) - - -def protect_float(text): - # Turns 3.14 into <3_f_14> to prevent splitting - return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text) - - -def unprotect_float(text): - # Turns <3_f_14> into 3.14 - return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text) - - -def split_text(text, length): - text = clean_text(text) - - # Break the text into pieces with following rules: - # 1. Split the text at ".", "!", "?" if text is NOT a float - # 2. If the text is longer than length, split at "," - # 3. If the text is still longer than length, split at " " - # 4. If the text is still longer than length, split at any character to length - - texts = [text] - texts = map(protect_float, texts) - texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"}) - texts = map(unprotect_float, texts) - texts = break_text(texts, length, {",", ","}) - texts = break_text(texts, length, {" "}) - texts = list(break_text_by_length(texts, length)) - - # Then, merge the texts into segments with length <= length - segments = [] - curr = "" - - for text in texts: - if utf_8_len(curr) + utf_8_len(text) <= length: - curr += text - else: - add_cleaned(curr, segments) - curr = text - - if curr: - add_cleaned(curr, segments) - - return segments - - -if __name__ == "__main__": - # Test the split_text function - - text = "This is a test sentence. This is another test sentence. And a third one." - - assert split_text(text, 50) == [ - "This is a test sentence.", - "This is another test sentence. And a third one.", - ] - assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"] - assert split_text(" ", 10) == [] - assert split_text("a", 10) == ["a"] - - text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines." - assert split_text(text, 50) == [ - "This is a test sentence with only commas,", - "and no dots, and no exclamation marks,", - "and no question marks, and no newlines.", - ] - - text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence." - # First half split at " ", second half split at "," - assert split_text(text, 50) == [ - "This is a test sentence This is a test sentence", - "This is a test sentence. This is a test sentence,", - "This is a test sentence, This is a test sentence.", - ] - - text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。" - assert split_text(text, 50) == [ - "这是一段很长的中文文本,", - "而且没有句号,也没有感叹号,", - "也没有问号,也没有换行符.", - ] +import re +import string + +from fish_speech.text.clean import clean_text + + +def utf_8_len(text: str): + return len(text.encode("utf-8")) + + +def break_text(texts, length, splits: set): + for text in texts: + if utf_8_len(text) <= length: + yield text + continue + + curr = "" + for char in text: + curr += char + + if char in splits: + yield curr + curr = "" + + if curr: + yield curr + + +def break_text_by_length(texts, length): + for text in texts: + if utf_8_len(text) <= length: + yield text + continue + + curr = "" + for char in text: + curr += char + + if utf_8_len(curr) >= length: + yield curr + curr = "" + + if curr: + yield curr + + +def add_cleaned(curr, segments): + curr = curr.strip() + if curr and not all(c.isspace() or c in string.punctuation for c in curr): + segments.append(curr) + + +def protect_float(text): + # Turns 3.14 into <3_f_14> to prevent splitting + return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text) + + +def unprotect_float(text): + # Turns <3_f_14> into 3.14 + return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text) + + +def split_text(text, length): + text = clean_text(text) + + # Break the text into pieces with following rules: + # 1. Split the text at ".", "!", "?" if text is NOT a float + # 2. If the text is longer than length, split at "," + # 3. If the text is still longer than length, split at " " + # 4. If the text is still longer than length, split at any character to length + + texts = [text] + texts = map(protect_float, texts) + texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"}) + texts = map(unprotect_float, texts) + texts = break_text(texts, length, {",", ","}) + texts = break_text(texts, length, {" "}) + texts = list(break_text_by_length(texts, length)) + + # Then, merge the texts into segments with length <= length + segments = [] + curr = "" + + for text in texts: + if utf_8_len(curr) + utf_8_len(text) <= length: + curr += text + else: + add_cleaned(curr, segments) + curr = text + + if curr: + add_cleaned(curr, segments) + + return segments + + +if __name__ == "__main__": + # Test the split_text function + + text = "This is a test sentence. This is another test sentence. And a third one." + + assert split_text(text, 50) == [ + "This is a test sentence.", + "This is another test sentence. And a third one.", + ] + assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"] + assert split_text(" ", 10) == [] + assert split_text("a", 10) == ["a"] + + text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines." + assert split_text(text, 50) == [ + "This is a test sentence with only commas,", + "and no dots, and no exclamation marks,", + "and no question marks, and no newlines.", + ] + + text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence." + # First half split at " ", second half split at "," + assert split_text(text, 50) == [ + "This is a test sentence This is a test sentence", + "This is a test sentence. This is a test sentence,", + "This is a test sentence, This is a test sentence.", + ] + + text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。" + assert split_text(text, 50) == [ + "这是一段很长的中文文本,", + "而且没有句号,也没有感叹号,", + "也没有问号,也没有换行符.", + ] diff --git a/fish_speech/tokenizer.py b/fish_speech/tokenizer.py index f4d512d31263dcb2abc95c3a7bf3cd4bde8c4830..9a140f9db98269b9abac8f36d5b613500f2cb881 100644 --- a/fish_speech/tokenizer.py +++ b/fish_speech/tokenizer.py @@ -1,152 +1,179 @@ -import base64 -import json -import logging -from pathlib import Path - -import tiktoken - -logger = logging.getLogger(__name__) - -# This is a modified version of the default pattern from GPT-4o, that better handles punctuations. -FISH_TIKTOKEN_PATTERN = "|".join( - [ - r"(?i:'s|'t|'re|'ve|'m|'ll|'d)", - r"\p{P}", - r"[^\r\n\p{L}\p{N}]?\p{L}+", - r"\p{N}", - r" ?[^\s\p{L}\p{N}]+[\r\n]*", - r"\s*[\r\n]+", - r"\s+(\?!\S)", - r"\s+", - ] -) -TIKTOKEN_MAX_ENCODE_CHARS = 400_000 - -BOS_TOKEN = "<|begin_of_text|>" -EOS_TOKEN = "<|end_of_text|>" -PAD_TOKEN = "<|pad|>" -IM_START_TOKEN = "<|im_start|>" -IM_END_TOKEN = "<|im_end|>" - -MODALITY_TEXT_TOKEN = "<|text|>" -MODALITY_VOICE_TOKEN = "<|voice|>" -MODALITY_INTERLEAVE_TOKEN = "<|interleave|>" -MODALITY_TOKENS = { - "text": MODALITY_TEXT_TOKEN, - "voice": MODALITY_VOICE_TOKEN, - "interleave": MODALITY_INTERLEAVE_TOKEN, -} - -PLACEHOLDER_TOKEN = [""] * 4 -for i in range(4): - PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>" - -SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>" -SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)] - -# Warning: when you add a new special token, you should only add it to the end of the list. -ALL_SPECIAL_TOKENS = [ - BOS_TOKEN, - EOS_TOKEN, - PAD_TOKEN, - IM_START_TOKEN, - IM_END_TOKEN, - PLACEHOLDER_TOKEN[0], - PLACEHOLDER_TOKEN[1], - PLACEHOLDER_TOKEN[2], - PLACEHOLDER_TOKEN[3], - MODALITY_TEXT_TOKEN, - MODALITY_VOICE_TOKEN, - MODALITY_INTERLEAVE_TOKEN, - *SEMANTIC_TOKENS, -] - - -class FishTokenizer: - def __init__(self, model_path: str) -> None: - mergeable_ranks = self.load_tiktoken_bpe(model_path) - special_token_begin = len(mergeable_ranks) - self.all_special_tokens_with_ids = { - token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS) - } - self.semantic_id_to_token_id = { - i: self.all_special_tokens_with_ids[token] - for i, token in enumerate(SEMANTIC_TOKENS) - } - self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]] - self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]] - - self.tkt_model = tiktoken.core.Encoding( - name=Path(model_path).stem, - pat_str=FISH_TIKTOKEN_PATTERN, - mergeable_ranks=mergeable_ranks, - special_tokens=self.all_special_tokens_with_ids, - ) - - @staticmethod - def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]: - data = {} - for line in open(tiktoken_bpe_file).read().splitlines(): - if not line: - continue - token, rank = line.split() - data[base64.b64decode(token)] = int(rank) - return data - - def get_token_id(self, token: str) -> int: - return self.all_special_tokens_with_ids[token] - - def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]: - assert isinstance(s, str) - - subs = [] - for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS): - subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS]) - - if allowed_special is True: - allowed_special = self.tkt_model.special_tokens_set - elif allowed_special is False: - allowed_special = set() - - return sum( - self.tkt_model.encode_batch( - subs, allowed_special=allowed_special, disallowed_special=set() - ), - start=[], - ) - - def decode(self, tokens: list[int]) -> str: - return self.tkt_model.decode(tokens) - - def save_pretrained(self, path: str): - path = Path(path) - path.mkdir(parents=True, exist_ok=True) - - with open(path / "tokenizer.tiktoken", "w") as f: - for token, rank in self.tkt_model._mergeable_ranks.items(): - f.write(f"{base64.b64encode(token).decode()} {rank}\n") - - with open(path / "special_tokens.json", "w") as f: - json.dump( - self.all_special_tokens_with_ids, - f, - indent=2, - ensure_ascii=False, - ) - - @staticmethod - def from_pretrained(path: str): - return FishTokenizer(Path(path) / "tokenizer.tiktoken") - - -if __name__ == "__main__": - tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken") - tokenizer.save_pretrained("checkpoints/fish-speech-0.5B") - tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B") - - print( - [ - tokenizer.decode([i]) - for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}") - ] - ) +import base64 +import json +import logging +import re +from pathlib import Path + +import tiktoken + +logger = logging.getLogger(__name__) + +# This is a modified version of the default pattern from GPT-4o, that better handles punctuations. +FISH_TIKTOKEN_PATTERN = "|".join( + [ + r"(?i:'s|'t|'re|'ve|'m|'ll|'d)", + r"\p{P}", + r"[^\r\n\p{L}\p{N}]?\p{L}+", + r"\p{N}", + r" ?[^\s\p{L}\p{N}]+[\r\n]*", + r"\s*[\r\n]+", + r"\s+(\?!\S)", + r"\s+", + ] +) +TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + +BOS_TOKEN = "<|begin_of_text|>" +EOS_TOKEN = "<|end_of_text|>" +PAD_TOKEN = "<|pad|>" +IM_START_TOKEN = "<|im_start|>" +IM_END_TOKEN = "<|im_end|>" +PHONEME_START_TOKEN = "<|phoneme_start|>" +PHONEME_END_TOKEN = "<|phoneme_end|>" +TOOL_CALL_START_TOKEN = "<|tool_call_start|>" +TOOL_CALL_END_TOKEN = "<|tool_call_end|>" + +MODALITY_TEXT_TOKEN = "<|text|>" +MODALITY_VOICE_TOKEN = "<|voice|>" +MODALITY_INTERLEAVE_TOKEN = "<|interleave|>" +AUDIO_START_TOKEN = "<|audio_start|>" +AUDIO_END_TOKEN = "<|audio_end|>" +AUDIO_EMBED_TOKEN = "<|audio|>" +MODALITY_TOKENS = { + "text": MODALITY_TEXT_TOKEN, + "voice": MODALITY_VOICE_TOKEN, + "interleave": MODALITY_INTERLEAVE_TOKEN, +} + +SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>" +SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)] + +# Warning: when you add a new special token, you should only add it to the end of the list. +ALL_SPECIAL_TOKENS = [ + BOS_TOKEN, + EOS_TOKEN, + PAD_TOKEN, + IM_START_TOKEN, + IM_END_TOKEN, + PHONEME_START_TOKEN, + PHONEME_END_TOKEN, + TOOL_CALL_START_TOKEN, + TOOL_CALL_END_TOKEN, + MODALITY_TEXT_TOKEN, + MODALITY_VOICE_TOKEN, + MODALITY_INTERLEAVE_TOKEN, + AUDIO_START_TOKEN, + AUDIO_END_TOKEN, + AUDIO_EMBED_TOKEN, + *SEMANTIC_TOKENS, +] + + +class FishTokenizer: + def __init__( + self, model_path: str, special_tokens: list[str] = ALL_SPECIAL_TOKENS + ) -> None: + mergeable_ranks = self.load_tiktoken_bpe(model_path) + special_token_begin = len(mergeable_ranks) + self.all_special_tokens_with_ids = { + token: special_token_begin + i for i, token in enumerate(special_tokens) + } + + self.semantic_id_to_token_id = {} + end_idx = 0 + for token in special_tokens: + if token.startswith("<|semantic:"): + idx = int(re.match(r"<\|semantic:(\d+)\|>", token).group(1)) + self.semantic_id_to_token_id[idx] = self.all_special_tokens_with_ids[ + token + ] + + if idx > end_idx: + end_idx = idx + + self.semantic_begin_id = self.semantic_id_to_token_id[0] + self.semantic_end_id = self.semantic_id_to_token_id[end_idx] + + self.tkt_model = tiktoken.core.Encoding( + name=Path(model_path).stem, + pat_str=FISH_TIKTOKEN_PATTERN, + mergeable_ranks=mergeable_ranks, + special_tokens=self.all_special_tokens_with_ids, + ) + + @property + def vocab_size(self): + return len(self.tkt_model._mergeable_ranks) + + @property + def num_special_tokens(self): + return len(self.all_special_tokens_with_ids) + + @staticmethod + def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]: + data = {} + for line in open(tiktoken_bpe_file).read().splitlines(): + if not line: + continue + token, rank = line.split() + if token == "=": + continue + data[base64.b64decode(token)] = int(rank) + return data + + def get_token_id(self, token: str) -> int: + return self.all_special_tokens_with_ids[token] + + def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]: + assert isinstance(s, str) + + subs = [] + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS): + subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS]) + + if allowed_special is True: + allowed_special = self.tkt_model.special_tokens_set + elif allowed_special is False: + allowed_special = set() + + return sum( + self.tkt_model.encode_batch( + subs, allowed_special=allowed_special, disallowed_special=set() + ), + start=[], + ) + + def decode(self, tokens: list[int]) -> str: + return self.tkt_model.decode(tokens) + + def save_pretrained(self, path: str): + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + with open(path / "tokenizer.tiktoken", "w") as f: + for token, rank in self.tkt_model._mergeable_ranks.items(): + a = base64.b64encode(token).decode() + if a == "": + a = "=" + f.write(f"{a} {rank}\n") + + with open(path / "special_tokens.json", "w") as f: + json.dump( + self.all_special_tokens_with_ids, + f, + indent=2, + ensure_ascii=False, + ) + + @staticmethod + def from_pretrained(path: str): + special_tokens_path = Path(path) / "special_tokens.json" + if special_tokens_path.exists(): + with open(special_tokens_path) as f: + all_special_tokens_with_ids = json.load(f) + else: + all_special_tokens_with_ids = ALL_SPECIAL_TOKENS + + return FishTokenizer( + Path(path) / "tokenizer.tiktoken", all_special_tokens_with_ids + ) diff --git a/fish_speech/utils/__init__.py b/fish_speech/utils/__init__.py index 185517110af6cacfbda3c3d9561e6081d825fca4..53cf2f23174ddac9bf523730aca2f6a9965d134a 100644 --- a/fish_speech/utils/__init__.py +++ b/fish_speech/utils/__init__.py @@ -1,24 +1,24 @@ -from .braceexpand import braceexpand -from .context import autocast_exclude_mps -from .file import get_latest_checkpoint -from .instantiators import instantiate_callbacks, instantiate_loggers -from .logger import RankedLogger -from .logging_utils import log_hyperparameters -from .rich_utils import enforce_tags, print_config_tree -from .utils import extras, get_metric_value, set_seed, task_wrapper - -__all__ = [ - "enforce_tags", - "extras", - "get_metric_value", - "RankedLogger", - "instantiate_callbacks", - "instantiate_loggers", - "log_hyperparameters", - "print_config_tree", - "task_wrapper", - "braceexpand", - "get_latest_checkpoint", - "autocast_exclude_mps", - "set_seed", -] +from .braceexpand import braceexpand +from .context import autocast_exclude_mps +from .file import get_latest_checkpoint +from .instantiators import instantiate_callbacks, instantiate_loggers +from .logger import RankedLogger +from .logging_utils import log_hyperparameters +from .rich_utils import enforce_tags, print_config_tree +from .utils import extras, get_metric_value, set_seed, task_wrapper + +__all__ = [ + "enforce_tags", + "extras", + "get_metric_value", + "RankedLogger", + "instantiate_callbacks", + "instantiate_loggers", + "log_hyperparameters", + "print_config_tree", + "task_wrapper", + "braceexpand", + "get_latest_checkpoint", + "autocast_exclude_mps", + "set_seed", +] diff --git a/fish_speech/utils/braceexpand.py b/fish_speech/utils/braceexpand.py index 8888977ce194fc5caa9e85bcf548e3bc42a3c52c..f3ac739f01f7e10e039c68c1157d6c761064f974 100644 --- a/fish_speech/utils/braceexpand.py +++ b/fish_speech/utils/braceexpand.py @@ -1,217 +1,217 @@ -""" -Bash-style brace expansion -Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py -License: MIT -""" - -import re -import string -from itertools import chain, product -from typing import Iterable, Iterator, Optional - -__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"] - - -class UnbalancedBracesError(ValueError): - pass - - -alphabet = string.ascii_uppercase + string.ascii_lowercase - -int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$") -char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$") -escape_re = re.compile(r"\\(.)") - - -def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]: - """braceexpand(pattern) -> iterator over generated strings - - Returns an iterator over the strings resulting from brace expansion - of pattern. This function implements Brace Expansion as described in - bash(1), with the following limitations: - - * A pattern containing unbalanced braces will raise an - UnbalancedBracesError exception. In bash, unbalanced braces will either - be partly expanded or ignored. - - * A mixed-case character range like '{Z..a}' or '{a..Z}' will not - include the characters '[]^_`' between 'Z' and 'a'. - - When escape is True (the default), characters in pattern can be - prefixed with a backslash to cause them not to be interpreted as - special characters for brace expansion (such as '{', '}', ','). - To pass through a a literal backslash, double it ('\\\\'). - - When escape is False, backslashes in pattern have no special - meaning and will be preserved in the output. - - Examples: - - >>> from braceexpand import braceexpand - - # Integer range - >>> list(braceexpand('item{1..3}')) - ['item1', 'item2', 'item3'] - - # Character range - >>> list(braceexpand('{a..c}')) - ['a', 'b', 'c'] - - # Sequence - >>> list(braceexpand('index.html{,.backup}')) - ['index.html', 'index.html.backup'] - - # Nested patterns - >>> list(braceexpand('python{2.{5..7},3.{2,3}}')) - ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3'] - - # Prefixing an integer with zero causes all numbers to be padded to - # the same width. - >>> list(braceexpand('{07..10}')) - ['07', '08', '09', '10'] - - # An optional increment can be specified for ranges. - >>> list(braceexpand('{a..g..2}')) - ['a', 'c', 'e', 'g'] - - # Ranges can go in both directions. - >>> list(braceexpand('{4..1}')) - ['4', '3', '2', '1'] - - # Numbers can be negative - >>> list(braceexpand('{2..-1}')) - ['2', '1', '0', '-1'] - - # Unbalanced braces raise an exception. - >>> list(braceexpand('{1{2,3}')) - Traceback (most recent call last): - ... - UnbalancedBracesError: Unbalanced braces: '{1{2,3}' - - # By default, the backslash is the escape character. - >>> list(braceexpand(r'{1\\{2,3}')) - ['1{2', '3'] - - # Setting 'escape' to False disables backslash escaping. - >>> list(braceexpand(r'\\{1,2}', escape=False)) - ['\\\\1', '\\\\2'] - - """ - return ( - escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape) - ) - - -def parse_pattern(pattern: str, escape: bool) -> Iterator[str]: - start = 0 - pos = 0 - bracketdepth = 0 - items: list[Iterable[str]] = [] - - # print 'pattern:', pattern - while pos < len(pattern): - if escape and pattern[pos] == "\\": - pos += 2 - continue - elif pattern[pos] == "{": - if bracketdepth == 0 and pos > start: - # print 'literal:', pattern[start:pos] - items.append([pattern[start:pos]]) - start = pos - bracketdepth += 1 - elif pattern[pos] == "}": - bracketdepth -= 1 - if bracketdepth == 0: - # print 'expression:', pattern[start+1:pos] - expr = pattern[start + 1 : pos] - item = parse_expression(expr, escape) - if item is None: # not a range or sequence - items.extend([["{"], parse_pattern(expr, escape), ["}"]]) - else: - items.append(item) - start = pos + 1 # skip the closing brace - pos += 1 - - if bracketdepth != 0: # unbalanced braces - raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern) - - if start < pos: - items.append([pattern[start:]]) - - return ("".join(item) for item in product(*items)) - - -def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]: - int_range_match = int_range_re.match(expr) - if int_range_match: - return make_int_range(*int_range_match.groups()) - - char_range_match = char_range_re.match(expr) - if char_range_match: - return make_char_range(*char_range_match.groups()) - - return parse_sequence(expr, escape) - - -def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]: - # sequence -> chain(*sequence_items) - start = 0 - pos = 0 - bracketdepth = 0 - items: list[Iterable[str]] = [] - - # print 'sequence:', seq - while pos < len(seq): - if escape and seq[pos] == "\\": - pos += 2 - continue - elif seq[pos] == "{": - bracketdepth += 1 - elif seq[pos] == "}": - bracketdepth -= 1 - elif seq[pos] == "," and bracketdepth == 0: - items.append(parse_pattern(seq[start:pos], escape)) - start = pos + 1 # skip the comma - pos += 1 - - if bracketdepth != 0: - raise UnbalancedBracesError - if not items: - return None - - # part after the last comma (may be the empty string) - items.append(parse_pattern(seq[start:], escape)) - return chain(*items) - - -def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]: - if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]): - padding = max(len(left), len(right)) - else: - padding = 0 - step = (int(incr) or 1) if incr else 1 - start = int(left) - end = int(right) - r = range(start, end + 1, step) if start < end else range(start, end - 1, -step) - fmt = "%0{}d".format(padding) - return (fmt % i for i in r) - - -def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str: - step = (int(incr) or 1) if incr else 1 - start = alphabet.index(left) - end = alphabet.index(right) - if start < end: - return alphabet[start : end + 1 : step] - else: - end = end or -len(alphabet) - return alphabet[start : end - 1 : -step] - - -if __name__ == "__main__": - import doctest - import sys - - failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL) - if failed: - sys.exit(1) +""" +Bash-style brace expansion +Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py +License: MIT +""" + +import re +import string +from itertools import chain, product +from typing import Iterable, Iterator, Optional + +__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"] + + +class UnbalancedBracesError(ValueError): + pass + + +alphabet = string.ascii_uppercase + string.ascii_lowercase + +int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$") +char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$") +escape_re = re.compile(r"\\(.)") + + +def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]: + """braceexpand(pattern) -> iterator over generated strings + + Returns an iterator over the strings resulting from brace expansion + of pattern. This function implements Brace Expansion as described in + bash(1), with the following limitations: + + * A pattern containing unbalanced braces will raise an + UnbalancedBracesError exception. In bash, unbalanced braces will either + be partly expanded or ignored. + + * A mixed-case character range like '{Z..a}' or '{a..Z}' will not + include the characters '[]^_`' between 'Z' and 'a'. + + When escape is True (the default), characters in pattern can be + prefixed with a backslash to cause them not to be interpreted as + special characters for brace expansion (such as '{', '}', ','). + To pass through a a literal backslash, double it ('\\\\'). + + When escape is False, backslashes in pattern have no special + meaning and will be preserved in the output. + + Examples: + + >>> from braceexpand import braceexpand + + # Integer range + >>> list(braceexpand('item{1..3}')) + ['item1', 'item2', 'item3'] + + # Character range + >>> list(braceexpand('{a..c}')) + ['a', 'b', 'c'] + + # Sequence + >>> list(braceexpand('index.html{,.backup}')) + ['index.html', 'index.html.backup'] + + # Nested patterns + >>> list(braceexpand('python{2.{5..7},3.{2,3}}')) + ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3'] + + # Prefixing an integer with zero causes all numbers to be padded to + # the same width. + >>> list(braceexpand('{07..10}')) + ['07', '08', '09', '10'] + + # An optional increment can be specified for ranges. + >>> list(braceexpand('{a..g..2}')) + ['a', 'c', 'e', 'g'] + + # Ranges can go in both directions. + >>> list(braceexpand('{4..1}')) + ['4', '3', '2', '1'] + + # Numbers can be negative + >>> list(braceexpand('{2..-1}')) + ['2', '1', '0', '-1'] + + # Unbalanced braces raise an exception. + >>> list(braceexpand('{1{2,3}')) + Traceback (most recent call last): + ... + UnbalancedBracesError: Unbalanced braces: '{1{2,3}' + + # By default, the backslash is the escape character. + >>> list(braceexpand(r'{1\\{2,3}')) + ['1{2', '3'] + + # Setting 'escape' to False disables backslash escaping. + >>> list(braceexpand(r'\\{1,2}', escape=False)) + ['\\\\1', '\\\\2'] + + """ + return ( + escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape) + ) + + +def parse_pattern(pattern: str, escape: bool) -> Iterator[str]: + start = 0 + pos = 0 + bracketdepth = 0 + items: list[Iterable[str]] = [] + + # print 'pattern:', pattern + while pos < len(pattern): + if escape and pattern[pos] == "\\": + pos += 2 + continue + elif pattern[pos] == "{": + if bracketdepth == 0 and pos > start: + # print 'literal:', pattern[start:pos] + items.append([pattern[start:pos]]) + start = pos + bracketdepth += 1 + elif pattern[pos] == "}": + bracketdepth -= 1 + if bracketdepth == 0: + # print 'expression:', pattern[start+1:pos] + expr = pattern[start + 1 : pos] + item = parse_expression(expr, escape) + if item is None: # not a range or sequence + items.extend([["{"], parse_pattern(expr, escape), ["}"]]) + else: + items.append(item) + start = pos + 1 # skip the closing brace + pos += 1 + + if bracketdepth != 0: # unbalanced braces + raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern) + + if start < pos: + items.append([pattern[start:]]) + + return ("".join(item) for item in product(*items)) + + +def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]: + int_range_match = int_range_re.match(expr) + if int_range_match: + return make_int_range(*int_range_match.groups()) + + char_range_match = char_range_re.match(expr) + if char_range_match: + return make_char_range(*char_range_match.groups()) + + return parse_sequence(expr, escape) + + +def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]: + # sequence -> chain(*sequence_items) + start = 0 + pos = 0 + bracketdepth = 0 + items: list[Iterable[str]] = [] + + # print 'sequence:', seq + while pos < len(seq): + if escape and seq[pos] == "\\": + pos += 2 + continue + elif seq[pos] == "{": + bracketdepth += 1 + elif seq[pos] == "}": + bracketdepth -= 1 + elif seq[pos] == "," and bracketdepth == 0: + items.append(parse_pattern(seq[start:pos], escape)) + start = pos + 1 # skip the comma + pos += 1 + + if bracketdepth != 0: + raise UnbalancedBracesError + if not items: + return None + + # part after the last comma (may be the empty string) + items.append(parse_pattern(seq[start:], escape)) + return chain(*items) + + +def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]: + if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]): + padding = max(len(left), len(right)) + else: + padding = 0 + step = (int(incr) or 1) if incr else 1 + start = int(left) + end = int(right) + r = range(start, end + 1, step) if start < end else range(start, end - 1, -step) + fmt = "%0{}d".format(padding) + return (fmt % i for i in r) + + +def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str: + step = (int(incr) or 1) if incr else 1 + start = alphabet.index(left) + end = alphabet.index(right) + if start < end: + return alphabet[start : end + 1 : step] + else: + end = end or -len(alphabet) + return alphabet[start : end - 1 : -step] + + +if __name__ == "__main__": + import doctest + import sys + + failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL) + if failed: + sys.exit(1) diff --git a/fish_speech/utils/context.py b/fish_speech/utils/context.py index 618c4deceaa2578cd9f0672a65d1dd55430c7dcc..f04a99290ab32f7fe5b60656075a2d03af8468d6 100644 --- a/fish_speech/utils/context.py +++ b/fish_speech/utils/context.py @@ -1,13 +1,13 @@ -from contextlib import nullcontext - -import torch - - -def autocast_exclude_mps( - device_type: str, dtype: torch.dtype -) -> nullcontext | torch.autocast: - return ( - nullcontext() - if torch.backends.mps.is_available() - else torch.autocast(device_type, dtype) - ) +from contextlib import nullcontext + +import torch + + +def autocast_exclude_mps( + device_type: str, dtype: torch.dtype +) -> nullcontext | torch.autocast: + return ( + nullcontext() + if torch.backends.mps.is_available() + else torch.autocast(device_type, dtype) + ) diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py index 7516ad28b31e9836e6c02a991690d3e466d3ea62..a54c22655fc52d55db8932f7c6edabe017b965f4 100644 --- a/fish_speech/utils/file.py +++ b/fish_speech/utils/file.py @@ -1,16 +1,139 @@ -import os -from pathlib import Path - - -def get_latest_checkpoint(path: Path | str) -> Path | None: - # Find the latest checkpoint - ckpt_dir = Path(path) - - if ckpt_dir.exists() is False: - return None - - ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime) - if len(ckpts) == 0: - return None - - return ckpts[-1] +import os +from pathlib import Path +from typing import Union + +from loguru import logger +from natsort import natsorted + +AUDIO_EXTENSIONS = { + ".mp3", + ".wav", + ".flac", + ".ogg", + ".m4a", + ".wma", + ".aac", + ".aiff", + ".aif", + ".aifc", +} + +VIDEO_EXTENSIONS = { + ".mp4", + ".avi", +} + + +def get_latest_checkpoint(path: Path | str) -> Path | None: + # Find the latest checkpoint + ckpt_dir = Path(path) + + if ckpt_dir.exists() is False: + return None + + ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime) + if len(ckpts) == 0: + return None + + return ckpts[-1] + + +def audio_to_bytes(file_path): + if not file_path or not Path(file_path).exists(): + return None + with open(file_path, "rb") as wav_file: + wav = wav_file.read() + return wav + + +def read_ref_text(ref_text): + path = Path(ref_text) + if path.exists() and path.is_file(): + with path.open("r", encoding="utf-8") as file: + return file.read() + return ref_text + + +def list_files( + path: Union[Path, str], + extensions: set[str] = set(), + recursive: bool = False, + sort: bool = True, +) -> list[Path]: + """List files in a directory. + + Args: + path (Path): Path to the directory. + extensions (set, optional): Extensions to filter. Defaults to None. + recursive (bool, optional): Whether to search recursively. Defaults to False. + sort (bool, optional): Whether to sort the files. Defaults to True. + + Returns: + list: List of files. + """ + + if isinstance(path, str): + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"Directory {path} does not exist.") + + files = [file for ext in extensions for file in path.rglob(f"*{ext}")] + + if sort: + files = natsorted(files) + + return files + + +def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]: + """ + Load a Bert-VITS2 style filelist. + """ + + files = set() + results = [] + count_duplicated, count_not_found = 0, 0 + + LANGUAGE_TO_LANGUAGES = { + "zh": ["zh", "en"], + "jp": ["jp", "en"], + "en": ["en"], + } + + with open(path, "r", encoding="utf-8") as f: + for line in f.readlines(): + splits = line.strip().split("|", maxsplit=3) + if len(splits) != 4: + logger.warning(f"Invalid line: {line}") + continue + + filename, speaker, language, text = splits + file = Path(filename) + language = language.strip().lower() + + if language == "ja": + language = "jp" + + assert language in ["zh", "jp", "en"], f"Invalid language {language}" + languages = LANGUAGE_TO_LANGUAGES[language] + + if file in files: + logger.warning(f"Duplicated file: {file}") + count_duplicated += 1 + continue + + if not file.exists(): + logger.warning(f"File not found: {file}") + count_not_found += 1 + continue + + results.append((file, speaker, languages, text)) + + if count_duplicated > 0: + logger.warning(f"Total duplicated files: {count_duplicated}") + + if count_not_found > 0: + logger.warning(f"Total files not found: {count_not_found}") + + return results diff --git a/fish_speech/utils/instantiators.py b/fish_speech/utils/instantiators.py index d1a08fe2fd76bedc5b4ad40f8dddfa40e6951c58..f6ee463924f588a35477937fbe3c3364043bdf3e 100644 --- a/fish_speech/utils/instantiators.py +++ b/fish_speech/utils/instantiators.py @@ -1,50 +1,50 @@ -from typing import List - -import hydra -from omegaconf import DictConfig -from pytorch_lightning import Callback -from pytorch_lightning.loggers import Logger - -from .logger import RankedLogger - -log = RankedLogger(__name__, rank_zero_only=True) - - -def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: - """Instantiates callbacks from config.""" - - callbacks: List[Callback] = [] - - if not callbacks_cfg: - log.warning("No callback configs found! Skipping..") - return callbacks - - if not isinstance(callbacks_cfg, DictConfig): - raise TypeError("Callbacks config must be a DictConfig!") - - for _, cb_conf in callbacks_cfg.items(): - if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: - log.info(f"Instantiating callback <{cb_conf._target_}>") - callbacks.append(hydra.utils.instantiate(cb_conf)) - - return callbacks - - -def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: - """Instantiates loggers from config.""" - - logger: List[Logger] = [] - - if not logger_cfg: - log.warning("No logger configs found! Skipping...") - return logger - - if not isinstance(logger_cfg, DictConfig): - raise TypeError("Logger config must be a DictConfig!") - - for _, lg_conf in logger_cfg.items(): - if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: - log.info(f"Instantiating logger <{lg_conf._target_}>") - logger.append(hydra.utils.instantiate(lg_conf)) - - return logger +from typing import List + +import hydra +from omegaconf import DictConfig +from pytorch_lightning import Callback +from pytorch_lightning.loggers import Logger + +from .logger import RankedLogger + +log = RankedLogger(__name__, rank_zero_only=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/fish_speech/utils/logger.py b/fish_speech/utils/logger.py index 5e909c26380affa14ec2e8e92ce5ecb37dc0777e..94f94f738d1d87404354d086c30ef0ad9ab04cdc 100644 --- a/fish_speech/utils/logger.py +++ b/fish_speech/utils/logger.py @@ -1,55 +1,55 @@ -import logging -from typing import Mapping, Optional - -from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only - - -class RankedLogger(logging.LoggerAdapter): - """A multi-GPU-friendly python command line logger.""" - - def __init__( - self, - name: str = __name__, - rank_zero_only: bool = True, - extra: Optional[Mapping[str, object]] = None, - ) -> None: - """Initializes a multi-GPU-friendly python command line logger that logs on all processes - with their rank prefixed in the log message. - - :param name: The name of the logger. Default is ``__name__``. - :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. - :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. - """ - logger = logging.getLogger(name) - super().__init__(logger=logger, extra=extra) - self.rank_zero_only = rank_zero_only - - def log( - self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs - ) -> None: - """Delegate a log call to the underlying logger, after prefixing its message with the rank - of the process it's being logged from. If `'rank'` is provided, then the log will only - occur on that rank/process. - - :param level: The level to log at. Look at `logging.__init__.py` for more information. - :param msg: The message to log. - :param rank: The rank to log at. - :param args: Additional args to pass to the underlying logging function. - :param kwargs: Any additional keyword args to pass to the underlying logging function. - """ - if self.isEnabledFor(level): - msg, kwargs = self.process(msg, kwargs) - current_rank = getattr(rank_zero_only, "rank", None) - if current_rank is None: - raise RuntimeError( - "The `rank_zero_only.rank` needs to be set before use" - ) - msg = rank_prefixed_message(msg, current_rank) - if self.rank_zero_only: - if current_rank == 0: - self.logger.log(level, msg, *args, **kwargs) - else: - if rank is None: - self.logger.log(level, msg, *args, **kwargs) - elif current_rank == rank: - self.logger.log(level, msg, *args, **kwargs) +import logging +from typing import Mapping, Optional + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = True, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log( + self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs + ) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError( + "The `rank_zero_only.rank` needs to be set before use" + ) + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/fish_speech/utils/logging_utils.py b/fish_speech/utils/logging_utils.py index ead61c20564687585e945bf6e88f13d803851bd2..8e3b0a2519e12845f09e5fbe86dfccbf5b345429 100644 --- a/fish_speech/utils/logging_utils.py +++ b/fish_speech/utils/logging_utils.py @@ -1,48 +1,48 @@ -from lightning.pytorch.utilities import rank_zero_only - -from fish_speech.utils import logger as log - - -@rank_zero_only -def log_hyperparameters(object_dict: dict) -> None: - """Controls which config parts are saved by lightning loggers. - - Additionally saves: - - Number of model parameters - """ - - hparams = {} - - cfg = object_dict["cfg"] - model = object_dict["model"] - trainer = object_dict["trainer"] - - if not trainer.logger: - log.warning("Logger not found! Skipping hyperparameter logging...") - return - - hparams["model"] = cfg["model"] - - # save number of model parameters - hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) - hparams["model/params/trainable"] = sum( - p.numel() for p in model.parameters() if p.requires_grad - ) - hparams["model/params/non_trainable"] = sum( - p.numel() for p in model.parameters() if not p.requires_grad - ) - - hparams["data"] = cfg["data"] - hparams["trainer"] = cfg["trainer"] - - hparams["callbacks"] = cfg.get("callbacks") - hparams["extras"] = cfg.get("extras") - - hparams["task_name"] = cfg.get("task_name") - hparams["tags"] = cfg.get("tags") - hparams["ckpt_path"] = cfg.get("ckpt_path") - hparams["seed"] = cfg.get("seed") - - # send hparams to all loggers - for logger in trainer.loggers: - logger.log_hyperparams(hparams) +from lightning.pytorch.utilities import rank_zero_only + +from fish_speech.utils import logger as log + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/fish_speech/utils/rich_utils.py b/fish_speech/utils/rich_utils.py index 5a11ba95b6e54461e9f4faba9ca1f98de6e194ab..6a465f54d610779766d51e3d1a020a3b1517fd1f 100644 --- a/fish_speech/utils/rich_utils.py +++ b/fish_speech/utils/rich_utils.py @@ -1,100 +1,100 @@ -from pathlib import Path -from typing import Sequence - -import rich -import rich.syntax -import rich.tree -from hydra.core.hydra_config import HydraConfig -from lightning.pytorch.utilities import rank_zero_only -from omegaconf import DictConfig, OmegaConf, open_dict -from rich.prompt import Prompt - -from fish_speech.utils import logger as log - - -@rank_zero_only -def print_config_tree( - cfg: DictConfig, - print_order: Sequence[str] = ( - "data", - "model", - "callbacks", - "logger", - "trainer", - "paths", - "extras", - ), - resolve: bool = False, - save_to_file: bool = False, -) -> None: - """Prints content of DictConfig using Rich library and its tree structure. - - Args: - cfg (DictConfig): Configuration composed by Hydra. - print_order (Sequence[str], optional): Determines in what order config components are printed. - resolve (bool, optional): Whether to resolve reference fields of DictConfig. - save_to_file (bool, optional): Whether to export config to the hydra output folder. - """ # noqa: E501 - - style = "dim" - tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) - - queue = [] - - # add fields from `print_order` to queue - for field in print_order: - ( - queue.append(field) - if field in cfg - else log.warning( - f"Field '{field}' not found in config. " - + f"Skipping '{field}' config printing..." - ) - ) - - # add all the other fields to queue (not specified in `print_order`) - for field in cfg: - if field not in queue: - queue.append(field) - - # generate config tree from queue - for field in queue: - branch = tree.add(field, style=style, guide_style=style) - - config_group = cfg[field] - if isinstance(config_group, DictConfig): - branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) - else: - branch_content = str(config_group) - - branch.add(rich.syntax.Syntax(branch_content, "yaml")) - - # print config tree - rich.print(tree) - - # save config tree to file - if save_to_file: - with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: - rich.print(tree, file=file) - - -@rank_zero_only -def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: - """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501 - - if not cfg.get("tags"): - if "id" in HydraConfig().cfg.hydra.job: - raise ValueError("Specify tags before launching a multirun!") - - log.warning("No tags provided in config. Prompting user to input tags...") - tags = Prompt.ask("Enter a list of comma separated tags", default="dev") - tags = [t.strip() for t in tags.split(",") if t != ""] - - with open_dict(cfg): - cfg.tags = tags - - log.info(f"Tags: {cfg.tags}") - - if save_to_file: - with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: - rich.print(cfg.tags, file=file) +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from fish_speech.utils import logger as log + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + save_to_file (bool, optional): Whether to export config to the hydra output folder. + """ # noqa: E501 + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + ( + queue.append(field) + if field in cfg + else log.warning( + f"Field '{field}' not found in config. " + + f"Skipping '{field}' config printing..." + ) + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501 + + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/fish_speech/utils/schema.py b/fish_speech/utils/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..c8fe27b6c859cd13a98bc72c6eacf6c48bbe8706 --- /dev/null +++ b/fish_speech/utils/schema.py @@ -0,0 +1,146 @@ +import base64 +import os +import queue +from dataclasses import dataclass +from typing import Literal + +import torch +from pydantic import BaseModel, Field, conint, model_validator +from pydantic.functional_validators import SkipValidation +from typing_extensions import Annotated + +from fish_speech.content_sequence import TextPart, VQPart + + +class ServeVQPart(BaseModel): + type: Literal["vq"] = "vq" + codes: SkipValidation[list[list[int]]] + + +class ServeTextPart(BaseModel): + type: Literal["text"] = "text" + text: str + + +class ServeAudioPart(BaseModel): + type: Literal["audio"] = "audio" + audio: bytes + + +class ServeASRRequest(BaseModel): + # The audio should be an uncompressed PCM float16 audio + audios: list[bytes] + sample_rate: int = 44100 + language: Literal["zh", "en", "ja", "auto"] = "auto" + + +class ServeASRTranscription(BaseModel): + text: str + duration: float + huge_gap: bool + + +class ServeASRSegment(BaseModel): + text: str + start: float + end: float + + +class ServeTimedASRResponse(BaseModel): + text: str + segments: list[ServeASRSegment] + duration: float + + +class ServeASRResponse(BaseModel): + transcriptions: list[ServeASRTranscription] + + +class ServeRequest(BaseModel): + # Raw content sequence dict that we can use with ContentSequence(**content) + content: dict + max_new_tokens: int = 600 + top_p: float = 0.7 + repetition_penalty: float = 1.2 + temperature: float = 0.7 + streaming: bool = False + num_samples: int = 1 + early_stop_threshold: float = 1.0 + + +class ServeVQGANEncodeRequest(BaseModel): + # The audio here should be in wav, mp3, etc + audios: list[bytes] + + +class ServeVQGANEncodeResponse(BaseModel): + tokens: SkipValidation[list[list[list[int]]]] + + +class ServeVQGANDecodeRequest(BaseModel): + tokens: SkipValidation[list[list[list[int]]]] + + +class ServeVQGANDecodeResponse(BaseModel): + # The audio here should be in PCM float16 format + audios: list[bytes] + + +class ServeStreamDelta(BaseModel): + role: Literal["system", "assistant", "user"] | None = None + part: ServeVQPart | ServeTextPart | None = None + + +class ServeStreamResponse(BaseModel): + sample_id: int = 0 + delta: ServeStreamDelta | None = None + finish_reason: Literal["stop", "error"] | None = None + stats: dict[str, int | float | str] | None = None + + +class ServeReferenceAudio(BaseModel): + audio: bytes + text: str + + @model_validator(mode="before") + def decode_audio(cls, values): + audio = values.get("audio") + if ( + isinstance(audio, str) and len(audio) > 255 + ): # Check if audio is a string (Base64) + try: + values["audio"] = base64.b64decode(audio) + except Exception as e: + # If the audio is not a valid base64 string, we will just ignore it and let the server handle it + pass + return values + + def __repr__(self) -> str: + return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})" + + +class ServeTTSRequest(BaseModel): + text: str + chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 + # Audio format + format: Literal["wav", "pcm", "mp3"] = "wav" + # References audios for in-context learning + references: list[ServeReferenceAudio] = [] + # Reference id + # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/ + # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 + reference_id: str | None = None + seed: int | None = None + use_memory_cache: Literal["on", "off"] = "off" + # Normalize text for en & zh, this increase stability for numbers + normalize: bool = True + # not usually used below + streaming: bool = False + max_new_tokens: int = 1024 + top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8 + repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.1 + temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8 + + class Config: + # Allow arbitrary types for pytorch related types + arbitrary_types_allowed = True diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py index 81ce022e2e62781cc62016d70de33916c736f85d..19ea435c996f6862da8885c3e8c9e8ca2b291e32 100644 --- a/fish_speech/utils/spectrogram.py +++ b/fish_speech/utils/spectrogram.py @@ -1,122 +1,124 @@ -import torch -import torchaudio.functional as F -from torch import Tensor, nn -from torchaudio.transforms import MelScale - - -class LinearSpectrogram(nn.Module): - def __init__( - self, - n_fft=2048, - win_length=2048, - hop_length=512, - center=False, - mode="pow2_sqrt", - ): - super().__init__() - - self.n_fft = n_fft - self.win_length = win_length - self.hop_length = hop_length - self.center = center - self.mode = mode - - self.register_buffer("window", torch.hann_window(win_length), persistent=False) - - def forward(self, y: Tensor) -> Tensor: - if y.ndim == 3: - y = y.squeeze(1) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - ( - (self.win_length - self.hop_length) // 2, - (self.win_length - self.hop_length + 1) // 2, - ), - mode="reflect", - ).squeeze(1) - - spec = torch.stft( - y, - self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=self.window, - center=self.center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - - spec = torch.view_as_real(spec) - - if self.mode == "pow2_sqrt": - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - - return spec - - -class LogMelSpectrogram(nn.Module): - def __init__( - self, - sample_rate=44100, - n_fft=2048, - win_length=2048, - hop_length=512, - n_mels=128, - center=False, - f_min=0.0, - f_max=None, - ): - super().__init__() - - self.sample_rate = sample_rate - self.n_fft = n_fft - self.win_length = win_length - self.hop_length = hop_length - self.center = center - self.n_mels = n_mels - self.f_min = f_min - self.f_max = f_max or float(sample_rate // 2) - - self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) - - fb = F.melscale_fbanks( - n_freqs=self.n_fft // 2 + 1, - f_min=self.f_min, - f_max=self.f_max, - n_mels=self.n_mels, - sample_rate=self.sample_rate, - norm="slaney", - mel_scale="slaney", - ) - self.register_buffer( - "fb", - fb, - persistent=False, - ) - - def compress(self, x: Tensor) -> Tensor: - return torch.log(torch.clamp(x, min=1e-5)) - - def decompress(self, x: Tensor) -> Tensor: - return torch.exp(x) - - def apply_mel_scale(self, x: Tensor) -> Tensor: - return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2) - - def forward( - self, x: Tensor, return_linear: bool = False, sample_rate: int = None - ) -> Tensor: - if sample_rate is not None and sample_rate != self.sample_rate: - x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate) - - linear = self.spectrogram(x) - x = self.apply_mel_scale(linear) - x = self.compress(x) - - if return_linear: - return x, self.compress(linear) - - return x +import torch +import torchaudio.functional as F +from torch import Tensor, nn +from torchaudio.transforms import MelScale + + +class LinearSpectrogram(nn.Module): + def __init__( + self, + n_fft=2048, + win_length=2048, + hop_length=512, + center=False, + mode="pow2_sqrt", + ): + super().__init__() + + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.mode = mode + self.return_complex = True + + self.register_buffer("window", torch.hann_window(win_length), persistent=False) + + def forward(self, y: Tensor) -> Tensor: + if y.ndim == 3: + y = y.squeeze(1) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + ( + (self.win_length - self.hop_length) // 2, + (self.win_length - self.hop_length + 1) // 2, + ), + mode="reflect", + ).squeeze(1) + + spec = torch.stft( + y, + self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=self.return_complex, + ) + + if self.return_complex: + spec = torch.view_as_real(spec) + + if self.mode == "pow2_sqrt": + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + return spec + + +class LogMelSpectrogram(nn.Module): + def __init__( + self, + sample_rate=44100, + n_fft=2048, + win_length=2048, + hop_length=512, + n_mels=128, + center=False, + f_min=0.0, + f_max=None, + ): + super().__init__() + + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.n_mels = n_mels + self.f_min = f_min + self.f_max = f_max or float(sample_rate // 2) + + self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) + + fb = F.melscale_fbanks( + n_freqs=self.n_fft // 2 + 1, + f_min=self.f_min, + f_max=self.f_max, + n_mels=self.n_mels, + sample_rate=self.sample_rate, + norm="slaney", + mel_scale="slaney", + ) + self.register_buffer( + "fb", + fb, + persistent=False, + ) + + def compress(self, x: Tensor) -> Tensor: + return torch.log(torch.clamp(x, min=1e-5)) + + def decompress(self, x: Tensor) -> Tensor: + return torch.exp(x) + + def apply_mel_scale(self, x: Tensor) -> Tensor: + return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2) + + def forward( + self, x: Tensor, return_linear: bool = False, sample_rate: int = None + ) -> Tensor: + if sample_rate is not None and sample_rate != self.sample_rate: + x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate) + + linear = self.spectrogram(x) + x = self.apply_mel_scale(linear) + x = self.compress(x) + + if return_linear: + return x, self.compress(linear) + + return x diff --git a/fish_speech/utils/utils.py b/fish_speech/utils/utils.py index f5e02cd0c8f8dec1d002e7d634b00434e70873f9..5a34bdcfedff76c333f50ed8be050d0dd5a8f98a 100644 --- a/fish_speech/utils/utils.py +++ b/fish_speech/utils/utils.py @@ -1,136 +1,136 @@ -import random -import warnings -from importlib.util import find_spec -from typing import Callable - -import numpy as np -import torch -from omegaconf import DictConfig - -from .logger import RankedLogger -from .rich_utils import enforce_tags, print_config_tree - -log = RankedLogger(__name__, rank_zero_only=True) - - -def extras(cfg: DictConfig) -> None: - """Applies optional utilities before the task is started. - - Utilities: - - Ignoring python warnings - - Setting tags from command line - - Rich config printing - """ - - # return if no `extras` config - if not cfg.get("extras"): - log.warning("Extras config not found! ") - return - - # disable python warnings - if cfg.extras.get("ignore_warnings"): - log.info("Disabling python warnings! ") - warnings.filterwarnings("ignore") - - # prompt user to input tags from command line if none are provided in the config - if cfg.extras.get("enforce_tags"): - log.info("Enforcing tags! ") - enforce_tags(cfg, save_to_file=True) - - # pretty print config tree using Rich library - if cfg.extras.get("print_config"): - log.info("Printing config tree with Rich! ") - print_config_tree(cfg, resolve=True, save_to_file=True) - - -def task_wrapper(task_func: Callable) -> Callable: - """Optional decorator that controls the failure behavior when executing the task function. - - This wrapper can be used to: - - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) - - save the exception to a `.log` file - - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) - - etc. (adjust depending on your needs) - - Example: - ``` - @utils.task_wrapper - def train(cfg: DictConfig) -> Tuple[dict, dict]: - - ... - - return metric_dict, object_dict - ``` - """ # noqa: E501 - - def wrap(cfg: DictConfig): - # execute the task - try: - metric_dict, object_dict = task_func(cfg=cfg) - - # things to do if exception occurs - except Exception as ex: - # save exception to `.log` file - log.exception("") - - # some hyperparameter combinations might be invalid or - # cause out-of-memory errors so when using hparam search - # plugins like Optuna, you might want to disable - # raising the below exception to avoid multirun failure - raise ex - - # things to always do after either success or exception - finally: - # display output dir path in terminal - log.info(f"Output dir: {cfg.paths.run_dir}") - - # always close wandb run (even if exception occurs so multirun won't fail) - if find_spec("wandb"): # check if wandb is installed - import wandb - - if wandb.run: - log.info("Closing wandb!") - wandb.finish() - - return metric_dict, object_dict - - return wrap - - -def get_metric_value(metric_dict: dict, metric_name: str) -> float: - """Safely retrieves value of the metric logged in LightningModule.""" - - if not metric_name: - log.info("Metric name is None! Skipping metric value retrieval...") - return None - - if metric_name not in metric_dict: - raise Exception( - f"Metric value not found! \n" - "Make sure metric name logged in LightningModule is correct!\n" - "Make sure `optimized_metric` name in `hparams_search` config is correct!" - ) - - metric_value = metric_dict[metric_name].item() - log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") - - return metric_value - - -def set_seed(seed: int): - if seed < 0: - seed = -seed - if seed > (1 << 31): - seed = 1 << 31 - - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - if torch.backends.cudnn.is_available(): - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False +import random +import warnings +from importlib.util import find_spec +from typing import Callable + +import numpy as np +import torch +from omegaconf import DictConfig + +from .logger import RankedLogger +from .rich_utils import enforce_tags, print_config_tree + +log = RankedLogger(__name__, rank_zero_only=True) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[dict, dict]: + + ... + + return metric_dict, object_dict + ``` + """ # noqa: E501 + + def wrap(cfg: DictConfig): + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or + # cause out-of-memory errors so when using hparam search + # plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.run_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: dict, metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule.""" + + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def set_seed(seed: int): + if seed < 0: + seed = -seed + if seed > (1 << 31): + seed = 1 << 31 + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if torch.backends.cudnn.is_available(): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False diff --git a/requirements.txt b/requirements.txt index d730ec85cdfe0996143386126b3f21307d1db116..b1eebf25ed100cbb9b6ccc30c434ba1fe433a3be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch==2.3.0 +torch torchaudio transformers>=4.35.2 datasets>=2.14.5 @@ -18,8 +18,8 @@ loguru>=0.6.0 loralib>=0.1.2 natsort>=8.4.0 pyrootutils>=1.0.4 +descript-audiotools vector_quantize_pytorch==1.14.24 -samplerate>=0.2.1 resampy>=0.4.3 spaces>=0.26.1 einx[torch]==0.2.2 @@ -31,4 +31,7 @@ soundfile cachetools funasr silero-vad -tiktoken \ No newline at end of file +tiktoken +numpy +huggingface_hub +git+https://github.com/descriptinc/descript-audio-codec diff --git a/tools/api_client.py b/tools/api_client.py new file mode 100644 index 0000000000000000000000000000000000000000..ee69cd28a05ce09a671f877b4a2a5d3c4dcc6d79 --- /dev/null +++ b/tools/api_client.py @@ -0,0 +1,225 @@ +import argparse +import base64 +import wave + +import ormsgpack +import pyaudio +import requests +from pydub import AudioSegment +from pydub.playback import play + +from fish_speech.utils.file import audio_to_bytes, read_ref_text +from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest + + +def parse_args(): + + parser = argparse.ArgumentParser( + description="Send a WAV file and text to a server and receive synthesized audio.", + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--url", + "-u", + type=str, + default="http://127.0.0.1:8080/v1/tts", + help="URL of the server", + ) + parser.add_argument( + "--text", "-t", type=str, required=True, help="Text to be synthesized" + ) + parser.add_argument( + "--reference_id", + "-id", + type=str, + default=None, + help="ID of the reference model to be used for the speech\n(Local: name of folder containing audios and files)", + ) + parser.add_argument( + "--reference_audio", + "-ra", + type=str, + nargs="+", + default=None, + help="Path to the audio file", + ) + parser.add_argument( + "--reference_text", + "-rt", + type=str, + nargs="+", + default=None, + help="Reference text for voice synthesis", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default="generated_audio", + help="Output audio file name", + ) + parser.add_argument( + "--play", + action=argparse.BooleanOptionalAction, + default=True, + help="Whether to play audio after receiving data", + ) + parser.add_argument( + "--format", type=str, choices=["wav", "mp3", "flac"], default="wav" + ) + parser.add_argument( + "--latency", + type=str, + default="normal", + choices=["normal", "balanced"], + help="Used in api.fish.audio/v1/tts", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=1024, + help="Maximum new tokens to generate. \n0 means no limit.", + ) + parser.add_argument( + "--chunk_length", type=int, default=300, help="Chunk length for synthesis" + ) + parser.add_argument( + "--top_p", type=float, default=0.8, help="Top-p sampling for synthesis" + ) + parser.add_argument( + "--repetition_penalty", + type=float, + default=1.1, + help="Repetition penalty for synthesis", + ) + parser.add_argument( + "--temperature", type=float, default=0.8, help="Temperature for sampling" + ) + + parser.add_argument( + "--streaming", type=bool, default=False, help="Enable streaming response" + ) + parser.add_argument( + "--channels", type=int, default=1, help="Number of audio channels" + ) + parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio") + parser.add_argument( + "--use_memory_cache", + type=str, + default="off", + choices=["on", "off"], + help="Cache encoded references codes in memory.\n", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="`None` means randomized inference, otherwise deterministic.\n" + "It can't be used for fixing a timbre.", + ) + parser.add_argument( + "--api_key", + type=str, + default="YOUR_API_KEY", + help="API key for authentication", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + + args = parse_args() + + idstr: str | None = args.reference_id + # priority: ref_id > [{text, audio},...] + if idstr is None: + ref_audios = args.reference_audio + ref_texts = args.reference_text + if ref_audios is None: + byte_audios = [] + else: + byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios] + if ref_texts is None: + ref_texts = [] + else: + ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts] + else: + byte_audios = [] + ref_texts = [] + pass # in api.py + + data = { + "text": args.text, + "references": [ + ServeReferenceAudio( + audio=ref_audio if ref_audio is not None else b"", text=ref_text + ) + for ref_text, ref_audio in zip(ref_texts, byte_audios) + ], + "reference_id": idstr, + "format": args.format, + "max_new_tokens": args.max_new_tokens, + "chunk_length": args.chunk_length, + "top_p": args.top_p, + "repetition_penalty": args.repetition_penalty, + "temperature": args.temperature, + "streaming": args.streaming, + "use_memory_cache": args.use_memory_cache, + "seed": args.seed, + } + + pydantic_data = ServeTTSRequest(**data) + + response = requests.post( + args.url, + data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), + stream=args.streaming, + headers={ + "authorization": f"Bearer {args.api_key}", + "content-type": "application/msgpack", + }, + ) + + if response.status_code == 200: + if args.streaming: + p = pyaudio.PyAudio() + audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format + stream = p.open( + format=audio_format, channels=args.channels, rate=args.rate, output=True + ) + + wf = wave.open(f"{args.output}.wav", "wb") + wf.setnchannels(args.channels) + wf.setsampwidth(p.get_sample_size(audio_format)) + wf.setframerate(args.rate) + + stream_stopped_flag = False + + try: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + stream.write(chunk) + wf.writeframesraw(chunk) + else: + if not stream_stopped_flag: + stream.stop_stream() + stream_stopped_flag = True + finally: + stream.close() + p.terminate() + wf.close() + else: + audio_content = response.content + audio_path = f"{args.output}.{args.format}" + with open(audio_path, "wb") as audio_file: + audio_file.write(audio_content) + + audio = AudioSegment.from_file(audio_path, format=args.format) + if args.play: + play(audio) + print(f"Audio has been saved to '{audio_path}'.") + else: + print(f"Request failed with status code {response.status_code}") + print(response.json()) diff --git a/tools/api_server.py b/tools/api_server.py new file mode 100644 index 0000000000000000000000000000000000000000..f918199a39e8982930e5505c11e909b6f96940bc --- /dev/null +++ b/tools/api_server.py @@ -0,0 +1,122 @@ +import re +from threading import Lock + +import pyrootutils +import uvicorn +from kui.asgi import ( + Depends, + FactoryClass, + HTTPException, + HttpRoute, + Kui, + OpenAPI, + Routes, +) +from kui.cors import CORSConfig +from kui.openapi.specification import Info +from kui.security import bearer_auth +from loguru import logger +from typing_extensions import Annotated + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from tools.server.api_utils import MsgPackRequest, parse_args +from tools.server.exception_handler import ExceptionHandler +from tools.server.model_manager import ModelManager +from tools.server.views import routes + + +class API(ExceptionHandler): + def __init__(self): + self.args = parse_args() + + def api_auth(endpoint): + async def verify(token: Annotated[str, Depends(bearer_auth)]): + if token != self.args.api_key: + raise HTTPException(401, None, "Invalid token") + return await endpoint() + + async def passthrough(): + return await endpoint() + + if self.args.api_key is not None: + return verify + else: + return passthrough + + self.routes = Routes( + routes, # keep existing routes + http_middlewares=[api_auth], # apply api_auth middleware + ) + + # OpenAPIの設定 + self.openapi = OpenAPI( + Info( + { + "title": "Fish Speech API", + "version": "1.5.0", + } + ), + ).routes + + # Initialize the app + self.app = Kui( + routes=self.routes + self.openapi[1:], # Remove the default route + exception_handlers={ + HTTPException: self.http_exception_handler, + Exception: self.other_exception_handler, + }, + factory_class=FactoryClass(http=MsgPackRequest), + cors_config=CORSConfig(), + ) + + # Add the state variables + self.app.state.lock = Lock() + self.app.state.device = self.args.device + self.app.state.max_text_length = self.args.max_text_length + + # Associate the app with the model manager + self.app.on_startup(self.initialize_app) + + async def initialize_app(self, app: Kui): + # Make the ModelManager available to the views + app.state.model_manager = ModelManager( + mode=self.args.mode, + device=self.args.device, + half=self.args.half, + compile=self.args.compile, + asr_enabled=self.args.load_asr_model, + llama_checkpoint_path=self.args.llama_checkpoint_path, + decoder_checkpoint_path=self.args.decoder_checkpoint_path, + decoder_config_name=self.args.decoder_config_name, + ) + + logger.info(f"Startup done, listening server at http://{self.args.listen}") + + +# Each worker process created by Uvicorn has its own memory space, +# meaning that models and variables are not shared between processes. +# Therefore, any variables (like `llama_queue` or `decoder_model`) +# will not be shared across workers. + +# Multi-threading for deep learning can cause issues, such as inconsistent +# outputs if multiple threads access the same buffers simultaneously. +# Instead, it's better to use multiprocessing or independent models per thread. + +if __name__ == "__main__": + api = API() + + # IPv6 address format is [xxxx:xxxx::xxxx]:port + match = re.search(r"\[([^\]]+)\]:(\d+)$", api.args.listen) + if match: + host, port = match.groups() # IPv6 + else: + host, port = api.args.listen.split(":") # IPv4 + + uvicorn.run( + api.app, + host=host, + port=int(port), + workers=api.args.workers, + log_level="info", + ) diff --git a/tools/download_models.py b/tools/download_models.py index fc735d36e5e07645d46faa035cd5cd3ad88ebdb3..79e23b0daaef125de26c654b09fbb9eeb8fe43cb 100644 --- a/tools/download_models.py +++ b/tools/download_models.py @@ -1,55 +1,55 @@ -import os - -from huggingface_hub import hf_hub_download - - -# Download -def check_and_download_files(repo_id, file_list, local_dir): - os.makedirs(local_dir, exist_ok=True) - for file in file_list: - file_path = os.path.join(local_dir, file) - if not os.path.exists(file_path): - print(f"{file} 不存在,从 Hugging Face 仓库下载...") - hf_hub_download( - repo_id=repo_id, - filename=file, - resume_download=True, - local_dir=local_dir, - local_dir_use_symlinks=False, - ) - else: - print(f"{file} 已存在,跳过下载。") - - -# 1st -repo_id_1 = "fishaudio/fish-speech-1.4" -local_dir_1 = "./checkpoints/fish-speech-1.4" -files_1 = [ - "model.pth", - "README.md", - "special_tokens_map.json", - "tokenizer_config.json", - "tokenizer.json", - "config.json", - "firefly-gan-vq-fsq-8x1024-21hz-generator.pth", -] - -# 3rd -repo_id_3 = "fishaudio/fish-speech-1" -local_dir_3 = "./" -files_3 = [ - "ffmpeg.exe", - "ffprobe.exe", -] - -# 4th -repo_id_4 = "SpicyqSama007/fish-speech-packed" -local_dir_4 = "./" -files_4 = [ - "asr-label-win-x64.exe", -] - -check_and_download_files(repo_id_1, files_1, local_dir_1) - -check_and_download_files(repo_id_3, files_3, local_dir_3) -check_and_download_files(repo_id_4, files_4, local_dir_4) +import os + +from huggingface_hub import hf_hub_download + + +# Download +def check_and_download_files(repo_id, file_list, local_dir): + os.makedirs(local_dir, exist_ok=True) + for file in file_list: + file_path = os.path.join(local_dir, file) + if not os.path.exists(file_path): + print(f"{file} 不存在,从 Hugging Face 仓库下载...") + hf_hub_download( + repo_id=repo_id, + filename=file, + resume_download=True, + local_dir=local_dir, + local_dir_use_symlinks=False, + ) + else: + print(f"{file} 已存在,跳过下载。") + + +# 1st +repo_id_1 = "fishaudio/fish-speech-1.5" +local_dir_1 = "./checkpoints/openaudio-s1-mini" +files_1 = [ + ".gitattributes", + "model.pth", + "README.md", + "special_tokens.json", + "tokenizer.tiktoken", + "config.json", + "firefly-gan-vq-fsq-8x1024-21hz-generator.pth", +] + +# 3rd +repo_id_3 = "fishaudio/fish-speech-1" +local_dir_3 = "./" +files_3 = [ + "ffmpeg.exe", + "ffprobe.exe", +] + +# 4th +repo_id_4 = "SpicyqSama007/fish-speech-packed" +local_dir_4 = "./" +files_4 = [ + "asr-label-win-x64.exe", +] + +check_and_download_files(repo_id_1, files_1, local_dir_1) + +check_and_download_files(repo_id_3, files_3, local_dir_3) +check_and_download_files(repo_id_4, files_4, local_dir_4) diff --git a/tools/llama/build_dataset.py b/tools/llama/build_dataset.py index 19f8d992312fd98bfb9e1e8e200ab6b5e8153337..20e2219956adc419aba91cde5d9097fad4288315 100644 --- a/tools/llama/build_dataset.py +++ b/tools/llama/build_dataset.py @@ -1,169 +1,169 @@ -import itertools -import os -import re -from collections import defaultdict -from functools import partial -from multiprocessing import Pool -from pathlib import Path - -import click -import numpy as np -from loguru import logger -from tqdm import tqdm - -from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData -from fish_speech.datasets.protos.text_data_stream import pack_pb_stream -from tools.file import load_filelist - -# To avoid CPU overload -os.environ["MKL_NUM_THREADS"] = "1" -os.environ["OMP_NUM_THREADS"] = "1" - - -def task_generator_folder(root: Path, text_extension: str): - files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}")) - files = sorted(files) - - grouped_files = defaultdict(list) - for file in tqdm(files, desc=f"Grouping {root}"): - p = str(file.parent) - speaker = file.parent.name - - try: - if isinstance(text_extension, str): - texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")] - else: - texts = [ - file.with_suffix(ext).read_text(encoding="utf-8") - for ext in text_extension - ] - except Exception as e: - logger.error(f"Failed to read text {file}: {e}") - continue - - grouped_files[p].append((speaker, file, texts)) - - logger.info( - f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..." - ) - - for i in grouped_files.values(): - subset = [(f, t) for _, f, t in i] - yield i[0][0], subset, "folder" - - -def task_generator_filelist(filelist): - grouped_files = defaultdict(list) - for filename, speaker, _, text in load_filelist(filelist): - grouped_files[speaker].append((Path(filename), [text])) - - logger.info(f"Found {len(grouped_files)} groups in {filelist}") - for speaker, values in grouped_files.items(): - yield speaker, values, "filelist" - - -def run_task(task): - name, subset, source = task - - # Parse the files - sentences = [] - for file, texts in subset: - np_file = file.with_suffix(".npy") - if np_file.exists() is False: - logger.warning(f"Can't find {np_file}") - continue - - new_texts = [] - - for text in texts: - # Simple cleaning: replace { xxx } and < xxx > with space - text = re.sub(r"\{.*?\}", " ", text) - text = re.sub(r"<.*?>", " ", text) - text = re.sub(r"\s+", " ", text) - new_texts.append(text) - - try: - semantics = np.load(np_file) - except Exception as e: - logger.error(f"Failed to parse {file}: {e}") - continue - - if isinstance(semantics, np.ndarray): - semantics = semantics.tolist() - - sentences.append( - Sentence( - texts=new_texts, - semantics=[Semantics(values=s) for s in semantics], - ) - ) - - # Pack the sentences - return pack_pb_stream( - TextData( - source=source, - name=name, - sentences=sentences, - ) - ) - - -@click.command() -@click.option( - "--input", - type=click.Path(path_type=Path), - required=True, - help="A folder containing the dataset or a filelist", - multiple=True, -) -@click.option( - "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft" -) -@click.option("--num-workers", type=int, default=16) -@click.option("--text-extension", type=str, default=[".txt"], multiple=True) -@click.option( - "--shard-size", type=int, default=10, help="The maximum size of each shard in mb" -) -def main(input, output, num_workers, text_extension, shard_size): - generator_fns = [] - - for f in input: - assert f.exists(), f"{f} not found" - - if f.is_dir(): - generator_fn = task_generator_folder(f, text_extension) - else: - generator_fn = task_generator_filelist(f) - - generator_fns.append(generator_fn) - - generator_fn = itertools.chain(*generator_fns) - output.mkdir(parents=True, exist_ok=True) - - dataset_fp = None - tar_idx = 0 - written_size = 0 - - with Pool(num_workers) as p: - for result in tqdm(p.imap_unordered(run_task, generator_fn)): - if dataset_fp is None: - dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb") - - dataset_fp.write(result) - written_size += len(result) - - if written_size > shard_size * 1024 * 1024: - logger.info(f"Finished writing {tar_idx} shards to {output}") - dataset_fp.close() - dataset_fp = None - written_size = 0 - tar_idx += 1 - - if dataset_fp is not None: - dataset_fp.close() - - logger.info(f"Finished writing {tar_idx + 1} shards to {output}") - - -if __name__ == "__main__": - main() +import itertools +import os +import re +from collections import defaultdict +from functools import partial +from multiprocessing import Pool +from pathlib import Path + +import click +import numpy as np +from loguru import logger +from tqdm import tqdm + +from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData +from fish_speech.datasets.protos.text_data_stream import pack_pb_stream +from fish_speech.utils.file import load_filelist + +# To avoid CPU overload +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["OMP_NUM_THREADS"] = "1" + + +def task_generator_folder(root: Path, text_extension: str): + files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}")) + files = sorted(files) + + grouped_files = defaultdict(list) + for file in tqdm(files, desc=f"Grouping {root}"): + p = str(file.parent) + speaker = file.parent.name + + try: + if isinstance(text_extension, str): + texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")] + else: + texts = [ + file.with_suffix(ext).read_text(encoding="utf-8") + for ext in text_extension + ] + except Exception as e: + logger.error(f"Failed to read text {file}: {e}") + continue + + grouped_files[p].append((speaker, file, texts)) + + logger.info( + f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..." + ) + + for i in grouped_files.values(): + subset = [(f, t) for _, f, t in i] + yield i[0][0], subset, "folder" + + +def task_generator_filelist(filelist): + grouped_files = defaultdict(list) + for filename, speaker, _, text in load_filelist(filelist): + grouped_files[speaker].append((Path(filename), [text])) + + logger.info(f"Found {len(grouped_files)} groups in {filelist}") + for speaker, values in grouped_files.items(): + yield speaker, values, "filelist" + + +def run_task(task): + name, subset, source = task + + # Parse the files + sentences = [] + for file, texts in subset: + np_file = file.with_suffix(".npy") + if np_file.exists() is False: + logger.warning(f"Can't find {np_file}") + continue + + new_texts = [] + + for text in texts: + # Simple cleaning: replace { xxx } and < xxx > with space + text = re.sub(r"\{.*?\}", " ", text) + text = re.sub(r"<.*?>", " ", text) + text = re.sub(r"\s+", " ", text) + new_texts.append(text) + + try: + semantics = np.load(np_file) + except Exception as e: + logger.error(f"Failed to parse {file}: {e}") + continue + + if isinstance(semantics, np.ndarray): + semantics = semantics.tolist() + + sentences.append( + Sentence( + texts=new_texts, + semantics=[Semantics(values=s) for s in semantics], + ) + ) + + # Pack the sentences + return pack_pb_stream( + TextData( + source=source, + name=name, + sentences=sentences, + ) + ) + + +@click.command() +@click.option( + "--input", + type=click.Path(path_type=Path), + required=True, + help="A folder containing the dataset or a filelist", + multiple=True, +) +@click.option( + "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft" +) +@click.option("--num-workers", type=int, default=16) +@click.option("--text-extension", type=str, default=[".txt"], multiple=True) +@click.option( + "--shard-size", type=int, default=10, help="The maximum size of each shard in mb" +) +def main(input, output, num_workers, text_extension, shard_size): + generator_fns = [] + + for f in input: + assert f.exists(), f"{f} not found" + + if f.is_dir(): + generator_fn = task_generator_folder(f, text_extension) + else: + generator_fn = task_generator_filelist(f) + + generator_fns.append(generator_fn) + + generator_fn = itertools.chain(*generator_fns) + output.mkdir(parents=True, exist_ok=True) + + dataset_fp = None + tar_idx = 0 + written_size = 0 + + with Pool(num_workers) as p: + for result in tqdm(p.imap_unordered(run_task, generator_fn)): + if dataset_fp is None: + dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb") + + dataset_fp.write(result) + written_size += len(result) + + if written_size > shard_size * 1024 * 1024: + logger.info(f"Finished writing {tar_idx} shards to {output}") + dataset_fp.close() + dataset_fp = None + written_size = 0 + tar_idx += 1 + + if dataset_fp is not None: + dataset_fp.close() + + logger.info(f"Finished writing {tar_idx + 1} shards to {output}") + + +if __name__ == "__main__": + main() diff --git a/tools/llama/eval_in_context.py b/tools/llama/eval_in_context.py index a62f006ba443e14b2450bf9e15927a41556b0068..41d6397472e712d796a6668aa21e84835b87d899 100644 --- a/tools/llama/eval_in_context.py +++ b/tools/llama/eval_in_context.py @@ -1,171 +1,171 @@ -import pyrootutils -import torch -import torch.nn.functional as F -from matplotlib import pyplot as plt -from transformers import AutoTokenizer - -# register eval resolver and root -pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) - -from torch.utils.data import DataLoader - -from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator -from tools.llama.generate import load_model - - -def smooth( - scalars: list[float], weight: float -) -> list[float]: # Weight between 0 and 1 - last = scalars[0] # First value in the plot (first timestep) - smoothed = list() - for point in scalars: - smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value - smoothed.append(smoothed_val) # Save it - last = smoothed_val # Anchor the last smoothed value - - return smoothed - - -@torch.inference_mode() -def analyze_one_model(loader, config, weight, max_length): - device = "cuda" if torch.cuda.is_available() else "cpu" - model = load_model( - config, - weight, - device, - torch.bfloat16, - max_length, - compile=False, - )[0] - - current_step = 0 - model.eval() - - semantic_loss_sum = torch.zeros( - max_length, - dtype=torch.float32, - device=device, - ) - counter = torch.zeros( - max_length, - dtype=torch.long, - device=device, - ) - - for batch in loader: - batch = {k: v.to(device) for k, v in batch.items()} - - labels = batch["labels"] - outputs = model( - inp=batch["inputs"], - key_padding_mask=batch["attention_masks"], - ) - - token_logits = outputs.token_logits - codebook_logits = outputs.codebook_logits - - # Generate labels - base_loss = F.cross_entropy( - token_logits.reshape(-1, token_logits.size(-1)), - labels[:, 0].reshape(-1), - ignore_index=-100, - reduction="none", - ) - - codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT - semantic_loss = F.cross_entropy( - codebook_logits.reshape(-1, codebook_logits.size(-1)), - codebook_labels.reshape(-1), - ignore_index=-100, - reduction="none", - ) - - base_loss = base_loss.reshape(labels[:, 0].shape) - semantic_loss = semantic_loss.reshape(codebook_labels.shape) - - semantic_loss_frame = semantic_loss.mean(-1) - pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks - - for loss_sample, pad in zip(semantic_loss_frame, pad_pos): - semantic_loss_sum[~pad] += loss_sample[~pad] - counter[~pad] += 1 - - current_step += 1 - if current_step == 10: - break - - semantic_loss = semantic_loss.cpu() - counter = counter.cpu() - xs, ys = [], [] - - for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)): - if count > 0: - xs.append(i) - ys.append((loss / count).item()) # for better loss visualization - - smoothed_ys = smooth(ys, 0.95) - - # Unload model - del model - torch.cuda.empty_cache() - - return xs, ys, smoothed_ys - - -def main(): - tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1") - max_length = 4096 - - ds = AutoAugTextDataset( - ["data/protos/sft/云天河"], - tokenizer=tokenizer, - use_speaker=False, - interactive_prob=1.0, - max_length=max_length, - ) - - loader = DataLoader( - ds, - batch_size=8, - collate_fn=TextDataCollator(tokenizer, max_length=max_length), - num_workers=0, - shuffle=False, - ) - - plt.figure(figsize=(10, 5), dpi=200) - - plt.xlabel("Frame") - plt.ylabel("Loss") - plt.yscale("log") - plt.title("Semantic Loss") - plt.grid(which="both", axis="both") - plt.xlim(0, max_length) - - tests = [ - ( - "pertrain-medium", - "dual_ar_2_codebook_medium", - "checkpoints/text2semantic-pretrain-medium-2k-v1.pth", - ), - ( - "sft-medium", - "dual_ar_2_codebook_medium", - "checkpoints/text2semantic-sft-medium-v1.1-4k.pth", - ), - ( - "sft-large", - "dual_ar_2_codebook_large", - "checkpoints/text2semantic-sft-large-v1.1-4k.pth", - ), - ] - - for name, config, weight in tests: - xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length) - plt.plot(xs, smoothed_ys, label=name) - - plt.legend() - plt.savefig("semantic_loss.png") - - -if __name__ == "__main__": - main() +import pyrootutils +import torch +import torch.nn.functional as F +from matplotlib import pyplot as plt +from transformers import AutoTokenizer + +# register eval resolver and root +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from torch.utils.data import DataLoader + +from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator +from fish_speech.models.text2semantic.inference import load_model + + +def smooth( + scalars: list[float], weight: float +) -> list[float]: # Weight between 0 and 1 + last = scalars[0] # First value in the plot (first timestep) + smoothed = list() + for point in scalars: + smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value + smoothed.append(smoothed_val) # Save it + last = smoothed_val # Anchor the last smoothed value + + return smoothed + + +@torch.inference_mode() +def analyze_one_model(loader, config, weight, max_length): + device = "cuda" if torch.cuda.is_available() else "cpu" + model = load_model( + config, + weight, + device, + torch.bfloat16, + max_length, + compile=False, + )[0] + + current_step = 0 + model.eval() + + semantic_loss_sum = torch.zeros( + max_length, + dtype=torch.float32, + device=device, + ) + counter = torch.zeros( + max_length, + dtype=torch.long, + device=device, + ) + + for batch in loader: + batch = {k: v.to(device) for k, v in batch.items()} + + labels = batch["labels"] + outputs = model( + inp=batch["inputs"], + key_padding_mask=batch["attention_masks"], + ) + + token_logits = outputs.token_logits + codebook_logits = outputs.codebook_logits + + # Generate labels + base_loss = F.cross_entropy( + token_logits.reshape(-1, token_logits.size(-1)), + labels[:, 0].reshape(-1), + ignore_index=-100, + reduction="none", + ) + + codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT + semantic_loss = F.cross_entropy( + codebook_logits.reshape(-1, codebook_logits.size(-1)), + codebook_labels.reshape(-1), + ignore_index=-100, + reduction="none", + ) + + base_loss = base_loss.reshape(labels[:, 0].shape) + semantic_loss = semantic_loss.reshape(codebook_labels.shape) + + semantic_loss_frame = semantic_loss.mean(-1) + pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks + + for loss_sample, pad in zip(semantic_loss_frame, pad_pos): + semantic_loss_sum[~pad] += loss_sample[~pad] + counter[~pad] += 1 + + current_step += 1 + if current_step == 10: + break + + semantic_loss = semantic_loss.cpu() + counter = counter.cpu() + xs, ys = [], [] + + for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)): + if count > 0: + xs.append(i) + ys.append((loss / count).item()) # for better loss visualization + + smoothed_ys = smooth(ys, 0.95) + + # Unload model + del model + torch.cuda.empty_cache() + + return xs, ys, smoothed_ys + + +def main(): + tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1") + max_length = 4096 + + ds = AutoAugTextDataset( + ["data/protos/sft/云天河"], + tokenizer=tokenizer, + use_speaker=False, + interactive_prob=1.0, + max_length=max_length, + ) + + loader = DataLoader( + ds, + batch_size=8, + collate_fn=TextDataCollator(tokenizer, max_length=max_length), + num_workers=0, + shuffle=False, + ) + + plt.figure(figsize=(10, 5), dpi=200) + + plt.xlabel("Frame") + plt.ylabel("Loss") + plt.yscale("log") + plt.title("Semantic Loss") + plt.grid(which="both", axis="both") + plt.xlim(0, max_length) + + tests = [ + ( + "pertrain-medium", + "dual_ar_2_codebook_medium", + "checkpoints/text2semantic-pretrain-medium-2k-v1.pth", + ), + ( + "sft-medium", + "dual_ar_2_codebook_medium", + "checkpoints/text2semantic-sft-medium-v1.1-4k.pth", + ), + ( + "sft-large", + "dual_ar_2_codebook_large", + "checkpoints/text2semantic-sft-large-v1.1-4k.pth", + ), + ] + + for name, config, weight in tests: + xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length) + plt.plot(xs, smoothed_ys, label=name) + + plt.legend() + plt.savefig("semantic_loss.png") + + +if __name__ == "__main__": + main() diff --git a/tools/llama/merge_lora.py b/tools/llama/merge_lora.py index 510dd4fc25ac1a51593cb189ea79d8a7f429548a..1080ff5668f6712a7bd51d28476369c49806775d 100644 --- a/tools/llama/merge_lora.py +++ b/tools/llama/merge_lora.py @@ -1,95 +1,96 @@ -import shutil -from copy import deepcopy -from pathlib import Path - -import click -import hydra -import torch -from hydra import compose, initialize -from hydra.utils import instantiate -from loguru import logger - -from fish_speech.models.text2semantic.llama import BaseTransformer -from fish_speech.models.text2semantic.lora import get_merged_state_dict - - -@click.command() -@click.option("--lora-config", type=str, default="r_8_alpha_16") -@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4") -@click.option("--lora-weight", type=str, required=True) -@click.option("--output", type=str, required=True) -def merge(lora_config, base_weight, lora_weight, output): - output = Path(output) - logger.info( - f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}" - ) - - with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"): - cfg = compose(config_name=lora_config) - - lora_config = instantiate(cfg) - logger.info(f"Loaded lora model with config {lora_config}") - - llama_model = BaseTransformer.from_pretrained( - path=base_weight, - load_weights=True, - lora_config=lora_config, - ) - logger.info(f"Loaded llama model") - - llama_state_dict = llama_model.state_dict() - llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k} - llama_state_dict_copy = deepcopy(llama_state_dict) - lora_state_dict = torch.load(lora_weight, map_location="cpu") - - if "state_dict" in llama_state_dict: - llama_state_dict = llama_state_dict["state_dict"] - - if "state_dict" in lora_state_dict: - lora_state_dict = lora_state_dict["state_dict"] - - # remove prefix model. - if any(k.startswith("model.") for k in llama_state_dict.keys()): - llama_state_dict = { - k.replace("model.", ""): v - for k, v in llama_state_dict.items() - if k.startswith("model.") - } - if any(k.startswith("model.") for k in lora_state_dict.keys()): - lora_state_dict = { - k.replace("model.", ""): v - for k, v in lora_state_dict.items() - if k.startswith("model.") - } - - logger.info(f"Found {len(llama_state_dict)} keys in llama model") - logger.info(f"Found {len(lora_state_dict)} keys in lora model") - - merged_state_dict = llama_state_dict | lora_state_dict - llama_model.load_state_dict(merged_state_dict, strict=True) - logger.info(f"Merged model loaded") - - # Trigger eval mode to merge lora - llama_model.eval() - llama_model.save_pretrained(output, drop_lora=True) - logger.info(f"Saved merged model to {output}, validating") - - new_state_dict = torch.load(output / "model.pth", map_location="cpu") - original_keys = set(llama_state_dict_copy.keys()) - merged_keys = set(new_state_dict.keys()) - - assert original_keys == merged_keys, "Keys should be same" - - for key in original_keys: - diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item() - if diff_l1 != 0: - break - else: - logger.error("Merged model is same as the original model") - exit(1) - - logger.info("Merged model is different from the original model, check passed") - - -if __name__ == "__main__": - merge() +import shutil +from copy import deepcopy +from pathlib import Path + +import click +import hydra +import torch +from hydra import compose, initialize +from hydra.utils import instantiate +from loguru import logger + +from fish_speech.models.text2semantic.llama import BaseTransformer +from fish_speech.models.text2semantic.lora import get_merged_state_dict + + +@click.command() +@click.option("--lora-config", type=str, default="r_8_alpha_16") +@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4") +@click.option("--lora-weight", type=str, required=True) +@click.option("--output", type=str, required=True) +def merge(lora_config, base_weight, lora_weight, output): + output = Path(output) + logger.info( + f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}" + ) + + with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"): + cfg = compose(config_name=lora_config) + + lora_config = instantiate(cfg) + logger.info(f"Loaded lora model with config {lora_config}") + + llama_model = BaseTransformer.from_pretrained( + path=base_weight, + load_weights=True, + lora_config=lora_config, + ) + logger.info(f"Loaded llama model") + + llama_state_dict = llama_model.state_dict() + llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k} + llama_state_dict_copy = deepcopy(llama_state_dict) + lora_state_dict = torch.load(lora_weight, map_location="cpu", weights_only=False) + + if "state_dict" in llama_state_dict: + llama_state_dict = llama_state_dict["state_dict"] + + if "state_dict" in lora_state_dict: + lora_state_dict = lora_state_dict["state_dict"] + + # remove prefix model. + if any(k.startswith("model.") for k in llama_state_dict.keys()): + llama_state_dict = { + k.replace("model.", ""): v + for k, v in llama_state_dict.items() + if k.startswith("model.") + } + if any(k.startswith("model.") for k in lora_state_dict.keys()): + lora_state_dict = { + k.replace("model.", ""): v + for k, v in lora_state_dict.items() + if k.startswith("model.") + } + + logger.info(f"Found {len(llama_state_dict)} keys in llama model") + logger.info(f"Found {len(lora_state_dict)} keys in lora model") + + merged_state_dict = llama_state_dict | lora_state_dict + llama_model.load_state_dict(merged_state_dict, strict=True) + logger.info(f"Merged model loaded") + + # Trigger eval mode to merge lora + llama_model.eval() + llama_model.save_pretrained(output, drop_lora=True) + logger.info(f"Saved merged model to {output}, validating") + + new_state_dict = torch.load(output / "model.pth", map_location="cpu") + original_keys = set(llama_state_dict_copy.keys()) + + tolerance = 1e-5 + for key in original_keys: + diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item() + if diff_l1 > tolerance: + logger.info(f"Significant difference found in key: {key}") + break + + if diff_l1 <= tolerance: + logger.warning( + "Merged model seems identical to the original model. Further validation might be needed." + ) + else: + logger.info("Merged model is different from the original model, check passed") + + +if __name__ == "__main__": + merge() diff --git a/tools/llama/quantize.py b/tools/llama/quantize.py index 2a8124da4d3634bb57e8fa5368228d51bb712f77..7cd29891829432a263fcd6a6d58bd247a7d7d587 100644 --- a/tools/llama/quantize.py +++ b/tools/llama/quantize.py @@ -1,497 +1,497 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -import datetime -import shutil - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -import time -from pathlib import Path - -import click -import torch -import torch.nn as nn -import torch.nn.functional as F - -from fish_speech.models.text2semantic.llama import find_multiple -from tools.llama.generate import load_model - -##### Quantization Primitives ###### - - -def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): - # assumes symmetric quantization - # assumes axis == 0 - # assumes dense memory format - # TODO(future): relax ^ as needed - - # default setup for affine quantization of activations - eps = torch.finfo(torch.float32).eps - - # get min and max - min_val, max_val = torch.aminmax(x, dim=1) - - # calculate scales and zero_points based on min and max - # reference: https://fburl.com/code/srbiybme - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) - device = min_val_neg.device - - # reference: https://fburl.com/code/4wll53rk - max_val_pos = torch.max(-min_val_neg, max_val_pos) - scales = max_val_pos / (float(quant_max - quant_min) / 2) - # ensure scales is the same dtype as the original tensor - scales = torch.clamp(scales, min=eps).to(x.dtype) - zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) - - # quantize based on qmin/qmax/scales/zp - # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 - x_div = x / scales.unsqueeze(-1) - x_round = torch.round(x_div) - x_zp = x_round + zero_points.unsqueeze(-1) - quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) - - return quant, scales, zero_points - - -def get_group_qparams(w, n_bit=4, groupsize=128): - # needed for GPTQ with padding - if groupsize > w.shape[-1]: - groupsize = w.shape[-1] - assert groupsize > 1 - assert w.shape[-1] % groupsize == 0 - assert w.dim() == 2 - - to_quant = w.reshape(-1, groupsize) - assert torch.isnan(to_quant).sum() == 0 - - max_val = to_quant.amax(dim=1, keepdim=True) - min_val = to_quant.amin(dim=1, keepdim=True) - max_int = 2**n_bit - 1 - scales = (max_val - min_val).clamp(min=1e-6) / max_int - zeros = min_val + scales * (2 ** (n_bit - 1)) - return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( - torch.bfloat16 - ).reshape(w.shape[0], -1) - - -def pack_scales_and_zeros(scales, zeros): - assert scales.shape == zeros.shape - assert scales.dtype == torch.bfloat16 - assert zeros.dtype == torch.bfloat16 - return ( - torch.cat( - [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), - ], - 2, - ) - .transpose(0, 1) - .contiguous() - ) - - -def unpack_scales_and_zeros(scales_and_zeros): - assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 - assert scales_and_zeros.dtype == torch.float - return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) - - -def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): - assert groupsize > 1 - # needed for GPTQ single column quantize - if groupsize > w.shape[-1] and scales.shape[-1] == 1: - groupsize = w.shape[-1] - - assert w.shape[-1] % groupsize == 0 - assert w.dim() == 2 - - to_quant = w.reshape(-1, groupsize) - assert torch.isnan(to_quant).sum() == 0 - - scales = scales.reshape(-1, 1) - zeros = zeros.reshape(-1, 1) - min_val = zeros - scales * (2 ** (n_bit - 1)) - max_int = 2**n_bit - 1 - min_int = 0 - w_int32 = ( - to_quant.sub(min_val) - .div(scales) - .round() - .clamp_(min_int, max_int) - .to(torch.int32) - .reshape_as(w) - ) - - return w_int32 - - -def group_quantize_tensor(w, n_bit=4, groupsize=128): - scales, zeros = get_group_qparams(w, n_bit, groupsize) - w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) - scales_and_zeros = pack_scales_and_zeros(scales, zeros) - return w_int32, scales_and_zeros - - -def group_dequantize_tensor_from_qparams( - w_int32, scales, zeros, n_bit=4, groupsize=128 -): - assert groupsize > 1 - # needed for GPTQ single column dequantize - if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: - groupsize = w_int32.shape[-1] - assert w_int32.shape[-1] % groupsize == 0 - assert w_int32.dim() == 2 - - w_int32_grouped = w_int32.reshape(-1, groupsize) - scales = scales.reshape(-1, 1) - zeros = zeros.reshape(-1, 1) - - w_dq = ( - w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) - ) - return w_dq - - -def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): - scales, zeros = unpack_scales_and_zeros(scales_and_zeros) - return group_dequantize_tensor_from_qparams( - w_int32, scales, zeros, n_bit, groupsize - ) - - -class QuantHandler: - def __init__(self, mod): - self.mod = mod - - def create_quantized_state_dict(self) -> "StateDict": - pass - - def convert_for_runtime(self) -> "nn.Module": - pass - - -##### Weight-only int8 per-channel quantized code ###### - - -def replace_linear_weight_only_int8_per_channel(module): - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - setattr( - module, - name, - WeightOnlyInt8Linear(child.in_features, child.out_features), - ) - else: - replace_linear_weight_only_int8_per_channel(child) - - -class WeightOnlyInt8QuantHandler: - def __init__(self, mod): - self.mod = mod - - @torch.no_grad() - def create_quantized_state_dict(self): - cur_state_dict = self.mod.state_dict() - for fqn, mod in self.mod.named_modules(): - if isinstance(mod, torch.nn.Linear): - int8_weight, scales, _ = dynamically_quantize_per_channel( - mod.weight.float(), -128, 127, torch.int8 - ) - cur_state_dict[f"{fqn}.weight"] = int8_weight - cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) - - return cur_state_dict - - def convert_for_runtime(self): - replace_linear_weight_only_int8_per_channel(self.mod) - return self.mod - - -class WeightOnlyInt8Linear(torch.nn.Module): - __constants__ = ["in_features", "out_features"] - in_features: int - out_features: int - weight: torch.Tensor - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.register_buffer( - "weight", torch.empty((out_features, in_features), dtype=torch.int8) - ) - self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales - - -##### weight only int4 per channel groupwise quantized code ###### - - -def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): - weight_int32, scales_and_zeros = group_quantize_tensor( - weight_bf16, n_bit=4, groupsize=groupsize - ) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - weight_int32, inner_k_tiles - ) - return weight_int4pack, scales_and_zeros - - -def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): - origin_x_size = x.size() - x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm( - x, weight_int4pack, groupsize, scales_and_zeros - ) - new_shape = origin_x_size[:-1] + (out_features,) - c = c.reshape(new_shape) - return c - - -def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1): - return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 - - -def replace_linear_int4(module, groupsize, inner_k_tiles, padding): - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): - setattr( - module, - name, - WeightOnlyInt4Linear( - child.in_features, - child.out_features, - bias=False, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - padding=False, - ), - ) - elif padding: - setattr( - module, - name, - WeightOnlyInt4Linear( - child.in_features, - child.out_features, - bias=False, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - padding=True, - ), - ) - else: - replace_linear_int4(child, groupsize, inner_k_tiles, padding) - - -class WeightOnlyInt4QuantHandler: - def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): - self.mod = mod - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.padding = padding - assert groupsize in [32, 64, 128, 256] - assert inner_k_tiles in [2, 4, 8] - - @torch.no_grad() - def create_quantized_state_dict(self): - cur_state_dict = self.mod.state_dict() - for fqn, mod in self.mod.named_modules(): - if isinstance(mod, torch.nn.Linear): - assert not mod.bias - out_features = mod.out_features - in_features = mod.in_features - assert out_features % 8 == 0, "require out_features % 8 == 0" - print(f"linear: {fqn}, in={in_features}, out={out_features}") - - weight = mod.weight.data - if not _check_linear_int4_k( - in_features, self.groupsize, self.inner_k_tiles - ): - if self.padding: - import torch.nn.functional as F - - print( - f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" - ) - padded_in_features = find_multiple(in_features, 1024) - weight = F.pad( - weight, pad=(0, padded_in_features - in_features) - ) - else: - print( - f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " - + "and that groupsize and inner_k_tiles*16 evenly divide into it" - ) - continue - ( - weight_int4pack, - scales_and_zeros, - ) = prepare_int4_weight_and_scales_and_zeros( - weight.to(torch.bfloat16).to("cuda"), - self.groupsize, - self.inner_k_tiles, - ) - cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") - cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") - - return cur_state_dict - - def convert_for_runtime(self): - replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) - return self.mod - - -class WeightOnlyInt4Linear(torch.nn.Module): - __constants__ = ["in_features", "out_features"] - in_features: int - out_features: int - weight: torch.Tensor - - def __init__( - self, - in_features: int, - out_features: int, - bias=True, - device=None, - dtype=None, - groupsize: int = 128, - inner_k_tiles: int = 8, - padding: bool = True, - ) -> None: - super().__init__() - self.padding = padding - if padding: - self.origin_in_features = in_features - in_features = find_multiple(in_features, 1024) - - self.in_features = in_features - self.out_features = out_features - assert not bias, "require bias=False" - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - - assert out_features % 8 == 0, "require out_features % 8 == 0" - assert ( - in_features % (inner_k_tiles * 16) == 0 - ), "require in_features % (innerKTiles * 16) == 0" - self.register_buffer( - "weight", - torch.empty( - ( - out_features // 8, - in_features // (inner_k_tiles * 16), - 32, - inner_k_tiles // 2, - ), - dtype=torch.int32, - ), - ) - self.register_buffer( - "scales_and_zeros", - torch.empty( - (in_features // groupsize, out_features, 2), dtype=torch.bfloat16 - ), - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = input.to(torch.bfloat16) - if self.padding: - import torch.nn.functional as F - - input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) - return linear_forward_int4( - input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize - ) - - -def generate_folder_name(): - now = datetime.datetime.now() - folder_name = now.strftime("%Y%m%d_%H%M%S") - return folder_name - - -@click.command() -@click.option( - "--checkpoint-path", - type=click.Path(path_type=Path, exists=True), - default="checkpoints/fish-speech-1.4", -) -@click.option( - "--mode", type=str, default="int8", help="type of quantization to perform" -) -@click.option( - "--groupsize", type=int, default=128, help="Group size for int4 quantization." -) -@click.option("--timestamp", type=str, default="None", help="When to do quantization") -def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None: - - device = "cpu" - precision = torch.bfloat16 - - print("Loading model ...") - t0 = time.time() - - model, _ = load_model( - checkpoint_path=checkpoint_path, - device=device, - precision=precision, - compile=False, - ) - vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth" - now = timestamp if timestamp != "None" else generate_folder_name() - - if mode == "int8": - print( - "Quantizing model weights for int8 weight-only symmetric per-channel quantization" - ) - quant_handler = WeightOnlyInt8QuantHandler(model) - quantized_state_dict = quant_handler.create_quantized_state_dict() - - dir_name = checkpoint_path - dst_name = Path(f"checkpoints/fs-1.2-int8-{now}") - shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve())) - if (dst_name / vq_model).exists(): - (dst_name / vq_model).unlink() - quantize_path = dst_name / "model.pth" - - elif mode == "int4": - print( - "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization" - ) - quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) - quantized_state_dict = quant_handler.create_quantized_state_dict() - - dir_name = checkpoint_path - dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}") - shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve())) - if (dst_name / vq_model).exists(): - (dst_name / vq_model).unlink() - quantize_path = dst_name / "model.pth" - - else: - raise ValueError( - f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]" - ) - - print(f"Writing quantized weights to {quantize_path}") - quantize_path.unlink(missing_ok=True) # remove existing file if one already there - torch.save(quantized_state_dict, quantize_path) - print(f"Quantization complete took {time.time() - t0:.02f} seconds") - - -if __name__ == "__main__": - quantize() +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +import datetime +import shutil + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import time +from pathlib import Path + +import click +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fish_speech.models.text2semantic.inference import load_model +from fish_speech.models.text2semantic.llama import find_multiple + +##### Quantization Primitives ###### + + +def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scales and zero_points based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scales = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scales is the same dtype as the original tensor + scales = torch.clamp(scales, min=eps).to(x.dtype) + zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scales/zp + # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x / scales.unsqueeze(-1) + x_round = torch.round(x_div) + x_zp = x_round + zero_points.unsqueeze(-1) + quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + + return quant, scales, zero_points + + +def get_group_qparams(w, n_bit=4, groupsize=128): + # needed for GPTQ with padding + if groupsize > w.shape[-1]: + groupsize = w.shape[-1] + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + zeros = min_val + scales * (2 ** (n_bit - 1)) + return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( + torch.bfloat16 + ).reshape(w.shape[0], -1) + + +def pack_scales_and_zeros(scales, zeros): + assert scales.shape == zeros.shape + assert scales.dtype == torch.bfloat16 + assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + +def unpack_scales_and_zeros(scales_and_zeros): + assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 + assert scales_and_zeros.dtype == torch.float + return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + + +def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): + assert groupsize > 1 + # needed for GPTQ single column quantize + if groupsize > w.shape[-1] and scales.shape[-1] == 1: + groupsize = w.shape[-1] + + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + min_val = zeros - scales * (2 ** (n_bit - 1)) + max_int = 2**n_bit - 1 + min_int = 0 + w_int32 = ( + to_quant.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape_as(w) + ) + + return w_int32 + + +def group_quantize_tensor(w, n_bit=4, groupsize=128): + scales, zeros = get_group_qparams(w, n_bit, groupsize) + w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) + scales_and_zeros = pack_scales_and_zeros(scales, zeros) + return w_int32, scales_and_zeros + + +def group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit=4, groupsize=128 +): + assert groupsize > 1 + # needed for GPTQ single column dequantize + if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: + groupsize = w_int32.shape[-1] + assert w_int32.shape[-1] % groupsize == 0 + assert w_int32.dim() == 2 + + w_int32_grouped = w_int32.reshape(-1, groupsize) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + + w_dq = ( + w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) + ) + return w_dq + + +def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): + scales, zeros = unpack_scales_and_zeros(scales_and_zeros) + return group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit, groupsize + ) + + +class QuantHandler: + def __init__(self, mod): + self.mod = mod + + def create_quantized_state_dict(self) -> "StateDict": + pass + + def convert_for_runtime(self) -> "nn.Module": + pass + + +##### Weight-only int8 per-channel quantized code ###### + + +def replace_linear_weight_only_int8_per_channel(module): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr( + module, + name, + WeightOnlyInt8Linear(child.in_features, child.out_features), + ) + else: + replace_linear_weight_only_int8_per_channel(child) + + +class WeightOnlyInt8QuantHandler: + def __init__(self, mod): + self.mod = mod + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + int8_weight, scales, _ = dynamically_quantize_per_channel( + mod.weight.float(), -128, 127, torch.int8 + ) + cur_state_dict[f"{fqn}.weight"] = int8_weight + cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_weight_only_int8_per_channel(self.mod) + return self.mod + + +class WeightOnlyInt8Linear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer( + "weight", torch.empty((out_features, in_features), dtype=torch.int8) + ) + self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + + +##### weight only int4 per channel groupwise quantized code ###### + + +def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): + weight_int32, scales_and_zeros = group_quantize_tensor( + weight_bf16, n_bit=4, groupsize=groupsize + ) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + weight_int32, inner_k_tiles + ) + return weight_int4pack, scales_and_zeros + + +def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm( + x, weight_int4pack, groupsize, scales_and_zeros + ) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + + +def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1): + return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 + + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): + setattr( + module, + name, + WeightOnlyInt4Linear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + padding=False, + ), + ) + elif padding: + setattr( + module, + name, + WeightOnlyInt4Linear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + padding=True, + ), + ) + else: + replace_linear_int4(child, groupsize, inner_k_tiles, padding) + + +class WeightOnlyInt4QuantHandler: + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + self.mod = mod + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding = padding + assert groupsize in [32, 64, 128, 256] + assert inner_k_tiles in [2, 4, 8] + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + weight = mod.weight.data + if not _check_linear_int4_k( + in_features, self.groupsize, self.inner_k_tiles + ): + if self.padding: + import torch.nn.functional as F + + print( + f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" + ) + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad( + weight, pad=(0, padded_in_features - in_features) + ) + else: + print( + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it" + ) + continue + ( + weight_int4pack, + scales_and_zeros, + ) = prepare_int4_weight_and_scales_and_zeros( + weight.to(torch.bfloat16).to("cuda"), + self.groupsize, + self.inner_k_tiles, + ) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) + return self.mod + + +class WeightOnlyInt4Linear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias=True, + device=None, + dtype=None, + groupsize: int = 128, + inner_k_tiles: int = 8, + padding: bool = True, + ) -> None: + super().__init__() + self.padding = padding + if padding: + self.origin_in_features = in_features + in_features = find_multiple(in_features, 1024) + + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + + assert out_features % 8 == 0, "require out_features % 8 == 0" + assert ( + in_features % (inner_k_tiles * 16) == 0 + ), "require in_features % (innerKTiles * 16) == 0" + self.register_buffer( + "weight", + torch.empty( + ( + out_features // 8, + in_features // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), + dtype=torch.int32, + ), + ) + self.register_buffer( + "scales_and_zeros", + torch.empty( + (in_features // groupsize, out_features, 2), dtype=torch.bfloat16 + ), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(torch.bfloat16) + if self.padding: + import torch.nn.functional as F + + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + return linear_forward_int4( + input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize + ) + + +def generate_folder_name(): + now = datetime.datetime.now() + folder_name = now.strftime("%Y%m%d_%H%M%S") + return folder_name + + +@click.command() +@click.option( + "--checkpoint-path", + type=click.Path(path_type=Path, exists=True), + default="checkpoints/fish-speech-1.4", +) +@click.option( + "--mode", type=str, default="int8", help="type of quantization to perform" +) +@click.option( + "--groupsize", type=int, default=128, help="Group size for int4 quantization." +) +@click.option("--timestamp", type=str, default="None", help="When to do quantization") +def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None: + + device = "cpu" + precision = torch.bfloat16 + + print("Loading model ...") + t0 = time.time() + + model, _ = load_model( + checkpoint_path=checkpoint_path, + device=device, + precision=precision, + compile=False, + ) + vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + now = timestamp if timestamp != "None" else generate_folder_name() + + if mode == "int8": + print( + "Quantizing model weights for int8 weight-only symmetric per-channel quantization" + ) + quant_handler = WeightOnlyInt8QuantHandler(model) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path + dst_name = Path(f"checkpoints/fs-1.2-int8-{now}") + shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve())) + if (dst_name / vq_model).exists(): + (dst_name / vq_model).unlink() + quantize_path = dst_name / "model.pth" + + elif mode == "int4": + print( + "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization" + ) + quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path + dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}") + shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve())) + if (dst_name / vq_model).exists(): + (dst_name / vq_model).unlink() + quantize_path = dst_name / "model.pth" + + else: + raise ValueError( + f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]" + ) + + print(f"Writing quantized weights to {quantize_path}") + quantize_path.unlink(missing_ok=True) # remove existing file if one already there + torch.save(quantized_state_dict, quantize_path) + print(f"Quantization complete took {time.time() - t0:.02f} seconds") + + +if __name__ == "__main__": + quantize() diff --git a/tools/run_webui.py b/tools/run_webui.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc004ef51c9443cf100796485c0e83b856738be --- /dev/null +++ b/tools/run_webui.py @@ -0,0 +1,104 @@ +import os +from argparse import ArgumentParser +from pathlib import Path + +import pyrootutils +import torch +from loguru import logger + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from fish_speech.inference_engine import TTSInferenceEngine +from fish_speech.models.dac.inference import load_model as load_decoder_model +from fish_speech.models.text2semantic.inference import launch_thread_safe_queue +from fish_speech.utils.schema import ServeTTSRequest +from tools.webui import build_app +from tools.webui.inference import get_inference_wrapper + +# Make einx happy +os.environ["EINX_FILTER_TRACEBACK"] = "false" + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + "--llama-checkpoint-path", + type=Path, + default="checkpoints/openaudio-s1-mini", + ) + parser.add_argument( + "--decoder-checkpoint-path", + type=Path, + default="checkpoints/openaudio-s1-mini/codec.pth", + ) + parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--half", action="store_true") + parser.add_argument("--compile", action="store_true") + parser.add_argument("--max-gradio-length", type=int, default=0) + parser.add_argument("--theme", type=str, default="light") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + args.precision = torch.half if args.half else torch.bfloat16 + + # Check if MPS or CUDA is available + if torch.backends.mps.is_available(): + args.device = "mps" + logger.info("mps is available, running on mps.") + elif not torch.cuda.is_available(): + logger.info("CUDA is not available, running on CPU.") + args.device = "cpu" + + logger.info("Loading Llama model...") + llama_queue = launch_thread_safe_queue( + checkpoint_path=args.llama_checkpoint_path, + device=args.device, + precision=args.precision, + compile=args.compile, + ) + + logger.info("Loading VQ-GAN model...") + decoder_model = load_decoder_model( + config_name=args.decoder_config_name, + checkpoint_path=args.decoder_checkpoint_path, + device=args.device, + ) + + logger.info("Decoder model loaded, warming up...") + + # Create the inference engine + inference_engine = TTSInferenceEngine( + llama_queue=llama_queue, + decoder_model=decoder_model, + compile=args.compile, + precision=args.precision, + ) + + # Dry run to check if the model is loaded correctly and avoid the first-time latency + list( + inference_engine.inference( + ServeTTSRequest( + text="Hello world.", + references=[], + reference_id=None, + max_new_tokens=1024, + chunk_length=200, + top_p=0.7, + repetition_penalty=1.5, + temperature=0.7, + format="wav", + ) + ) + ) + + logger.info("Warming up done, launching the web UI...") + + # Get the inference function with the immutable arguments + inference_fct = get_inference_wrapper(inference_engine) + + app = build_app(inference_fct, args.theme) + app.launch(show_api=True) diff --git a/tools/server/api_utils.py b/tools/server/api_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe990041bceac39a6b85d97fc019673af3d9013 --- /dev/null +++ b/tools/server/api_utils.py @@ -0,0 +1,76 @@ +from argparse import ArgumentParser +from http import HTTPStatus +from typing import Annotated, Any + +import ormsgpack +from baize.datastructures import ContentType +from kui.asgi import HTTPException, HttpRequest + +from fish_speech.inference_engine import TTSInferenceEngine +from fish_speech.utils.schema import ServeTTSRequest +from tools.server.inference import inference_wrapper as inference + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts") + parser.add_argument("--load-asr-model", action="store_true") + parser.add_argument( + "--llama-checkpoint-path", + type=str, + default="checkpoints/openaudio-s1-mini", + ) + parser.add_argument( + "--decoder-checkpoint-path", + type=str, + default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + ) + parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--half", action="store_true") + parser.add_argument("--compile", action="store_true") + parser.add_argument("--max-text-length", type=int, default=0) + parser.add_argument("--listen", type=str, default="127.0.0.1:8080") + parser.add_argument("--workers", type=int, default=1) + parser.add_argument("--api-key", type=str, default=None) + + return parser.parse_args() + + +class MsgPackRequest(HttpRequest): + async def data( + self, + ) -> Annotated[ + Any, ContentType("application/msgpack"), ContentType("application/json") + ]: + if self.content_type == "application/msgpack": + return ormsgpack.unpackb(await self.body) + + elif self.content_type == "application/json": + return await self.json + + raise HTTPException( + HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + headers={"Accept": "application/msgpack, application/json"}, + ) + + +async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine): + for chunk in inference(req, engine): + if isinstance(chunk, bytes): + yield chunk + + +async def buffer_to_async_generator(buffer): + yield buffer + + +def get_content_type(audio_format): + if audio_format == "wav": + return "audio/wav" + elif audio_format == "flac": + return "audio/flac" + elif audio_format == "mp3": + return "audio/mpeg" + else: + return "application/octet-stream" diff --git a/tools/server/exception_handler.py b/tools/server/exception_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..07d595fabb7af4e00a1fb67a78b466fea0c2c0f4 --- /dev/null +++ b/tools/server/exception_handler.py @@ -0,0 +1,27 @@ +import traceback +from http import HTTPStatus + +from kui.asgi import HTTPException, JSONResponse + + +class ExceptionHandler: + + async def http_exception_handler(self, exc: HTTPException): + return JSONResponse( + dict( + statusCode=exc.status_code, + message=exc.content, + error=HTTPStatus(exc.status_code).phrase, + ), + exc.status_code, + exc.headers, + ) + + async def other_exception_handler(self, exc: Exception): + traceback.print_exc() + + status = HTTPStatus.INTERNAL_SERVER_ERROR + return JSONResponse( + dict(statusCode=status, message=str(exc), error=status.phrase), + status, + ) diff --git a/tools/server/inference.py b/tools/server/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..060e24b968d7c7b04f8f99f302cff0773b1ecdd8 --- /dev/null +++ b/tools/server/inference.py @@ -0,0 +1,45 @@ +from http import HTTPStatus + +import numpy as np +from kui.asgi import HTTPException + +from fish_speech.inference_engine import TTSInferenceEngine +from fish_speech.utils.schema import ServeTTSRequest + +AMPLITUDE = 32768 # Needs an explaination + + +def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine): + """ + Wrapper for the inference function. + Used in the API server. + """ + count = 0 + for result in engine.inference(req): + match result.code: + case "header": + if isinstance(result.audio, tuple): + yield result.audio[1] + + case "error": + raise HTTPException( + HTTPStatus.INTERNAL_SERVER_ERROR, + content=str(result.error), + ) + + case "segment": + count += 1 + if isinstance(result.audio, tuple): + yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes() + + case "final": + count += 1 + if isinstance(result.audio, tuple): + yield result.audio[1] + return None # Stop the generator + + if count == 0: + raise HTTPException( + HTTPStatus.INTERNAL_SERVER_ERROR, + content="No audio generated, please check the input text.", + ) diff --git a/tools/server/model_manager.py b/tools/server/model_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..77540a21e5559d38da0fcc92f68445a118f210af --- /dev/null +++ b/tools/server/model_manager.py @@ -0,0 +1,122 @@ +import torch +from funasr import AutoModel +from loguru import logger + +from fish_speech.inference_engine import TTSInferenceEngine +from fish_speech.models.dac.inference import load_model as load_decoder_model +from fish_speech.models.text2semantic.inference import ( + launch_thread_safe_queue, + launch_thread_safe_queue_agent, +) +from fish_speech.utils.schema import ServeTTSRequest +from tools.server.inference import inference_wrapper as inference + +ASR_MODEL_NAME = "iic/SenseVoiceSmall" + + +class ModelManager: + def __init__( + self, + mode: str, + device: str, + half: bool, + compile: bool, + asr_enabled: bool, + llama_checkpoint_path: str, + decoder_checkpoint_path: str, + decoder_config_name: str, + ) -> None: + + self.mode = mode + self.device = device + self.half = half + self.compile = compile + + self.precision = torch.half if half else torch.bfloat16 + + # Check if MPS or CUDA is available + if torch.backends.mps.is_available(): + self.device = "mps" + logger.info("mps is available, running on mps.") + elif not torch.cuda.is_available(): + self.device = "cpu" + logger.info("CUDA is not available, running on CPU.") + + # Load the ASR model if enabled + if asr_enabled: + self.load_asr_model(self.device) + + # Load the TTS models + self.load_llama_model( + llama_checkpoint_path, self.device, self.precision, self.compile, self.mode + ) + self.load_decoder_model( + decoder_config_name, decoder_checkpoint_path, self.device + ) + self.tts_inference_engine = TTSInferenceEngine( + llama_queue=self.llama_queue, + decoder_model=self.decoder_model, + precision=self.precision, + compile=self.compile, + ) + + # Warm up the models + if self.mode == "tts": + self.warm_up(self.tts_inference_engine) + + def load_asr_model(self, device, hub="ms") -> None: + self.asr_model = AutoModel( + model=ASR_MODEL_NAME, + device=device, + disable_pbar=True, + hub=hub, + ) + logger.info("ASR model loaded.") + + def load_llama_model( + self, checkpoint_path, device, precision, compile, mode + ) -> None: + + if mode == "tts": + self.llama_queue = launch_thread_safe_queue( + checkpoint_path=checkpoint_path, + device=device, + precision=precision, + compile=compile, + ) + elif mode == "agent": + self.llama_queue, self.tokenizer, self.config = ( + launch_thread_safe_queue_agent( + checkpoint_path=checkpoint_path, + device=device, + precision=precision, + compile=compile, + ) + ) + else: + raise ValueError(f"Invalid mode: {mode}") + + logger.info("LLAMA model loaded.") + + def load_decoder_model(self, config_name, checkpoint_path, device) -> None: + self.decoder_model = load_decoder_model( + config_name=config_name, + checkpoint_path=checkpoint_path, + device=device, + ) + logger.info("Decoder model loaded.") + + def warm_up(self, tts_inference_engine) -> None: + request = ServeTTSRequest( + text="Hello world.", + references=[], + reference_id=None, + max_new_tokens=1024, + chunk_length=200, + top_p=0.7, + repetition_penalty=1.2, + temperature=0.7, + format="wav", + ) + list(inference(request, tts_inference_engine)) + logger.info("Models warmed up.") diff --git a/tools/server/model_utils.py b/tools/server/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..822e2f984b2fd6b6ec8d8ed4d40d5f73a05c1039 --- /dev/null +++ b/tools/server/model_utils.py @@ -0,0 +1,129 @@ +import io +import re + +import librosa +import torch +import torchaudio +from cachetools import LRUCache, cached + +CACHE_MAXSIZE = 10000 +MICRO_BATCH_SIZE = 8 +ASR_SAMPLE_RATE = 16000 +HUGE_GAP_THRESHOLD = 4000 + + +@torch.no_grad() +@torch.autocast(device_type="cuda", dtype=torch.half) +def batch_encode(model, audios_list: list[bytes]): + audios: list[torch.Tensor] = [ + ( + torch.from_numpy( + librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0] + )[None] + if isinstance(audio, bytes) + else audio + ) + for audio in audios_list + ] + + lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device) + max_length = lengths.max().item() + + print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s") + + padded = torch.stack( + [ + torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1]))) + for audio in audios + ] + ).to(model.device) + + features, feature_lengths = model.encode(padded, audio_lengths=lengths) + features, feature_lengths = features.cpu(), feature_lengths.cpu() + + return [feature[..., :length] for feature, length in zip(features, feature_lengths)] + + +@cached( + cache=LRUCache(maxsize=CACHE_MAXSIZE), + key=lambda model, audios: (model.device, tuple(audios)), +) +def cached_vqgan_batch_encode(model, audios: list[bytes]): + return batch_encode(model, audios) + + +@torch.no_grad() +@torch.autocast(device_type="cuda", dtype=torch.half) +def batch_vqgan_decode(model, features): + lengths = torch.tensor( + [feature.shape[-1] for feature in features], device=model.device + ) + max_length = lengths.max().item() + padded = torch.stack( + [ + torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1])) + for feature in features + ] + ).to(model.device) + + # If bs too large, we do micro batch decode + audios, audio_lengths = [], [] + for i in range(0, padded.shape[0], MICRO_BATCH_SIZE): + audio, audio_length = model.decode( + padded[i : i + MICRO_BATCH_SIZE], + feature_lengths=lengths[i : i + MICRO_BATCH_SIZE], + ) + audios.append(audio) + audio_lengths.append(audio_length) + audios = torch.cat(audios, dim=0) + audio_lengths = torch.cat(audio_lengths, dim=0) + audios, audio_lengths = audios.cpu(), audio_lengths.cpu() + + return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)] + + +@torch.no_grad() +def batch_asr(model, lock, audios, sr, language="auto"): + resampled_audios = [] + for audio in audios: + audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE) + assert audio.ndim == 1 + resampled_audios.append(audio) + + with lock: + res = model.generate( + input=resampled_audios, + batch_size=len(resampled_audios), + language=language, + use_itn=True, + ) + + results = [] + for r, audio in zip(res, audios): + text = r["text"] + text = re.sub(r"<\|.*?\|>", "", text) + duration = len(audio) / sr * 1000 + huge_gap = False + + if "timestamp" in r and len(r["timestamp"]) > 2: + for timestamp_a, timestamp_b in zip( + r["timestamp"][:-1], r["timestamp"][1:] + ): + # If there is a gap of more than 4 seconds, we consider it as a huge gap + if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD: + huge_gap = True + break + + # Doesn't make sense to have a huge gap at the end + if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD: + huge_gap = True + + results.append( + { + "text": text, + "duration": duration, + "huge_gap": huge_gap, + } + ) + + return results diff --git a/tools/server/views.py b/tools/server/views.py new file mode 100644 index 0000000000000000000000000000000000000000..905f6adbe0d96a12fac0170fdda5362bdb688eb8 --- /dev/null +++ b/tools/server/views.py @@ -0,0 +1,213 @@ +import io +import os +import time +from http import HTTPStatus + +import numpy as np +import ormsgpack +import soundfile as sf +import torch +from kui.asgi import ( + Body, + HTTPException, + HttpView, + JSONResponse, + Routes, + StreamResponse, + request, +) +from loguru import logger +from typing_extensions import Annotated + +from fish_speech.utils.schema import ( + ServeASRRequest, + ServeASRResponse, + ServeChatRequest, + ServeTTSRequest, + ServeVQGANDecodeRequest, + ServeVQGANDecodeResponse, + ServeVQGANEncodeRequest, + ServeVQGANEncodeResponse, +) +from tools.server.agent import get_response_generator +from tools.server.api_utils import ( + buffer_to_async_generator, + get_content_type, + inference_async, +) +from tools.server.inference import inference_wrapper as inference +from tools.server.model_manager import ModelManager +from tools.server.model_utils import ( + batch_asr, + batch_vqgan_decode, + cached_vqgan_batch_encode, +) + +MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1)) + +routes = Routes() + + +@routes.http("/v1/health") +class Health(HttpView): + @classmethod + async def get(cls): + return JSONResponse({"status": "ok"}) + + @classmethod + async def post(cls): + return JSONResponse({"status": "ok"}) + + +@routes.http.post("/v1/vqgan/encode") +async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]): + # Get the model from the app + model_manager: ModelManager = request.app.state.model_manager + decoder_model = model_manager.decoder_model + + # Encode the audio + start_time = time.time() + tokens = cached_vqgan_batch_encode(decoder_model, req.audios) + logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms") + + # Return the response + return ormsgpack.packb( + ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]), + option=ormsgpack.OPT_SERIALIZE_PYDANTIC, + ) + + +@routes.http.post("/v1/vqgan/decode") +async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]): + # Get the model from the app + model_manager: ModelManager = request.app.state.model_manager + decoder_model = model_manager.decoder_model + + # Decode the audio + tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens] + start_time = time.time() + audios = batch_vqgan_decode(decoder_model, tokens) + logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms") + audios = [audio.astype(np.float16).tobytes() for audio in audios] + + # Return the response + return ormsgpack.packb( + ServeVQGANDecodeResponse(audios=audios), + option=ormsgpack.OPT_SERIALIZE_PYDANTIC, + ) + + +@routes.http.post("/v1/asr") +async def asr(req: Annotated[ServeASRRequest, Body(exclusive=True)]): + # Get the model from the app + model_manager: ModelManager = request.app.state.model_manager + asr_model = model_manager.asr_model + lock = request.app.state.lock + + # Perform ASR + start_time = time.time() + audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios] + audios = [torch.from_numpy(audio).float() for audio in audios] + + if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios): + raise HTTPException(status_code=400, content="Audio length is too long") + + transcriptions = batch_asr( + asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language + ) + logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms") + + # Return the response + return ormsgpack.packb( + ServeASRResponse(transcriptions=transcriptions), + option=ormsgpack.OPT_SERIALIZE_PYDANTIC, + ) + + +@routes.http.post("/v1/tts") +async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]): + # Get the model from the app + app_state = request.app.state + model_manager: ModelManager = app_state.model_manager + engine = model_manager.tts_inference_engine + sample_rate = engine.decoder_model.spec_transform.sample_rate + + # Check if the text is too long + if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length: + raise HTTPException( + HTTPStatus.BAD_REQUEST, + content=f"Text is too long, max length is {app_state.max_text_length}", + ) + + # Check if streaming is enabled + if req.streaming and req.format != "wav": + raise HTTPException( + HTTPStatus.BAD_REQUEST, + content="Streaming only supports WAV format", + ) + + # Perform TTS + if req.streaming: + return StreamResponse( + iterable=inference_async(req, engine), + headers={ + "Content-Disposition": f"attachment; filename=audio.{req.format}", + }, + content_type=get_content_type(req.format), + ) + else: + fake_audios = next(inference(req, engine)) + buffer = io.BytesIO() + sf.write( + buffer, + fake_audios, + sample_rate, + format=req.format, + ) + + return StreamResponse( + iterable=buffer_to_async_generator(buffer.getvalue()), + headers={ + "Content-Disposition": f"attachment; filename=audio.{req.format}", + }, + content_type=get_content_type(req.format), + ) + + +@routes.http.post("/v1/chat") +async def chat(req: Annotated[ServeChatRequest, Body(exclusive=True)]): + # Check that the number of samples requested is correct + if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES: + raise HTTPException( + HTTPStatus.BAD_REQUEST, + content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}", + ) + + # Get the type of content provided + content_type = request.headers.get("Content-Type", "application/json") + json_mode = "application/json" in content_type + + # Get the models from the app + model_manager: ModelManager = request.app.state.model_manager + llama_queue = model_manager.llama_queue + tokenizer = model_manager.tokenizer + config = model_manager.config + + device = request.app.state.device + + # Get the response generators + response_generator = get_response_generator( + llama_queue, tokenizer, config, req, device, json_mode + ) + + # Return the response in the correct format + if req.streaming is False: + result = response_generator() + if json_mode: + return JSONResponse(result.model_dump()) + else: + return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) + + return StreamResponse( + iterable=response_generator(), content_type="text/event-stream" + ) diff --git a/tools/smart_pad.py b/tools/smart_pad.py index 6ce8c4d8dd0fd63e8039822adb4424a38d8e80fe..2d52323affb189659e6418b2f6e69eb1c4268864 100644 --- a/tools/smart_pad.py +++ b/tools/smart_pad.py @@ -1,60 +1,60 @@ -import random -from multiprocessing import Pool -from pathlib import Path - -import click -import librosa -import torch.nn.functional as F -import torchaudio -from tqdm import tqdm - -from tools.file import AUDIO_EXTENSIONS, list_files - -threshold = 10 ** (-50 / 20.0) - - -def process(file): - waveform, sample_rate = torchaudio.load(str(file), backend="sox") - if waveform.size(0) > 1: - waveform = waveform.mean(dim=0, keepdim=True) - - loudness = librosa.feature.rms( - y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True - )[0] - - for i in range(len(loudness) - 1, 0, -1): - if loudness[i] > threshold: - break - - end_silent_time = (len(loudness) - i) * 512 / sample_rate - - if end_silent_time <= 0.3: - random_time = random.uniform(0.3, 0.7) - end_silent_time - waveform = F.pad( - waveform, (0, int(random_time * sample_rate)), mode="constant", value=0 - ) - - for i in range(len(loudness)): - if loudness[i] > threshold: - break - - start_silent_time = i * 512 / sample_rate - - if start_silent_time > 0.02: - waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :] - - torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate) - - -@click.command() -@click.argument("source", type=Path) -@click.option("--num-workers", type=int, default=12) -def main(source, num_workers): - files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True)) - - with Pool(num_workers) as p: - list(tqdm(p.imap_unordered(process, files), total=len(files))) - - -if __name__ == "__main__": - main() +import random +from multiprocessing import Pool +from pathlib import Path + +import click +import librosa +import torch.nn.functional as F +import torchaudio +from tqdm import tqdm + +from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files + +threshold = 10 ** (-50 / 20.0) + + +def process(file): + waveform, sample_rate = torchaudio.load(str(file), backend="sox") + if waveform.size(0) > 1: + waveform = waveform.mean(dim=0, keepdim=True) + + loudness = librosa.feature.rms( + y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True + )[0] + + for i in range(len(loudness) - 1, 0, -1): + if loudness[i] > threshold: + break + + end_silent_time = (len(loudness) - i) * 512 / sample_rate + + if end_silent_time <= 0.3: + random_time = random.uniform(0.3, 0.7) - end_silent_time + waveform = F.pad( + waveform, (0, int(random_time * sample_rate)), mode="constant", value=0 + ) + + for i in range(len(loudness)): + if loudness[i] > threshold: + break + + start_silent_time = i * 512 / sample_rate + + if start_silent_time > 0.02: + waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :] + + torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate) + + +@click.command() +@click.argument("source", type=Path) +@click.option("--num-workers", type=int, default=12) +def main(source, num_workers): + files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True)) + + with Pool(num_workers) as p: + list(tqdm(p.imap_unordered(process, files), total=len(files))) + + +if __name__ == "__main__": + main() diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py index bdb914c2a3b13d13a79882be3e7b33027ce5109a..977afdf3260994ef31d2189a5973a2628b26c0c5 100644 --- a/tools/vqgan/create_train_split.py +++ b/tools/vqgan/create_train_split.py @@ -1,83 +1,83 @@ -import math -from pathlib import Path -from random import Random - -import click -from loguru import logger -from pydub import AudioSegment -from tqdm import tqdm - -from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist - - -@click.command() -@click.argument("root", type=click.Path(exists=True, path_type=Path)) -@click.option("--val-ratio", type=float, default=None) -@click.option("--val-count", type=int, default=None) -@click.option("--filelist", default=None, type=Path) -@click.option("--min-duration", default=None, type=float) -@click.option("--max-duration", default=None, type=float) -def main(root, val_ratio, val_count, filelist, min_duration, max_duration): - if filelist: - files = [i[0] for i in load_filelist(filelist)] - else: - files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True) - - if min_duration is None and max_duration is None: - filtered_files = list(map(str, [file.relative_to(root) for file in files])) - else: - filtered_files = [] - for file in tqdm(files): - try: - audio = AudioSegment.from_file(str(file)) - duration = len(audio) / 1000.0 - - if min_duration is not None and duration < min_duration: - logger.info( - f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}" - ) - continue - - if max_duration is not None and duration > max_duration: - logger.info( - f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}" - ) - continue - - filtered_files.append(str(file.relative_to(root))) - except Exception as e: - logger.info(f"Error processing {file}: {e}") - - logger.info( - f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering" - ) - - Random(42).shuffle(filtered_files) - - if val_count is None and val_ratio is None: - logger.info("Validation ratio and count not specified, using min(20%, 100)") - val_size = min(100, math.ceil(len(filtered_files) * 0.2)) - elif val_count is not None and val_ratio is not None: - logger.error("Cannot specify both val_count and val_ratio") - return - elif val_count is not None: - if val_count < 1 or val_count > len(filtered_files): - logger.error("val_count must be between 1 and number of files") - return - val_size = val_count - else: - val_size = math.ceil(len(filtered_files) * val_ratio) - - logger.info(f"Using {val_size} files for validation") - - with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f: - f.write("\n".join(filtered_files[val_size:])) - - with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f: - f.write("\n".join(filtered_files[:val_size])) - - logger.info("Done") - - -if __name__ == "__main__": - main() +import math +from pathlib import Path +from random import Random + +import click +from loguru import logger +from pydub import AudioSegment +from tqdm import tqdm + +from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist + + +@click.command() +@click.argument("root", type=click.Path(exists=True, path_type=Path)) +@click.option("--val-ratio", type=float, default=None) +@click.option("--val-count", type=int, default=None) +@click.option("--filelist", default=None, type=Path) +@click.option("--min-duration", default=None, type=float) +@click.option("--max-duration", default=None, type=float) +def main(root, val_ratio, val_count, filelist, min_duration, max_duration): + if filelist: + files = [i[0] for i in load_filelist(filelist)] + else: + files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True) + + if min_duration is None and max_duration is None: + filtered_files = list(map(str, [file.relative_to(root) for file in files])) + else: + filtered_files = [] + for file in tqdm(files): + try: + audio = AudioSegment.from_file(str(file)) + duration = len(audio) / 1000.0 + + if min_duration is not None and duration < min_duration: + logger.info( + f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}" + ) + continue + + if max_duration is not None and duration > max_duration: + logger.info( + f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}" + ) + continue + + filtered_files.append(str(file.relative_to(root))) + except Exception as e: + logger.info(f"Error processing {file}: {e}") + + logger.info( + f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering" + ) + + Random(42).shuffle(filtered_files) + + if val_count is None and val_ratio is None: + logger.info("Validation ratio and count not specified, using min(20%, 100)") + val_size = min(100, math.ceil(len(filtered_files) * 0.2)) + elif val_count is not None and val_ratio is not None: + logger.error("Cannot specify both val_count and val_ratio") + return + elif val_count is not None: + if val_count < 1 or val_count > len(filtered_files): + logger.error("val_count must be between 1 and number of files") + return + val_size = val_count + else: + val_size = math.ceil(len(filtered_files) * val_ratio) + + logger.info(f"Using {val_size} files for validation") + + with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f: + f.write("\n".join(filtered_files[val_size:])) + + with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f: + f.write("\n".join(filtered_files[:val_size])) + + logger.info("Done") + + +if __name__ == "__main__": + main() diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py index bccc721a7f5d5cb68596df62d7b4c629814538c5..e621eaddac7215cb68767d138f748bc40d29d67d 100644 --- a/tools/vqgan/extract_vq.py +++ b/tools/vqgan/extract_vq.py @@ -1,233 +1,232 @@ -import os -import subprocess as sp -import sys -import time -from datetime import timedelta -from functools import lru_cache -from pathlib import Path -from random import Random - -import click -import numpy as np -import torch -import torchaudio -from hydra import compose, initialize -from hydra.utils import instantiate -from lightning import LightningModule -from loguru import logger -from omegaconf import OmegaConf - -from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist - -# register eval resolver -OmegaConf.register_new_resolver("eval", eval) -# This file is used to convert the audio files to text files using the Whisper model. -# It's mainly used to generate the training data for the VQ model. - -backends = torchaudio.list_audio_backends() - -if "ffmpeg" in backends: - backend = "ffmpeg" -else: - backend = "soundfile" - -RANK = int(os.environ.get("SLURM_PROCID", 0)) -WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1)) - -logger_format = ( - "{time:YYYY-MM-DD HH:mm:ss.SSS} | " - "{level: <8} | " - "{name}:{function}:{line} | " - "{extra[rank]} - {message}" -) -logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"}) -logger.remove() -logger.add(sys.stderr, format=logger_format) - - -@lru_cache(maxsize=1) -def get_model( - config_name: str = "firefly_gan_vq", - checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", - device: str | torch.device = "cuda", -): - with initialize(version_base="1.3", config_path="../../fish_speech/configs"): - cfg = compose(config_name=config_name) - - model = instantiate(cfg) - state_dict = torch.load( - checkpoint_path, - map_location=device, - ) - if "state_dict" in state_dict: - state_dict = state_dict["state_dict"] - - if any("generator" in k for k in state_dict): - state_dict = { - k.replace("generator.", ""): v - for k, v in state_dict.items() - if "generator." in k - } - - model.load_state_dict(state_dict, strict=False) - model.eval() - model.to(device) - - logger.info(f"Loaded model") - return model - - -@torch.inference_mode() -def process_batch(files: list[Path], model) -> float: - wavs = [] - audio_lengths = [] - new_files = [] - max_length = total_time = 0 - - for file in files: - try: - wav, sr = torchaudio.load( - str(file), backend=backend - ) # Need to install libsox-dev - except Exception as e: - logger.error(f"Error reading {file}: {e}") - continue - - if wav.shape[0] > 1: - wav = wav.mean(dim=0, keepdim=True) - - wav = torchaudio.functional.resample( - wav.cuda(), sr, model.spec_transform.sample_rate - )[0] - total_time += len(wav) / model.spec_transform.sample_rate - max_length = max(max_length, len(wav)) - - wavs.append(wav) - audio_lengths.append(len(wav)) - new_files.append(file) - - files = new_files - - # Pad to max length - for i, wav in enumerate(wavs): - wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant") - - audios = torch.stack(wavs, dim=0)[:, None] - audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long) - - # Calculate lengths - indices, feature_lengths = model.encode(audios, audio_lengths) - - # Save to disk - outputs = indices.cpu().numpy() - - for file, length, feature, audio_length in zip( - files, feature_lengths, outputs, audio_lengths - ): - feature = feature[:, :length] - - # (T,) - with open(file.with_suffix(".npy"), "wb") as f: - np.save(f, feature) - - return total_time - - -@click.command() -@click.argument("folder") -@click.option("--num-workers", default=1) -@click.option("--config-name", default="firefly_gan_vq") -@click.option( - "--checkpoint-path", - default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", -) -@click.option("--batch-size", default=64) -@click.option("--filelist", default=None, type=Path) -def main( - folder: str, - num_workers: int, - config_name: str, - checkpoint_path: str, - batch_size: int, - filelist: Path, -): - if num_workers > 1 and WORLD_SIZE != num_workers: - assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both" - - logger.info(f"Spawning {num_workers} workers") - - if torch.cuda.is_available(): - visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) - if visible_devices is None: - visible_devices = list(range(torch.cuda.device_count())) - else: - visible_devices = visible_devices.split(",") - else: - # Set to empty string to avoid using GPU - visible_devices = [""] - - processes = [] - for i in range(num_workers): - env = os.environ.copy() - env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)]) - env["SLURM_PROCID"] = str(i) - env["SLURM_NTASKS"] = str(num_workers) - - processes.append( - sp.Popen( - [sys.executable] + sys.argv.copy(), - env=env, - ) - ) - - for p in processes: - p.wait() - - logger.info(f"All workers finished") - return - - # This is a worker - logger.info(f"Starting worker") - if filelist: - files = [i[0] for i in load_filelist(filelist)] - else: - files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False) - - print(f"Found {len(files)} files") - files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()] - - total_files = len(files) - files = files[RANK::WORLD_SIZE] - logger.info(f"Processing {len(files)}/{total_files} files") - - # Batch processing - total_time = 0 - begin_time = time.time() - processed_files = 0 - model = get_model(config_name, checkpoint_path) - - for n_batch, idx in enumerate(range(0, len(files), batch_size)): - batch = files[idx : idx + batch_size] - batch_time = process_batch(batch, model) - - total_time += batch_time - processed_files += len(batch) - - if (n_batch + 1) % 10 == 0: - eta = ( - (time.time() - begin_time) - / processed_files - * (len(files) - processed_files) - ) - logger.info( - f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, " - + f"ETA: {timedelta(seconds=round(eta))}s" - ) - - logger.info( - f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio" - ) - - -if __name__ == "__main__": - main() +import os +import subprocess as sp +import sys +import time +from datetime import timedelta +from functools import lru_cache +from pathlib import Path +from random import Random + +import click +import numpy as np +import torch +import torchaudio +from hydra import compose, initialize +from hydra.utils import instantiate +from loguru import logger +from omegaconf import OmegaConf + +from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist + +# register eval resolver +OmegaConf.register_new_resolver("eval", eval) +# This file is used to convert the audio files to text files using the Whisper model. +# It's mainly used to generate the training data for the VQ model. + +backends = torchaudio.list_audio_backends() + +if "ffmpeg" in backends: + backend = "ffmpeg" +else: + backend = "soundfile" + +RANK = int(os.environ.get("SLURM_PROCID", 0)) +WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1)) + +logger_format = ( + "{time:YYYY-MM-DD HH:mm:ss.SSS} | " + "{level: <8} | " + "{name}:{function}:{line} | " + "{extra[rank]} - {message}" +) +logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"}) +logger.remove() +logger.add(sys.stderr, format=logger_format) + + +@lru_cache(maxsize=1) +def get_model( + config_name: str = "modded_dac_vq", + checkpoint_path: str = "checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + device: str | torch.device = "cuda", +): + with initialize(version_base="1.3", config_path="../../fish_speech/configs"): + cfg = compose(config_name=config_name) + + model = instantiate(cfg) + state_dict = torch.load( + checkpoint_path, + map_location=device, + ) + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + + if any("generator" in k for k in state_dict): + state_dict = { + k.replace("generator.", ""): v + for k, v in state_dict.items() + if "generator." in k + } + + model.load_state_dict(state_dict, strict=False) + model.eval() + model.to(device) + + logger.info(f"Loaded model") + return model + + +@torch.inference_mode() +def process_batch(files: list[Path], model) -> float: + wavs = [] + audio_lengths = [] + new_files = [] + max_length = total_time = 0 + + for file in files: + try: + wav, sr = torchaudio.load( + str(file), backend=backend + ) # Need to install libsox-dev + except Exception as e: + logger.error(f"Error reading {file}: {e}") + continue + + if wav.shape[0] > 1: + wav = wav.mean(dim=0, keepdim=True) + + wav = torchaudio.functional.resample( + wav.cuda(), sr, model.spec_transform.sample_rate + )[0] + total_time += len(wav) / model.spec_transform.sample_rate + max_length = max(max_length, len(wav)) + + wavs.append(wav) + audio_lengths.append(len(wav)) + new_files.append(file) + + files = new_files + + # Pad to max length + for i, wav in enumerate(wavs): + wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant") + + audios = torch.stack(wavs, dim=0)[:, None] + audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long) + + # Calculate lengths + indices, feature_lengths = model.encode(audios, audio_lengths) + + # Save to disk + outputs = indices.cpu().numpy() + + for file, length, feature, audio_length in zip( + files, feature_lengths, outputs, audio_lengths + ): + feature = feature[:, :length] + + # (T,) + with open(file.with_suffix(".npy"), "wb") as f: + np.save(f, feature) + + return total_time + + +@click.command() +@click.argument("folder") +@click.option("--num-workers", default=1) +@click.option("--config-name", default="modded_dac_vq") +@click.option( + "--checkpoint-path", + default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", +) +@click.option("--batch-size", default=64) +@click.option("--filelist", default=None, type=Path) +def main( + folder: str, + num_workers: int, + config_name: str, + checkpoint_path: str, + batch_size: int, + filelist: Path, +): + if num_workers > 1 and WORLD_SIZE != num_workers: + assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both" + + logger.info(f"Spawning {num_workers} workers") + + if torch.cuda.is_available(): + visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if visible_devices is None: + visible_devices = list(range(torch.cuda.device_count())) + else: + visible_devices = visible_devices.split(",") + else: + # Set to empty string to avoid using GPU + visible_devices = [""] + + processes = [] + for i in range(num_workers): + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)]) + env["SLURM_PROCID"] = str(i) + env["SLURM_NTASKS"] = str(num_workers) + + processes.append( + sp.Popen( + [sys.executable] + sys.argv.copy(), + env=env, + ) + ) + + for p in processes: + p.wait() + + logger.info(f"All workers finished") + return + + # This is a worker + logger.info(f"Starting worker") + if filelist: + files = [i[0] for i in load_filelist(filelist)] + else: + files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False) + + print(f"Found {len(files)} files") + files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()] + + total_files = len(files) + files = files[RANK::WORLD_SIZE] + logger.info(f"Processing {len(files)}/{total_files} files") + + # Batch processing + total_time = 0 + begin_time = time.time() + processed_files = 0 + model = get_model(config_name, checkpoint_path) + + for n_batch, idx in enumerate(range(0, len(files), batch_size)): + batch = files[idx : idx + batch_size] + batch_time = process_batch(batch, model) + + total_time += batch_time + processed_files += len(batch) + + if (n_batch + 1) % 10 == 0: + eta = ( + (time.time() - begin_time) + / processed_files + * (len(files) - processed_files) + ) + logger.info( + f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, " + + f"ETA: {timedelta(seconds=round(eta))}s" + ) + + logger.info( + f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio" + ) + + +if __name__ == "__main__": + main() diff --git a/tools/webui/__init__.py b/tools/webui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9b9a0259aaf77778709d81994bf9885e1ec4b2f --- /dev/null +++ b/tools/webui/__init__.py @@ -0,0 +1,155 @@ +from typing import Callable + +import gradio as gr + +from fish_speech.i18n import i18n +from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER + + +def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks: + with gr.Blocks(theme=gr.themes.Base()) as app: + gr.Markdown(HEADER_MD) + + # Use light theme by default + app.load( + None, + None, + js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}" + % theme, + ) + + # Inference + with gr.Row(): + with gr.Column(scale=3): + text = gr.Textbox( + label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10 + ) + + with gr.Row(): + with gr.Column(): + with gr.Tab(label=i18n("Advanced Config")): + with gr.Row(): + chunk_length = gr.Slider( + label=i18n("Iterative Prompt Length, 0 means off"), + minimum=100, + maximum=400, + value=300, + step=8, + ) + + max_new_tokens = gr.Slider( + label=i18n( + "Maximum tokens per batch, 0 means no limit" + ), + minimum=0, + maximum=2048, + value=0, + step=8, + ) + + with gr.Row(): + top_p = gr.Slider( + label="Top-P", + minimum=0.7, + maximum=0.95, + value=0.8, + step=0.01, + ) + + repetition_penalty = gr.Slider( + label=i18n("Repetition Penalty"), + minimum=1, + maximum=1.2, + value=1.1, + step=0.01, + ) + + with gr.Row(): + temperature = gr.Slider( + label="Temperature", + minimum=0.7, + maximum=1.0, + value=0.8, + step=0.01, + ) + seed = gr.Number( + label="Seed", + info="0 means randomized inference, otherwise deterministic", + value=0, + ) + + with gr.Tab(label=i18n("Reference Audio")): + with gr.Row(): + gr.Markdown( + i18n( + "5 to 10 seconds of reference audio, useful for specifying speaker." + ) + ) + with gr.Row(): + reference_id = gr.Textbox( + label=i18n("Reference ID"), + placeholder="Leave empty to use uploaded references", + ) + + with gr.Row(): + use_memory_cache = gr.Radio( + label=i18n("Use Memory Cache"), + choices=["on", "off"], + value="on", + ) + + with gr.Row(): + reference_audio = gr.Audio( + label=i18n("Reference Audio"), + type="filepath", + ) + with gr.Row(): + reference_text = gr.Textbox( + label=i18n("Reference Text"), + lines=1, + placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", + value="", + ) + + with gr.Column(scale=3): + with gr.Row(): + error = gr.HTML( + label=i18n("Error Message"), + visible=True, + ) + with gr.Row(): + audio = gr.Audio( + label=i18n("Generated Audio"), + type="numpy", + interactive=False, + visible=True, + ) + + with gr.Row(): + with gr.Column(scale=3): + generate = gr.Button( + value="\U0001f3a7 " + i18n("Generate"), + variant="primary", + ) + + # Submit + generate.click( + inference_fct, + [ + text, + reference_id, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + seed, + use_memory_cache, + ], + [audio, error], + concurrency_limit=1, + ) + + return app diff --git a/tools/webui/inference.py b/tools/webui/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..e6cd1d75599ad38b757b57aaf33ee32b62936338 --- /dev/null +++ b/tools/webui/inference.py @@ -0,0 +1,89 @@ +import html +from functools import partial +from typing import Any, Callable + +from fish_speech.i18n import i18n +from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest + + +def inference_wrapper( + text, + reference_id, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + seed, + use_memory_cache, + engine, +): + """ + Wrapper for the inference function. + Used in the Gradio interface. + """ + + if reference_audio: + references = get_reference_audio(reference_audio, reference_text) + else: + references = [] + + req = ServeTTSRequest( + text=text, + reference_id=reference_id if reference_id else None, + references=references, + max_new_tokens=max_new_tokens, + chunk_length=chunk_length, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + seed=int(seed) if seed else None, + use_memory_cache=use_memory_cache, + ) + + for result in engine.inference(req): + match result.code: + case "final": + return result.audio, None + case "error": + return None, build_html_error_message(i18n(result.error)) + case _: + pass + + return None, i18n("No audio generated") + + +def get_reference_audio(reference_audio: str, reference_text: str) -> list: + """ + Get the reference audio bytes. + """ + + with open(reference_audio, "rb") as audio_file: + audio_bytes = audio_file.read() + + return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)] + + +def build_html_error_message(error: Any) -> str: + + error = error if isinstance(error, Exception) else Exception("Unknown error") + + return f""" +
+ {html.escape(str(error))} +
+ """ + + +def get_inference_wrapper(engine) -> Callable: + """ + Get the inference function with the immutable arguments. + """ + + return partial( + inference_wrapper, + engine=engine, + ) diff --git a/tools/webui/variables.py b/tools/webui/variables.py new file mode 100644 index 0000000000000000000000000000000000000000..db42d5d797e821e9a34832bea4344ecb726c97db --- /dev/null +++ b/tools/webui/variables.py @@ -0,0 +1,14 @@ +from fish_speech.i18n import i18n + +HEADER_MD = f"""# Fish Speech + +{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")} + +{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")} + +{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")} + +{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")} +""" + +TEXTBOX_PLACEHOLDER = i18n("Put your text here.") diff --git a/tools/whisper_asr.py b/tools/whisper_asr.py index 8ac720d2dad94fca5087852928f7e6ce57d1b92d..1c34efbe44928def767305d77882f1fcb0c7eccb 100644 --- a/tools/whisper_asr.py +++ b/tools/whisper_asr.py @@ -1,176 +1,176 @@ -""" -Used to transcribe all audio files in one folder into another folder. -e.g. -Directory structure: ---pre_data_root -----SP_1 -------01.wav -------02.wav -------...... -----SP_2 -------01.wav -------02.wav -------...... -Use -python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1 -to transcribe the first speaker. - -Use -python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2 -to transcribe the second speaker. - -Note: Be aware of your audio sample rate, which defaults to 44.1kHz. -""" - -import re -from pathlib import Path - -import click -import soundfile as sf -from faster_whisper import WhisperModel -from loguru import logger -from pydub import AudioSegment -from tqdm import tqdm - -from tools.file import AUDIO_EXTENSIONS, list_files - - -@click.command() -@click.option("--model-size", default="large-v3", help="Size of the Whisper model") -@click.option( - "--compute-type", - default="float16", - help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]", -) -@click.option("--audio-dir", required=True, help="Directory containing audio files") -@click.option( - "--save-dir", required=True, help="Directory to save processed audio files" -) -@click.option( - "--sample-rate", - default=44100, - type=int, - help="Output sample rate, default to input sample rate", -) -@click.option("--device", default="cuda", help="Device to use [cuda / cpu]") -@click.option("--language", default="auto", help="Language of the transcription") -@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing") -def main( - model_size, - compute_type, - audio_dir, - save_dir, - sample_rate, - device, - language, - initial_prompt, -): - logger.info("Loading / Downloading Faster Whisper model...") - - model = WhisperModel( - model_size, - device=device, - compute_type=compute_type, - download_root="faster_whisper", - ) - - logger.info("Model loaded.") - - save_path = Path(save_dir) - save_path.mkdir(parents=True, exist_ok=True) - - audio_files = list_files( - path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True - ) - - for file_path in tqdm(audio_files, desc="Processing audio file"): - file_stem = file_path.stem - file_suffix = file_path.suffix - - rel_path = Path(file_path).relative_to(audio_dir) - (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True) - - audio = AudioSegment.from_file(file_path) - - segments, info = model.transcribe( - file_path, - beam_size=5, - language=None if language == "auto" else language, - initial_prompt=initial_prompt, - ) - - print( - "Detected language '%s' with probability %f" - % (info.language, info.language_probability) - ) - print("Total len(ms): ", len(audio)) - - whole_text = None - for segment in segments: - id, start, end, text = ( - segment.id, - segment.start, - segment.end, - segment.text, - ) - print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text)) - if not whole_text: - whole_text = text - else: - whole_text += ", " + text - - whole_text += "." - - audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}" - audio.export(audio_save_path, format=file_suffix[1:]) - print(f"Exported {audio_save_path}") - - transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab" - with open( - transcript_save_path, - "w", - encoding="utf-8", - ) as f: - f.write(whole_text) - - -if __name__ == "__main__": - main() - exit(0) - - audio = AudioSegment.from_wav( - r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav" - ) - - model_size = "large-v3" - - model = WhisperModel( - model_size, - device="cuda", - compute_type="float16", - download_root="faster_whisper", - ) - - segments, info = model.transcribe( - r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav", - beam_size=5, - ) - - print( - "Detected language '%s' with probability %f" - % (info.language, info.language_probability) - ) - print("Total len(ms): ", len(audio)) - - for i, segment in enumerate(segments): - print( - "Segment %03d [%.2fs -> %.2fs] %s" - % (i, segment.start, segment.end, segment.text) - ) - start_ms = int(segment.start * 1000) - end_ms = int(segment.end * 1000) - segment_audio = audio[start_ms:end_ms] - segment_audio.export(f"segment_{i:03d}.wav", format="wav") - print(f"Exported segment_{i:03d}.wav") - - print("All segments have been exported.") +""" +Used to transcribe all audio files in one folder into another folder. +e.g. +Directory structure: +--pre_data_root +----SP_1 +------01.wav +------02.wav +------...... +----SP_2 +------01.wav +------02.wav +------...... +Use +python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1 +to transcribe the first speaker. + +Use +python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2 +to transcribe the second speaker. + +Note: Be aware of your audio sample rate, which defaults to 44.1kHz. +""" + +import re +from pathlib import Path + +import click +import soundfile as sf +from faster_whisper import WhisperModel +from loguru import logger +from pydub import AudioSegment +from tqdm import tqdm + +from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files + + +@click.command() +@click.option("--model-size", default="large-v3", help="Size of the Whisper model") +@click.option( + "--compute-type", + default="float16", + help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]", +) +@click.option("--audio-dir", required=True, help="Directory containing audio files") +@click.option( + "--save-dir", required=True, help="Directory to save processed audio files" +) +@click.option( + "--sample-rate", + default=44100, + type=int, + help="Output sample rate, default to input sample rate", +) +@click.option("--device", default="cuda", help="Device to use [cuda / cpu]") +@click.option("--language", default="auto", help="Language of the transcription") +@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing") +def main( + model_size, + compute_type, + audio_dir, + save_dir, + sample_rate, + device, + language, + initial_prompt, +): + logger.info("Loading / Downloading Faster Whisper model...") + + model = WhisperModel( + model_size, + device=device, + compute_type=compute_type, + download_root="faster_whisper", + ) + + logger.info("Model loaded.") + + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + audio_files = list_files( + path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True + ) + + for file_path in tqdm(audio_files, desc="Processing audio file"): + file_stem = file_path.stem + file_suffix = file_path.suffix + + rel_path = Path(file_path).relative_to(audio_dir) + (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True) + + audio = AudioSegment.from_file(file_path) + + segments, info = model.transcribe( + file_path, + beam_size=5, + language=None if language == "auto" else language, + initial_prompt=initial_prompt, + ) + + print( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) + ) + print("Total len(ms): ", len(audio)) + + whole_text = None + for segment in segments: + id, start, end, text = ( + segment.id, + segment.start, + segment.end, + segment.text, + ) + print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text)) + if not whole_text: + whole_text = text + else: + whole_text += ", " + text + + whole_text += "." + + audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}" + audio.export(audio_save_path, format=file_suffix[1:]) + print(f"Exported {audio_save_path}") + + transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab" + with open( + transcript_save_path, + "w", + encoding="utf-8", + ) as f: + f.write(whole_text) + + +if __name__ == "__main__": + main() + exit(0) + + audio = AudioSegment.from_wav( + r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav" + ) + + model_size = "large-v3" + + model = WhisperModel( + model_size, + device="cuda", + compute_type="float16", + download_root="faster_whisper", + ) + + segments, info = model.transcribe( + r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav", + beam_size=5, + ) + + print( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) + ) + print("Total len(ms): ", len(audio)) + + for i, segment in enumerate(segments): + print( + "Segment %03d [%.2fs -> %.2fs] %s" + % (i, segment.start, segment.end, segment.text) + ) + start_ms = int(segment.start * 1000) + end_ms = int(segment.end * 1000) + segment_audio = audio[start_ms:end_ms] + segment_audio.export(f"segment_{i:03d}.wav", format="wav") + print(f"Exported segment_{i:03d}.wav") + + print("All segments have been exported.")