|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script exports a pre-trained ZipVoice or ZipVoice-Distill model from PyTorch to |
|
ONNX. |
|
|
|
Usage: |
|
|
|
python3 -m zipvoice.bin.onnx_export \ |
|
--model-name zipvoice \ |
|
--token-file data/tokens_emilia.txt \ |
|
--checkpoint exp/zipvoice/epoch-11-avg-4.pt \ |
|
--model-config conf/zipvoice_base.json \ |
|
--onnx-model-dir exp/zipvoice_onnx |
|
|
|
`--model-name` can be `zipvoice` or `zipvoice_distill`, |
|
which are the models before and after distillation, respectively. |
|
""" |
|
|
|
|
|
import argparse |
|
import json |
|
import os |
|
from typing import Dict |
|
|
|
import onnx |
|
import safetensors.torch |
|
import torch |
|
from onnxruntime.quantization import QuantType, quantize_dynamic |
|
from torch import Tensor, nn |
|
|
|
from zipvoice.models.zipvoice import ZipVoice |
|
from zipvoice.models.zipvoice_distill import ZipVoiceDistill |
|
from zipvoice.tokenizer.tokenizer import SimpleTokenizer |
|
from zipvoice.utils.checkpoint import load_checkpoint |
|
from zipvoice.utils.common import AttributeDict |
|
from zipvoice.utils.scaling_converter import convert_scaled_to_non_scaled |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
|
|
parser.add_argument( |
|
"--onnx-model-dir", |
|
type=str, |
|
default="exp", |
|
help="Dir to the exported models", |
|
) |
|
|
|
parser.add_argument( |
|
"--model-name", |
|
type=str, |
|
default="zipvoice", |
|
choices=["zipvoice", "zipvoice_distill"], |
|
help="The model used for inference", |
|
) |
|
|
|
parser.add_argument( |
|
"--token-file", |
|
type=str, |
|
default="data/tokens_emilia.txt", |
|
help="The file that contains information that maps tokens to ids," |
|
"which is a text file with '{token}\t{token_id}' per line.", |
|
) |
|
|
|
parser.add_argument( |
|
"--checkpoint", |
|
type=str, |
|
default="exp_zipvoice/epoch-11-avg-4.pt", |
|
help="The model checkpoint.", |
|
) |
|
|
|
parser.add_argument( |
|
"--model-config", |
|
type=str, |
|
default="conf/zipvoice_base.json", |
|
help="The model configuration file.", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def add_meta_data(filename: str, meta_data: Dict[str, str]): |
|
"""Add meta data to an ONNX model. It is changed in-place. |
|
|
|
Args: |
|
filename: |
|
Filename of the ONNX model to be changed. |
|
meta_data: |
|
Key-value pairs. |
|
""" |
|
model = onnx.load(filename) |
|
for key, value in meta_data.items(): |
|
meta = model.metadata_props.add() |
|
meta.key = key |
|
meta.value = value |
|
|
|
onnx.save(model, filename) |
|
|
|
|
|
class OnnxTextModel(nn.Module): |
|
def __init__(self, model: nn.Module): |
|
"""A wrapper for ZipVoice text encoder.""" |
|
super().__init__() |
|
self.embed = model.embed |
|
self.text_encoder = model.text_encoder |
|
self.pad_id = model.pad_id |
|
|
|
def forward( |
|
self, |
|
tokens: Tensor, |
|
prompt_tokens: Tensor, |
|
prompt_features_len: Tensor, |
|
speed: Tensor, |
|
) -> Tensor: |
|
cat_tokens = torch.cat([prompt_tokens, tokens], dim=1) |
|
cat_tokens = nn.functional.pad(cat_tokens, (0, 1), value=self.pad_id) |
|
tokens_len = cat_tokens.shape[1] - 1 |
|
padding_mask = (torch.arange(tokens_len + 1) == tokens_len).unsqueeze(0) |
|
|
|
embed = self.embed(cat_tokens) |
|
embed = self.text_encoder(x=embed, t=None, padding_mask=padding_mask) |
|
|
|
features_len = torch.ceil( |
|
(prompt_features_len / prompt_tokens.shape[1] * tokens_len / speed) |
|
).to(dtype=torch.int64) |
|
|
|
token_dur = torch.div(features_len, tokens_len, rounding_mode="floor").to( |
|
dtype=torch.int64 |
|
) |
|
|
|
text_condition = embed[:, :-1, :].unsqueeze(2).expand(-1, -1, token_dur, -1) |
|
text_condition = text_condition.reshape(embed.shape[0], -1, embed.shape[2]) |
|
|
|
text_condition = torch.cat( |
|
[ |
|
text_condition, |
|
embed[:, -1:, :].expand(-1, features_len - text_condition.shape[1], -1), |
|
], |
|
dim=1, |
|
) |
|
|
|
return text_condition |
|
|
|
|
|
class OnnxFlowMatchingModel(nn.Module): |
|
def __init__(self, model: nn.Module): |
|
"""A wrapper for ZipVoice flow-matching decoder.""" |
|
super().__init__() |
|
self.distill = model.distill |
|
self.fm_decoder = model.fm_decoder |
|
self.model_func = getattr(model, "forward_fm_decoder") |
|
self.feat_dim = model.feat_dim |
|
|
|
def forward( |
|
self, |
|
t: Tensor, |
|
x: Tensor, |
|
text_condition: Tensor, |
|
speech_condition: torch.Tensor, |
|
guidance_scale: Tensor, |
|
) -> Tensor: |
|
if self.distill: |
|
return self.model_func( |
|
t=t, |
|
xt=x, |
|
text_condition=text_condition, |
|
speech_condition=speech_condition, |
|
guidance_scale=guidance_scale, |
|
) |
|
else: |
|
x = x.repeat(2, 1, 1) |
|
text_condition = torch.cat( |
|
[torch.zeros_like(text_condition), text_condition], dim=0 |
|
) |
|
speech_condition = torch.cat( |
|
[ |
|
torch.where( |
|
t > 0.5, torch.zeros_like(speech_condition), speech_condition |
|
), |
|
speech_condition, |
|
], |
|
dim=0, |
|
) |
|
guidance_scale = torch.where(t > 0.5, guidance_scale, guidance_scale * 2.0) |
|
data_uncond, data_cond = self.model_func( |
|
t=t, |
|
xt=x, |
|
text_condition=text_condition, |
|
speech_condition=speech_condition, |
|
).chunk(2, dim=0) |
|
v = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond |
|
return v |
|
|
|
|
|
def export_text_encoder( |
|
model: OnnxTextModel, |
|
filename: str, |
|
opset_version: int = 11, |
|
) -> None: |
|
"""Export the text encoder model to ONNX format. |
|
|
|
Args: |
|
model: |
|
The input model |
|
filename: |
|
The filename to save the exported ONNX model. |
|
opset_version: |
|
The opset version to use. |
|
""" |
|
tokens = torch.tensor([[2, 3, 4, 5]], dtype=torch.int64) |
|
prompt_tokens = torch.tensor([[0, 1]], dtype=torch.int64) |
|
prompt_features_len = torch.tensor(10, dtype=torch.int64) |
|
speed = torch.tensor(1.0, dtype=torch.float32) |
|
|
|
model = torch.jit.trace(model, (tokens, prompt_tokens, prompt_features_len, speed)) |
|
|
|
torch.onnx.export( |
|
model, |
|
(tokens, prompt_tokens, prompt_features_len, speed), |
|
filename, |
|
verbose=False, |
|
opset_version=opset_version, |
|
input_names=["tokens", "prompt_tokens", "prompt_features_len", "speed"], |
|
output_names=["text_condition"], |
|
dynamic_axes={ |
|
"tokens": {0: "N", 1: "T"}, |
|
"prompt_tokens": {0: "N", 1: "T"}, |
|
"text_condition": {0: "N", 1: "T"}, |
|
}, |
|
) |
|
|
|
meta_data = { |
|
"version": "1", |
|
"model_author": "k2-fsa", |
|
"comment": "ZipVoice text encoder", |
|
} |
|
print(f"meta_data: {meta_data}") |
|
add_meta_data(filename=filename, meta_data=meta_data) |
|
|
|
print(f"Exported to {filename}") |
|
|
|
|
|
def export_fm_decoder( |
|
model: OnnxFlowMatchingModel, |
|
filename: str, |
|
opset_version: int = 11, |
|
) -> None: |
|
"""Export the flow matching decoder model to ONNX format. |
|
|
|
Args: |
|
model: |
|
The input model |
|
filename: |
|
The filename to save the exported ONNX model. |
|
opset_version: |
|
The opset version to use. |
|
""" |
|
feat_dim = model.feat_dim |
|
seq_len = 200 |
|
t = torch.tensor(0.5, dtype=torch.float32) |
|
x = torch.randn(1, seq_len, feat_dim, dtype=torch.float32) |
|
text_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32) |
|
speech_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32) |
|
guidance_scale = torch.tensor(1.0, dtype=torch.float32) |
|
|
|
model = torch.jit.trace( |
|
model, (t, x, text_condition, speech_condition, guidance_scale) |
|
) |
|
|
|
torch.onnx.export( |
|
model, |
|
(t, x, text_condition, speech_condition, guidance_scale), |
|
filename, |
|
verbose=False, |
|
opset_version=opset_version, |
|
input_names=["t", "x", "text_condition", "speech_condition", "guidance_scale"], |
|
output_names=["v"], |
|
dynamic_axes={ |
|
"x": {0: "N", 1: "T"}, |
|
"text_condition": {0: "N", 1: "T"}, |
|
"speech_condition": {0: "N", 1: "T"}, |
|
"v": {0: "N", 1: "T"}, |
|
}, |
|
) |
|
|
|
meta_data = { |
|
"version": "1", |
|
"model_author": "k2-fsa", |
|
"comment": "ZipVoice flow-matching decoder", |
|
"feat_dim": str(feat_dim), |
|
} |
|
print(f"meta_data: {meta_data}") |
|
add_meta_data(filename=filename, meta_data=meta_data) |
|
|
|
print(f"Exported to {filename}") |
|
|
|
|
|
@torch.no_grad() |
|
def main(): |
|
parser = get_parser() |
|
args = parser.parse_args() |
|
|
|
params = AttributeDict() |
|
params.update(vars(args)) |
|
|
|
model_config = params.model_config |
|
with open(model_config, "r") as f: |
|
model_config = json.load(f) |
|
for key, value in model_config["model"].items(): |
|
setattr(params, key, value) |
|
for key, value in model_config["feature"].items(): |
|
setattr(params, key, value) |
|
|
|
token_file = params.token_file |
|
tokenizer = SimpleTokenizer(token_file) |
|
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id} |
|
|
|
if params.model_name == "zipvoice": |
|
model = ZipVoice( |
|
**model_config["model"], |
|
**tokenizer_config, |
|
) |
|
else: |
|
assert params.model_name == "zipvoice_distill" |
|
model = ZipVoiceDistill( |
|
**model_config["model"], |
|
**tokenizer_config, |
|
) |
|
model_ckpt = params.checkpoint |
|
|
|
if model_ckpt.endswith(".safetensors"): |
|
safetensors.torch.load_model(model, model_ckpt) |
|
elif model_ckpt.endswith(".pt"): |
|
load_checkpoint(filename=model_ckpt, model=model, strict=True) |
|
else: |
|
raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}") |
|
|
|
device = torch.device("cpu") |
|
model = model.to(device) |
|
model.eval() |
|
|
|
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) |
|
|
|
print("Exporting model") |
|
os.makedirs(params.onnx_model_dir, exist_ok=True) |
|
opset_version = 11 |
|
|
|
text_encoder = OnnxTextModel(model=model) |
|
text_encoder_file = f"{params.onnx_model_dir}/text_encoder.onnx" |
|
export_text_encoder( |
|
model=text_encoder, |
|
filename=text_encoder_file, |
|
opset_version=opset_version, |
|
) |
|
|
|
fm_decoder = OnnxFlowMatchingModel(model=model) |
|
fm_decoder_file = f"{params.onnx_model_dir}/fm_decoder.onnx" |
|
export_fm_decoder( |
|
model=fm_decoder, |
|
filename=fm_decoder_file, |
|
opset_version=opset_version, |
|
) |
|
|
|
print("Generate int8 quantization models") |
|
|
|
text_encoder_int8_file = f"{params.onnx_model_dir}/text_encoder_int8.onnx" |
|
quantize_dynamic( |
|
model_input=text_encoder_file, |
|
model_output=text_encoder_int8_file, |
|
op_types_to_quantize=["MatMul"], |
|
weight_type=QuantType.QInt8, |
|
) |
|
|
|
fm_decoder_int8_file = f"{params.onnx_model_dir}/fm_decoder_int8.onnx" |
|
quantize_dynamic( |
|
model_input=fm_decoder_file, |
|
model_output=fm_decoder_int8_file, |
|
op_types_to_quantize=["MatMul"], |
|
weight_type=QuantType.QInt8, |
|
) |
|
|
|
print("Done!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|