import random from typing import Tuple import torch from torch.nn.utils.rnn import pad_sequence from wenet.utils.common import pad_list from gxl_ai_utils.utils import utils_file def add_sos_eos4speech_llm(ys_pad: torch.Tensor, sos: int, eos: int, ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]: """Add and labels. 为out后接一个eos. in基本保持不变 Args: ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) sos (int): index of eos (int): index of ignore_id (int): index of padding Returns: ys_in (torch.Tensor) : (B, Lmax) ys_out (torch.Tensor) : (B, Lmax + 1) Examples: >>> sos_id = 10 >>> eos_id = 11 >>> ignore_id = -1 >>> ys_pad tensor([[ 1, 2, 3, 4, 5], [ 4, 5, 6, -1, -1], [ 7, 8, 9, -1, -1]], dtype=torch.int32) >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) >>> ys_in tensor([[ 1, 2, 3, 4, 5], [ 4, 5, 6, 11, 11], [ 7, 8, 9, 11, 11]]) >>> ys_out tensor([[ 1, 2, 3, 4, 5, 11], [ 4, 5, 6, 11, -1, -1], [ 7, 8, 9, 11, -1, -1]]) """ _sos = torch.tensor([sos], dtype=torch.long, requires_grad=False, device=ys_pad.device) _eos = torch.tensor([eos], dtype=torch.long, requires_grad=False, device=ys_pad.device) ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys # ys_in = [torch.cat([_sos, y], dim=0) for y in ys] ys_in = [y for y in ys] ys_out = [torch.cat([y, _eos], dim=0) for y in ys] return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) global_prompt_dict = None def get_prompt_by_task(task_name): """ 根据task给定指定的prompt, 并实现prompt的多样随意性 Args: task_name: Returns: """ global global_prompt_dict if global_prompt_dict is None: global_prompt_dict = utils_file.load_dict_from_yaml('conf/prompt.yaml') random_index = random.randint(0, len(global_prompt_dict[task_name]) - 1) return global_prompt_dict[task_name][random_index] import torch def merge_labels_with_valid_adjacent( labels_embeds1, labels_target1, labels_mask1, labels_embeds2, labels_target2, labels_mask2, pad_value=0, ignore_id=-100 ): """ 合并两组标签,有效特征紧邻拼接,无效特征后移 Args: labels_embeds1 (Tensor): 标签1嵌入,形状 (B, L1, D) labels_target1 (Tensor): 标签1目标,形状 (B, L1) labels_mask1 (Tensor): 标签1掩码,形状 (B, L1) labels_embeds2 (Tensor): 标签2嵌入,形状 (B, L2, D) labels_target2 (Tensor): 标签2目标,形状 (B, L2) labels_mask2 (Tensor): 标签2掩码,形状 (B, L2) pad_value (int): 嵌入填充值 ignore_id (int): 目标填充值(如IGNORE_ID) Returns: merged_embeds (Tensor): 合并嵌入,形状 (B, L1+L2, D) merged_target (Tensor): 合并目标,形状 (B, L1+L2) merged_mask (Tensor): 合并掩码,形状 (B, L1+L2) """ batch_size = labels_embeds1.size(0) max_len = labels_embeds1.size(1) + labels_embeds2.size(1) merged_embeds = [] merged_target = [] merged_mask = [] for i in range(batch_size): # 提取有效特征索引 valid_indices1 = torch.where(labels_mask1[i])[0] valid_indices2 = torch.where(labels_mask2[i])[0] # 合并有效特征段 valid_embeds = torch.cat([ labels_embeds1[i, valid_indices1], labels_embeds2[i, valid_indices2] ], dim=0) valid_target = torch.cat([ labels_target1[i, valid_indices1], labels_target2[i, valid_indices2] ], dim=0) valid_mask = torch.cat([ labels_mask1[i, valid_indices1], labels_mask2[i, valid_indices2] ], dim=0) # 填充无效部分 pad_length = max_len - len(valid_embeds) padded_embeds = torch.cat([ valid_embeds, torch.full((pad_length, labels_embeds1.size(2)), pad_value, device=labels_embeds1.device) ], dim=0) padded_target = torch.cat([ valid_target, torch.full((pad_length,), ignore_id, device=labels_target1.device) ], dim=0) padded_mask = torch.cat([ valid_mask, torch.zeros(pad_length, dtype=torch.bool, device=labels_mask1.device) ], dim=0) merged_embeds.append(padded_embeds) merged_target.append(padded_target) merged_mask.append(padded_mask) # 堆叠批次结果 merged_embeds = torch.stack(merged_embeds, dim=0).to(labels_embeds1.device) merged_target = torch.stack(merged_target, dim=0).to(labels_target1.device) merged_mask = torch.stack(merged_mask, dim=0).to(labels_mask1.device) return merged_embeds, merged_target, merged_mask def make_streaming_mode_from_s2s_old(text_tokens_padded, text_tokens_lens, speech_tokens_padded, speech_tokens_lens, ): """ Args: text_tokens_padded: (B, Lmax) text_tokens_lens: (B,) speech_tokens_padded: (B, Lmax2) speech_tokens_lens: (B,) Returns: streaming_mode_tokens_padded: (B, Lmax+Lmax2+1) streaming_mode_tokens_lens: (B,) 首先assert每个单元的文字有效token的数量的3倍是少于该单元的speech token的数量。 然后做如下排列:对于batch内的每个item, 先排6个文字有效token,然后再排18个speech 有效token,然后再排6个文字token,然后排18个speech token,以此类推,直到有效文本token用尽。 """ text_tokens_padded = text_tokens_padded.to(torch.int64) speech_tokens_padded = speech_tokens_padded.to(torch.int64) batch_size = text_tokens_padded.size(0) device = text_tokens_padded.device # 验证文字token数量不超过语音token的1/3 for i in range(batch_size): assert text_tokens_lens[i] * 3 <= speech_tokens_lens[i], \ f"Batch {i}: Text tokens * 3 should be less than speech tokens" streaming_mode_tokens_list = [] streaming_mode_lens = [] for i in range(batch_size): text_tokens = text_tokens_padded[i, :text_tokens_lens[i]] speech_tokens = speech_tokens_padded[i, :speech_tokens_lens[i]].to(torch.int64) streaming_tokens = [] text_idx = 0 speech_idx = 0 while text_idx < text_tokens_lens[i]: # 处理文本token(6个一组),防止越界 chunk_size = min(6, text_tokens_lens[i] - text_idx) streaming_tokens.extend(text_tokens[text_idx:text_idx + chunk_size].tolist()) text_idx += chunk_size # 如果文本token不足6个,添加999标记 if chunk_size < 6: streaming_tokens.append(999) # 处理语音token(18个一组),防止越界 speech_chunk = min(18, speech_tokens_lens[i] - speech_idx) streaming_tokens.extend(speech_tokens[speech_idx:speech_idx + speech_chunk].tolist()) speech_idx += speech_chunk # 如果文本token正好用完,添加999标记 if text_idx == text_tokens_lens[i] and text_tokens_lens[i] % 6 == 0: streaming_tokens.append(999) # 添加剩余的语音token streaming_tokens.extend(speech_tokens[speech_idx:].tolist()) # 转换为BFLOAT16张量 streaming_mode_tokens_list.append(torch.tensor(streaming_tokens, dtype=torch.int64, device=device)) streaming_mode_lens.append(len(streaming_tokens)) streaming_mode_tokens_padded = pad_sequence(streaming_mode_tokens_list, batch_first=True, padding_value=0).to( device) streaming_mode_tokens_lens = torch.tensor(streaming_mode_lens, device=device) return streaming_mode_tokens_padded, streaming_mode_tokens_lens def make_streaming_mode_from_s2s(text_tokens_padded, text_tokens_lens, speech_tokens_padded, speech_tokens_lens, ): """ Args: text_tokens_padded: (B, Lmax) text_tokens_lens: (B,) speech_tokens_padded: (B, Lmax2) speech_tokens_lens: (B,) Returns: streaming_mode_tokens_padded: (B, Lmax+Lmax2+1) streaming_mode_tokens_lens: (B,) 首先assert每个单元的文字有效token的数量的3倍是少于该单元的speech token的数量。 然后做如下排列:对于batch内的每个item, 先排6个文字有效token,然后再排18个speech 有效token,然后再排6个文字token,然后排18个speech token,以此类推,直到有效文本token用尽。 : [13708, 766, 835, 29] """ text_tokens_padded = text_tokens_padded.to(torch.int64) speech_tokens_padded = speech_tokens_padded.to(torch.int64) batch_size = text_tokens_padded.size(0) device = text_tokens_padded.device # 验证文字token数量不超过语音token的1/3 for i in range(batch_size): assert text_tokens_lens[i] * 3 <= speech_tokens_lens[i], \ f"Batch {i}: Text tokens * 3 should be less than speech tokens" streaming_mode_tokens_list = [] streaming_mode_lens = [] for i in range(batch_size): text_tokens = text_tokens_padded[i, :text_tokens_lens[i]] speech_tokens = speech_tokens_padded[i, :speech_tokens_lens[i]].to(torch.int64) streaming_tokens = [] text_idx = 0 speech_idx = 0 while text_idx < text_tokens_lens[i]: # 这里的指针指的是左指针,肯定不能等于 len(text_tokens) # 处理文本token(6个一组),防止越界 chunk_size = min(6, text_tokens_lens[i] - text_idx) streaming_tokens.extend(text_tokens[text_idx:text_idx + chunk_size].tolist()) text_idx += chunk_size # 处理语音token(18个一组),防止越界 speech_chunk = min(18, speech_tokens_lens[i] - speech_idx) streaming_tokens.extend(speech_tokens[speech_idx:speech_idx + speech_chunk].tolist()) speech_idx += speech_chunk # 添加剩余的语音token streaming_tokens.extend(speech_tokens[speech_idx:].tolist()) streaming_mode_tokens_list.append(torch.tensor(streaming_tokens, dtype=torch.int64, device=device)) streaming_mode_lens.append(len(streaming_tokens)) streaming_mode_tokens_padded = pad_sequence(streaming_mode_tokens_list, batch_first=True, padding_value=0).to( device) streaming_mode_tokens_lens = torch.tensor(streaming_mode_lens, device=device) return streaming_mode_tokens_padded, streaming_mode_tokens_lens def make_streaming_mode_from_s2s4think( text_tokens_padded, text_tokens_lens, speech_tokens_padded, speech_tokens_lens, ): """ Args: text_tokens_padded: (B, Lmax) text_tokens_lens: (B,) speech_tokens_padded: (B, Lmax2) speech_tokens_lens: (B,) Returns: streaming_mode_tokens_padded: (B, Lmax+Lmax2+1) streaming_mode_tokens_lens: (B,) """ text_tokens_padded = text_tokens_padded.to(torch.int64) speech_tokens_padded = speech_tokens_padded.to(torch.int64) batch_size = text_tokens_padded.size(0) device = text_tokens_padded.device # 验证文字 token 数量不超过语音 token 的 1/3 for i in range(batch_size): assert text_tokens_lens[i] * 3 <= speech_tokens_lens[i], \ f"Batch {i}: Text tokens * 3 should be <= speech tokens" streaming_mode_tokens_list = [] streaming_mode_lens = [] # 要检测的子序列 target_seq = [13708, 766, 835, 29] seq_len = len(target_seq) for i in range(batch_size): # 取出本样本的有效文本和语音序列 text_tokens = text_tokens_padded[i, :text_tokens_lens[i]] speech_tokens = speech_tokens_padded[i, :speech_tokens_lens[i]] streaming_tokens = [] # —— 新增逻辑:先在 text_tokens 中寻找整个子序列 target_seq —— text_list = text_tokens.tolist() prefix_end_idx = 0 # 滑窗匹配 for j in range(text_tokens_lens[i] - seq_len + 1): if text_list[j:j + seq_len] == target_seq: prefix_end_idx = j + seq_len break # 如果找到了,就先把前缀一次性输出 if prefix_end_idx > 0: streaming_tokens.extend(text_list[:prefix_end_idx]) text_idx = prefix_end_idx else: text_idx = 0 # —— 新增逻辑结束 —— speech_idx = 0 # 之后再从 text_idx 开始做常规的“6 文本 + 18 语音”交错 while text_idx < text_tokens_lens[i]: # 文本块(最多 6) chunk_size = min(6, text_tokens_lens[i] - text_idx) streaming_tokens.extend(text_list[text_idx:text_idx + chunk_size]) text_idx += chunk_size # 语音块(最多 18) speech_chunk = min(18, speech_tokens_lens[i] - speech_idx) streaming_tokens.extend(speech_tokens[speech_idx:speech_idx + speech_chunk].tolist()) speech_idx += speech_chunk # 最后再把剩余的所有语音 token 全部补上 streaming_tokens.extend(speech_tokens[speech_idx:].tolist()) # 收集本样本结果 streaming_mode_tokens_list.append( torch.tensor(streaming_tokens, dtype=torch.int64, device=device) ) streaming_mode_lens.append(len(streaming_tokens)) # padding 到同样长度 streaming_mode_tokens_padded = pad_sequence( streaming_mode_tokens_list, batch_first=True, padding_value=0 ).to(device) streaming_mode_tokens_lens = torch.tensor(streaming_mode_lens, device=device) return streaming_mode_tokens_padded, streaming_mode_tokens_lens def do_embedding_for_two_embeds(input_token_ids, dividing_id, embedding1, embedding2): """ Args: input_token_ids: (B, Lmax) ,其词表范围是[0, vocab_size1+vocab_size2) dividing_id: int, 第一个词表的个数 embedding1: nn.Embedding(vocab_size1, embedding_dim) embedding2: nn.Embedding(vocab_size2, embedding_dim) Returns: embedding1_output: (B, Lmax, D) 把两个embeddings 虚拟成一个大的词向量 """ input_token_ids = input_token_ids.to(torch.int64) mask4embedding1 = input_token_ids < dividing_id mask4embedding2 = input_token_ids >= dividing_id embedding1_output = embedding1(input_token_ids[mask4embedding1]).to(embedding1.weight.dtype) embedding2_output = embedding2(input_token_ids[mask4embedding2] - dividing_id).to(embedding1.weight.dtype) res_output = torch.zeros(input_token_ids.size(0), input_token_ids.size(1), embedding1.embedding_dim,dtype=embedding1.weight.dtype, device=embedding1.weight.device) res_output[mask4embedding1] = embedding1_output res_output[mask4embedding2] = embedding2_output return res_output def do_convert_num2text(num_str: str): """ 将数字字符串转换为中文数字 Args: num_str: 数字字符串 Returns: 转换后的中文数字字符串 """ import cn2an num_str = num_str.strip() output = cn2an.transform(num_str, "an2cn") return output def _do_test_for_streaming_chat(): # test make_streaming_mode_from_s2s text_tokens_padded = torch.randint(0, 100, (3, 10)).to(torch.device('npu:0')) text_tokens_lens = torch.tensor([5, 7, 3]).to(torch.device('npu:0')) speech_tokens_padded = torch.randint(100, 200, (3, 150)).to(torch.device('npu:0')) speech_tokens_lens = torch.tensor([100, 120, 80]).to(torch.device('npu:0')) streaming_mode_tokens_padded, streaming_mode_tokens_lens = make_streaming_mode_from_s2s(text_tokens_padded, text_tokens_lens, speech_tokens_padded, speech_tokens_lens) print(streaming_mode_tokens_padded.shape) print(streaming_mode_tokens_padded.device) print(streaming_mode_tokens_lens) print(streaming_mode_tokens_lens.device) # test do_embedding_for_two_embeds input_token_ids = torch.randint(0, 100, (3, 10)).to(torch.device('npu:0')) dividing_id = 50 embedding1 = torch.nn.Embedding(50, 10).to(torch.device('npu:0')) embedding2 = torch.nn.Embedding(50, 10).to(torch.device('npu:0')) res_output = do_embedding_for_two_embeds(input_token_ids, dividing_id, embedding1, embedding2) print(res_output.shape) print(res_output.device) a = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, ]).to(torch.device('npu:0')) print(a[3:1000]) if __name__ == '__main__': """"""