Spaces:
Running
on
Zero
Running
on
Zero
Delete onnx_export.py
Browse files- onnx_export.py +0 -334
onnx_export.py
DELETED
@@ -1,334 +0,0 @@
|
|
1 |
-
from module.models_onnx import SynthesizerTrn, symbols
|
2 |
-
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
3 |
-
import torch
|
4 |
-
import torchaudio
|
5 |
-
from torch import nn
|
6 |
-
from feature_extractor import cnhubert
|
7 |
-
cnhubert_base_path = "pretrained_models/chinese-hubert-base"
|
8 |
-
cnhubert.cnhubert_base_path=cnhubert_base_path
|
9 |
-
ssl_model = cnhubert.get_model()
|
10 |
-
from text import cleaned_text_to_sequence
|
11 |
-
import soundfile
|
12 |
-
from tools.my_utils import load_audio
|
13 |
-
import os
|
14 |
-
import json
|
15 |
-
|
16 |
-
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
17 |
-
hann_window = torch.hann_window(win_size).to(
|
18 |
-
dtype=y.dtype, device=y.device
|
19 |
-
)
|
20 |
-
y = torch.nn.functional.pad(
|
21 |
-
y.unsqueeze(1),
|
22 |
-
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
23 |
-
mode="reflect",
|
24 |
-
)
|
25 |
-
y = y.squeeze(1)
|
26 |
-
spec = torch.stft(
|
27 |
-
y,
|
28 |
-
n_fft,
|
29 |
-
hop_length=hop_size,
|
30 |
-
win_length=win_size,
|
31 |
-
window=hann_window,
|
32 |
-
center=center,
|
33 |
-
pad_mode="reflect",
|
34 |
-
normalized=False,
|
35 |
-
onesided=True,
|
36 |
-
return_complex=False,
|
37 |
-
)
|
38 |
-
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
39 |
-
return spec
|
40 |
-
|
41 |
-
|
42 |
-
class DictToAttrRecursive(dict):
|
43 |
-
def __init__(self, input_dict):
|
44 |
-
super().__init__(input_dict)
|
45 |
-
for key, value in input_dict.items():
|
46 |
-
if isinstance(value, dict):
|
47 |
-
value = DictToAttrRecursive(value)
|
48 |
-
self[key] = value
|
49 |
-
setattr(self, key, value)
|
50 |
-
|
51 |
-
def __getattr__(self, item):
|
52 |
-
try:
|
53 |
-
return self[item]
|
54 |
-
except KeyError:
|
55 |
-
raise AttributeError(f"Attribute {item} not found")
|
56 |
-
|
57 |
-
def __setattr__(self, key, value):
|
58 |
-
if isinstance(value, dict):
|
59 |
-
value = DictToAttrRecursive(value)
|
60 |
-
super(DictToAttrRecursive, self).__setitem__(key, value)
|
61 |
-
super().__setattr__(key, value)
|
62 |
-
|
63 |
-
def __delattr__(self, item):
|
64 |
-
try:
|
65 |
-
del self[item]
|
66 |
-
except KeyError:
|
67 |
-
raise AttributeError(f"Attribute {item} not found")
|
68 |
-
|
69 |
-
|
70 |
-
class T2SEncoder(nn.Module):
|
71 |
-
def __init__(self, t2s, vits):
|
72 |
-
super().__init__()
|
73 |
-
self.encoder = t2s.onnx_encoder
|
74 |
-
self.vits = vits
|
75 |
-
|
76 |
-
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
77 |
-
codes = self.vits.extract_latent(ssl_content)
|
78 |
-
prompt_semantic = codes[0, 0]
|
79 |
-
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
|
80 |
-
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
81 |
-
bert = bert.unsqueeze(0)
|
82 |
-
prompt = prompt_semantic.unsqueeze(0)
|
83 |
-
return self.encoder(all_phoneme_ids, bert), prompt
|
84 |
-
|
85 |
-
|
86 |
-
class T2SModel(nn.Module):
|
87 |
-
def __init__(self, t2s_path, vits_model):
|
88 |
-
super().__init__()
|
89 |
-
dict_s1 = torch.load(t2s_path, map_location="cpu")
|
90 |
-
self.config = dict_s1["config"]
|
91 |
-
self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
|
92 |
-
self.t2s_model.load_state_dict(dict_s1["weight"])
|
93 |
-
self.t2s_model.eval()
|
94 |
-
self.vits_model = vits_model.vq_model
|
95 |
-
self.hz = 50
|
96 |
-
self.max_sec = self.config["data"]["max_sec"]
|
97 |
-
self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
|
98 |
-
self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
99 |
-
self.t2s_model = self.t2s_model.model
|
100 |
-
self.t2s_model.init_onnx()
|
101 |
-
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
|
102 |
-
self.first_stage_decoder = self.t2s_model.first_stage_decoder
|
103 |
-
self.stage_decoder = self.t2s_model.stage_decoder
|
104 |
-
#self.t2s_model = torch.jit.script(self.t2s_model)
|
105 |
-
|
106 |
-
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
107 |
-
early_stop_num = self.t2s_model.early_stop_num
|
108 |
-
|
109 |
-
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
110 |
-
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
111 |
-
|
112 |
-
prefix_len = prompts.shape[1]
|
113 |
-
|
114 |
-
#[1,N,512] [1,N]
|
115 |
-
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
116 |
-
|
117 |
-
stop = False
|
118 |
-
for idx in range(1, 1500):
|
119 |
-
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
120 |
-
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
121 |
-
y, k, v, y_emb, logits, samples = enco
|
122 |
-
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
123 |
-
stop = True
|
124 |
-
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
125 |
-
stop = True
|
126 |
-
if stop:
|
127 |
-
break
|
128 |
-
y[0, -1] = 0
|
129 |
-
|
130 |
-
return y[:, -idx:].unsqueeze(0)
|
131 |
-
|
132 |
-
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
|
133 |
-
#self.onnx_encoder = torch.jit.script(self.onnx_encoder)
|
134 |
-
if dynamo:
|
135 |
-
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
136 |
-
onnx_encoder_export_output = torch.onnx.dynamo_export(
|
137 |
-
self.onnx_encoder,
|
138 |
-
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
139 |
-
export_options=export_options
|
140 |
-
)
|
141 |
-
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
142 |
-
return
|
143 |
-
|
144 |
-
torch.onnx.export(
|
145 |
-
self.onnx_encoder,
|
146 |
-
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
147 |
-
f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
|
148 |
-
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
149 |
-
output_names=["x", "prompts"],
|
150 |
-
dynamic_axes={
|
151 |
-
"ref_seq": {1 : "ref_length"},
|
152 |
-
"text_seq": {1 : "text_length"},
|
153 |
-
"ref_bert": {0 : "ref_length"},
|
154 |
-
"text_bert": {0 : "text_length"},
|
155 |
-
"ssl_content": {2 : "ssl_length"},
|
156 |
-
},
|
157 |
-
opset_version=16
|
158 |
-
)
|
159 |
-
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
160 |
-
|
161 |
-
torch.onnx.export(
|
162 |
-
self.first_stage_decoder,
|
163 |
-
(x, prompts),
|
164 |
-
f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
|
165 |
-
input_names=["x", "prompts"],
|
166 |
-
output_names=["y", "k", "v", "y_emb", "x_example"],
|
167 |
-
dynamic_axes={
|
168 |
-
"x": {1 : "x_length"},
|
169 |
-
"prompts": {1 : "prompts_length"},
|
170 |
-
},
|
171 |
-
verbose=False,
|
172 |
-
opset_version=16
|
173 |
-
)
|
174 |
-
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
175 |
-
|
176 |
-
torch.onnx.export(
|
177 |
-
self.stage_decoder,
|
178 |
-
(y, k, v, y_emb, x_example),
|
179 |
-
f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
|
180 |
-
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
181 |
-
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
182 |
-
dynamic_axes={
|
183 |
-
"iy": {1 : "iy_length"},
|
184 |
-
"ik": {1 : "ik_length"},
|
185 |
-
"iv": {1 : "iv_length"},
|
186 |
-
"iy_emb": {1 : "iy_emb_length"},
|
187 |
-
"ix_example": {1 : "ix_example_length"},
|
188 |
-
},
|
189 |
-
verbose=False,
|
190 |
-
opset_version=16
|
191 |
-
)
|
192 |
-
|
193 |
-
|
194 |
-
class VitsModel(nn.Module):
|
195 |
-
def __init__(self, vits_path):
|
196 |
-
super().__init__()
|
197 |
-
dict_s2 = torch.load(vits_path,map_location="cpu")
|
198 |
-
self.hps = dict_s2["config"]
|
199 |
-
self.hps = DictToAttrRecursive(self.hps)
|
200 |
-
self.hps.model.semantic_frame_rate = "25hz"
|
201 |
-
self.vq_model = SynthesizerTrn(
|
202 |
-
self.hps.data.filter_length // 2 + 1,
|
203 |
-
self.hps.train.segment_size // self.hps.data.hop_length,
|
204 |
-
n_speakers=self.hps.data.n_speakers,
|
205 |
-
**self.hps.model
|
206 |
-
)
|
207 |
-
self.vq_model.eval()
|
208 |
-
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
209 |
-
|
210 |
-
def forward(self, text_seq, pred_semantic, ref_audio):
|
211 |
-
refer = spectrogram_torch(
|
212 |
-
ref_audio,
|
213 |
-
self.hps.data.filter_length,
|
214 |
-
self.hps.data.sampling_rate,
|
215 |
-
self.hps.data.hop_length,
|
216 |
-
self.hps.data.win_length,
|
217 |
-
center=False
|
218 |
-
)
|
219 |
-
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
220 |
-
|
221 |
-
|
222 |
-
class GptSoVits(nn.Module):
|
223 |
-
def __init__(self, vits, t2s):
|
224 |
-
super().__init__()
|
225 |
-
self.vits = vits
|
226 |
-
self.t2s = t2s
|
227 |
-
|
228 |
-
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
|
229 |
-
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
230 |
-
audio = self.vits(text_seq, pred_semantic, ref_audio)
|
231 |
-
if debug:
|
232 |
-
import onnxruntime
|
233 |
-
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
|
234 |
-
audio1 = sess.run(None, {
|
235 |
-
"text_seq" : text_seq.detach().cpu().numpy(),
|
236 |
-
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
|
237 |
-
"ref_audio" : ref_audio.detach().cpu().numpy()
|
238 |
-
})
|
239 |
-
return audio, audio1
|
240 |
-
return audio
|
241 |
-
|
242 |
-
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
|
243 |
-
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
|
244 |
-
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
245 |
-
torch.onnx.export(
|
246 |
-
self.vits,
|
247 |
-
(text_seq, pred_semantic, ref_audio),
|
248 |
-
f"onnx/{project_name}/{project_name}_vits.onnx",
|
249 |
-
input_names=["text_seq", "pred_semantic", "ref_audio"],
|
250 |
-
output_names=["audio"],
|
251 |
-
dynamic_axes={
|
252 |
-
"text_seq": {1 : "text_length"},
|
253 |
-
"pred_semantic": {2 : "pred_length"},
|
254 |
-
"ref_audio": {1 : "audio_length"},
|
255 |
-
},
|
256 |
-
opset_version=17,
|
257 |
-
verbose=False
|
258 |
-
)
|
259 |
-
|
260 |
-
|
261 |
-
class SSLModel(nn.Module):
|
262 |
-
def __init__(self):
|
263 |
-
super().__init__()
|
264 |
-
self.ssl = ssl_model
|
265 |
-
|
266 |
-
def forward(self, ref_audio_16k):
|
267 |
-
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
268 |
-
|
269 |
-
|
270 |
-
def export(vits_path, gpt_path, project_name):
|
271 |
-
vits = VitsModel(vits_path)
|
272 |
-
gpt = T2SModel(gpt_path, vits)
|
273 |
-
gpt_sovits = GptSoVits(vits, gpt)
|
274 |
-
ssl = SSLModel()
|
275 |
-
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
276 |
-
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
277 |
-
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
278 |
-
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
279 |
-
ref_audio = torch.randn((1, 48000 * 5)).float()
|
280 |
-
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
|
281 |
-
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
|
282 |
-
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
|
283 |
-
|
284 |
-
try:
|
285 |
-
os.mkdir(f"onnx/{project_name}")
|
286 |
-
except:
|
287 |
-
pass
|
288 |
-
|
289 |
-
ssl_content = ssl(ref_audio_16k).float()
|
290 |
-
|
291 |
-
debug = False
|
292 |
-
|
293 |
-
if debug:
|
294 |
-
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
295 |
-
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
|
296 |
-
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
|
297 |
-
return
|
298 |
-
|
299 |
-
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
300 |
-
|
301 |
-
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
302 |
-
|
303 |
-
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
304 |
-
|
305 |
-
MoeVSConf = {
|
306 |
-
"Folder" : f"{project_name}",
|
307 |
-
"Name" : f"{project_name}",
|
308 |
-
"Type" : "GPT-SoVits",
|
309 |
-
"Rate" : vits.hps.data.sampling_rate,
|
310 |
-
"NumLayers": gpt.t2s_model.num_layers,
|
311 |
-
"EmbeddingDim": gpt.t2s_model.embedding_dim,
|
312 |
-
"Dict": "BasicDict",
|
313 |
-
"BertPath": "chinese-roberta-wwm-ext-large",
|
314 |
-
"Symbol": symbols,
|
315 |
-
"AddBlank": False
|
316 |
-
}
|
317 |
-
|
318 |
-
MoeVSConfJson = json.dumps(MoeVSConf)
|
319 |
-
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
|
320 |
-
json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
|
321 |
-
|
322 |
-
|
323 |
-
if __name__ == "__main__":
|
324 |
-
try:
|
325 |
-
os.mkdir("onnx")
|
326 |
-
except:
|
327 |
-
pass
|
328 |
-
|
329 |
-
gpt_path = "GPT_weights/nahida-e25.ckpt"
|
330 |
-
vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
|
331 |
-
exp_path = "nahida"
|
332 |
-
export(vits_path, gpt_path, exp_path)
|
333 |
-
|
334 |
-
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|