Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import os | |
import random | |
import re | |
import yaml | |
from cn2an import an2cn | |
from gxl_ai_utils.utils import utils_file | |
from wenet.utils.init_tokenizer import init_tokenizer | |
from gxl_ai_utils.config.gxl_config import GxlNode | |
from wenet.utils.init_model import init_model | |
import logging | |
import librosa | |
import torch | |
import torchaudio | |
def load_model_and_tokenizer(checkpoint_path, config_path): | |
""" | |
封装了加载模型和分词器的逻辑 | |
Args: | |
checkpoint_path (str): 模型权重文件路径 | |
config_path (str): 模型配置文件路径 | |
device (torch.device): 加载模型的设备 | |
Returns: | |
model: 加载好的模型 | |
tokenizer: 加载好的分词器 | |
""" | |
print(f"正在从以下路径加载模型: {checkpoint_path}") | |
args = GxlNode({"checkpoint": checkpoint_path}) | |
configs = utils_file.load_dict_from_yaml(config_path) | |
model, configs = init_model(args, configs) | |
model = model.to(torch.bfloat16) | |
model.eval() # 设置为评估模式 | |
tokenizer = init_tokenizer(configs) | |
print(f"模型 {checkpoint_path} ") | |
return model, tokenizer | |
def token_list2wav(token_list, prompt_speech, wav_path, cosyvoice): | |
token_list = [int(i) for i in token_list] | |
j = cosyvoice.inference_zero_shot_gz_22k( | |
'收到好友从远方寄来的生日礼物。', | |
'希望你以后能够做的比我还好呦。', prompt_speech, stream=False, token_list=token_list) | |
utils_file.makedir_for_file(wav_path) | |
torchaudio.save(wav_path, j['tts_speech'],cosyvoice.sample_rate) | |
print(f'语音合成完成,保存到 {wav_path}') | |
return wav_path | |
def do_resample(input_wav_path): | |
"""...""" | |
waveform, sample_rate = torchaudio.load(input_wav_path) | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
waveform = resampler(waveform) | |
return waveform, 16000 | |
def get_feat_from_wav_path(input_wav_path, device:torch.device=torch.device('cuda')): | |
"""...""" | |
waveform, sample_rate = do_resample(input_wav_path) | |
waveform = waveform.squeeze(0) | |
window = torch.hann_window(400) | |
stft = torch.stft(waveform, 400, 160, window=window, return_complex=True) | |
magnitudes = stft[..., :-1].abs() ** 2 | |
filters = torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=400, n_mels=80)) | |
mel_spec = filters @ magnitudes | |
log_spec = torch.clamp(mel_spec, min=1e-10).log10() | |
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | |
log_spec = (log_spec + 4.0) / 4.0 | |
feat = log_spec.transpose(0, 1) | |
feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).to(device) | |
feat = feat.unsqueeze(0).to(device) | |
feat = feat.to(torch.bfloat16) | |
return feat, feat_lens | |
def do_format_shard_manifest4one(input_shards_path, tmp_file_path=None): | |
if tmp_file_path is None: | |
tmp_file_path = f'~/.cache/.temp/{random.randint(10000, 99999)}.txt' | |
data_path_i = input_shards_path | |
utils_file.logging_info(f'path:{data_path_i} ') | |
final_data_list_i = utils_file.load_list_file_clean(data_path_i) | |
# 判断数据类型 | |
if "combines_list.txt" in data_path_i: | |
print(f'是 combine类型的数据') | |
tar_root_path = data_path_i.replace('combines_list.txt', 'combines_tar_root.txt') | |
if not os.path.exists(tar_root_path): | |
utils_file.logging_error( | |
f'combine_list.txt:{data_path_i} 对应的 combines_tar_root.txt:{tar_root_path} 不存在') | |
return | |
tar_root = utils_file.load_first_row_clean(tar_root_path) | |
if tar_root.endswith('/'): | |
tar_root = tar_root[:-1] | |
utils_file.logging_info(f' tar_root:{tar_root}') | |
new_final_data_list_i = [] | |
for data_path_j in final_data_list_i: | |
# "combine_path|shard_path" | |
tmp_lines = f'{data_path_j}|{tar_root}/{utils_file.do_get_file_pure_name_from_path(data_path_j)}.tar' | |
new_final_data_list_i.append(tmp_lines) | |
else: | |
print(f'不是 combine类型的数据,是传统shard类型的数据') | |
new_final_data_list_i = [f'-|{data_path_j}' for data_path_j in final_data_list_i] | |
utils_file.logging_info(f'true load num is : {len(new_final_data_list_i)}') | |
utils_file.write_list_to_file(new_final_data_list_i, tmp_file_path) | |
return tmp_file_path | |
def convert_numbers_in_string(s): | |
# 正则表达式匹配数字(支持整数、小数、负数) | |
pattern = r'-?\d+\.?\d*' | |
def replace_func(match): | |
num_str = match.group() | |
try: | |
# 尝试转换数字 | |
return an2cn(num_str) | |
except ValueError: | |
# 若转换失败(如非有效数字),返回原内容 | |
return num_str | |
# 替换字符串中所有匹配的数字 | |
return re.sub(pattern, replace_func, s) | |
def get_test_conf(config_path): | |
with open(config_path, 'r', encoding='utf-8') as fin: | |
print(f"加载配置文件 {config_path}") | |
configs = yaml.load(fin, Loader=yaml.FullLoader) | |
configs['dataset_conf']['filter_conf']['filter_no_extra_info'] = False | |
test_conf = copy.deepcopy(configs['dataset_conf']) | |
# test_conf['filter_conf']['max_length'] = 3000 # whisper最长处理30s 102400 | |
test_conf['filter_conf']['min_length'] = 10 | |
test_conf['filter_conf']['token_max_length'] = 102400 | |
test_conf['filter_conf']['token_min_length'] = 1 | |
test_conf['filter_conf']['max_output_input_ratio'] = 102400 | |
test_conf['filter_conf']['min_output_input_ratio'] = 0 | |
test_conf['filter_conf']['filter_no_extra_info'] = False | |
test_conf['filter_conf']['max_seq_len'] = 102400 | |
test_conf['speed_perturb'] = False | |
test_conf['spec_aug'] = False | |
test_conf['spec_sub'] = False | |
test_conf['spec_trim'] = False | |
test_conf['shuffle'] = False | |
test_conf['sort'] = False | |
test_conf['cycle'] = 1 | |
test_conf['list_shuffle'] = True | |
if 'fbank_conf' in test_conf: | |
test_conf['fbank_conf']['dither'] = 0.0 | |
elif 'mfcc_conf' in test_conf: | |
test_conf['mfcc_conf']['dither'] = 0.0 | |
test_conf['batch_conf']['batch_type'] = "static" | |
test_conf['batch_conf']['batch_size'] = 1 | |
test_conf['split_num'] = 1 | |
test_conf['multi_num'] = 1 | |
test_conf['other_filter_conf'] = {} | |
test_conf['data_recover'] = False | |
return configs, test_conf | |