# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import codecs import copy import librosa import logging import random import tarfile from subprocess import PIPE, Popen from urllib.parse import urlparse import torch import torchaudio import torchaudio.compliance.kaldi as kaldi import torch.nn.functional as F from gxl_ai_utils.utils import utils_file from torch.nn.utils.rnn import pad_sequence from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer torchaudio.set_audio_backend("soundfile") AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) import re def process_text(text): # 1. 删除汉字左右两侧的空格 text = re.sub(r'\s*([\u4e00-\u9fff])\s*', r'\1', text) # 2. 将英文转成小写 text = text.lower() # 3. 删除 < 和 > 符号两侧的空格 text = re.sub(r'\s*<\s*', '<', text) text = re.sub(r'\s*>\s*', '>', text) return text def process_text2(text, task_tag): # 1. 删除汉字左右两侧的空格 text = re.sub(r'\s*([\u4e00-\u9fff])\s*', r'\1', text) # 2. 将英文转成小写 if task_tag == "": text = text.lower() # 3. 删除 < 和 > 符号两侧的空格 text = re.sub(r'\s*<\s*', '<', text) text = re.sub(r'\s*>\s*', '>', text) return text def insert_at_position(lst, item_str, position, is_wav:bool): """ 将 item_str 插入到 lst 的第 position 个位置(1-based), 若 lst 长度不足则以 "-1" 填充至目标长度后再插入。 """ index = position - 1 # 一次性计算需要补充的 "-1" 数目并批量 extend if len(lst) < position: lst.extend(["-1"] * (position - len(lst))) if lst[index] != "-1": assert isinstance(lst[index], dict), f'lst[index] is not a dict {lst[index]}' if is_wav: lst[index]['wav'] = item_str['wav'] else: lst[index]['txt'] = item_str['txt'] else: lst[index] = item_str return lst def check_wav_format(s): match = re.fullmatch(r"wav_(\d+)", s) if match: return True, int(match.group(1)) else: return False, -1 def check_txt_format(s): match = re.fullmatch(r"txt_(\d+)", s) if match: return True, int(match.group(1)) else: return False, -1 def load_dict_list_from_jsonl(jsonl_file_path) -> list: """""" with codecs.open(jsonl_file_path, 'r', encoding='utf-8') as f: lines = f.readlines() lines_res = [] for line in lines: try: line = json.loads(line) lines_res.append(line) except Exception as e: print(e) continue return lines_res def url_opener(data): """ Give url or local file, return file descriptor Inplace operation. Args: data(Iterable[str]): url or local file list Returns: Iterable[{src, stream}] """ for sample in data: assert 'src' in sample # TODO(Binbin Zhang): support HTTP url = sample['src'] if "|" not in url: utils_file.logging_error(f'OSUM-EChat url_opener 错误,url格式不正确 {url}, 不含有|') continue combine_path, shard_path = url.split('|') if combine_path == "-": big_dict = None else: try: dict_list = load_dict_list_from_jsonl(combine_path) except Exception as e: utils_file.logging_error(f'OSUM-EChat url_opener 错误,加载combine_path {combine_path} 失败 {e}') dict_list = [] big_dict = {} for item in dict_list: big_dict[item['key']] = item try: pr = urlparse(shard_path) # local file if pr.scheme == '' or pr.scheme == 'file': stream = open(shard_path, 'rb') # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP else: cmd = f'wget -q -O - {shard_path}' process = Popen(cmd, shell=True, stdout=PIPE) sample.update(process=process) stream = process.stdout sample.update(stream=stream,big_dict=big_dict) yield sample except Exception as ex: logging.warning('Failed to open {}'.format(shard_path)) def tar_file_and_group_full_data(data, total_num=0): """ Expand a stream of open tar files into a stream of tar file contents. And groups the file with same prefix Args: data: Iterable[{src, stream}] Returns: Iterable[{key, wav, txt, sample_rate}] """ index = 0 total_num = total_num for sample in data: index += 1 # utils_file.logging_limit_print(f'OSUM-EChat 正在消化第{index}个tar包') assert 'stream' in sample stream = None try: stream = tarfile.open(fileobj=sample['stream'], mode="r:*") big_dict = sample['big_dict'] prev_prefix = None example = {'history': []} valid = True for tarinfo in stream: name = tarinfo.name pos = name.rfind('.') assert pos > 0, f' pos {pos}' prefix, postfix = name[:pos], name[pos + 1:] if prev_prefix is not None and prefix != prev_prefix: example['key'] = prev_prefix if valid: # assert 'txt' in example if 'txt' not in example: example['txt'] = '' if 'wav' not in example: example['wav'] = torch.randn(1, 160000) example['sample_rate'] = 16000 # utils_file.logging_info(f'OSUM-EChat SHUCHU第{index}个tar包') yield example example = {'history': []} valid = True with stream.extractfile(tarinfo) as file_obj: try: if big_dict is not None: if prefix not in big_dict: raise Exception(f'{prefix} not in big_dict') else: info_dict = big_dict[prefix] if 'txt' not in info_dict or 'task' not in info_dict or 'extra' not in info_dict: raise Exception(f'info_dict {info_dict} not include txt, task, extra') # utils_file.logging_limit_print(f'info dict: {info_dict}') if postfix == 'txt': example['txt'] = info_dict['txt'] elif postfix == 'task': example['task'] = info_dict['task'] elif postfix == 'extra': example['extra'] = info_dict['extra'] elif postfix in AUDIO_FORMAT_SETS: waveform, sample_rate = torchaudio.load(file_obj) # 检查音频的维度 num_channels = waveform.shape[0] # 如果音频是多通道的,则进行通道平均 if num_channels > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) example['wav'] = waveform example['sample_rate'] = sample_rate else: pass else: if postfix == 'txt': example['txt'] = file_obj.read().decode('utf8').strip() elif postfix == 'task': example['task'] = file_obj.read().decode('utf8').strip() elif postfix == 'extra': extra_str = file_obj.read().decode('utf8').strip() example['extra'] = json.loads(extra_str) elif postfix in AUDIO_FORMAT_SETS: waveform, sample_rate = torchaudio.load(file_obj) # 检查音频的维度 num_channels = waveform.shape[0] # 如果音频是多通道的,则进行通道平均 if num_channels > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) example['wav'] = waveform example['sample_rate'] = sample_rate else: pass except Exception as ex: valid = False utils_file.logging_error('error to parse ex: {}'.format(ex)) # 1. 基础信息:错误对象、文件名、错误类型 # error_msg = ( # f"Failed to parse {name}! " # f"Error type: {type(ex).__name__}, " # f"Message: {str(ex)}" # ) # # 2. 补充堆栈跟踪(完整调用链路) # stack_trace = traceback.format_exc() # # 3. 组合日志信息,使用warning级别输出(或error级别更合适) # logging.warning(f"{error_msg}\nStack trace:\n{stack_trace}") prev_prefix = prefix if prev_prefix is not None: example['key'] = prev_prefix if 'txt' in example: if 'wav' not in example: example['wav'] = torch.randn(1, 160000) example['sample_rate'] = 16000 utils_file.logging_info(f'*************OSUM-EChat SHUCHU第{index}/{total_num}个tar包') yield example except Exception as ex: logging.warning( 'In tar_file_and_group: {} when processing {}'.format( ex, sample['src'])) finally: if stream is not None: stream.close() if 'process' in sample: sample['process'].communicate() sample['stream'].close() # for history # elif check_wav_format(postfix)[0]: # position = check_wav_format(postfix)[1] # waveform, sample_rate = torchaudio.load(file_obj) # if sample_rate != 16000: # waveform = torchaudio.transforms.Resample( # orig_freq=sample_rate, new_freq=16000)(waveform) # feat = do_compute_log_mel_spectrogram(waveform) # history_item = {'wav': feat, "txt": "", 'position': position} # insert_at_position(example['history'], history_item, position, is_wav=True) # # elif check_txt_format(postfix)[0]: # position = check_txt_format(postfix)[1] # txt_str = file_obj.read().decode( # 'utf8').strip() # history_item = {'wav': '', "txt": txt_str, 'position': position} # insert_at_position(example['history'], history_item, position, is_wav=False) def parse_raw(data): """ Parse key/wav/txt from json line Args: data: Iterable[str], str is a json line has key/wav/txt Returns: Iterable[{key, wav, txt, sample_rate}] """ for sample in data: assert 'src' in sample json_line = sample['src'] obj = json.loads(json_line) assert 'key' in obj assert 'wav' in obj assert 'txt' in obj key = obj['key'] wav_file = obj['wav'] txt = obj['txt'] try: if 'start' in obj: assert 'end' in obj sample_rate = torchaudio.info(wav_file).sample_rate start_frame = int(obj['start'] * sample_rate) end_frame = int(obj['end'] * sample_rate) waveform, _ = torchaudio.load(filepath=wav_file, num_frames=end_frame - start_frame, frame_offset=start_frame) else: waveform, sample_rate = torchaudio.load(wav_file) # 检查音频的维度 num_channels = waveform.shape[0] # 如果音频是多通道的,则进行通道平均 if num_channels > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) example = copy.deepcopy(obj) # copy and keep all the fields example['wav'] = waveform # overwrite wav example['sample_rate'] = sample_rate yield example except Exception as ex: logging.warning('Failed to read {}'.format(wav_file)) def parse_speaker(data, speaker_table_path): speaker_dict = {} with open(speaker_table_path, 'r', encoding='utf8') as fin: for line in fin: arr = line.strip().split() speaker_dict[arr[0]] = int(arr[1]) for sample in data: assert 'speaker' in sample speaker = sample['speaker'] sample['speaker'] = speaker_dict.get(speaker, 0) yield sample global_style_dict = { "朗读": "新闻科普", "科普百科": "新闻科普", "悬疑恐怖": "恐怖故事", "童话故事": "童话故事", "客服": "客服", "诗歌": "诗歌散文", "散文": "诗歌散文", "武侠评书": "有声书", "小说": "有声书", "历史": "有声书", "科幻": "有声书", "对话": "日常口语", "口语": "日常口语", "幽默": "其他", "其他": "其他", } # global_chat_dict = utils_file.load_dict_from_scp("/mnt/sfs/asr/update_data/3500_chat_asr/osum_echat_all_3500_with_asr_chat.scp") asr_X_set = {" ", "