File size: 6,440 Bytes
841f290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daa8d34
841f290
 
 
 
 
 
 
 
 
 
 
 
 
 
daa8d34
841f290
 
daa8d34
841f290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import copy
import os
import random
import re

import yaml
from cn2an import an2cn
from gxl_ai_utils.utils import utils_file
from wenet.utils.init_tokenizer import init_tokenizer
from gxl_ai_utils.config.gxl_config import GxlNode
from wenet.utils.init_model import init_model
import logging
import librosa
import torch
import torchaudio



def load_model_and_tokenizer(checkpoint_path, config_path):
    """
    封装了加载模型和分词器的逻辑
    Args:
        checkpoint_path (str): 模型权重文件路径
        config_path (str): 模型配置文件路径
        device (torch.device): 加载模型的设备
    Returns:
        model: 加载好的模型
        tokenizer: 加载好的分词器
    """
    print(f"正在从以下路径加载模型: {checkpoint_path}")
    args = GxlNode({"checkpoint": checkpoint_path})
    configs = utils_file.load_dict_from_yaml(config_path)
    model, configs = init_model(args, configs)
    model = model.to(torch.bfloat16)
    model.eval()  # 设置为评估模式
    tokenizer = init_tokenizer(configs)
    print(f"模型 {checkpoint_path} ")
    return model, tokenizer

def token_list2wav(token_list, prompt_speech, wav_path, cosyvoice):
    token_list = [int(i) for i in token_list]
    j = cosyvoice.inference_zero_shot_gz_22k(
        '收到好友从远方寄来的生日礼物。',
        '希望你以后能够做的比我还好呦。', prompt_speech, stream=False, token_list=token_list)
    utils_file.makedir_for_file(wav_path)
    torchaudio.save(wav_path, j['tts_speech'],cosyvoice.sample_rate)
    print(f'语音合成完成,保存到 {wav_path}')
    return wav_path

def do_resample(input_wav_path):
    """..."""
    waveform, sample_rate = torchaudio.load(input_wav_path)
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
    waveform = resampler(waveform)
    return waveform, 16000


def get_feat_from_wav_path(input_wav_path, device:torch.device=torch.device('cuda')):
    """..."""
    waveform, sample_rate = do_resample(input_wav_path)
    waveform = waveform.squeeze(0)
    window = torch.hann_window(400)
    stft = torch.stft(waveform, 400, 160, window=window, return_complex=True)
    magnitudes = stft[..., :-1].abs() ** 2
    filters = torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=400, n_mels=80))
    mel_spec = filters @ magnitudes
    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0
    feat = log_spec.transpose(0, 1)
    feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).to(device)
    feat = feat.unsqueeze(0).to(device)
    feat = feat.to(torch.bfloat16)
    return feat, feat_lens



def do_format_shard_manifest4one(input_shards_path, tmp_file_path=None):
    if tmp_file_path is None:
        tmp_file_path = f'~/.cache/.temp/{random.randint(10000, 99999)}.txt'
    data_path_i = input_shards_path
    utils_file.logging_info(f'path:{data_path_i} ')
    final_data_list_i = utils_file.load_list_file_clean(data_path_i)
    # 判断数据类型
    if "combines_list.txt" in data_path_i:
        print(f'是 combine类型的数据')
        tar_root_path = data_path_i.replace('combines_list.txt', 'combines_tar_root.txt')
        if not os.path.exists(tar_root_path):
            utils_file.logging_error(
                f'combine_list.txt:{data_path_i} 对应的 combines_tar_root.txt:{tar_root_path} 不存在')
            return
        tar_root = utils_file.load_first_row_clean(tar_root_path)
        if tar_root.endswith('/'):
            tar_root = tar_root[:-1]
        utils_file.logging_info(f' tar_root:{tar_root}')
        new_final_data_list_i = []
        for data_path_j in final_data_list_i:
            # "combine_path|shard_path"
            tmp_lines = f'{data_path_j}|{tar_root}/{utils_file.do_get_file_pure_name_from_path(data_path_j)}.tar'
            new_final_data_list_i.append(tmp_lines)
    else:
        print(f'不是 combine类型的数据,是传统shard类型的数据')
        new_final_data_list_i = [f'-|{data_path_j}' for data_path_j in final_data_list_i]

    utils_file.logging_info(f'true load num is : {len(new_final_data_list_i)}')
    utils_file.write_list_to_file(new_final_data_list_i, tmp_file_path)
    return tmp_file_path



def convert_numbers_in_string(s):
    # 正则表达式匹配数字(支持整数、小数、负数)
    pattern = r'-?\d+\.?\d*'

    def replace_func(match):
        num_str = match.group()
        try:
            # 尝试转换数字
            return an2cn(num_str)
        except ValueError:
            # 若转换失败(如非有效数字),返回原内容
            return num_str
    # 替换字符串中所有匹配的数字
    return re.sub(pattern, replace_func, s)

def get_test_conf(config_path):
    with open(config_path, 'r', encoding='utf-8') as fin:
        print(f"加载配置文件 {config_path}")
        configs = yaml.load(fin, Loader=yaml.FullLoader)
    configs['dataset_conf']['filter_conf']['filter_no_extra_info'] = False
    test_conf = copy.deepcopy(configs['dataset_conf'])

    # test_conf['filter_conf']['max_length'] = 3000 # whisper最长处理30s 102400
    test_conf['filter_conf']['min_length'] = 10
    test_conf['filter_conf']['token_max_length'] = 102400
    test_conf['filter_conf']['token_min_length'] = 1
    test_conf['filter_conf']['max_output_input_ratio'] = 102400
    test_conf['filter_conf']['min_output_input_ratio'] = 0
    test_conf['filter_conf']['filter_no_extra_info'] = False
    test_conf['filter_conf']['max_seq_len'] = 102400
    test_conf['speed_perturb'] = False
    test_conf['spec_aug'] = False
    test_conf['spec_sub'] = False
    test_conf['spec_trim'] = False
    test_conf['shuffle'] = False
    test_conf['sort'] = False
    test_conf['cycle'] = 1
    test_conf['list_shuffle'] = True
    if 'fbank_conf' in test_conf:
        test_conf['fbank_conf']['dither'] = 0.0
    elif 'mfcc_conf' in test_conf:
        test_conf['mfcc_conf']['dither'] = 0.0
    test_conf['batch_conf']['batch_type'] = "static"
    test_conf['batch_conf']['batch_size'] = 1
    test_conf['split_num'] = 1
    test_conf['multi_num'] = 1
    test_conf['other_filter_conf'] = {}
    test_conf['data_recover'] = False
    return configs, test_conf