Spaces:
Running
on
Zero
Running
on
Zero
import random | |
from typing import Tuple | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from wenet.utils.common import pad_list | |
from gxl_ai_utils.utils import utils_file | |
def add_sos_eos4speech_llm(ys_pad: torch.Tensor, sos: int, eos: int, | |
ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Add <sos> and <eos> labels. | |
为out后接一个eos. in基本保持不变 | |
Args: | |
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) | |
sos (int): index of <sos> | |
eos (int): index of <eeos> | |
ignore_id (int): index of padding | |
Returns: | |
ys_in (torch.Tensor) : (B, Lmax) | |
ys_out (torch.Tensor) : (B, Lmax + 1) | |
Examples: | |
>>> sos_id = 10 | |
>>> eos_id = 11 | |
>>> ignore_id = -1 | |
>>> ys_pad | |
tensor([[ 1, 2, 3, 4, 5], | |
[ 4, 5, 6, -1, -1], | |
[ 7, 8, 9, -1, -1]], dtype=torch.int32) | |
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) | |
>>> ys_in | |
tensor([[ 1, 2, 3, 4, 5], | |
[ 4, 5, 6, 11, 11], | |
[ 7, 8, 9, 11, 11]]) | |
>>> ys_out | |
tensor([[ 1, 2, 3, 4, 5, 11], | |
[ 4, 5, 6, 11, -1, -1], | |
[ 7, 8, 9, 11, -1, -1]]) | |
""" | |
_sos = torch.tensor([sos], | |
dtype=torch.long, | |
requires_grad=False, | |
device=ys_pad.device) | |
_eos = torch.tensor([eos], | |
dtype=torch.long, | |
requires_grad=False, | |
device=ys_pad.device) | |
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys | |
# ys_in = [torch.cat([_sos, y], dim=0) for y in ys] | |
ys_in = [y for y in ys] | |
ys_out = [torch.cat([y, _eos], dim=0) for y in ys] | |
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) | |
global_prompt_dict = None | |
def get_prompt_by_task(task_name): | |
""" | |
根据task给定指定的prompt, 并实现prompt的多样随意性 | |
Args: | |
task_name: | |
Returns: | |
""" | |
global global_prompt_dict | |
if global_prompt_dict is None: | |
global_prompt_dict = utils_file.load_dict_from_yaml('conf/prompt.yaml') | |
random_index = random.randint(0, len(global_prompt_dict[task_name]) - 1) | |
return global_prompt_dict[task_name][random_index] | |
import torch | |
def merge_labels_with_valid_adjacent( | |
labels_embeds1, labels_target1, labels_mask1, | |
labels_embeds2, labels_target2, labels_mask2, | |
pad_value=0, ignore_id=-100 | |
): | |
""" | |
合并两组标签,有效特征紧邻拼接,无效特征后移 | |
Args: | |
labels_embeds1 (Tensor): 标签1嵌入,形状 (B, L1, D) | |
labels_target1 (Tensor): 标签1目标,形状 (B, L1) | |
labels_mask1 (Tensor): 标签1掩码,形状 (B, L1) | |
labels_embeds2 (Tensor): 标签2嵌入,形状 (B, L2, D) | |
labels_target2 (Tensor): 标签2目标,形状 (B, L2) | |
labels_mask2 (Tensor): 标签2掩码,形状 (B, L2) | |
pad_value (int): 嵌入填充值 | |
ignore_id (int): 目标填充值(如IGNORE_ID) | |
Returns: | |
merged_embeds (Tensor): 合并嵌入,形状 (B, L1+L2, D) | |
merged_target (Tensor): 合并目标,形状 (B, L1+L2) | |
merged_mask (Tensor): 合并掩码,形状 (B, L1+L2) | |
""" | |
batch_size = labels_embeds1.size(0) | |
max_len = labels_embeds1.size(1) + labels_embeds2.size(1) | |
merged_embeds = [] | |
merged_target = [] | |
merged_mask = [] | |
for i in range(batch_size): | |
# 提取有效特征索引 | |
valid_indices1 = torch.where(labels_mask1[i])[0] | |
valid_indices2 = torch.where(labels_mask2[i])[0] | |
# 合并有效特征段 | |
valid_embeds = torch.cat([ | |
labels_embeds1[i, valid_indices1], | |
labels_embeds2[i, valid_indices2] | |
], dim=0) | |
valid_target = torch.cat([ | |
labels_target1[i, valid_indices1], | |
labels_target2[i, valid_indices2] | |
], dim=0) | |
valid_mask = torch.cat([ | |
labels_mask1[i, valid_indices1], | |
labels_mask2[i, valid_indices2] | |
], dim=0) | |
# 填充无效部分 | |
pad_length = max_len - len(valid_embeds) | |
padded_embeds = torch.cat([ | |
valid_embeds, | |
torch.full((pad_length, labels_embeds1.size(2)), pad_value, device=labels_embeds1.device) | |
], dim=0) | |
padded_target = torch.cat([ | |
valid_target, | |
torch.full((pad_length,), ignore_id, device=labels_target1.device) | |
], dim=0) | |
padded_mask = torch.cat([ | |
valid_mask, | |
torch.zeros(pad_length, dtype=torch.bool, device=labels_mask1.device) | |
], dim=0) | |
merged_embeds.append(padded_embeds) | |
merged_target.append(padded_target) | |
merged_mask.append(padded_mask) | |
# 堆叠批次结果 | |
merged_embeds = torch.stack(merged_embeds, dim=0).to(labels_embeds1.device) | |
merged_target = torch.stack(merged_target, dim=0).to(labels_target1.device) | |
merged_mask = torch.stack(merged_mask, dim=0).to(labels_mask1.device) | |
return merged_embeds, merged_target, merged_mask | |
def make_streaming_mode_from_s2s_old(text_tokens_padded, text_tokens_lens, speech_tokens_padded, speech_tokens_lens, ): | |
""" | |
Args: | |
text_tokens_padded: (B, Lmax) | |
text_tokens_lens: (B,) | |
speech_tokens_padded: (B, Lmax2) | |
speech_tokens_lens: (B,) | |
Returns: | |
streaming_mode_tokens_padded: (B, Lmax+Lmax2+1) | |
streaming_mode_tokens_lens: (B,) | |
首先assert每个单元的文字有效token的数量的3倍是少于该单元的speech token的数量。 | |
然后做如下排列:对于batch内的每个item, 先排6个文字有效token,然后再排18个speech 有效token,然后再排6个文字token,然后排18个speech token,以此类推,直到有效文本token用尽。 | |
""" | |
text_tokens_padded = text_tokens_padded.to(torch.int64) | |
speech_tokens_padded = speech_tokens_padded.to(torch.int64) | |
batch_size = text_tokens_padded.size(0) | |
device = text_tokens_padded.device | |
# 验证文字token数量不超过语音token的1/3 | |
for i in range(batch_size): | |
assert text_tokens_lens[i] * 3 <= speech_tokens_lens[i], \ | |
f"Batch {i}: Text tokens * 3 should be less than speech tokens" | |
streaming_mode_tokens_list = [] | |
streaming_mode_lens = [] | |
for i in range(batch_size): | |
text_tokens = text_tokens_padded[i, :text_tokens_lens[i]] | |
speech_tokens = speech_tokens_padded[i, :speech_tokens_lens[i]].to(torch.int64) | |
streaming_tokens = [] | |
text_idx = 0 | |
speech_idx = 0 | |
while text_idx < text_tokens_lens[i]: | |
# 处理文本token(6个一组),防止越界 | |
chunk_size = min(6, text_tokens_lens[i] - text_idx) | |
streaming_tokens.extend(text_tokens[text_idx:text_idx + chunk_size].tolist()) | |
text_idx += chunk_size | |
# 如果文本token不足6个,添加999标记 | |
if chunk_size < 6: | |
streaming_tokens.append(999) | |
# 处理语音token(18个一组),防止越界 | |
speech_chunk = min(18, speech_tokens_lens[i] - speech_idx) | |
streaming_tokens.extend(speech_tokens[speech_idx:speech_idx + speech_chunk].tolist()) | |
speech_idx += speech_chunk | |
# 如果文本token正好用完,添加999标记 | |
if text_idx == text_tokens_lens[i] and text_tokens_lens[i] % 6 == 0: | |
streaming_tokens.append(999) | |
# 添加剩余的语音token | |
streaming_tokens.extend(speech_tokens[speech_idx:].tolist()) | |
# 转换为BFLOAT16张量 | |
streaming_mode_tokens_list.append(torch.tensor(streaming_tokens, dtype=torch.int64, device=device)) | |
streaming_mode_lens.append(len(streaming_tokens)) | |
streaming_mode_tokens_padded = pad_sequence(streaming_mode_tokens_list, batch_first=True, padding_value=0).to( | |
device) | |
streaming_mode_tokens_lens = torch.tensor(streaming_mode_lens, device=device) | |
return streaming_mode_tokens_padded, streaming_mode_tokens_lens | |
def make_streaming_mode_from_s2s(text_tokens_padded, text_tokens_lens, speech_tokens_padded, speech_tokens_lens, ): | |
""" | |
Args: | |
text_tokens_padded: (B, Lmax) | |
text_tokens_lens: (B,) | |
speech_tokens_padded: (B, Lmax2) | |
speech_tokens_lens: (B,) | |
Returns: | |
streaming_mode_tokens_padded: (B, Lmax+Lmax2+1) | |
streaming_mode_tokens_lens: (B,) | |
首先assert每个单元的文字有效token的数量的3倍是少于该单元的speech token的数量。 | |
然后做如下排列:对于batch内的每个item, 先排6个文字有效token,然后再排18个speech 有效token,然后再排6个文字token,然后排18个speech token,以此类推,直到有效文本token用尽。 | |
<think_end> : [13708, 766, 835, 29] | |
""" | |
text_tokens_padded = text_tokens_padded.to(torch.int64) | |
speech_tokens_padded = speech_tokens_padded.to(torch.int64) | |
batch_size = text_tokens_padded.size(0) | |
device = text_tokens_padded.device | |
# 验证文字token数量不超过语音token的1/3 | |
for i in range(batch_size): | |
assert text_tokens_lens[i] * 3 <= speech_tokens_lens[i], \ | |
f"Batch {i}: Text tokens * 3 should be less than speech tokens" | |
streaming_mode_tokens_list = [] | |
streaming_mode_lens = [] | |
for i in range(batch_size): | |
text_tokens = text_tokens_padded[i, :text_tokens_lens[i]] | |
speech_tokens = speech_tokens_padded[i, :speech_tokens_lens[i]].to(torch.int64) | |
streaming_tokens = [] | |
text_idx = 0 | |
speech_idx = 0 | |
while text_idx < text_tokens_lens[i]: # 这里的指针指的是左指针,肯定不能等于 len(text_tokens) | |
# 处理文本token(6个一组),防止越界 | |
chunk_size = min(6, text_tokens_lens[i] - text_idx) | |
streaming_tokens.extend(text_tokens[text_idx:text_idx + chunk_size].tolist()) | |
text_idx += chunk_size | |
# 处理语音token(18个一组),防止越界 | |
speech_chunk = min(18, speech_tokens_lens[i] - speech_idx) | |
streaming_tokens.extend(speech_tokens[speech_idx:speech_idx + speech_chunk].tolist()) | |
speech_idx += speech_chunk | |
# 添加剩余的语音token | |
streaming_tokens.extend(speech_tokens[speech_idx:].tolist()) | |
streaming_mode_tokens_list.append(torch.tensor(streaming_tokens, dtype=torch.int64, device=device)) | |
streaming_mode_lens.append(len(streaming_tokens)) | |
streaming_mode_tokens_padded = pad_sequence(streaming_mode_tokens_list, batch_first=True, padding_value=0).to( | |
device) | |
streaming_mode_tokens_lens = torch.tensor(streaming_mode_lens, device=device) | |
return streaming_mode_tokens_padded, streaming_mode_tokens_lens | |
def make_streaming_mode_from_s2s4think( | |
text_tokens_padded, text_tokens_lens, | |
speech_tokens_padded, speech_tokens_lens, | |
): | |
""" | |
Args: | |
text_tokens_padded: (B, Lmax) | |
text_tokens_lens: (B,) | |
speech_tokens_padded: (B, Lmax2) | |
speech_tokens_lens: (B,) | |
Returns: | |
streaming_mode_tokens_padded: (B, Lmax+Lmax2+1) | |
streaming_mode_tokens_lens: (B,) | |
""" | |
text_tokens_padded = text_tokens_padded.to(torch.int64) | |
speech_tokens_padded = speech_tokens_padded.to(torch.int64) | |
batch_size = text_tokens_padded.size(0) | |
device = text_tokens_padded.device | |
# 验证文字 token 数量不超过语音 token 的 1/3 | |
for i in range(batch_size): | |
assert text_tokens_lens[i] * 3 <= speech_tokens_lens[i], \ | |
f"Batch {i}: Text tokens * 3 should be <= speech tokens" | |
streaming_mode_tokens_list = [] | |
streaming_mode_lens = [] | |
# 要检测的子序列 | |
target_seq = [13708, 766, 835, 29] | |
seq_len = len(target_seq) | |
for i in range(batch_size): | |
# 取出本样本的有效文本和语音序列 | |
text_tokens = text_tokens_padded[i, :text_tokens_lens[i]] | |
speech_tokens = speech_tokens_padded[i, :speech_tokens_lens[i]] | |
streaming_tokens = [] | |
# —— 新增逻辑:先在 text_tokens 中寻找整个子序列 target_seq —— | |
text_list = text_tokens.tolist() | |
prefix_end_idx = 0 | |
# 滑窗匹配 | |
for j in range(text_tokens_lens[i] - seq_len + 1): | |
if text_list[j:j + seq_len] == target_seq: | |
prefix_end_idx = j + seq_len | |
break | |
# 如果找到了,就先把前缀一次性输出 | |
if prefix_end_idx > 0: | |
streaming_tokens.extend(text_list[:prefix_end_idx]) | |
text_idx = prefix_end_idx | |
else: | |
text_idx = 0 | |
# —— 新增逻辑结束 —— | |
speech_idx = 0 | |
# 之后再从 text_idx 开始做常规的“6 文本 + 18 语音”交错 | |
while text_idx < text_tokens_lens[i]: | |
# 文本块(最多 6) | |
chunk_size = min(6, text_tokens_lens[i] - text_idx) | |
streaming_tokens.extend(text_list[text_idx:text_idx + chunk_size]) | |
text_idx += chunk_size | |
# 语音块(最多 18) | |
speech_chunk = min(18, speech_tokens_lens[i] - speech_idx) | |
streaming_tokens.extend(speech_tokens[speech_idx:speech_idx + speech_chunk].tolist()) | |
speech_idx += speech_chunk | |
# 最后再把剩余的所有语音 token 全部补上 | |
streaming_tokens.extend(speech_tokens[speech_idx:].tolist()) | |
# 收集本样本结果 | |
streaming_mode_tokens_list.append( | |
torch.tensor(streaming_tokens, dtype=torch.int64, device=device) | |
) | |
streaming_mode_lens.append(len(streaming_tokens)) | |
# padding 到同样长度 | |
streaming_mode_tokens_padded = pad_sequence( | |
streaming_mode_tokens_list, | |
batch_first=True, | |
padding_value=0 | |
).to(device) | |
streaming_mode_tokens_lens = torch.tensor(streaming_mode_lens, device=device) | |
return streaming_mode_tokens_padded, streaming_mode_tokens_lens | |
def do_embedding_for_two_embeds(input_token_ids, dividing_id, embedding1, embedding2): | |
""" | |
Args: | |
input_token_ids: (B, Lmax) ,其词表范围是[0, vocab_size1+vocab_size2) | |
dividing_id: int, 第一个词表的个数 | |
embedding1: nn.Embedding(vocab_size1, embedding_dim) | |
embedding2: nn.Embedding(vocab_size2, embedding_dim) | |
Returns: | |
embedding1_output: (B, Lmax, D) | |
把两个embeddings 虚拟成一个大的词向量 | |
""" | |
input_token_ids = input_token_ids.to(torch.int64) | |
mask4embedding1 = input_token_ids < dividing_id | |
mask4embedding2 = input_token_ids >= dividing_id | |
embedding1_output = embedding1(input_token_ids[mask4embedding1]).to(embedding1.weight.dtype) | |
embedding2_output = embedding2(input_token_ids[mask4embedding2] - dividing_id).to(embedding1.weight.dtype) | |
res_output = torch.zeros(input_token_ids.size(0), input_token_ids.size(1), embedding1.embedding_dim,dtype=embedding1.weight.dtype, | |
device=embedding1.weight.device) | |
res_output[mask4embedding1] = embedding1_output | |
res_output[mask4embedding2] = embedding2_output | |
return res_output | |
def do_convert_num2text(num_str: str): | |
""" | |
将数字字符串转换为中文数字 | |
Args: | |
num_str: 数字字符串 | |
Returns: | |
转换后的中文数字字符串 | |
""" | |
import cn2an | |
num_str = num_str.strip() | |
output = cn2an.transform(num_str, "an2cn") | |
return output | |
def _do_test_for_streaming_chat(): | |
# test make_streaming_mode_from_s2s | |
text_tokens_padded = torch.randint(0, 100, (3, 10)).to(torch.device('npu:0')) | |
text_tokens_lens = torch.tensor([5, 7, 3]).to(torch.device('npu:0')) | |
speech_tokens_padded = torch.randint(100, 200, (3, 150)).to(torch.device('npu:0')) | |
speech_tokens_lens = torch.tensor([100, 120, 80]).to(torch.device('npu:0')) | |
streaming_mode_tokens_padded, streaming_mode_tokens_lens = make_streaming_mode_from_s2s(text_tokens_padded, | |
text_tokens_lens, | |
speech_tokens_padded, | |
speech_tokens_lens) | |
print(streaming_mode_tokens_padded.shape) | |
print(streaming_mode_tokens_padded.device) | |
print(streaming_mode_tokens_lens) | |
print(streaming_mode_tokens_lens.device) | |
# test do_embedding_for_two_embeds | |
input_token_ids = torch.randint(0, 100, (3, 10)).to(torch.device('npu:0')) | |
dividing_id = 50 | |
embedding1 = torch.nn.Embedding(50, 10).to(torch.device('npu:0')) | |
embedding2 = torch.nn.Embedding(50, 10).to(torch.device('npu:0')) | |
res_output = do_embedding_for_two_embeds(input_token_ids, dividing_id, embedding1, embedding2) | |
print(res_output.shape) | |
print(res_output.device) | |
a = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, ]).to(torch.device('npu:0')) | |
print(a[3:1000]) | |
if __name__ == '__main__': | |
"""""" | |