lj1995 commited on
Commit
0db7fec
·
verified ·
1 Parent(s): 1645547

Delete onnx_export.py

Browse files
Files changed (1) hide show
  1. onnx_export.py +0 -334
onnx_export.py DELETED
@@ -1,334 +0,0 @@
1
- from module.models_onnx import SynthesizerTrn, symbols
2
- from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
3
- import torch
4
- import torchaudio
5
- from torch import nn
6
- from feature_extractor import cnhubert
7
- cnhubert_base_path = "pretrained_models/chinese-hubert-base"
8
- cnhubert.cnhubert_base_path=cnhubert_base_path
9
- ssl_model = cnhubert.get_model()
10
- from text import cleaned_text_to_sequence
11
- import soundfile
12
- from tools.my_utils import load_audio
13
- import os
14
- import json
15
-
16
- def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
17
- hann_window = torch.hann_window(win_size).to(
18
- dtype=y.dtype, device=y.device
19
- )
20
- y = torch.nn.functional.pad(
21
- y.unsqueeze(1),
22
- (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
23
- mode="reflect",
24
- )
25
- y = y.squeeze(1)
26
- spec = torch.stft(
27
- y,
28
- n_fft,
29
- hop_length=hop_size,
30
- win_length=win_size,
31
- window=hann_window,
32
- center=center,
33
- pad_mode="reflect",
34
- normalized=False,
35
- onesided=True,
36
- return_complex=False,
37
- )
38
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
39
- return spec
40
-
41
-
42
- class DictToAttrRecursive(dict):
43
- def __init__(self, input_dict):
44
- super().__init__(input_dict)
45
- for key, value in input_dict.items():
46
- if isinstance(value, dict):
47
- value = DictToAttrRecursive(value)
48
- self[key] = value
49
- setattr(self, key, value)
50
-
51
- def __getattr__(self, item):
52
- try:
53
- return self[item]
54
- except KeyError:
55
- raise AttributeError(f"Attribute {item} not found")
56
-
57
- def __setattr__(self, key, value):
58
- if isinstance(value, dict):
59
- value = DictToAttrRecursive(value)
60
- super(DictToAttrRecursive, self).__setitem__(key, value)
61
- super().__setattr__(key, value)
62
-
63
- def __delattr__(self, item):
64
- try:
65
- del self[item]
66
- except KeyError:
67
- raise AttributeError(f"Attribute {item} not found")
68
-
69
-
70
- class T2SEncoder(nn.Module):
71
- def __init__(self, t2s, vits):
72
- super().__init__()
73
- self.encoder = t2s.onnx_encoder
74
- self.vits = vits
75
-
76
- def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
77
- codes = self.vits.extract_latent(ssl_content)
78
- prompt_semantic = codes[0, 0]
79
- bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
80
- all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
81
- bert = bert.unsqueeze(0)
82
- prompt = prompt_semantic.unsqueeze(0)
83
- return self.encoder(all_phoneme_ids, bert), prompt
84
-
85
-
86
- class T2SModel(nn.Module):
87
- def __init__(self, t2s_path, vits_model):
88
- super().__init__()
89
- dict_s1 = torch.load(t2s_path, map_location="cpu")
90
- self.config = dict_s1["config"]
91
- self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
92
- self.t2s_model.load_state_dict(dict_s1["weight"])
93
- self.t2s_model.eval()
94
- self.vits_model = vits_model.vq_model
95
- self.hz = 50
96
- self.max_sec = self.config["data"]["max_sec"]
97
- self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
98
- self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
99
- self.t2s_model = self.t2s_model.model
100
- self.t2s_model.init_onnx()
101
- self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
102
- self.first_stage_decoder = self.t2s_model.first_stage_decoder
103
- self.stage_decoder = self.t2s_model.stage_decoder
104
- #self.t2s_model = torch.jit.script(self.t2s_model)
105
-
106
- def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
107
- early_stop_num = self.t2s_model.early_stop_num
108
-
109
- #[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
110
- x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
111
-
112
- prefix_len = prompts.shape[1]
113
-
114
- #[1,N,512] [1,N]
115
- y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
116
-
117
- stop = False
118
- for idx in range(1, 1500):
119
- #[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
120
- enco = self.stage_decoder(y, k, v, y_emb, x_example)
121
- y, k, v, y_emb, logits, samples = enco
122
- if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
123
- stop = True
124
- if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
125
- stop = True
126
- if stop:
127
- break
128
- y[0, -1] = 0
129
-
130
- return y[:, -idx:].unsqueeze(0)
131
-
132
- def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
133
- #self.onnx_encoder = torch.jit.script(self.onnx_encoder)
134
- if dynamo:
135
- export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
136
- onnx_encoder_export_output = torch.onnx.dynamo_export(
137
- self.onnx_encoder,
138
- (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
139
- export_options=export_options
140
- )
141
- onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
142
- return
143
-
144
- torch.onnx.export(
145
- self.onnx_encoder,
146
- (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
147
- f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
148
- input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
149
- output_names=["x", "prompts"],
150
- dynamic_axes={
151
- "ref_seq": {1 : "ref_length"},
152
- "text_seq": {1 : "text_length"},
153
- "ref_bert": {0 : "ref_length"},
154
- "text_bert": {0 : "text_length"},
155
- "ssl_content": {2 : "ssl_length"},
156
- },
157
- opset_version=16
158
- )
159
- x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
160
-
161
- torch.onnx.export(
162
- self.first_stage_decoder,
163
- (x, prompts),
164
- f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
165
- input_names=["x", "prompts"],
166
- output_names=["y", "k", "v", "y_emb", "x_example"],
167
- dynamic_axes={
168
- "x": {1 : "x_length"},
169
- "prompts": {1 : "prompts_length"},
170
- },
171
- verbose=False,
172
- opset_version=16
173
- )
174
- y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
175
-
176
- torch.onnx.export(
177
- self.stage_decoder,
178
- (y, k, v, y_emb, x_example),
179
- f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
180
- input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
181
- output_names=["y", "k", "v", "y_emb", "logits", "samples"],
182
- dynamic_axes={
183
- "iy": {1 : "iy_length"},
184
- "ik": {1 : "ik_length"},
185
- "iv": {1 : "iv_length"},
186
- "iy_emb": {1 : "iy_emb_length"},
187
- "ix_example": {1 : "ix_example_length"},
188
- },
189
- verbose=False,
190
- opset_version=16
191
- )
192
-
193
-
194
- class VitsModel(nn.Module):
195
- def __init__(self, vits_path):
196
- super().__init__()
197
- dict_s2 = torch.load(vits_path,map_location="cpu")
198
- self.hps = dict_s2["config"]
199
- self.hps = DictToAttrRecursive(self.hps)
200
- self.hps.model.semantic_frame_rate = "25hz"
201
- self.vq_model = SynthesizerTrn(
202
- self.hps.data.filter_length // 2 + 1,
203
- self.hps.train.segment_size // self.hps.data.hop_length,
204
- n_speakers=self.hps.data.n_speakers,
205
- **self.hps.model
206
- )
207
- self.vq_model.eval()
208
- self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
209
-
210
- def forward(self, text_seq, pred_semantic, ref_audio):
211
- refer = spectrogram_torch(
212
- ref_audio,
213
- self.hps.data.filter_length,
214
- self.hps.data.sampling_rate,
215
- self.hps.data.hop_length,
216
- self.hps.data.win_length,
217
- center=False
218
- )
219
- return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
220
-
221
-
222
- class GptSoVits(nn.Module):
223
- def __init__(self, vits, t2s):
224
- super().__init__()
225
- self.vits = vits
226
- self.t2s = t2s
227
-
228
- def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
229
- pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
230
- audio = self.vits(text_seq, pred_semantic, ref_audio)
231
- if debug:
232
- import onnxruntime
233
- sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
234
- audio1 = sess.run(None, {
235
- "text_seq" : text_seq.detach().cpu().numpy(),
236
- "pred_semantic" : pred_semantic.detach().cpu().numpy(),
237
- "ref_audio" : ref_audio.detach().cpu().numpy()
238
- })
239
- return audio, audio1
240
- return audio
241
-
242
- def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
243
- self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
244
- pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
245
- torch.onnx.export(
246
- self.vits,
247
- (text_seq, pred_semantic, ref_audio),
248
- f"onnx/{project_name}/{project_name}_vits.onnx",
249
- input_names=["text_seq", "pred_semantic", "ref_audio"],
250
- output_names=["audio"],
251
- dynamic_axes={
252
- "text_seq": {1 : "text_length"},
253
- "pred_semantic": {2 : "pred_length"},
254
- "ref_audio": {1 : "audio_length"},
255
- },
256
- opset_version=17,
257
- verbose=False
258
- )
259
-
260
-
261
- class SSLModel(nn.Module):
262
- def __init__(self):
263
- super().__init__()
264
- self.ssl = ssl_model
265
-
266
- def forward(self, ref_audio_16k):
267
- return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
268
-
269
-
270
- def export(vits_path, gpt_path, project_name):
271
- vits = VitsModel(vits_path)
272
- gpt = T2SModel(gpt_path, vits)
273
- gpt_sovits = GptSoVits(vits, gpt)
274
- ssl = SSLModel()
275
- ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
276
- text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
277
- ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
278
- text_bert = torch.randn((text_seq.shape[1], 1024)).float()
279
- ref_audio = torch.randn((1, 48000 * 5)).float()
280
- # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
281
- ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
282
- ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
283
-
284
- try:
285
- os.mkdir(f"onnx/{project_name}")
286
- except:
287
- pass
288
-
289
- ssl_content = ssl(ref_audio_16k).float()
290
-
291
- debug = False
292
-
293
- if debug:
294
- a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
295
- soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
296
- soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
297
- return
298
-
299
- a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
300
-
301
- soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
302
-
303
- gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
304
-
305
- MoeVSConf = {
306
- "Folder" : f"{project_name}",
307
- "Name" : f"{project_name}",
308
- "Type" : "GPT-SoVits",
309
- "Rate" : vits.hps.data.sampling_rate,
310
- "NumLayers": gpt.t2s_model.num_layers,
311
- "EmbeddingDim": gpt.t2s_model.embedding_dim,
312
- "Dict": "BasicDict",
313
- "BertPath": "chinese-roberta-wwm-ext-large",
314
- "Symbol": symbols,
315
- "AddBlank": False
316
- }
317
-
318
- MoeVSConfJson = json.dumps(MoeVSConf)
319
- with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
320
- json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
321
-
322
-
323
- if __name__ == "__main__":
324
- try:
325
- os.mkdir("onnx")
326
- except:
327
- pass
328
-
329
- gpt_path = "GPT_weights/nahida-e25.ckpt"
330
- vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
331
- exp_path = "nahida"
332
- export(vits_path, gpt_path, exp_path)
333
-
334
- # soundfile.write("out.wav", a, vits.hps.data.sampling_rate)