OSUM-EChat / wenet /osum_echat /wav_instrcut_tools.py
xlgeng's picture
开始部署
841f290
# 负责wav & text instruct 的处理
import codecs
import json
import random
from gxl_ai_utils.utils import utils_file
import torch
import os
# asr+X
task_names = [
"age",
"gender",
"style",
"emotion",
"caption"
]
map_dict = {
"<TRANSCRIBE> <STYLE>": "style",
"<TRANSCRIBE> <AGE>": "age",
"<TRANSCRIBE> <GENDER>": "gender",
"<TRANSCRIBE> <EMOTION>": "emotion",
"<TRANSCRIBE> <CAPTION>": "caption",
"<STYLE>": "style_only_X",
"<AGE>": "age_only_X",
"<GENDER>": "gender_only_X",
"<EMOTION>": "emotion_only_X",
"<CAPTION>": "caption_only_X",
}
question_info_dir = "/mnt/sfs/asr/code/osum_osum_echat/examples/wenetspeech/whisper/conf/language_follow_prompt/Q_with_asr"
# question_info_dir = "D:\osum_echat_workspace_2\osum_using_ofical\osum_osum_echat\examples\wenetspeech\whisper\conf\language_follow_prompt\Q_with_asr"
question_big_dict = {}
question_wav_big_dict = {}
# for task_name in task_names:
# task_json_path = os.path.join(question_info_dir, f"{task_name}.json")
# dict_i = utils_file.load_dict_from_json(task_json_path)
# question_strs = list(dict_i.values())
# question_big_dict[task_name] = question_strs
#
#
# for task_name in task_names:
# task_json_path = os.path.join(question_info_dir, f"{task_name}.json")
# dict_i = utils_file.load_dict_from_json(task_json_path)
# question_strs = list(dict_i.keys()) # 获取字典的键作为问题字符串列表
# question_wav_big_dict[task_name] = []
# for i in range(1, 21):
# new_question_strs = [f"/mnt/sfs/asr/update_data/Q_with_asr_added_by_20250406/{task_name}_{x}_speaker_{i}.wav" for x in question_strs]
# question_wav_big_dict[task_name].extend(new_question_strs)
question_info_dir = "/mnt/sfs/asr/code/osum_osum_echat/examples/wenetspeech/whisper/conf/language_follow_prompt/Q"
# question_info_dir = "D:\osum_echat_workspace_2\osum_using_ofical\osum_osum_echat\examples\wenetspeech\whisper\conf\language_follow_prompt\Q"
# for task_name in task_names:
# task_json_path = os.path.join(question_info_dir, f"{task_name}.json")
# dict_i = utils_file.load_dict_from_json(task_json_path)
# question_strs = list(dict_i.values())
# question_big_dict[task_name+"_only_X"] = question_strs
# for task_name in task_names:
# task_json_path = os.path.join(question_info_dir, f"{task_name}.json")
# dict_i = utils_file.load_dict_from_json(task_json_path)
# question_strs = list(dict_i.keys()) # 获取字典的键作为问题字符串列表
# question_wav_big_dict[task_name+"_only_X"] = []
# for i in range(1,21):
# new_question_strs = [f"/mnt/sfs/asr/update_data/Q_only_X_added_by_20250406/{task_name}_{x}_speaker_{i}.wav" for x in question_strs]
# question_wav_big_dict[task_name+"_only_X"].extend(new_question_strs)
# for answer
answer_info_dir = "/mnt/sfs/asr/code/osum_osum_echat/examples/wenetspeech/whisper/conf/language_follow_prompt/Q_with_asr/answer"
# answer_info_dir = "D:\osum_echat_workspace_2\osum_using_ofical\osum_osum_echat\examples\wenetspeech\whisper\conf\language_follow_prompt\Q_with_asr\\answer"
answer_big_dict = {}
# for task_name in task_names:
# task_json_path = os.path.join(answer_info_dir, f"{task_name}.json")
# dict_i = utils_file.load_dict_from_json(task_json_path)
# answer_big_dict[task_name] = dict_i
answer_info_dir = "/mnt/sfs/asr/code/osum_osum_echat/examples/wenetspeech/whisper/conf/language_follow_prompt/Q/answer"
# answer_info_dir = "D:\osum_echat_workspace_2\osum_using_ofical\osum_osum_echat\examples\wenetspeech\whisper\conf\language_follow_prompt\Q\\answer"
# for task_name in task_names:
# task_json_path = os.path.join(answer_info_dir, f"{task_name}.json")
# dict_i = utils_file.load_dict_from_json(task_json_path)
# answer_big_dict[task_name+"_only_X"] = dict_i
def get_question_prompt_by_task(task_tag):
key_str = map_dict.get(task_tag, None)
if key_str is None:
return None
question_strs = question_big_dict.get(key_str, [])
if len(question_strs) == 0:
return None
# 随机算一个句子
question_str = random.choice(question_strs)
return question_str
def get_question_wav_path_by_task(task_tag):
key_str = map_dict.get(task_tag, None)
if key_str is None:
return None
question_strs = question_wav_big_dict.get(key_str, [])
if len(question_strs) == 0:
return None
# 随机算一个句子
question_str = random.choice(question_strs)
return question_str
def get_answer_prompt_by_task(task_tag, answer_tag, asr_txt=None):
key_str = map_dict.get(task_tag, None)
if key_str is None:
utils_file.logging_error(f"task_tag {task_tag} not in map_dict")
return None
answer_dict = answer_big_dict.get(key_str, {})
answer_strs = answer_dict.get(answer_tag, [])
if len(answer_strs) == 0:
utils_file.logging_error(f"answer_tag {answer_tag} not in answer_dict of {task_tag}")
return None
# 随机算一个句子
answer_str = random.choice(answer_strs)
if asr_txt is not None:
if "{}" not in answer_str:
utils_file.logging_error(f"answer_str {answer_str} not contain","{}", 'but asr_txt is not None')
return None
answer_str = answer_str.format(asr_txt)
else:
if "{}" in answer_str:
utils_file.logging_error(f"answer_str {answer_str} contain","{}", 'but asr_txt is None')
return None
return answer_str
if __name__ == '__main__':
task_tag = "<CAPTION>"
type_tag = "church_bells"
question_str = get_question_wav_path_by_task(task_tag)
print(question_str)
question_str = get_question_wav_path_by_task(task_tag)
print(question_str)
question_str = get_question_wav_path_by_task(task_tag)
print(question_str)
question_str = get_question_wav_path_by_task(task_tag)
print(question_str)
question_str = get_question_prompt_by_task(task_tag)
print(question_str)
question_str = get_question_prompt_by_task(task_tag)
print(question_str)
question_str = get_question_prompt_by_task(task_tag)
print(question_str)
question_str = get_question_prompt_by_task(task_tag)
print(question_str)
question_str = get_answer_prompt_by_task(task_tag, type_tag)
print(question_str)
question_str = get_answer_prompt_by_task(task_tag, type_tag)
print(question_str)
question_str = get_answer_prompt_by_task(task_tag, type_tag)
print(question_str)
question_str = get_answer_prompt_by_task(task_tag, type_tag)
print(question_str)
question_str = get_answer_prompt_by_task(task_tag, type_tag, "I am a child")
print(question_str)
question_str = get_answer_prompt_by_task(task_tag, type_tag, "I am a child")
print(question_str)
question_str = get_answer_prompt_by_task(task_tag, type_tag, "I am a child")
print(question_str)
question_str = get_answer_prompt_by_task(task_tag, type_tag, "I am a child")
print(question_str)
# caption_json_path = "D:\osum_echat_workspace_2\osum_using_ofical\osum_osum_echat\examples\wenetspeech\whisper\conf\language_follow_prompt\Q_with_asr\\answer\caption.json"
# caption_dict = utils_file.load_dict_from_json(caption_json_path)
# new_caption_dict = {}
# utils_file.print_list(list(caption_dict.keys()))
# for k, v in caption_dict.items():
# new_k = k.lower()
# new_caption_dict[new_k] = v
# with codecs.open(caption_json_path, 'w', encoding='utf-8') as f:
# json.dump(new_caption_dict, f, ensure_ascii=False, indent=4)