OSUM-EChat / wenet /osum_echat /utils4llmasr.py
xlgeng's picture
开始部署
841f290
import random
from typing import Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
from wenet.utils.common import pad_list
from gxl_ai_utils.utils import utils_file
def add_sos_eos4speech_llm(ys_pad: torch.Tensor, sos: int, eos: int,
ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add <sos> and <eos> labels.
为out后接一个eos. in基本保持不变
Args:
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of <sos>
eos (int): index of <eeos>
ignore_id (int): index of padding
Returns:
ys_in (torch.Tensor) : (B, Lmax)
ys_out (torch.Tensor) : (B, Lmax + 1)
Examples:
>>> sos_id = 10
>>> eos_id = 11
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=torch.int32)
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
>>> ys_in
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, 11, 11],
[ 7, 8, 9, 11, 11]])
>>> ys_out
tensor([[ 1, 2, 3, 4, 5, 11],
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
_sos = torch.tensor([sos],
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
_eos = torch.tensor([eos],
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
# ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
ys_in = [y for y in ys]
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
global_prompt_dict = None
def get_prompt_by_task(task_name):
"""
根据task给定指定的prompt, 并实现prompt的多样随意性
Args:
task_name:
Returns:
"""
global global_prompt_dict
if global_prompt_dict is None:
global_prompt_dict = utils_file.load_dict_from_yaml('conf/prompt.yaml')
random_index = random.randint(0, len(global_prompt_dict[task_name]) - 1)
return global_prompt_dict[task_name][random_index]
import torch
def merge_labels_with_valid_adjacent(
labels_embeds1, labels_target1, labels_mask1,
labels_embeds2, labels_target2, labels_mask2,
pad_value=0, ignore_id=-100
):
"""
合并两组标签,有效特征紧邻拼接,无效特征后移
Args:
labels_embeds1 (Tensor): 标签1嵌入,形状 (B, L1, D)
labels_target1 (Tensor): 标签1目标,形状 (B, L1)
labels_mask1 (Tensor): 标签1掩码,形状 (B, L1)
labels_embeds2 (Tensor): 标签2嵌入,形状 (B, L2, D)
labels_target2 (Tensor): 标签2目标,形状 (B, L2)
labels_mask2 (Tensor): 标签2掩码,形状 (B, L2)
pad_value (int): 嵌入填充值
ignore_id (int): 目标填充值(如IGNORE_ID)
Returns:
merged_embeds (Tensor): 合并嵌入,形状 (B, L1+L2, D)
merged_target (Tensor): 合并目标,形状 (B, L1+L2)
merged_mask (Tensor): 合并掩码,形状 (B, L1+L2)
"""
batch_size = labels_embeds1.size(0)
max_len = labels_embeds1.size(1) + labels_embeds2.size(1)
merged_embeds = []
merged_target = []
merged_mask = []
for i in range(batch_size):
# 提取有效特征索引
valid_indices1 = torch.where(labels_mask1[i])[0]
valid_indices2 = torch.where(labels_mask2[i])[0]
# 合并有效特征段
valid_embeds = torch.cat([
labels_embeds1[i, valid_indices1],
labels_embeds2[i, valid_indices2]
], dim=0)
valid_target = torch.cat([
labels_target1[i, valid_indices1],
labels_target2[i, valid_indices2]
], dim=0)
valid_mask = torch.cat([
labels_mask1[i, valid_indices1],
labels_mask2[i, valid_indices2]
], dim=0)
# 填充无效部分
pad_length = max_len - len(valid_embeds)
padded_embeds = torch.cat([
valid_embeds,
torch.full((pad_length, labels_embeds1.size(2)), pad_value, device=labels_embeds1.device)
], dim=0)
padded_target = torch.cat([
valid_target,
torch.full((pad_length,), ignore_id, device=labels_target1.device)
], dim=0)
padded_mask = torch.cat([
valid_mask,
torch.zeros(pad_length, dtype=torch.bool, device=labels_mask1.device)
], dim=0)
merged_embeds.append(padded_embeds)
merged_target.append(padded_target)
merged_mask.append(padded_mask)
# 堆叠批次结果
merged_embeds = torch.stack(merged_embeds, dim=0).to(labels_embeds1.device)
merged_target = torch.stack(merged_target, dim=0).to(labels_target1.device)
merged_mask = torch.stack(merged_mask, dim=0).to(labels_mask1.device)
return merged_embeds, merged_target, merged_mask
def make_streaming_mode_from_s2s_old(text_tokens_padded, text_tokens_lens, speech_tokens_padded, speech_tokens_lens, ):
"""
Args:
text_tokens_padded: (B, Lmax)
text_tokens_lens: (B,)
speech_tokens_padded: (B, Lmax2)
speech_tokens_lens: (B,)
Returns:
streaming_mode_tokens_padded: (B, Lmax+Lmax2+1)
streaming_mode_tokens_lens: (B,)
首先assert每个单元的文字有效token的数量的3倍是少于该单元的speech token的数量。
然后做如下排列:对于batch内的每个item, 先排6个文字有效token,然后再排18个speech 有效token,然后再排6个文字token,然后排18个speech token,以此类推,直到有效文本token用尽。
"""
text_tokens_padded = text_tokens_padded.to(torch.int64)
speech_tokens_padded = speech_tokens_padded.to(torch.int64)
batch_size = text_tokens_padded.size(0)
device = text_tokens_padded.device
# 验证文字token数量不超过语音token的1/3
for i in range(batch_size):
assert text_tokens_lens[i] * 3 <= speech_tokens_lens[i], \
f"Batch {i}: Text tokens * 3 should be less than speech tokens"
streaming_mode_tokens_list = []
streaming_mode_lens = []
for i in range(batch_size):
text_tokens = text_tokens_padded[i, :text_tokens_lens[i]]
speech_tokens = speech_tokens_padded[i, :speech_tokens_lens[i]].to(torch.int64)
streaming_tokens = []
text_idx = 0
speech_idx = 0
while text_idx < text_tokens_lens[i]:
# 处理文本token(6个一组),防止越界
chunk_size = min(6, text_tokens_lens[i] - text_idx)
streaming_tokens.extend(text_tokens[text_idx:text_idx + chunk_size].tolist())
text_idx += chunk_size
# 如果文本token不足6个,添加999标记
if chunk_size < 6:
streaming_tokens.append(999)
# 处理语音token(18个一组),防止越界
speech_chunk = min(18, speech_tokens_lens[i] - speech_idx)
streaming_tokens.extend(speech_tokens[speech_idx:speech_idx + speech_chunk].tolist())
speech_idx += speech_chunk
# 如果文本token正好用完,添加999标记
if text_idx == text_tokens_lens[i] and text_tokens_lens[i] % 6 == 0:
streaming_tokens.append(999)
# 添加剩余的语音token
streaming_tokens.extend(speech_tokens[speech_idx:].tolist())
# 转换为BFLOAT16张量
streaming_mode_tokens_list.append(torch.tensor(streaming_tokens, dtype=torch.int64, device=device))
streaming_mode_lens.append(len(streaming_tokens))
streaming_mode_tokens_padded = pad_sequence(streaming_mode_tokens_list, batch_first=True, padding_value=0).to(
device)
streaming_mode_tokens_lens = torch.tensor(streaming_mode_lens, device=device)
return streaming_mode_tokens_padded, streaming_mode_tokens_lens
def make_streaming_mode_from_s2s(text_tokens_padded, text_tokens_lens, speech_tokens_padded, speech_tokens_lens, ):
"""
Args:
text_tokens_padded: (B, Lmax)
text_tokens_lens: (B,)
speech_tokens_padded: (B, Lmax2)
speech_tokens_lens: (B,)
Returns:
streaming_mode_tokens_padded: (B, Lmax+Lmax2+1)
streaming_mode_tokens_lens: (B,)
首先assert每个单元的文字有效token的数量的3倍是少于该单元的speech token的数量。
然后做如下排列:对于batch内的每个item, 先排6个文字有效token,然后再排18个speech 有效token,然后再排6个文字token,然后排18个speech token,以此类推,直到有效文本token用尽。
<think_end> : [13708, 766, 835, 29]
"""
text_tokens_padded = text_tokens_padded.to(torch.int64)
speech_tokens_padded = speech_tokens_padded.to(torch.int64)
batch_size = text_tokens_padded.size(0)
device = text_tokens_padded.device
# 验证文字token数量不超过语音token的1/3
for i in range(batch_size):
assert text_tokens_lens[i] * 3 <= speech_tokens_lens[i], \
f"Batch {i}: Text tokens * 3 should be less than speech tokens"
streaming_mode_tokens_list = []
streaming_mode_lens = []
for i in range(batch_size):
text_tokens = text_tokens_padded[i, :text_tokens_lens[i]]
speech_tokens = speech_tokens_padded[i, :speech_tokens_lens[i]].to(torch.int64)
streaming_tokens = []
text_idx = 0
speech_idx = 0
while text_idx < text_tokens_lens[i]: # 这里的指针指的是左指针,肯定不能等于 len(text_tokens)
# 处理文本token(6个一组),防止越界
chunk_size = min(6, text_tokens_lens[i] - text_idx)
streaming_tokens.extend(text_tokens[text_idx:text_idx + chunk_size].tolist())
text_idx += chunk_size
# 处理语音token(18个一组),防止越界
speech_chunk = min(18, speech_tokens_lens[i] - speech_idx)
streaming_tokens.extend(speech_tokens[speech_idx:speech_idx + speech_chunk].tolist())
speech_idx += speech_chunk
# 添加剩余的语音token
streaming_tokens.extend(speech_tokens[speech_idx:].tolist())
streaming_mode_tokens_list.append(torch.tensor(streaming_tokens, dtype=torch.int64, device=device))
streaming_mode_lens.append(len(streaming_tokens))
streaming_mode_tokens_padded = pad_sequence(streaming_mode_tokens_list, batch_first=True, padding_value=0).to(
device)
streaming_mode_tokens_lens = torch.tensor(streaming_mode_lens, device=device)
return streaming_mode_tokens_padded, streaming_mode_tokens_lens
def make_streaming_mode_from_s2s4think(
text_tokens_padded, text_tokens_lens,
speech_tokens_padded, speech_tokens_lens,
):
"""
Args:
text_tokens_padded: (B, Lmax)
text_tokens_lens: (B,)
speech_tokens_padded: (B, Lmax2)
speech_tokens_lens: (B,)
Returns:
streaming_mode_tokens_padded: (B, Lmax+Lmax2+1)
streaming_mode_tokens_lens: (B,)
"""
text_tokens_padded = text_tokens_padded.to(torch.int64)
speech_tokens_padded = speech_tokens_padded.to(torch.int64)
batch_size = text_tokens_padded.size(0)
device = text_tokens_padded.device
# 验证文字 token 数量不超过语音 token 的 1/3
for i in range(batch_size):
assert text_tokens_lens[i] * 3 <= speech_tokens_lens[i], \
f"Batch {i}: Text tokens * 3 should be <= speech tokens"
streaming_mode_tokens_list = []
streaming_mode_lens = []
# 要检测的子序列
target_seq = [13708, 766, 835, 29]
seq_len = len(target_seq)
for i in range(batch_size):
# 取出本样本的有效文本和语音序列
text_tokens = text_tokens_padded[i, :text_tokens_lens[i]]
speech_tokens = speech_tokens_padded[i, :speech_tokens_lens[i]]
streaming_tokens = []
# —— 新增逻辑:先在 text_tokens 中寻找整个子序列 target_seq ——
text_list = text_tokens.tolist()
prefix_end_idx = 0
# 滑窗匹配
for j in range(text_tokens_lens[i] - seq_len + 1):
if text_list[j:j + seq_len] == target_seq:
prefix_end_idx = j + seq_len
break
# 如果找到了,就先把前缀一次性输出
if prefix_end_idx > 0:
streaming_tokens.extend(text_list[:prefix_end_idx])
text_idx = prefix_end_idx
else:
text_idx = 0
# —— 新增逻辑结束 ——
speech_idx = 0
# 之后再从 text_idx 开始做常规的“6 文本 + 18 语音”交错
while text_idx < text_tokens_lens[i]:
# 文本块(最多 6)
chunk_size = min(6, text_tokens_lens[i] - text_idx)
streaming_tokens.extend(text_list[text_idx:text_idx + chunk_size])
text_idx += chunk_size
# 语音块(最多 18)
speech_chunk = min(18, speech_tokens_lens[i] - speech_idx)
streaming_tokens.extend(speech_tokens[speech_idx:speech_idx + speech_chunk].tolist())
speech_idx += speech_chunk
# 最后再把剩余的所有语音 token 全部补上
streaming_tokens.extend(speech_tokens[speech_idx:].tolist())
# 收集本样本结果
streaming_mode_tokens_list.append(
torch.tensor(streaming_tokens, dtype=torch.int64, device=device)
)
streaming_mode_lens.append(len(streaming_tokens))
# padding 到同样长度
streaming_mode_tokens_padded = pad_sequence(
streaming_mode_tokens_list,
batch_first=True,
padding_value=0
).to(device)
streaming_mode_tokens_lens = torch.tensor(streaming_mode_lens, device=device)
return streaming_mode_tokens_padded, streaming_mode_tokens_lens
def do_embedding_for_two_embeds(input_token_ids, dividing_id, embedding1, embedding2):
"""
Args:
input_token_ids: (B, Lmax) ,其词表范围是[0, vocab_size1+vocab_size2)
dividing_id: int, 第一个词表的个数
embedding1: nn.Embedding(vocab_size1, embedding_dim)
embedding2: nn.Embedding(vocab_size2, embedding_dim)
Returns:
embedding1_output: (B, Lmax, D)
把两个embeddings 虚拟成一个大的词向量
"""
input_token_ids = input_token_ids.to(torch.int64)
mask4embedding1 = input_token_ids < dividing_id
mask4embedding2 = input_token_ids >= dividing_id
embedding1_output = embedding1(input_token_ids[mask4embedding1]).to(embedding1.weight.dtype)
embedding2_output = embedding2(input_token_ids[mask4embedding2] - dividing_id).to(embedding1.weight.dtype)
res_output = torch.zeros(input_token_ids.size(0), input_token_ids.size(1), embedding1.embedding_dim,dtype=embedding1.weight.dtype,
device=embedding1.weight.device)
res_output[mask4embedding1] = embedding1_output
res_output[mask4embedding2] = embedding2_output
return res_output
def do_convert_num2text(num_str: str):
"""
将数字字符串转换为中文数字
Args:
num_str: 数字字符串
Returns:
转换后的中文数字字符串
"""
import cn2an
num_str = num_str.strip()
output = cn2an.transform(num_str, "an2cn")
return output
def _do_test_for_streaming_chat():
# test make_streaming_mode_from_s2s
text_tokens_padded = torch.randint(0, 100, (3, 10)).to(torch.device('npu:0'))
text_tokens_lens = torch.tensor([5, 7, 3]).to(torch.device('npu:0'))
speech_tokens_padded = torch.randint(100, 200, (3, 150)).to(torch.device('npu:0'))
speech_tokens_lens = torch.tensor([100, 120, 80]).to(torch.device('npu:0'))
streaming_mode_tokens_padded, streaming_mode_tokens_lens = make_streaming_mode_from_s2s(text_tokens_padded,
text_tokens_lens,
speech_tokens_padded,
speech_tokens_lens)
print(streaming_mode_tokens_padded.shape)
print(streaming_mode_tokens_padded.device)
print(streaming_mode_tokens_lens)
print(streaming_mode_tokens_lens.device)
# test do_embedding_for_two_embeds
input_token_ids = torch.randint(0, 100, (3, 10)).to(torch.device('npu:0'))
dividing_id = 50
embedding1 = torch.nn.Embedding(50, 10).to(torch.device('npu:0'))
embedding2 = torch.nn.Embedding(50, 10).to(torch.device('npu:0'))
res_output = do_embedding_for_two_embeds(input_token_ids, dividing_id, embedding1, embedding2)
print(res_output.shape)
print(res_output.device)
a = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, ]).to(torch.device('npu:0'))
print(a[3:1000])
if __name__ == '__main__':
""""""