OSUM-EChat / app.py
xlgeng's picture
开始部署
aea4592
import ast
import base64
import datetime
import json
import logging
import os
import librosa
import spaces
import gradio as gr
import sys
import time
import traceback
import torch
import torchaudio
from common_utils.utils4infer import get_feat_from_wav_path, load_model_and_tokenizer, token_list2wav
sys.path.insert(0, '.')
sys.path.insert(0, './tts')
sys.path.insert(0, './tts/third_party/Matcha-TTS')
from patches import modelling_qwen2_infer_gpu # 打patch
from tts.cosyvoice.cli.cosyvoice import CosyVoice
from tts.cosyvoice.utils.file_utils import load_wav
is_npu = False
try:
import torch_npu
except ImportError:
is_npu = False
print("torch_npu is not available. if you want to use npu, please install it.")
import time
import datetime
import torch
from common_utils.utils4infer import get_feat_from_wav_path, token_list2wav
from huggingface_hub import hf_hub_download
# 从Hugging Face下载.pt文件
CHECKPOINT_PATH_A = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="language_think_final.pt")
CHECKPOINT_PATH_B=None
# CHECKPOINT_PATH_B= hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="tag_think_final.pt")
CONFIG_PATH = "./conf/ct_config.yaml"
NAME_A="language_think"
NAME_B="tag_think"
cosyvoice_model_path = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="CosyVoice-300M-25Hz.tar")
# 将tar包解压到当前目录
os.system(f"tar -xvf {cosyvoice_model_path}")
print("解压cosyvoice模型pt文件完成")
cosyvoice_model_path="./CosyVoice-300M-25Hz"
print("开始加载模型 A...")
model_a, tokenizer_a = load_model_and_tokenizer(CHECKPOINT_PATH_A, CONFIG_PATH)
model_a
print("\n开始加载模型 B...")
if CHECKPOINT_PATH_B is not None:
model_b, tokenizer_b = load_model_and_tokenizer(CHECKPOINT_PATH_B, CONFIG_PATH)
model_b.eval().cuda()
else:
model_b, tokenizer_b = None, None
loaded_models = {
NAME_A: {"model": model_b, "tokenizer": tokenizer_b},
NAME_B: {"model": model_b, "tokenizer": tokenizer_b},
} if model_b is not None else {
NAME_A: {"model": model_b, "tokenizer": tokenizer_b},
}
print("\n所有模型已加载完毕。")
# cosyvoice = CosyVoice(cosyvoice_model_path)
# cosyvoice.eval().cuda()
# 将图片转换为 Base64
with open("./tts/assert/实验室.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
# 任务映射
TASK_PROMPT_MAPPING = {
"empathetic_s2s_dialogue with think": "THINK",
"empathetic_s2s_dialogue no think": "s2s_no_think",
"empathetic_s2t_dialogue with think": "s2t_think",
"empathetic_s2t_dialogue no think": "s2t_no_think",
"ASR (Automatic Speech Recognition)": "转录这段音频中的语音内容为文字。",
"SRWT (Speech Recognition with Timestamps)": "请识别音频内容,并对所有英文词和中文字进行时间对齐,标注格式为<>,时间精度0.1秒。",
"VED (Vocal Event Detection)(类别:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other)": "请将音频转化为文字,并在末尾添加相关音频事件标签,标签格式为<>。",
"SER (Speech Emotion Recognition)(类别:sad,anger,neutral,happy,surprise,fear,disgust,和other)": "请将音频内容转录成文字记录,并在记录末尾标注情感标签,以<>表示。",
"SSR (Speaking Style Recognition)(类别:新闻科普,恐怖故事,童话故事,客服,诗歌散文,有声书,日常口语,其他)": "请将音频中的讲话内容转化为文字,并在结尾处注明风格标签,用<>表示。",
"SGC (Speaker Gender Classification)(类别:female,male)": "请将音频转录为文字,并在文本末尾标注性别标签,标签格式为<>。",
"SAP (Speaker Age Prediction)(类别:child、adult和old)": "请将这段音频转录成文字,并在末尾加上年龄标签,格式为<>。",
"STTC (Speech to Text Chat)": "首先将语音转录为文字,然后对语音内容进行回复,转录和文字之间使用<开始回答>分割。",
"Only Age Prediction(类别:child、adult和old)": "请根据音频分析发言者的年龄并输出年龄标签,标签格式为<>。",
"Only Gender Classification(类别:female,male)": "根据下述音频内容判断说话者性别,返回性别标签,格式为<>.",
"Only Style Recognition(类别:新闻科普,恐怖故事,童话故事,客服,诗歌散文,有声书,日常口语,其他)": "对于以下音频,请直接判断风格并返回风格标签,标签格式为<>。",
"Only Emotion Recognition(类别:sad,anger,neutral,happy,surprise,fear,disgust,和other)": "请鉴别音频中的发言者情感并标出,标签格式为<>。",
"Only Event Detection(类别:laugh,cough,cry,screaming,sigh,throat clearing,sneeze,other)": "对音频进行标签化,返回音频事件标签,标签格式为<>。",
"ASR+AGE+GENDER": '请将这段音频进行转录,并在转录完成的文本末尾附加<年龄> <性别>标签。',
"AGE+GENDER": "请识别以下音频发言者的年龄和性别.",
"ASR+STYLE+AGE+GENDER": "请对以下音频内容进行转录,并在文本结尾分别添加<风格>、<年龄>、<性别>标签。",
"STYLE+AGE+GENDER": "请对以下音频进行分析,识别说话风格、说话者年龄和性别。",
"ASR with punctuations": "需对提供的语音文件执行文本转换,同时为转换结果补充必要的标点。",
"ASR EVENT AGE GENDER": "请将以下音频内容进行转录,并在转录完成的文本末尾分别附加<音频事件>、<年龄>、<性别>标签。",
"ASR EMOTION AGE GENDER": "请将下列音频内容进行转录,并在转录文本的末尾分别添加<情感>、<年龄>、<性别>标签。",
}
prompt_path = hf_hub_download(repo_id="ASLP-lab/OSUM-EChat", filename="prompt.wav")
prompt_audio_choices = [
{"name": "拟人",
"value": prompt_path},
]
prompt_audio_cache = {}
for item in prompt_audio_choices:
prompt_audio_cache[item["value"]] = load_wav(item["value"], 22050)
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
@spaces.GPU
def true_decode_fuc(model, tokenizer, input_wav_path, input_prompt, task_choice, prompt_speech_data):
"""
合并所有推理逻辑的单个函数,处理所有任务类型
"""
print(f"wav_path: {input_wav_path}, prompt:{input_prompt}")
# 检查音频输入合法性
if input_wav_path is None and not input_prompt.endswith(("_TTS", "_T2T")):
print("音频信息未输入,且不是T2S或T2T任务")
return "错误:需要音频输入"
if input_wav_path is not None:
waveform, sample_rate = torchaudio.load(input_wav_path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
waveform = waveform.squeeze(0)
window = torch.hann_window(400)
stft = torch.stft(waveform, 400, 160, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=400, n_mels=80))
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
feat = log_spec.transpose(0, 1)
feat_lens = torch.tensor([feat.shape[0]], dtype=torch.int64).cuda()
feat = feat.unsqueeze(0).cuda()
feat = feat.to(torch.bfloat16)
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
else:
feat = None
feat_lens = None
# 通用初始化:模型设备设置
start_time = time.time()
res_text = None
model_a.eval().cuda()
try:
# 1. 处理TTS任务
if input_prompt.endswith("_TTS"):
text_for_tts = input_prompt.replace("_TTS", "")
# T2S推理逻辑
res_tensor = model_a.generate_tts(device=torch.device("cuda"), text=text_for_tts)[0]
res_token_list = res_tensor.tolist()
res_text = res_token_list[:-1]
print(f"T2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 2. 处理自定义提示任务
elif input_prompt.endswith("_self_prompt"):
prompt = input_prompt.replace("_self_prompt", "")
# S2T推理逻辑
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
# waveform, sample_rate = do_resample(input_wav_path)
res_text = model_a.generate(
wavs=feat,
wavs_len=feat_lens,
prompt=prompt,
cache_implementation="static"
)[0]
print(f"S2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 3. 处理T2T任务
elif input_prompt.endswith("_T2T"):
question_txt = input_prompt.replace("_T2T", "")
# T2T推理逻辑
print(f'开始t2t推理, question_txt: {question_txt}')
if is_npu: torch_npu.npu.synchronize()
res_text = model_a.generate_text2text(
device=torch.device("cuda"),
text=question_txt
)[0]
if is_npu: torch_npu.npu.synchronize()
print(f"T2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 4. 处理S2S无思考任务
elif input_prompt in ["识别语音内容,并以文字方式作出回答。",
"请推断对这段语音回答时的情感,标注情感类型,撰写流畅自然的聊天回复,并生成情感语音token。",
"s2s_no_think"]:
# S2S推理逻辑
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
if is_npu: torch_npu.npu.synchronize()
output_text, text_res, speech_res = model_a.generate_s2s_no_stream_with_repetition_penalty(
wavs=feat,
wavs_len=feat_lens,
)
if is_npu: torch_npu.npu.synchronize()
res_text = f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
print(f"S2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 5. 处理S2S有思考任务
elif input_prompt == "THINK":
# S2S带思考推理逻辑
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
if is_npu: torch_npu.npu.synchronize()
output_text, text_res, speech_res = model_a.generate_s2s_no_stream_think_with_repetition_penalty(
wavs=feat,
wavs_len=feat_lens,
)
if is_npu: torch_npu.npu.synchronize()
res_text = f'{output_text[0]}|{str(speech_res[0].tolist()[1:])}'
print(f"S2S 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 6. 处理S2T4Chat无思考任务
elif input_prompt == "s2t_no_think":
# S2T4Chat推理逻辑
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
if is_npu: torch_npu.npu.synchronize()
res_text = model_a.generate4chat(
wavs=feat,
wavs_len=feat_lens,
cache_implementation="static"
)[0]
if is_npu: torch_npu.npu.synchronize()
print(f"S2T4Chat 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 7. 处理S2T4Chat有思考任务
elif input_prompt == "s2t_think":
# S2T4Chat带思考推理逻辑
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
if is_npu: torch_npu.npu.synchronize()
res_text = model_a.generate4chat_think(
wavs=feat,
wavs_len=feat_lens,
cache_implementation="static"
)[0]
if is_npu: torch_npu.npu.synchronize()
print(f"S2T4Chat 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 8. 处理默认S2T任务
else:
# 默认S2T推理逻辑
# feat, feat_lens = get_feat_from_wav_path(input_wav_path)
print(f'feat shape: {feat.shape}, feat_lens: {feat_lens}')
if is_npu: torch_npu.npu.synchronize()
res_text = model_a.generate(
wavs=feat,
wavs_len=feat_lens,
prompt=input_prompt,
cache_implementation="static"
)[0]
if is_npu: torch_npu.npu.synchronize()
print(f"S2T 推理消耗时间: {time.time() - start_time:.2f} 秒")
# 替换特定标签
res_text = res_text.replace("<youth>", "<adult>").replace("<middle_age>", "<adult>").replace("<middle>",
"<adult>")
except Exception as e:
print(f"推理过程出错: {str(e)}")
return f"错误:{str(e)}"
output_res =res_text
# 4. 处理输出 (逻辑不变)
wav_path_output = input_wav_path
if task_choice == "TTS任务" or "empathetic_s2s_dialogue" in task_choice:
if isinstance(output_res, list): # TTS case
# cosyvoice.eval()
# time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# wav_path = f"./tmp/{time_str}.wav"
# wav_path_output = token_list2wav(output_res, prompt_speech_data, wav_path, cosyvoice)
# wav_path_output = get_wav_from_token_list(output_res, prompt_speech_data)
output_res = "生成的token: " + str(output_res)
elif isinstance(output_res, str) and "|" in output_res: # S2S case
try:
text_res, token_list_str = output_res.split("|")
token_list = json.loads(token_list_str)
# cosyvoice.eval()
time_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
wav_path = f"./tmp/{time_str}.wav"
# wav_path_output = token_list2wav(token_list, prompt_speech_data, wav_path, cosyvoice)
# wav_path_output = get_wav_from_token_list(token_list, prompt_speech_data)
output_res = text_res
except (ValueError, json.JSONDecodeError) as e:
print(f"处理S2S输出时出错: {e}")
output_res = f"错误:无法解析模型输出 - {output_res}"
return output_res, wav_path_output
def save_to_jsonl(if_correct, wav, prompt, res):
data = {
"if_correct": if_correct,
"wav": wav,
"task": prompt,
"res": res
}
with open("results.jsonl", "a", encoding="utf-8") as f:
f.write(json.dumps(data, ensure_ascii=False) + "\n")
def download_audio(input_wav_path):
return input_wav_path if input_wav_path else None
# --- Gradio 界面 ---
with gr.Blocks() as demo:
gr.Markdown(
f"""
<div style="display: flex; align-items: center; justify-content: center; text-align: center;">
<h1 style="font-family: 'Arial', sans-serif; color: #014377; font-size: 32px; margin-bottom: 0; display: inline-block; vertical-align: middle;">
OSUM Speech Understanding Model Test
</h1>
</div>
"""
)
# ### --- 关键修改:添加模型选择器 --- ###
with gr.Row():
model_selector = gr.Radio(
choices=list(loaded_models.keys()), # 从加载的模型字典中获取选项
value=NAME_A, # 默认值
label="选择推理模型",
interactive=True
)
with gr.Row():
with gr.Column(scale=1, min_width=300):
audio_input = gr.Audio(label="录音", sources=["microphone", "upload"], type="filepath", visible=True)
with gr.Column(scale=1, min_width=300):
output_text = gr.Textbox(label="输出结果", lines=6, placeholder="生成的结果将显示在这里...",
interactive=False)
with gr.Row():
task_dropdown = gr.Dropdown(label="任务",
choices=list(TASK_PROMPT_MAPPING.keys()) + ["自主输入文本", "TTS任务", "T2T任务"],
value="empathetic_s2s_dialogue with think")
prompt_speech_dropdown = gr.Dropdown(label="参考音频(prompt_speech)",
choices=[(item["name"], item["value"]) for item in prompt_audio_choices],
value=prompt_audio_choices[0]["value"], visible=True)
custom_prompt_input = gr.Textbox(label="自定义任务提示", placeholder="请输入自定义任务提示...", visible=False)
tts_input = gr.Textbox(label="TTS输入文本", placeholder="请输入TTS任务的文本...", visible=False)
t2t_input = gr.Textbox(label="T2T输入文本", placeholder="请输入T2T任务的文本...", visible=False)
audio_player = gr.Audio(label="播放音频", type="filepath", interactive=False)
with gr.Row():
download_button = gr.DownloadButton("下载音频", variant="secondary",
elem_classes=["button-height", "download-button"])
submit_button = gr.Button("开始处理", variant="primary", elem_classes=["button-height", "submit-button"])
with gr.Row(visible=False) as confirmation_row:
# ... (确认组件不变)
gr.Markdown("请判断结果是否正确:")
confirmation_buttons = gr.Radio(choices=["正确", "错误"], label="", interactive=True, container=False,
elem_classes="confirmation-buttons")
save_button = gr.Button("提交反馈", variant="secondary")
# ... (底部内容不变)
with gr.Row():
with gr.Column(scale=1, min_width=800):
gr.Markdown(f"""...""") # 省略底部HTML
def show_confirmation(output_res, input_wav_path, input_prompt):
return gr.update(visible=True), output_res, input_wav_path, input_prompt
def save_result(if_correct, wav, prompt, res):
save_to_jsonl(if_correct, wav, prompt, res)
return gr.update(visible=False)
# handle_submit 函数现在接收 `selected_model_name` 参数
def handle_submit(selected_model_name, input_wav_path, task_choice, custom_prompt, tts_text, t2t_text,
prompt_speech):
# 1. 根据选择的模型名称,从字典中获取对应的模型和分词器
print(f"用户选择了: {selected_model_name}")
model_info = loaded_models[selected_model_name]
model_to_use = model_info["model"]
tokenizer_to_use = model_info["tokenizer"]
# 2. 准备 prompt
prompt_speech_data = prompt_audio_cache.get(prompt_speech)
if task_choice == "自主输入文本":
input_prompt = custom_prompt + "_self_prompt"
elif task_choice == "TTS任务":
input_prompt = tts_text + "_TTS"
elif task_choice == "T2T任务":
input_prompt = t2t_text + "_T2T"
else:
input_prompt = TASK_PROMPT_MAPPING.get(task_choice, "未知任务类型")
output_res, wav_path_output = true_decode_fuc(model_to_use, tokenizer_to_use, input_wav_path, input_prompt,task_choice ,prompt_speech_data)
return output_res, wav_path_output
# --- 绑定事件 (下拉框逻辑不变) ---
task_dropdown.change(fn=lambda choice: gr.update(visible=choice == "自主输入文本"), inputs=task_dropdown,
outputs=custom_prompt_input)
task_dropdown.change(fn=lambda choice: gr.update(visible=choice == "TTS任务"), inputs=task_dropdown,
outputs=tts_input)
task_dropdown.change(fn=lambda choice: gr.update(visible=choice == "T2T任务"), inputs=task_dropdown,
outputs=t2t_input)
submit_button.click(
fn=handle_submit,
# 在 inputs 列表中添加模型选择器 `model_selector`
inputs=[model_selector, audio_input, task_dropdown, custom_prompt_input, tts_input, t2t_input,
prompt_speech_dropdown],
outputs=[output_text, audio_player]
).then(
fn=show_confirmation,
inputs=[output_text, audio_input, task_dropdown],
outputs=[confirmation_row, output_text, audio_input, task_dropdown]
)
download_button.click(fn=download_audio, inputs=[audio_input], outputs=[download_button])
save_button.click(fn=save_result, inputs=[confirmation_buttons, audio_input, task_dropdown, output_text],
outputs=confirmation_row)
# --- 关键修改:为两个模型分别进行预热 ---
print("开始预热模型...")
# 启动Gradio界面
print("\nGradio 界面启动中...")
if __name__ == '__main__':
demo.launch()