Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,440 Bytes
841f290 daa8d34 841f290 daa8d34 841f290 daa8d34 841f290 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import copy
import os
import random
import re
import yaml
from cn2an import an2cn
from gxl_ai_utils.utils import utils_file
from wenet.utils.init_tokenizer import init_tokenizer
from gxl_ai_utils.config.gxl_config import GxlNode
from wenet.utils.init_model import init_model
import logging
import librosa
import torch
import torchaudio
def load_model_and_tokenizer(checkpoint_path, config_path):
"""
封装了加载模型和分词器的逻辑
Args:
checkpoint_path (str): 模型权重文件路径
config_path (str): 模型配置文件路径
device (torch.device): 加载模型的设备
Returns:
model: 加载好的模型
tokenizer: 加载好的分词器
"""
print(f"正在从以下路径加载模型: {checkpoint_path}")
args = GxlNode({"checkpoint": checkpoint_path})
configs = utils_file.load_dict_from_yaml(config_path)
model, configs = init_model(args, configs)
model = model.to(torch.bfloat16)
model.eval() # 设置为评估模式
tokenizer = init_tokenizer(configs)
print(f"模型 {checkpoint_path} ")
return model, tokenizer
def token_list2wav(token_list, prompt_speech, wav_path, cosyvoice):
token_list = [int(i) for i in token_list]
j = cosyvoice.inference_zero_shot_gz_22k(
'收到好友从远方寄来的生日礼物。',
'希望你以后能够做的比我还好呦。', prompt_speech, stream=False, token_list=token_list)
utils_file.makedir_for_file(wav_path)
torchaudio.save(wav_path, j['tts_speech'],cosyvoice.sample_rate)
print(f'语音合成完成,保存到 {wav_path}')
return wav_path
def do_resample(input_wav_path):
"""..."""
waveform, sample_rate = torchaudio.load(input_wav_path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
return waveform, 16000
def get_feat_from_wav_path(input_wav_path, device:torch.device=torch.device('cuda')):
"""..."""
waveform, sample_rate = do_resample(input_wav_path)
waveform = waveform.squeeze(0)
window = torch.hann_window(400)
stft = torch.stft(waveform, 400, 160, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=400, n_mels=80))
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
feat = log_spec.transpose(0, 1)
feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).to(device)
feat = feat.unsqueeze(0).to(device)
feat = feat.to(torch.bfloat16)
return feat, feat_lens
def do_format_shard_manifest4one(input_shards_path, tmp_file_path=None):
if tmp_file_path is None:
tmp_file_path = f'~/.cache/.temp/{random.randint(10000, 99999)}.txt'
data_path_i = input_shards_path
utils_file.logging_info(f'path:{data_path_i} ')
final_data_list_i = utils_file.load_list_file_clean(data_path_i)
# 判断数据类型
if "combines_list.txt" in data_path_i:
print(f'是 combine类型的数据')
tar_root_path = data_path_i.replace('combines_list.txt', 'combines_tar_root.txt')
if not os.path.exists(tar_root_path):
utils_file.logging_error(
f'combine_list.txt:{data_path_i} 对应的 combines_tar_root.txt:{tar_root_path} 不存在')
return
tar_root = utils_file.load_first_row_clean(tar_root_path)
if tar_root.endswith('/'):
tar_root = tar_root[:-1]
utils_file.logging_info(f' tar_root:{tar_root}')
new_final_data_list_i = []
for data_path_j in final_data_list_i:
# "combine_path|shard_path"
tmp_lines = f'{data_path_j}|{tar_root}/{utils_file.do_get_file_pure_name_from_path(data_path_j)}.tar'
new_final_data_list_i.append(tmp_lines)
else:
print(f'不是 combine类型的数据,是传统shard类型的数据')
new_final_data_list_i = [f'-|{data_path_j}' for data_path_j in final_data_list_i]
utils_file.logging_info(f'true load num is : {len(new_final_data_list_i)}')
utils_file.write_list_to_file(new_final_data_list_i, tmp_file_path)
return tmp_file_path
def convert_numbers_in_string(s):
# 正则表达式匹配数字(支持整数、小数、负数)
pattern = r'-?\d+\.?\d*'
def replace_func(match):
num_str = match.group()
try:
# 尝试转换数字
return an2cn(num_str)
except ValueError:
# 若转换失败(如非有效数字),返回原内容
return num_str
# 替换字符串中所有匹配的数字
return re.sub(pattern, replace_func, s)
def get_test_conf(config_path):
with open(config_path, 'r', encoding='utf-8') as fin:
print(f"加载配置文件 {config_path}")
configs = yaml.load(fin, Loader=yaml.FullLoader)
configs['dataset_conf']['filter_conf']['filter_no_extra_info'] = False
test_conf = copy.deepcopy(configs['dataset_conf'])
# test_conf['filter_conf']['max_length'] = 3000 # whisper最长处理30s 102400
test_conf['filter_conf']['min_length'] = 10
test_conf['filter_conf']['token_max_length'] = 102400
test_conf['filter_conf']['token_min_length'] = 1
test_conf['filter_conf']['max_output_input_ratio'] = 102400
test_conf['filter_conf']['min_output_input_ratio'] = 0
test_conf['filter_conf']['filter_no_extra_info'] = False
test_conf['filter_conf']['max_seq_len'] = 102400
test_conf['speed_perturb'] = False
test_conf['spec_aug'] = False
test_conf['spec_sub'] = False
test_conf['spec_trim'] = False
test_conf['shuffle'] = False
test_conf['sort'] = False
test_conf['cycle'] = 1
test_conf['list_shuffle'] = True
if 'fbank_conf' in test_conf:
test_conf['fbank_conf']['dither'] = 0.0
elif 'mfcc_conf' in test_conf:
test_conf['mfcc_conf']['dither'] = 0.0
test_conf['batch_conf']['batch_type'] = "static"
test_conf['batch_conf']['batch_size'] = 1
test_conf['split_num'] = 1
test_conf['multi_num'] = 1
test_conf['other_filter_conf'] = {}
test_conf['data_recover'] = False
return configs, test_conf
|