Spaces:
Running
on
Zero
Running
on
Zero
File size: 21,368 Bytes
841f290 aea4592 841f290 aea4592 841f290 13f013f 841f290 58a2540 841f290 aea4592 841f290 13f013f 841f290 6f7f7cd 841f290 6f7f7cd 841f290 13f013f 841f290 13f013f 841f290 13f013f 841f290 6f7f7cd 841f290 6f7f7cd 841f290 aea4592 841f290 6f7f7cd 13f013f 6f7f7cd aea4592 6f7f7cd aea4592 13f013f 6f7f7cd 13f013f aea4592 6f7f7cd aea4592 6f7f7cd 13f013f 6f7f7cd aea4592 6f7f7cd 13f013f 6f7f7cd aea4592 6f7f7cd 13f013f 6f7f7cd aea4592 6f7f7cd 13f013f 6f7f7cd aea4592 6f7f7cd 13f013f 6f7f7cd 13f013f 6f7f7cd 13f013f 6f7f7cd 13f013f 6f7f7cd 841f290 6f7f7cd 841f290 7158248 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 |
import ast
import base64
import datetime
import json
import logging
import os
import librosa
import spaces
import gradio as gr
import sys
import time
import traceback
import torch
import torchaudio
from common_utils.utils4infer import get_feat_from_wav_path, load_model_and_tokenizer, token_list2wav
sys.path.insert(0, '.')
sys.path.insert(0, './tts')
sys.path.insert(0, './tts/third_party/Matcha-TTS')
from patches import modelling_qwen2_infer_gpu # 打patch
from tts.cosyvoice.cli.cosyvoice import CosyVoice
from tts.cosyvoice.utils.file_utils import load_wav
is_npu = False
try:
import torch_npu
except ImportError:
is_npu = False
print("torch_npu is not available. if you want to use npu, please install it.")
import time
import datetime
import torch
from common_utils.utils4infer import get_feat_from_wav_path, token_list2wav
from huggingface_hub import hf_hub_download
# 从Hugging Face下载.pt文件
CHECKPOINT_PATH_A = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="language_think_final.pt")
CHECKPOINT_PATH_B=None
# CHECKPOINT_PATH_B= hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="tag_think_final.pt")
CONFIG_PATH = "./conf/ct_config.yaml"
NAME_A="language_think"
NAME_B="tag_think"
cosyvoice_model_path = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="CosyVoice-300M-25Hz.tar")
# 将tar包解压到当前目录
os.system(f"tar -xvf {cosyvoice_model_path}")
print("解压cosyvoice模型pt文件完成")
cosyvoice_model_path="./CosyVoice-300M-25Hz"
print("开始加载模型 A...")
model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
model_a
print("\n开始加载模型 B...")
if CHECKPOINT_PATH_B is not None:
model_b, tokenizer_b = load_model_and_tokenizer(CHECKPOINT_PATH_B, CONFIG_PATH)
model_b.eval().cuda()
else:
model_b, tokenizer_b = None, None
loaded_models = {
NAME_A: {"model": model_b, "tokenizer": tokenizer_b},
NAME_B: {"model": model_b, "tokenizer": tokenizer_b},
} if model_b is not None else {
NAME_A: {"model": model_b, "tokenizer": tokenizer_b},
}
print("\n所有模型已加载完毕。")
# cosyvoice = CosyVoice(cosyvoice_model_path)
# cosyvoice.eval().cuda()
# 将图片转换为 Base64
with open("./tts/assert/实验室.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
# 任务映射
TASK_PROMPT_MAPPING = {
"empathetic_s2s_dialogue with think": "THINK",
"empathetic_s2s_dialogue no think": "s2s_no_think",
"empathetic_s2t_dialogue with think": "s2t_think",
"empathetic_s2t_dialogue no think": "s2t_no_think",
"ASR (Automatic Speech Recognition)": "转录这段音频中的语音内容为文字。",
"SRWT (Speech Recognition with Timestamps)": "请识别音频内容,并对所有英文词和中文字进行时间对齐,标注格式为<>,时间精度0.1秒。",
"VED (Vocal Event Detection)(类别:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other)": "请将音频转化为文字,并在末尾添加相关音频事件标签,标签格式为<>。",
"SER (Speech Emotion Recognition)(类别:sad,anger,neutral,happy,surprise,fear,disgust,和other)": "请将音频内容转录成文字记录,并在记录末尾标注情感标签,以<>表示。",
"SSR (Speaking Style Recognition)(类别:新闻科普,恐怖故事,童话故事,客服,诗歌散文,有声书,日常口语,其他)": "请将音频中的讲话内容转化为文字,并在结尾处注明风格标签,用<>表示。",
"SGC (Speaker Gender Classification)(类别:female,male)": "请将音频转录为文字,并在文本末尾标注性别标签,标签格式为<>。",
"SAP (Speaker Age Prediction)(类别:child、adult和old)": "请将这段音频转录成文字,并在末尾加上年龄标签,格式为<>。",
"STTC (Speech to Text Chat)": "首先将语音转录为文字,然后对语音内容进行回复,转录和文字之间使用<开始回答>分割。",
"Only Age Prediction(类别:child、adult和old)": "请根据音频分析发言者的年龄并输出年龄标签,标签格式为<>。",
"Only Gender Classification(类别:female,male)": "根据下述音频内容判断说话者性别,返回性别标签,格式为<>.",
"Only Style Recognition(类别:新闻科普,恐怖故事,童话故事,客服,诗歌散文,有声书,日常口语,其他)": "对于以下音频,请直接判断风格并返回风格标签,标签格式为<>。",
"Only Emotion Recognition(类别:sad,anger,neutral,happy,surprise,fear,disgust,和other)": "请鉴别音频中的发言者情感并标出,标签格式为<>。",
"Only Event Detection(类别:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other)": "对音频进行标签化,返回音频事件标签,标签格式为<>。",
"ASR+AGE+GENDER": '请将这段音频进行转录,并在转录完成的文本末尾附加<年龄> <性别>标签。',
"AGE+GENDER": "请识别以下音频发言者的年龄和性别.",
"ASR+STYLE+AGE+GENDER": "请对以下音频内容进行转录,并在文本结尾分别添加<风格>、<年龄>、<性别>标签。",
"STYLE+AGE+GENDER": "请对以下音频进行分析,识别说话风格、说话者年龄和性别。",
"ASR with punctuations": "需对提供的语音文件执行文本转换,同时为转换结果补充必要的标点。",
"ASR EVENT AGE GENDER": "请将以下音频内容进行转录,并在转录完成的文本末尾分别附加<音频事件>、<年龄>、<性别>标签。",
"ASR EMOTION AGE GENDER": "请将下列音频内容进行转录,并在转录文本的末尾分别添加<情感>、<年龄>、<性别>标签。",
}
prompt_path = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="prompt.wav")
prompt_audio_choices = [
{"name": "拟人",
"value": prompt_path},
]
prompt_audio_cache = {}
for item in prompt_audio_choices:
prompt_audio_cache[item["value"]] = load_wav(item["value"], 22050)
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
@spaces.GPU
def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice, prompt_speech_data):
"""
合并所有推理逻辑的单个函数,处理所有任务类型
"""
print(f"wav_path: {input_wav_path}, prompt:{input_prompt}")
# 检查音频输入合法性
if input_wav_path is None and not input_prompt.endswith(("_TTS", "_T2T")):
print("音频信息未输入,且不是T2S或T2T任务")
return "错误:需要音频输入"
if input_wav_path is not None:
waveform, sample_rate = torchaudio.load(input_wav_path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
waveform = waveform.squeeze(0)
window = torch.hann_window(400)
stft = torch.stft(waveform, 400, 160, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=400, n_mels=80))
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
feat = log_spec.transpose(0, 1)
feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).cuda()
feat = feat.unsqueeze(0).cuda()
feat = feat.to(torch.bfloat16)
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
else:
feat = None
feat_lens = None
# 通用初始化:模型设备设置
start_time = time.time()
res_text = None
model_a.eval().cuda()
try:
# 1. 处理TTS任务
if input_prompt.endswith("_TTS"):
text_for_tts = input_prompt.replace("_TTS", "")
# T2S推理逻辑
res_tensor = model_a.generate_tts(device=torch.device("cuda"), text=text_for_tts)[0]
res_token_list = res_tensor.tolist()
res_text = res_token_list[:-1]
print(f"T2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 2. 处理自定义提示任务
elif input_prompt.endswith("_self_prompt"):
prompt = input_prompt.replace("_self_prompt", "")
# S2T推理逻辑
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
# waveform, sample_rate = do_resample(input_wav_path)
res_text = model_a.generate(
wavs=feat,
wavs_len=feat_lens,
prompt=prompt,
cache_implementation="static"
)[0]
print(f"S2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 3. 处理T2T任务
elif input_prompt.endswith("_T2T"):
question_txt = input_prompt.replace("_T2T", "")
# T2T推理逻辑
print(f'开始t2t推理, question_txt: {question_txt}')
if is_npu: torch_npu.npu.synchronize()
res_text = model_a.generate_text2text(
device=torch.device("cuda"),
text=question_txt
)[0]
if is_npu: torch_npu.npu.synchronize()
print(f"T2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 4. 处理S2S无思考任务
elif input_prompt in ["识别语音内容,并以文字方式作出回答。",
"请推断对这段语音回答时的情感,标注情感类型,撰写流畅自然的聊天回复,并生成情感语音token。",
"s2s_no_think"]:
# S2S推理逻辑
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
if is_npu: torch_npu.npu.synchronize()
output_text, text_res, speech_res = model_a.generate_s2s_no_stream_with_repetition_penalty(
wavs=feat,
wavs_len=feat_lens,
)
if is_npu: torch_npu.npu.synchronize()
res_text = f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
print(f"S2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 5. 处理S2S有思考任务
elif input_prompt == "THINK":
# S2S带思考推理逻辑
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
if is_npu: torch_npu.npu.synchronize()
output_text, text_res, speech_res = model_a.generate_s2s_no_stream_think_with_repetition_penalty(
wavs=feat,
wavs_len=feat_lens,
)
if is_npu: torch_npu.npu.synchronize()
res_text = f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
print(f"S2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 6. 处理S2T4Chat无思考任务
elif input_prompt == "s2t_no_think":
# S2T4Chat推理逻辑
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
if is_npu: torch_npu.npu.synchronize()
res_text = model_a.generate4chat(
wavs=feat,
wavs_len=feat_lens,
cache_implementation="static"
)[0]
if is_npu: torch_npu.npu.synchronize()
print(f"S2T4Chat 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 7. 处理S2T4Chat有思考任务
elif input_prompt == "s2t_think":
# S2T4Chat带思考推理逻辑
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
if is_npu: torch_npu.npu.synchronize()
res_text = model_a.generate4chat_think(
wavs=feat,
wavs_len=feat_lens,
cache_implementation="static"
)[0]
if is_npu: torch_npu.npu.synchronize()
print(f"S2T4Chat 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 8. 处理默认S2T任务
else:
# 默认S2T推理逻辑
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
if is_npu: torch_npu.npu.synchronize()
res_text = model_a.generate(
wavs=feat,
wavs_len=feat_lens,
prompt=input_prompt,
cache_implementation="static"
)[0]
if is_npu: torch_npu.npu.synchronize()
print(f"S2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 替换特定标签
res_text = res_text.replace("<youth>", "<adult>").replace("<middle_age>", "<adult>").replace("<middle>",
"<adult>")
except Exception as e:
print(f"推理过程出错: {str(e)}")
return f"错误:{str(e)}"
output_res =res_text
# 4. 处理输出 (逻辑不变)
wav_path_output = input_wav_path
if task_choice == "TTS任务" or "empathetic_s2s_dialogue" in task_choice:
if isinstance(output_res, list): # TTS case
# cosyvoice.eval()
# time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# wav_path = f"./tmp/{time_str}.wav"
# wav_path_output = token_list2wav(output_res, prompt_speech_data, wav_path, cosyvoice)
# wav_path_output = get_wav_from_token_list(output_res, prompt_speech_data)
output_res = "生成的token: " + str(output_res)
elif isinstance(output_res, str) and "|" in output_res: # S2S case
try:
text_res, token_list_str = output_res.split("|")
token_list = json.loads(token_list_str)
# cosyvoice.eval()
time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
wav_path = f"./tmp/{time_str}.wav"
# wav_path_output = token_list2wav(token_list, prompt_speech_data, wav_path, cosyvoice)
# wav_path_output = get_wav_from_token_list(token_list, prompt_speech_data)
output_res = text_res
except (ValueError, json.JSONDecodeError) as e:
print(f"处理S2S输出时出错: {e}")
output_res = f"错误:无法解析模型输出 - {output_res}"
return output_res, wav_path_output
def save_to_jsonl(if_correct, wav, prompt, res):
data = {
"if_correct": if_correct,
"wav": wav,
"task": prompt,
"res": res
}
with open("results.jsonl", "a", encoding="utf-8") as f:
f.write(json.dumps(data, ensure_ascii=False) + "\n")
def download_audio(input_wav_path):
return input_wav_path if input_wav_path else None
# --- Gradio 界面 ---
with gr.Blocks() as demo:
gr.Markdown(
f"""
<div style="display: flex; align-items: center; justify-content: center; text-align: center;">
<h1 style="font-family: 'Arial', sans-serif; color: #014377; font-size: 32px; margin-bottom: 0; display: inline-block; vertical-align: middle;">
OSUM Speech Understanding Model Test
</h1>
</div>
"""
)
# ### --- 关键修改:添加模型选择器 --- ###
with gr.Row():
model_selector = gr.Radio(
choices=list(loaded_models.keys()), # 从加载的模型字典中获取选项
value=NAME_A, # 默认值
label="选择推理模型",
interactive=True
)
with gr.Row():
with gr.Column(scale=1, min_width=300):
audio_input = gr.Audio(label="录音", sources=["microphone", "upload"], type="filepath", visible=True)
with gr.Column(scale=1, min_width=300):
output_text = gr.Textbox(label="输出结果", lines=6, placeholder="生成的结果将显示在这里...",
interactive=False)
with gr.Row():
task_dropdown = gr.Dropdown(label="任务",
choices=list(TASK_PROMPT_MAPPING.keys()) + ["自主输入文本", "TTS任务", "T2T任务"],
value="empathetic_s2s_dialogue with think")
prompt_speech_dropdown = gr.Dropdown(label="参考音频(prompt_speech)",
choices=[(item["name"], item["value"]) for item in prompt_audio_choices],
value=prompt_audio_choices[0]["value"], visible=True)
custom_prompt_input = gr.Textbox(label="自定义任务提示", placeholder="请输入自定义任务提示...", visible=False)
tts_input = gr.Textbox(label="TTS输入文本", placeholder="请输入TTS任务的文本...", visible=False)
t2t_input = gr.Textbox(label="T2T输入文本", placeholder="请输入T2T任务的文本...", visible=False)
audio_player = gr.Audio(label="播放音频", type="filepath", interactive=False)
with gr.Row():
download_button = gr.DownloadButton("下载音频", variant="secondary",
elem_classes=["button-height", "download-button"])
submit_button = gr.Button("开始处理", variant="primary", elem_classes=["button-height", "submit-button"])
with gr.Row(visible=False) as confirmation_row:
# ... (确认组件不变)
gr.Markdown("请判断结果是否正确:")
confirmation_buttons = gr.Radio(choices=["正确", "错误"], label="", interactive=True, container=False,
elem_classes="confirmation-buttons")
save_button = gr.Button("提交反馈", variant="secondary")
# ... (底部内容不变)
with gr.Row():
with gr.Column(scale=1, min_width=800):
gr.Markdown(f"""...""") # 省略底部HTML
def show_confirmation(output_res, input_wav_path, input_prompt):
return gr.update(visible=True), output_res, input_wav_path, input_prompt
def save_result(if_correct, wav, prompt, res):
save_to_jsonl(if_correct, wav, prompt, res)
return gr.update(visible=False)
# handle_submit 函数现在接收 `selected_model_name` 参数
def handle_submit(selected_model_name, input_wav_path, task_choice, custom_prompt, tts_text, t2t_text,
prompt_speech):
# 1. 根据选择的模型名称,从字典中获取对应的模型和分词器
print(f"用户选择了: {selected_model_name}")
model_info = loaded_models[selected_model_name]
model_to_use = model_info["model"]
tokenizer_to_use = model_info["tokenizer"]
# 2. 准备 prompt
prompt_speech_data = prompt_audio_cache.get(prompt_speech)
if task_choice == "自主输入文本":
input_prompt = custom_prompt + "_self_prompt"
elif task_choice == "TTS任务":
input_prompt = tts_text + "_TTS"
elif task_choice == "T2T任务":
input_prompt = t2t_text + "_T2T"
else:
input_prompt = TASK_PROMPT_MAPPING.get(task_choice, "未知任务类型")
output_res, wav_path_output = true_decode_fuc(model_to_use, tokenizer_to_use, input_wav_path, input_prompt,task_choice ,prompt_speech_data)
return output_res, wav_path_output
# --- 绑定事件 (下拉框逻辑不变) ---
task_dropdown.change(fn=lambda choice: gr.update(visible=choice == "自主输入文本"), inputs=task_dropdown,
outputs=custom_prompt_input)
task_dropdown.change(fn=lambda choice: gr.update(visible=choice == "TTS任务"), inputs=task_dropdown,
outputs=tts_input)
task_dropdown.change(fn=lambda choice: gr.update(visible=choice == "T2T任务"), inputs=task_dropdown,
outputs=t2t_input)
submit_button.click(
fn=handle_submit,
# 在 inputs 列表中添加模型选择器 `model_selector`
inputs=[model_selector, audio_input, task_dropdown, custom_prompt_input, tts_input, t2t_input,
prompt_speech_dropdown],
outputs=[output_text, audio_player]
).then(
fn=show_confirmation,
inputs=[output_text, audio_input, task_dropdown],
outputs=[confirmation_row, output_text, audio_input, task_dropdown]
)
download_button.click(fn=download_audio, inputs=[audio_input], outputs=[download_button])
save_button.click(fn=save_result, inputs=[confirmation_buttons, audio_input, task_dropdown, output_text],
outputs=confirmation_row)
# --- 关键修改:为两个模型分别进行预热 ---
print("开始预热模型...")
# 启动Gradio界面
print("\nGradio 界面启动中...")
if __name__ == '__main__':
demo.launch() |