OSUM-EChat / wenet /osum_echat /llmasr_model_instruct_version.py
xlgeng's picture
开始部署
841f290
import logging
import os
from typing import Dict, List, Optional, Union
import torchaudio
import torch
from peft import LoraConfig, TaskType, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, StoppingCriteriaList
from patches.cumstom_stop_criteria import InterruptStopper, S2SStopCriteria, MaxTokenStopper
from patches.custom_speech_ngram_blocking import SpeechOnlyNGramBlockingLogitsProcessor, OSUM_chat_LogitsProcessor
from wenet.osum_echat.utils4llmasr import make_streaming_mode_from_s2s, do_embedding_for_two_embeds
from wenet.transformer.encoder import TransformerEncoder, TransformerEncoder2
from wenet.osum_echat.utils4llmasr import *
from gxl_ai_utils.utils import utils_file
from wenet.osum_echat.downsampler import get_downsampler, osum_echat2Conv1dSubsampling
from wenet.transformer.swish import New_gelu4npu
from wenet.utils.mask import make_pad_mask
class LLMASR_Model(nn.Module):
def __init__(self,
encoder,
encoder_output_dim,
llm_path,
lora=True, lora_alpha=32, lora_rank=8, lora_dropout=0.1,
is_inference=False,
downsample_rate=1,
adapter_type='osum_echat2',
speech_token_num=0,
train_speech_out=False):
""""""
super().__init__()
utils_file.logging_limit_print(f"instruct_version: LLMASR_Model init, is_inference={is_inference}, downsample_rate={downsample_rate}, adapter_type={adapter_type}, speech_token_num={speech_token_num}, train_speech_out={train_speech_out}")
self.downsample_rate = downsample_rate
self.encoder = encoder
self.ln_speech = nn.LayerNorm(encoder_output_dim)
# 连接层, 51.6M
if adapter_type == 'osum_echat':
self.speech_transformer = TransformerEncoder(
input_size=encoder_output_dim,
output_size=encoder_output_dim,
attention_heads=4,
linear_units=2560,
num_blocks=4,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="linear",
pos_enc_layer_type="abs_pos",
normalize_before=True
)
else:
self.speech_transformer = None
self.llama_model = AutoModelForCausalLM.from_pretrained(
llm_path,
# torch_dtype=torch.float32 if is_inference else torch.float16,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
output_hidden_states=True,
)
self.s2s_stop_criteria = None
self.max_token_criteria_list = None
self.max_length = 4000
self.min_length = 1
self.num_beams = 4
self.do_sample = True
self.top_p = 0.9
self.top_k = 5
self.repetition_penalty = 1.05
self.length_penalty = 1.0
self.temperature = 1.0
self.IGNORE_ID = -100
# lora
self.lora = lora
if lora:
utils_file.logging_limit_print("OSUM-EChat: 使用lora了")
# target_modules = ['w_pack', 'o_proj', 'gate_proj', 'down_proj']
target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'down_proj']
self.peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=is_inference,
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules,
)
self.llama_model = get_peft_model(self.llama_model, self.peft_config)
# tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
llm_path, use_fast=False, trust_remote_code=True)
"""
设置分词器的pad_token和padding的方向。
"""
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.tokenizer.padding_side = "right"
self.eos_token_id = self.tokenizer.eos_token_id
if hasattr(self.llama_model.config, 'hidden_size'):
utils_file.logging_limit_print(
f"self.llama_model.config.hidden_size: {self.llama_model.config.hidden_size}")
if adapter_type == 'osum_echat2':
self.down_sample_2 = osum_echat2Conv1dSubsampling(encoder_output_dim, self.llama_model.config.hidden_size)
elif adapter_type == 'osum_echat':
self.down_sample_2 = get_downsampler(downsample_rate, encoder_output_dim)
self.speech_llama_proj = nn.Linear(
encoder_output_dim, self.llama_model.config.hidden_size)
else:
raise NotImplementedError("self.llama_model.config.hidden_size not exist")
self.embed_tokens = self.llama_model.model.model.embed_tokens if self.lora else self.llama_model.model.embed_tokens
self.lm_head = self.llama_model.model.lm_head if self.lora else self.llama_model.lm_head
self.llm_vocab_size = self.lm_head.weight.shape[0]
self.speech_token_num = speech_token_num
# init speech token module
if speech_token_num > 0:
utils_file.logging_info(f'OSUM-EChat: 进行语音token生成任务, speech_token_num: {speech_token_num}')
self.speech_token_emded = torch.nn.Embedding(speech_token_num + 2, self.llama_model.config.hidden_size)
self.speech_head = torch.nn.Linear(self.llama_model.config.hidden_size, speech_token_num)
else:
# 不做任何处理
self.speech_head = nn.Identity()
self.speech_token_emded = nn.Identity()
self.speech_model = nn.Identity()
self.train_speech_out = train_speech_out
utils_file.logging_info(f'OSUM-EChat: 是否进行语音输出训练:{self.train_speech_out}')
self.loss_fct = CrossEntropyLoss(reduction='mean')
self.unk_token_id = 7672 # &&对应的id
self.add_embed_head = True
self.init_custom_speech_repetition_penalty()
self.init_custom_stop_criteria()
def set_task_type(self, task_type: str):
"""设置任务类型,用于设置生成的初始类型
Args:
task_type (str): 任务类型,从("ASR", "TTS", "S2S")选择
"""
assert task_type in ("ASR", "TTS", "S2S")
if task_type == "ASR":
self.llama_model.text_phase = True
elif task_type == "TTS":
self.llama_model.text_phase = False
elif task_type == "S2S":
self.llama_model.text_phase = True
def do_add_speech_embed_head(self):
if self.add_embed_head:
self.llama_model.speech_token_emded = self.speech_token_emded.to(torch.bfloat16)
self.llama_model.speech_head = self.speech_head.to(torch.bfloat16)
# self.llama_model.speech_token_emded = self.speech_token_emded.to(torch.bfloat16)
# self.llama_model.speech_head = self.speech_head.to(torch.bfloat16) # 带lora的时候用
self.add_embed_head = False
def init_custom_speech_repetition_penalty(self):
"""
"""
self.s2s_repetition_penalty = LogitsProcessorList()
# self.speech_repetition_penalty = SpeechOnlyRepetitionPenaltyLogitsProcessor(speech_token_num=4097, penalty=1.5)
self.speech_repetition_penalty = SpeechOnlyNGramBlockingLogitsProcessor(speech_token_num=4097, repeat_times=5,
special_token_repeat_times_dict={
1446: 10})
self.osum_chat_logit_processor1 = OSUM_chat_LogitsProcessor([99119, 1808, 7863], [102185, 17714, 31252])
self.s2s_repetition_penalty.append(self.osum_chat_logit_processor1)
self.s2s_repetition_penalty.append(self.speech_repetition_penalty)
self.llama_model.speech_repetition_penalty = self.speech_repetition_penalty
def init_custom_stop_criteria(self):
"""
创建需要的stop criteria
1. 对于t2t任务,遇到text_eos停止
2. 对于t2s任务,遇到speech_eos停止
3. 对于s2s任务,遇到speech_eos停止
同时要取消原本的停止条件
if generation_config._eos_token_tensor is not None:
取消 generation_config._eos_token_tensor 的停止,尝试直接给一个大于vocb_size的eos_token
"""
self.interrupt = InterruptStopper()
self.s2s_stop_criteria = StoppingCriteriaList()
self.s2s_stop_criteria.append(S2SStopCriteria(text_eos_id=151645, speech_eos_id=self.speech_token_num - 1))
self.s2s_stop_criteria.append(MaxTokenStopper(2000))
self.s2s_stop_criteria.append(self.interrupt)
def get_label_embedding(self, labels, labels_lengths, unk_id=7672):
""""""
labels_pad_mask = make_pad_mask(labels_lengths) # B, L
unk_mask = (labels == unk_id) # B, L
labels_pad_mask = labels_pad_mask | unk_mask #
labels = labels.masked_fill(labels_pad_mask, 0)
labels_embeds = self.embed_tokens(labels)
labels_target = labels.masked_fill(labels_pad_mask, self.IGNORE_ID) # B, L
labels_mask = ~labels_pad_mask
return labels_embeds, labels_target, labels_mask
def get_speech_token_label_embedding(self, speech_token_labels, speech_tokens_length):
""""""
speech_tokens_pad_mask = make_pad_mask(speech_tokens_length) # B, L
speech_token_labels = speech_token_labels.masked_fill(speech_tokens_pad_mask, 0)
speech_token_labels_embeds = self.speech_token_emded(speech_token_labels)
# utils_file.logging_limit_print(f'进行speech_token_labels修改,修改前 speech_token_labels',
# speech_token_labels.shape, speech_token_labels[0][-1], speech_token_labels[0][0])
speech_token_labels = speech_token_labels + self.llm_vocab_size
# utils_file.logging_limit_print(f'进行speech_token_labels修改,修改后 speech_token_labels',
# speech_token_labels.shape, speech_token_labels[0][-1], speech_token_labels[0][0])
speech_token_labels_target = speech_token_labels.masked_fill(speech_tokens_pad_mask, self.IGNORE_ID) # B, L
speech_token_labels_mask = ~speech_tokens_pad_mask
return speech_token_labels_embeds, speech_token_labels_target, speech_token_labels_mask
def _get_embedding_for_history(self, history_batch, device):
"""
prompt_patern1,prompt,history, wav, prompt_patern2,txt,answer_wav,
historcy_batch的内容是:
[ big_embed,
[
{wav: feat (L,D:80),->过encoder+link ,得到(L1, 2048)
txt: labels (L,), ->labels_embeds = self.embed_tokens(labels) -> (L2, 2048), 带txt eos
},->(L1+L2, 2048)
{wav: feat (L,D),
txt: labels (L,),
},->(L3+L4, 2048)
]-> (L1+L2+L3+L4, 2048),len:L1+L2+L3+L4
[],->(1, 2048) ,len:0
[
{wav: feat (L,D),
txt: labels (L,),
},
],-> (L1+L2, 2048),len:L1+L2
]
将每一条的历史信息的embedding拼接起来,如果有空历史信息,则用0pad, 最后得到pad后的history_embedding(B, L, D), history_lens(B)
Args:
history_batch:
device:
Returns:
history_embedding: B, L, D
history_lens: B
"""
assistant_start ="<|im_end|>\n<|im_start|>assistant\n"
assistant_start_id = self.tokenizer([assistant_start], return_tensors="pt"
)['input_ids'].to(device)
assistant_start_embedding = self.embed_tokens(assistant_start_id.squeeze(0))
assistant_end ="<|im_end|>\n"
assistant_end_id = self.tokenizer([assistant_end], return_tensors="pt"
)['input_ids'].to(device)
assistant_end_embedding = self.embed_tokens(assistant_end_id.squeeze(0))
user_start = "<|im_start|>user\n"
user_start_id = self.tokenizer([user_start], return_tensors="pt"
)['input_ids'].to(device)
user_start_embedding = self.embed_tokens(user_start_id.squeeze(0))
user_end ="<|im_end|>\n"
user_end_id = self.tokenizer([user_end], return_tensors="pt"
)['input_ids'].to(device)
user_end_embedding = self.embed_tokens(user_end_id.squeeze(0))
batch_embeddings = []
history_lens = []
# 判断是否所有样本都没有历史
if all(len(history) == 0 for history in history_batch):
return None, None
for history in history_batch:
history_embeds = []
for item in history:
wav_feat = item['wav'].to(device) # shape: (L, D)
wav_feat = wav_feat.unsqueeze(0).to(device) # shape: (1, L, D)
wav_embed, wav_mask = self._get_embedding_from_wav(wav_feat, torch.tensor([wav_feat.size(1)], device=device, dtype=torch.long))
wav_embed = wav_embed.squeeze(0) # shape: (L, D)
if len(history_embeds) != 0:
history_embeds.append(user_start_embedding) # 第一个user start 不要
history_embeds.append(wav_embed)
history_embeds.append(user_end_embedding)
history_embeds.append(assistant_start_embedding)
labels = item['txt'] # shape: (L,)
labels = torch.tensor(labels, device=device, dtype=torch.long)
embed = self.embed_tokens(labels) # (L2, D),一般 L2 = L
history_embeds.append(embed)
history_embeds.append(assistant_end_embedding)
history_embeds.append(user_start_embedding) # 最后添加一个user start
if history_embeds:
# 拼接所有历史条目的 embedding: (sum(Li), D)
full_embed = torch.cat(history_embeds, dim=0)
history_lens.append(full_embed.size(0))
else:
# 空历史
full_embed = torch.zeros((1, self.embed_tokens.embedding_dim), device=device)
history_lens.append(0)
batch_embeddings.append(full_embed)
# padding 到 batch 中最大长度
padded_embeddings = pad_sequence(batch_embeddings, batch_first=True, padding_value=0.0) # (B, L, D)
history_lens = torch.tensor(history_lens, device=device, dtype=torch.long)
padded_embeddings = padded_embeddings.to(device)
return padded_embeddings, history_lens
def forward(self,
batch,
device,
):
""""""
output_type = batch['output_type']
# qwen_instruct_prompt_pattern_chat = "<|im_start|>system\nYou are OSUM-chat, a dialogue. You understand both the meaning and paralinguistic cues in speech, as well as input text, and respond appropriately.<|im_end|>\n<|im_start|>user\n"
qwen_instruct_prompt_pattern_chat_s2s = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond with appropriate text and emotionally matching synthetic speech.<|im_end|>\n"
qwen_instruct_prompt_pattern_chat_s2s_think = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech. Before responding, first output your reasoning inside <think>...</think end>, analyzing the user’s words and vocal cues. Then generate a reply with appropriate text and emotionally matched synthetic speech.<|im_end|>\n"
qwen_instruct_prompt_pattern_chat_s2s_streaming = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You analyze speech (content + paralinguistic cues) and respond with interleaved text and emotionally-matched synthetic speech.<|im_end|>\n"
qwen_instruct_prompt_pattern_chat_s2s_streaming_think = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You analyze speech (both content and paralinguistic cues). Before responding, output your reasoning in <think>...</think end>. Then reply with interleaved text and emotionally matched synthetic speech.<|im_end|>\n"
qwen_instruct_prompt_pattern_chat_s2t = "<|im_start|>system\nYou are OSUM-chat, a speech-to-text dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond exclusively with appropriate text.<|im_end|>\n"
qwen_instruct_prompt_pattern__chat_t2t = "<|im_start|>system\nYou are OSUM-chat, a text-to-text dialogue assistant by ASLP Lab. You understand user input in text then respond exclusively with appropriate text.<|im_end|>\n"
qwen_instruct_prompt_pattern_1_understand = "<|im_start|>system\nYou are OSUM-chat, an audio understanding assistant by ASLP Lab. You can transcribe speech accurately and analyze paralinguistic cues to provide precise text responses.<|im_end|>\n"
qwen_instruct_prompt_pattern_1_tts = "<|im_start|>system\nYou are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural and fluent speech from text input.<|im_end|>\n"
qwen_instruct_prompt_pattern_1_tts_streaming = "<|im_start|>system\nYou are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural speech from text input and output both audio and the original text in interleaved format.<|im_end|>\n"
qwen_instruct_prompt_pattern_1_old = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
qwen_instruct_prompt_pattern_1_s2t_thinking = "<|im_start|>system\nYou are OSUM-chat, a thinking-enabled speech-to-text dialogue assistant by ASLP Lab. You not only comprehend the semantic meaning and paralinguistic cues in speech but also engage in deliberate reasoning to process such information. Based on this thinking process, you then respond exclusively with appropriate text.<|im_end|>\n"
# user_start = "<|im_start|>user\n"
# 赋予不同的系统提示。
if output_type == "s2t_chat":
system_prompt = qwen_instruct_prompt_pattern_chat_s2t
elif output_type == "s2t_chat_fake":
system_prompt = qwen_instruct_prompt_pattern_chat_s2s_think
elif output_type == "text":
system_prompt = qwen_instruct_prompt_pattern_1_understand
elif output_type == "speech2text_token" or output_type == "speech2text_token_history":
system_prompt = qwen_instruct_prompt_pattern_chat_s2s
elif output_type == "text2token":
system_prompt = qwen_instruct_prompt_pattern_1_tts
elif output_type == "speech2text_token_streaming":
system_prompt = qwen_instruct_prompt_pattern_chat_s2s_streaming
elif output_type == "speech2text_token_think":
system_prompt = qwen_instruct_prompt_pattern_chat_s2s_think
elif output_type == "text2token_streaming":
system_prompt = qwen_instruct_prompt_pattern_1_tts_streaming
elif output_type == "text2text":
system_prompt = qwen_instruct_prompt_pattern__chat_t2t
elif output_type == "s2t_chat_think":
system_prompt = qwen_instruct_prompt_pattern_1_s2t_thinking
else:
system_prompt = qwen_instruct_prompt_pattern_1_old
# if output_type == "speech2text_token_history":
# if output_type == "text2text" or output_type == "text":
# qwen_instruct_prompt_pattern_1 = qwen_instruct_prompt_pattern_1_old
# elif output_type == "speech2text_token" or output_type == "speech2text_token_streaming" or output_type == "text2text" or output_type == "s2t_chat":
# qwen_instruct_prompt_pattern_1 = qwen_instruct_prompt_pattern_chat
# elif output_type == "text2token":
# qwen_instruct_prompt_pattern_1 = qwen_instruct_prompt_pattern_1_tts
# else:
# qwen_instruct_prompt_pattern_1 = qwen_instruct_prompt_pattern_1_old
system_prompt = system_prompt + "<|im_start|>user\n"
rank = int(os.environ.get('RANK', 0))
utils_file.logging_limit_print(f'xxx output_type {output_type}, rank {rank}')
# if output_type == "s2t_chat":
# output_type = "text"
# assert output_type in ['text', 'speech2text_token', 'text2token'], f"output_type:{output_type} not support"
# speech inputs
if output_type == 'text' or output_type == 's2t_chat' or output_type == 's2t_chat_fake' or output_type== "s2t_chat_think" or output_type == 'speech2text_token' or output_type == "speech2text_token_streaming" or output_type == "speech2text_token_think" or output_type == "speech2text_token_history":
wavs = batch['feats'].to(device)
# utils_file.logging_limit_print(f'xxx wav shape {wavs.shape}')
wavs_len = batch['feats_lengths'].to(device)
B = wavs.shape[0]
# utils_file.logging_limit_print(f"xxx {wavs_len}")
speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
# utils_file.logging_limit_print(f'xxx speech embeding shape {speech_embeds.shape}')
# utils_file.logging_limit_print(f'xxx speech mask shape {speech_masks.shape}')
# utils_file.logging_limit_print(f'xxx speech mask 0 {speech_masks[0]}')
speech_target = torch.full(speech_masks.shape, self.IGNORE_ID).to(
speech_embeds.device)
# utils_file.logging_limit_print(f'xxx speech target shape {speech_target.shape}')
# utils_file.logging_limit_print(f'xxx speech target 0 {speech_target[0]}')
# add bos and eos
speech_embeds, speech_masks, speech_target = self._add_bos_eos(0+self.speech_token_num,
1+self.speech_token_num,
speech_embeds, speech_masks, speech_target)
elif output_type == "text2token" or output_type == "text2token_streaming":
labels = batch['target'].to(device)
labels_lengths = batch['target_lengths'].to(device) -1 # 减1是因为要去掉eos
B = labels.shape[0]
# text 2 token ,拿到文本序列,
max_len = max(labels_lengths) + 1
labels_pad_mask = make_pad_mask(labels_lengths, max_len=max_len)
labels = labels.masked_fill(labels_pad_mask, 0)
speech_embeds = self.embed_tokens(labels) # B, L, D
speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to(
speech_embeds.device)
speech_masks = ~labels_pad_mask
# add bos and eos
# speech_embeds, speech_masks, speech_target = self._add_bos_eos(0+self.speech_token_num,
# 1 + self.speech_token_num,
# speech_embeds, speech_masks, speech_target)
else: # text2text
speech_embeds = None
speech_masks = None
speech_target = None
# utils_file.logging_limit_print(f'xxx after add bos eos speech embeding shape {speech_embeds.shape}')
# utils_file.logging_limit_print(f'xxx after add bos eos speech mask shape {speech_masks.shape}')
# utils_file.logging_limit_print(f'xxx after add bos eos speech target shape {speech_target.shape}')
# utils_file.logging_limit_print(f'xxx after add bos eos speech mask 0 {speech_masks[0]}')
# utils_file.logging_limit_print(f'xxx after add bos eos speech target 0 {speech_target[0]}')
# prompt
if 'prompt' in batch:
prompt = batch['prompt'].to(device)
prompt_lengths = batch['prompt_lengths'].to(device)
prompt_pad_mask = make_pad_mask(prompt_lengths) # B, L
prompt = prompt.masked_fill(prompt_pad_mask, self.tokenizer.eos_token_id)
prompt_embeds = self.embed_tokens(prompt) # B, L, D
prompt_target = torch.full(prompt.shape, self.IGNORE_ID).to(
device) # B, L
prompt_mask = ~prompt_pad_mask
# utils_file.logging_limit_print(f'xxx prompt embeding shape {prompt_embeds.shape}')
# utils_file.logging_limit_print(f'xxx prompt mask shape {prompt_mask.shape}')
# utils_file.logging_limit_print(f'xxx prompt target shape {prompt_target.shape}')
else:
prompt_embeds = None
prompt_mask = None
prompt_target = None
inputs_embeds_list = []
attention_mask_list = []
target_list = []
prompt_pattern1 = self.tokenizer([system_prompt] * len(batch['target']), return_tensors="pt"
)['input_ids'].to(device)
prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)
prompt_pattern1_lens = torch.tensor([len(i) for i in prompt_pattern1]).to(device)
prompt_pattern1_mask = ~make_pad_mask(prompt_pattern1_lens)
prompt_pattern1_target = torch.full(prompt_pattern1.shape, self.IGNORE_ID).to(
device) # B, L
# user_start_id = self.tokenizer([user_start] * len(batch['target']), return_tensors="pt"
# )['input_ids'].to(device)
# user_start_embeds = self.embed_tokens(user_start_id)
# user_start_lens = torch.tensor([len(i) for i in user_start_id]).to(device)
# user_start_mask = ~make_pad_mask(user_start_lens)
# user_start_target = torch.full(user_start_id.shape, self.IGNORE_ID).to(
# device) # B, L
qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(batch['target']), return_tensors="pt"
)['input_ids'].to(device)
prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)
prompt_pattern2_lens = torch.tensor([len(i) for i in prompt_pattern2]).to(device)
prompt_pattern2_mask = ~make_pad_mask(prompt_pattern2_lens)
prompt_pattern2_target = torch.full(prompt_pattern2.shape, self.IGNORE_ID).to(
device) # B, L
inputs_embeds_list.append(prompt_pattern1_embeds)
attention_mask_list.append(prompt_pattern1_mask)
target_list.append(prompt_pattern1_target)
streaming_error = False
if output_type == "speech2text_token_streaming":
rank = int(os.environ.get('RANK', 0))
utils_file.logging_limit_print(f'开始处理speech2text_token streaming 任务')
labels = batch['target'].to(device)
labels_lengths = batch['target_lengths'].to(device)
speech_token_labels = batch['speech_tokens'].to(device)
speech_tokens_length = batch['speech_tokens_length'].to(device)
labels_pad_mask = make_pad_mask(labels_lengths) # B, L
labels = labels.masked_fill(labels_pad_mask, 0)
speech_tokens_pad_mask = make_pad_mask(speech_tokens_length) # B, L
speech_token_labels = speech_token_labels.masked_fill(speech_tokens_pad_mask, 0)
speech_token_labels = speech_token_labels + self.llm_vocab_size
if rank == 0:
utils_file.logging_limit_print(f'labels.shape {labels.shape}')
utils_file.logging_limit_print(f'labels_lengths.shape {labels_lengths.shape}')
utils_file.logging_limit_print(f'labels[0] {labels[0]}')
utils_file.logging_limit_print(f'------------------------')
utils_file.logging_limit_print(f'speech_token_labels.shape {speech_token_labels.shape}')
utils_file.logging_limit_print(f'speech_tokens_length.shape {speech_tokens_length.shape}')
utils_file.logging_limit_print(f'speech_token_labels[0] {speech_token_labels[0]}')
utils_file.logging_limit_print(f'==========================')
streaming_concat_ids, streaming_concat_lens = make_streaming_mode_from_s2s(labels, labels_lengths,
speech_token_labels,
speech_tokens_length)
if rank == 0:
utils_file.logging_limit_print(f'streaming_concat_ids.shape {streaming_concat_ids.shape}')
utils_file.logging_limit_print(f'streaming_concat_lens.shape {streaming_concat_lens.shape}')
utils_file.logging_limit_print(f'streaming_concat_lens {streaming_concat_lens[0]}')
utils_file.logging_limit_print(f'xxx streaming_concat_ids[0] : {streaming_concat_ids[0]}')
utils_file.logging_limit_print(f'------------------------')
streaming_concat_embeddings = do_embedding_for_two_embeds(streaming_concat_ids, self.llm_vocab_size, self.embed_tokens,
self.speech_token_emded)
streaming_concat_pad_mask = make_pad_mask(streaming_concat_lens)
streaming_concat_target = streaming_concat_ids.masked_fill(streaming_concat_pad_mask, self.IGNORE_ID)
streaming_concat_mask = ~streaming_concat_pad_mask
if rank == 0:
utils_file.logging_limit_print(f'streaming_concat_embeddings.shape {streaming_concat_embeddings.shape}')
utils_file.logging_limit_print(f'streaming_concat_mask shape {streaming_concat_mask.shape}')
utils_file.logging_limit_print(f'------------------------')
# if prompt_embeds is not None: # 对于s2s 对话任务,不再使用user prompt 输入
# inputs_embeds_list.append(prompt_embeds)
# attention_mask_list.append(prompt_mask)
# target_list.append(prompt_target)
# ===================history===================================
history_batch = batch.get('history', [])
history_embedding, history_lens = self._get_embedding_for_history(history_batch, device)
if history_embedding is not None:
utils_file.logging_info(f'OSUM-EChat: 进行历史信息的embedding')
history_pad_mask = make_pad_mask(history_lens) # B, L
history_target = torch.full(history_pad_mask.shape, self.IGNORE_ID).to(device) # B, L
history_mask = ~history_pad_mask
inputs_embeds_list.append(history_embedding)
attention_mask_list.append(history_mask)
target_list.append(history_target)
utils_file.logging_limit_print(f'xxx history embeding shape {history_embedding.shape}')
utils_file.logging_limit_print(f'xxx history mask shape {history_mask.shape}')
utils_file.logging_limit_print(f'xxx history target shape {history_target.shape}')
else:
utils_file.logging_limit_print(f'history is None')
# ==========================history end ===================
inputs_embeds_list.extend(
[ speech_embeds, prompt_pattern2_embeds, streaming_concat_embeddings])
attention_mask_list.extend([speech_masks, prompt_pattern2_mask, streaming_concat_mask])
target_list.extend([speech_target, prompt_pattern2_target, streaming_concat_target])
elif output_type == "text2token_streaming":
rank = int(os.environ.get('RANK', 0))
utils_file.logging_limit_print(f'开始tts streaming 任务')
labels = batch['target'].to(device)
labels_lengths = batch['target_lengths'].to(device)
speech_token_labels = batch['speech_tokens'].to(device)
speech_tokens_length = batch['speech_tokens_length'].to(device)
labels_pad_mask = make_pad_mask(labels_lengths) # B, L
labels = labels.masked_fill(labels_pad_mask, 0)
speech_tokens_pad_mask = make_pad_mask(speech_tokens_length) # B, L
speech_token_labels = speech_token_labels.masked_fill(speech_tokens_pad_mask, 0)
speech_token_labels = speech_token_labels + self.llm_vocab_size
streaming_concat_ids, streaming_concat_lens = make_streaming_mode_from_s2s(labels, labels_lengths,
speech_token_labels,
speech_tokens_length)
streaming_concat_embeddings = do_embedding_for_two_embeds(streaming_concat_ids, self.llm_vocab_size,
self.embed_tokens,
self.speech_token_emded)
streaming_concat_pad_mask = make_pad_mask(streaming_concat_lens)
streaming_concat_target = streaming_concat_ids.masked_fill(streaming_concat_pad_mask, self.IGNORE_ID)
streaming_concat_mask = ~streaming_concat_pad_mask
# if prompt_embeds is not None: # 对于tts 对话任务,不再使用user prompt 输入
# inputs_embeds_list.append(prompt_embeds)
# attention_mask_list.append(prompt_mask)
# target_list.append(prompt_target)
inputs_embeds_list.extend(
[ speech_embeds, prompt_pattern2_embeds, streaming_concat_embeddings])
attention_mask_list.extend([speech_masks, prompt_pattern2_mask, streaming_concat_mask])
target_list.extend([speech_target, prompt_pattern2_target, streaming_concat_target])
elif output_type == 'speech2text_token' or output_type == "speech2text_token_think" or output_type == "speech2text_token_history":
utils_file.logging_limit_print(f'xxx 开始处理speech2text_token任务')
labels = batch['target'].to(device)
labels_lengths = batch['target_lengths'].to(device)
speech_token_labels = batch['speech_tokens'].to(device)
speech_tokens_length = batch['speech_tokens_length'].to(device)
labels_embeds, labels_target, labels_mask = self.get_label_embedding(labels, labels_lengths)
speech_token_labels_embeds, speech_token_labels_target, speech_token_labels_mask = self.get_speech_token_label_embedding(
speech_token_labels, speech_tokens_length)
# if prompt_embeds is not None: # 对于s2s 对话任务,不再使用user prompt 输入
# inputs_embeds_list.append(prompt_embeds)
# attention_mask_list.append(prompt_mask)
# target_list.append(prompt_target)
# ===================history===================================
history_batch = batch.get('history', [])
history_embedding, history_lens = self._get_embedding_for_history(history_batch, device)
if history_embedding is not None:
utils_file.logging_info(f'OSUM-EChat: 进行历史信息的embedding')
history_pad_mask = make_pad_mask(history_lens) # B, L
history_target = torch.full(history_pad_mask.shape, self.IGNORE_ID).to(device) # B, L
history_mask = ~history_pad_mask
inputs_embeds_list.append(history_embedding)
attention_mask_list.append(history_mask)
target_list.append(history_target)
utils_file.logging_limit_print(f'xxx history embeding shape {history_embedding.shape}')
utils_file.logging_limit_print(f'xxx history mask shape {history_mask.shape}')
utils_file.logging_limit_print(f'xxx history target shape {history_target.shape}')
else:
utils_file.logging_limit_print(f'history is None')
# ==========================history end ===================
inputs_embeds_list.extend(
[ speech_embeds, prompt_pattern2_embeds, labels_embeds, speech_token_labels_embeds])
attention_mask_list.extend([speech_masks, prompt_pattern2_mask, labels_mask, speech_token_labels_mask])
target_list.extend([speech_target, prompt_pattern2_target, labels_target, speech_token_labels_target])
elif output_type == "text2token":
speech_token_labels = batch['speech_tokens'].to(device)
speech_tokens_length = batch['speech_tokens_length'].to(device)
speech_token_labels_embeds, speech_token_labels_target, speech_token_labels_mask = self.get_speech_token_label_embedding(
speech_token_labels, speech_tokens_length)
# if prompt_embeds is not None: # 对于tts 对话任务,不再使用user prompt 输入
# inputs_embeds_list.append(prompt_embeds)
# attention_mask_list.append(prompt_mask)
# target_list.append(prompt_target)
inputs_embeds_list.extend([ speech_embeds, prompt_pattern2_embeds, speech_token_labels_embeds])
attention_mask_list.extend([speech_masks, prompt_pattern2_mask, speech_token_labels_mask])
target_list.extend([speech_target, prompt_pattern2_target, speech_token_labels_target])
elif output_type == "text" or output_type == 's2t_chat' or output_type == "s2t_chat_fake" or output_type == "s2t_chat_think":
labels = batch['target'].to(device)
labels_lengths = batch['target_lengths'].to(device)
labels_embeds, labels_target, labels_mask = self.get_label_embedding(labels, labels_lengths)
if prompt_embeds is not None and output_type == 'text': # 对于s2t_chat 对话任务,不再使用user prompt 输入
inputs_embeds_list.append(prompt_embeds)
attention_mask_list.append(prompt_mask)
target_list.append(prompt_target)
elif output_type != 's2t_chat' or output_type != "s2t_chat_fake" or output_type != "s2t_chat_think":
utils_file.logging_limit_print(
f'prompt is None,task: {batch["task"]}, prompt_embeds:{prompt_embeds}, prompt_mask:{prompt_mask}')
inputs_embeds_list.extend([ speech_embeds, prompt_pattern2_embeds, labels_embeds])
attention_mask_list.extend([speech_masks, prompt_pattern2_mask, labels_mask])
target_list.extend([speech_target, prompt_pattern2_target, labels_target])
elif output_type == "text2text":
labels = batch['target'].to(device)
labels_lengths = batch['target_lengths'].to(device)
labels_embeds, labels_target, labels_mask = self.get_label_embedding(labels, labels_lengths)
if prompt_embeds is not None:
inputs_embeds_list.append(prompt_embeds)
attention_mask_list.append(prompt_mask)
target_list.append(prompt_target)
else:
utils_file.logging_limit_print(
f'prompt is None,task: {batch["task"]}, prompt_embeds:{prompt_embeds}, prompt_mask:{prompt_mask}')
inputs_embeds_list.extend([ prompt_pattern2_embeds, labels_embeds])
attention_mask_list.extend([ prompt_pattern2_mask, labels_mask])
target_list.extend([ prompt_pattern2_target, labels_target])
else:
raise NotImplementedError(f'output_type {output_type} not support')
inputs_embeds = torch.cat(inputs_embeds_list, dim=1)
# utils_file.logging_limit_print(f'xxx final inputs_embeds shape {inputs_embeds.shape}')
attention_mask = torch.cat(attention_mask_list, dim=1)
# utils_file.logging_limit_print(f'xxx final attention_mask shape {attention_mask.shape}')
# utils_file.logging_limit_print(f'xxx final attention_mask 0 {attention_mask[0]}')
target = torch.cat(target_list, dim=1)
# utils_file.logging_limit_print(f'xxx final target shape {target.shape}')
# utils_file.logging_limit_print(f'xxx final target 0 {target[0]}')
# utils_file.logging_limit_print(f'OSUM-EChat output_type: {output_type}')
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
# utils_file.logging_limit_print(f'xxx final position_ids shape {position_ids.shape}')
# utils_file.logging_limit_print(f'xxx final position_ids 0 {position_ids[0]}')
if output_type == 'text' or output_type == 's2t_chat' or output_type == "s2t_chat_fake" or output_type == "s2t_chat_think" or output_type == "text2text":
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
labels=target,
attention_mask=attention_mask,
position_ids=position_ids.to(inputs_embeds.device)
)
loss = outputs['loss']
return {"loss": loss,"output_type": output_type}
else:
utils_file.logging_limit_print(f'进行llama_model的 diy forward')
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
# labels=target,
attention_mask=attention_mask,
position_ids=position_ids.to(inputs_embeds.device)
)
hidden_states = outputs['hidden_states'][-1]
logits = self.lm_head(hidden_states)
logits2 = self.speech_head(hidden_states) # speech_head
combined_logits = torch.cat([logits, logits2], dim=-1)
# combined_logits = self.new_lm_head(hidden_states)
shift_logits = combined_logits[..., :-1, :].contiguous()
shift_target = target[..., 1:].contiguous()
# utils_file.logging_limit_print(
# f'xxx shift_logits shape: {shift_logits.shape}, shift_target shape: {shift_target.shape}')
# utils_file.logging_limit_print(f'xxx shift_target 0 {shift_target[0]}')
shift_logits = shift_logits.view(-1, combined_logits.shape[-1]) # 注意这里维度的调整,根据logits2的维度相应改变
shift_target = shift_target.view(-1)
shift_target = shift_target.to(shift_logits.device)
loss = self.loss_fct(shift_logits, shift_target)
loss.requires_grad_(True)
return {"loss": loss,"output_type": output_type}
def generate_s2s_streaming(
self,
wavs,
wavs_len,
prompt,
):
self.llama_model.eval()
speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
speech_embeds, speech_masks, _ = self._add_bos_eos(0 +self.speech_token_num, 1+self.speech_token_num,
speech_embeds, speech_masks, None)
device = speech_embeds.device
prompt = self.tokenizer([prompt], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_embeds = self.embed_tokens(prompt)
qwen_instruct_prompt_pattern_1_s_input_chat = "<|im_start|>system\nYou are OSUM-chat, a dialogue assistant created by . You understand both the meaning and paralinguistic cues in users' speech, and respond appropriately with text or voice.<|im_end|>\n<|im_start|>user\n"
qwen_instruct_prompt_pattern_1_t2t_chat = "<|im_start|>system\nYou are OSUM-chat, a dialogue assistant created by . You understand user input in text and respond with accurate and helpful text replies.<|im_end|>\n<|im_start|>user\n"
qwen_instruct_prompt_pattern_1_old = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
qwen_instruct_prompt_pattern_1 =qwen_instruct_prompt_pattern_1_old
prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1], return_tensors="pt"
)['input_ids'].to(device)
prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)
prompt_pattern1_lens = torch.tensor([len(i) for i in prompt_pattern1]).to(device)
prompt_pattern1_mask = ~make_pad_mask(prompt_pattern1_lens)
prompt_pattern1_target = torch.full(prompt_pattern1.shape, self.IGNORE_ID).to(
device) # B, L
qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] , return_tensors="pt"
)['input_ids'].to(device)
start_id = prompt_pattern2[0][-1]
new_prompt_pattern2 = prompt_pattern2[:,:-1]
prompt_pattern2_embeds = self.embed_tokens(new_prompt_pattern2)
prompt_pattern2_lens = torch.tensor([len(i) for i in new_prompt_pattern2],dtype=torch.long).to(device)
prompt_pattern2_mask = ~make_pad_mask(prompt_pattern2_lens)
prompt_pattern2_target = torch.full(new_prompt_pattern2.shape, self.IGNORE_ID).to(
device) # B, L
embeds = torch.cat([prompt_pattern1_embeds,prompt_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16:
# utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
embeds = embeds.to(torch.float16)
max_len = 350
hyps = [start_id]
print(f'start_id: {start_id}')
llm_out = self.llama_model(
inputs_embeds=embeds,
past_key_values=None,
output_hidden_states=True
)
batch_size = 1
top_k = 10
top_p = 0.9
temperature = 1
cache = llm_out.past_key_values
token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
temperatures_tensor = None if not temperature else torch.FloatTensor(
[temperature] * batch_size).to(device)
inferring_txt = True
txt_finished = False
speeech_finished = False
hyps_text = ""
speech_eos_num = 0
txt_list = []
token_list = []
i_num = 0
for i in range(max_len):
if inferring_txt and not txt_finished:
for i_txt in range(6):
i_num += 1
if i_num> 300:
break
llm_out = self.llama_model(
inputs_embeds=token_emb,
past_key_values=cache,
output_hidden_states=True
)
cache = llm_out.past_key_values
hidden_states = llm_out.hidden_states[-1]
token_logits = self.lm_head(hidden_states)
next_token_ids = self._sampler(
token_logits,
temperatures_tensor,
top_ps_tensor,
top_ks_tensor,
)
# next_token_ids = torch.argmax(token_logits, dim=-1)
print(i_num, next_token_ids, f'txt')
hyps.append(next_token_ids.item())
token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
if next_token_ids == self.eos_token_id:
txt_finished = True
hyps_text = self.tokenizer.decode(txt_list, skip_special_tokens=True, add_special_tokens=False)
print("hyps_text:", hyps_text)
print("text is over")
break
txt_list.append(next_token_ids.item())
hyps_text = self.tokenizer.decode(txt_list, skip_special_tokens=True, add_special_tokens=False)
print("hyps_text:", hyps_text)
inferring_txt = False
elif not speeech_finished:
for i_speech in range(18):
i_num += 1
if i_num> 300:
break
llm_out = self.llama_model(
inputs_embeds=token_emb,
past_key_values=cache,
output_hidden_states=True
)
cache = llm_out.past_key_values
hidden_states = llm_out.hidden_states[-1]
token_logits = self.speech_head(hidden_states)
next_token_ids = self._sampler(
token_logits,
temperatures_tensor,
top_ps_tensor,
top_ks_tensor,
)
# next_token_ids = torch.argmax(token_logits, dim=-1)
hyps.append(next_token_ids.item())
print(i_num, next_token_ids, f'speech')
token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
if next_token_ids == 4096:
speech_eos_num += 1
print(f'遇到 4096')
if speech_eos_num >= 2:
print("speech is over")
speeech_finished = True
break
token_list.append(next_token_ids.item())
inferring_txt = True
if speeech_finished:
break
if i_num > 300:
break
return [hyps_text + "|" + str(token_list)]
def generate(
self,
wavs,
wavs_len,
prompt,
**kwargs
):
self.llama_model.eval()
self.set_task_type("ASR")
self.do_add_speech_embed_head()
speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, 1 + self.speech_token_num,
speech_embeds, speech_masks, None)
prompt = self.tokenizer([prompt], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_embeds = self.embed_tokens(prompt)
qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, an audio understanding. You can transcribe speech accurately and anaosum_echat2e paralinguistic cues to provide precise text responses.<|im_end|>\n<|im_start|>user\n"
prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)
qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)
embeds = torch.cat([prompt_pattern1_embeds, prompt_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16 or self.embed_tokens.weight.dtype == torch.bfloat16:
embeds = embeds.to(torch.bfloat16)
atts = atts.to(torch.bfloat16)
outputs = self.llama_model.generate(
inputs_embeds=embeds,
max_new_tokens=self.max_length,
cache_implementation="static",
# num_beams=self.num_beams,
do_sample=self.do_sample,
min_length=self.min_length,
top_p=self.top_p,
top_k=self.top_k,
repetition_penalty=self.repetition_penalty,
length_penalty=self.length_penalty,
temperature=self.temperature,
# attention_mask=atts,
eos_token_id=151645,
pad_token_id=-100,
stopping_criteria=self.max_token_criteria_list,
do_compile=True,
)
output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True)
return output_text
def generate4chat(
self,
wavs,
wavs_len,
prompt=" ",
do_sample=True,
top_k=2,
top_p=1,
temperature=0.4,
**kwargs
):
print(f'do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, temperature: {temperature}')
self.llama_model.eval()
self.set_task_type("ASR")
self.do_add_speech_embed_head()
# self.do_merge_embed_head()
speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, 1 + self.speech_token_num,
speech_embeds, speech_masks, None)
prompt = self.tokenizer([prompt], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_embeds = self.embed_tokens(prompt)
# qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
# # sft
qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a speech-to-text dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond exclusively with appropriate text.<|im_end|>\n<|im_start|>user\n"
prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)
qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)
embeds = torch.cat([prompt_pattern1_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16 or self.embed_tokens.weight.dtype == torch.bfloat16:
# utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
# embeds = embeds.to(torch.float16)
embeds = embeds.to(torch.bfloat16)
atts = atts.to(torch.bfloat16)
outputs = self.llama_model.generate(
inputs_embeds=embeds,
max_new_tokens=self.max_length,
cache_implementation="static",
# num_beams=1,
do_sample=do_sample,
min_length=self.min_length,
top_p=top_p,
top_k=top_k,
repetition_penalty=self.repetition_penalty,
length_penalty=1,
temperature=temperature,
# attention_mask=atts,
eos_token_id=151645,
pad_token_id=-100,
do_compile=True,
stopping_criteria=self.max_token_criteria_list,
)
output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True)
return output_text
def generate4chat_think(
self,
wavs,
wavs_len,
prompt=" ",
do_sample=True,
top_k=2,
top_p=1,
temperature=0.4,
**kwargs
):
print(f'do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, temperature: {temperature}')
self.llama_model.eval()
self.set_task_type("ASR")
self.do_add_speech_embed_head()
# self.do_merge_embed_head()
speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, 1 + self.speech_token_num,
speech_embeds, speech_masks, None)
prompt = self.tokenizer([prompt], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_embeds = self.embed_tokens(prompt)
# qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
# # sft
qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a thinking-enabled speech-to-text dialogue assistant by ASLP Lab. You not only comprehend the semantic meaning and paralinguistic cues in speech but also engage in deliberate reasoning to process such information. Based on this thinking process, you then respond exclusively with appropriate text.<|im_end|>\n"
prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)
qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)
embeds = torch.cat([prompt_pattern1_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16 or self.embed_tokens.weight.dtype == torch.bfloat16:
# utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
# embeds = embeds.to(torch.float16)
embeds = embeds.to(torch.bfloat16)
atts = atts.to(torch.bfloat16)
outputs = self.llama_model.generate(
inputs_embeds=embeds,
max_new_tokens=self.max_length,
cache_implementation="static",
# num_beams=1,
do_sample=do_sample,
min_length=self.min_length,
top_p=top_p,
top_k=top_k,
repetition_penalty=self.repetition_penalty,
length_penalty=1,
temperature=temperature,
# attention_mask=atts,
eos_token_id=151645,
pad_token_id=-100,
do_compile=True,
stopping_criteria=self.max_token_criteria_list,
)
output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True)
return output_text
def generate_tts(
self,
device,
text,
):
# =====================准备input embedding=====================
self.llama_model.eval()
# 得到模型所在的device
# device = self.llama_model.device
self.set_task_type("TTS")
self.do_add_speech_embed_head()
# labels_lengths = torch.tensor([len(text[0])], dtype=torch.int64, device=device)
# labels = text[:,:]
labels = self.tokenizer(
[text],
return_tensors="pt",
add_special_tokens=False
).to(
self.embed_tokens.weight.device).input_ids # (1, L)
labels = labels.to(device)
labels_lengths = torch.tensor([len(labels[0])], dtype=torch.int64, device=device)
print(f'label_lengths:{labels_lengths}')
print(f'labels:{labels}')
labels_pad_mask = make_pad_mask(labels_lengths) # B, L
labels = labels.masked_fill(labels_pad_mask, 0)
speech_embeds = self.embed_tokens(labels) # B, L, D
speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to(
speech_embeds.device)
speech_masks = ~labels_pad_mask
# speech_embeds, speech_masks, speech_target = self._add_bos_eos(0 + self.speech_token_num,
# 1 + self.speech_token_num,
# speech_embeds, speech_masks, speech_target)
qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural and fluent speech from text input.<|im_end|>\n<|im_start|>user\n"
prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(speech_embeds), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)
qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(speech_embeds), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)
hyps = [self.speech_token_num - 1]
speech_begin_token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
embeds = torch.cat([prompt_pattern1_embeds,
speech_embeds,
prompt_pattern2_embeds,
speech_begin_token_emb], dim=1).to(torch.bfloat16)
# 指定top_k top_p temperature stop
# max_len = 250
top_k = 15 # 5
top_p = 0.8 # 0.9
temperature = 1.2 # 1
print(f"tts eos id = {self.speech_token_num - 1}")
llm_out = self.llama_model.generate(inputs_embeds=embeds,
max_new_tokens=self.max_length,
eos_token_id=self.speech_token_num - 1,
cache_implementation="static",
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stopping_criteria=StoppingCriteriaList([MaxTokenStopper(2000)]),
do_compile=True,
repetition_penalty=1.0,
)
return llm_out
def generate_tts_streaming(
self,
device,
prompt,
text,
):
self.llama_model.eval()
# labels_lengths = torch.tensor([len(text[0])], dtype=torch.int64, device=device)
# labels = text[:,:]
labels = self.tokenizer(
[text],
return_tensors="pt",
add_special_tokens=False
).to(
self.embed_tokens.weight.device).input_ids # (1, L)
labels = labels.to(device)
labels_lengths = torch.tensor([len(labels[0])], dtype=torch.int64, device=device)
print(f'label_lengths:{labels_lengths}')
print(f'labels:{labels}')
labels_pad_mask = make_pad_mask(labels_lengths) # B, L
labels = labels.masked_fill(labels_pad_mask, 0)
speech_embeds = self.embed_tokens(labels) # B, L, D
speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to(
speech_embeds.device)
speech_masks = ~labels_pad_mask
speech_embeds, speech_masks, speech_target = self._add_bos_eos(0 + self.speech_token_num,
1 + self.speech_token_num,
speech_embeds, speech_masks, speech_target)
qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(speech_embeds), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)
qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(speech_embeds), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)
prompt = self.tokenizer([prompt], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_embeds = self.embed_tokens(prompt)
# embeds = torch.cat([prompt_pattern1_embeds, prompt_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
# ----------------
start_id = prompt_pattern2[0][-1]
new_prompt_pattern2 = prompt_pattern2[:, :-1]
prompt_pattern2_embeds = self.embed_tokens(new_prompt_pattern2)
prompt_pattern2_lens = torch.tensor([len(i) for i in new_prompt_pattern2], dtype=torch.long).to(device)
prompt_pattern2_mask = ~make_pad_mask(prompt_pattern2_lens)
prompt_pattern2_target = torch.full(new_prompt_pattern2.shape, self.IGNORE_ID).to(
device) # B, L
embeds = torch.cat([prompt_pattern1_embeds, prompt_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16:
# utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
embeds = embeds.to(torch.float16)
max_len = 350
hyps = [start_id]
print(f'start_id: {start_id}')
llm_out = self.llama_model(
inputs_embeds=embeds,
past_key_values=None,
output_hidden_states=True
)
batch_size = 1
top_k = 10
top_p = 0.9
temperature = 1
cache = llm_out.past_key_values
token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
temperatures_tensor = None if not temperature else torch.FloatTensor(
[temperature] * batch_size).to(device)
inferring_txt = True
txt_finished = False
speeech_finished = False
hyps_text = ""
speech_eos_num = 0
txt_list = []
token_list = []
i_num = 0
for i in range(max_len):
if inferring_txt and not txt_finished:
for i_txt in range(6):
i_num += 1
if i_num > 300:
break
llm_out = self.llama_model(
inputs_embeds=token_emb,
past_key_values=cache,
output_hidden_states=True
)
cache = llm_out.past_key_values
hidden_states = llm_out.hidden_states[-1]
token_logits = self.lm_head(hidden_states)
next_token_ids = self._sampler(
token_logits,
temperatures_tensor,
top_ps_tensor,
top_ks_tensor,
)
# next_token_ids = torch.argmax(token_logits, dim=-1)
print(i_num, next_token_ids, f'txt')
hyps.append(next_token_ids.item())
token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
if next_token_ids == self.eos_token_id:
txt_finished = True
hyps_text = self.tokenizer.decode(txt_list, skip_special_tokens=True, add_special_tokens=False)
print("hyps_text:", hyps_text)
print("text is over")
break
txt_list.append(next_token_ids.item())
hyps_text = self.tokenizer.decode(txt_list, skip_special_tokens=True, add_special_tokens=False)
print("hyps_text:", hyps_text)
inferring_txt = False
elif not speeech_finished:
for i_speech in range(18):
i_num += 1
if i_num > 300:
break
llm_out = self.llama_model(
inputs_embeds=token_emb,
past_key_values=cache,
output_hidden_states=True
)
cache = llm_out.past_key_values
hidden_states = llm_out.hidden_states[-1]
token_logits = self.speech_head(hidden_states)
next_token_ids = self._sampler(
token_logits,
temperatures_tensor,
top_ps_tensor,
top_ks_tensor,
)
# next_token_ids = torch.argmax(token_logits, dim=-1)
hyps.append(next_token_ids.item())
print(i_num, next_token_ids, f'speech')
token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
if next_token_ids == 4096:
speech_eos_num += 1
print(f'遇到 4096')
if speech_eos_num >= 2:
print("speech is over")
speeech_finished = True
break
token_list.append(next_token_ids.item())
inferring_txt = True
if i_num > 300:
break
return [hyps_text + "|" + str(token_list)]
def generate_text2text(
self,
device,
text,
):
self.llama_model.eval()
# labels_lengths = torch.tensor([len(text[0])], dtype=torch.int64, device=device)
# labels = text[:,:]
labels = self.tokenizer(
[text],
return_tensors="pt",
add_special_tokens=False
).to(
self.embed_tokens.weight.device).input_ids # (1, L)
labels = labels.to(device)
labels_lengths = torch.tensor([len(labels[0])], dtype=torch.int64, device=device)
# print(f'label_lengths:{labels_lengths}')
# print(f'labels:{labels}')
labels_pad_mask = make_pad_mask(labels_lengths) # B, L
labels = labels.masked_fill(labels_pad_mask, 0)
speech_embeds = self.embed_tokens(labels) # B, L, D
speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to(
speech_embeds.device)
speech_masks = ~labels_pad_mask
# speech_embeds, speech_masks, speech_target = self._add_bos_eos(0 + self.speech_token_num,
# 1 + self.speech_token_num,
# speech_embeds, speech_masks, speech_target)
# qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
# # sft
qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a text-to-text dialogue assistant by ASLP Lab. You understand user input in text then respond exclusively with appropriate text.<|im_end|>\n<|im_start|>user\n"
prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(speech_embeds), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)
qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(speech_embeds), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)
embeds = torch.cat([prompt_pattern1_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16 or self.embed_tokens.weight.dtype == torch.bfloat16:
# utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
# embeds = embeds.to(torch.float16)
embeds = embeds.to(torch.bfloat16)
atts = atts.to(torch.bfloat16)
outputs = self.llama_model.generate(
inputs_embeds=embeds,
max_new_tokens=200,
num_beams=1,
do_sample=False,
min_length=self.min_length,
repetition_penalty=1.0,
length_penalty=1.0,
temperature=self.temperature,
attention_mask=atts,
eos_token_id=151645,
pad_token_id=-100,
do_compile=True,
cache_implementation="static",
)
output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True)
# output_text = [item.replace('<|endoftext|>', '') for item in output_text]
return output_text
def generate_s2s_no_stream_with_repetition_penalty(
self,
wavs,
wavs_len,
):
self.llama_model.eval()
self.set_task_type("S2S")
self.do_add_speech_embed_head()
# =====================准备input embedding=====================
speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, None,
speech_embeds, speech_masks, None)
device = speech_embeds.device
# qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech, as well as input text, and respond appropriately.<|im_end|>\n<|im_start|>user\n"
qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond with appropriate text and emotionally matching synthetic speech.<|im_end|>\n<|im_start|>user\n"
# qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)
qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)
hyps = [4098]
token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
embeds = torch.cat(
[prompt_pattern1_embeds, speech_embeds, token_emb, prompt_pattern2_embeds],
dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16:
# utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
embeds = embeds.to(torch.float16)
# max_len = 350
top_k = 10
top_p = 0.9
temperature = 1.2
invalid_eos = 10000000
self.osum_chat_logit_processor1.init_match_found()
llm_out = self.llama_model.generate(inputs_embeds=embeds,
max_new_tokens=self.max_length,
eos_token_id=invalid_eos,
cache_implementation="static",
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
logits_processor=self.s2s_repetition_penalty,
stopping_criteria=self.s2s_stop_criteria,
do_compile=True,
repetition_penalty=1.0,
)
text_eos_idx = (llm_out[0] == 151645).nonzero(as_tuple=True)[0][0].item()
text_res = llm_out[:, :text_eos_idx - 1]
speech_res = llm_out[:, text_eos_idx + 1:-1]
# print("llm_out", llm_out)
output_text = self.tokenizer.batch_decode(text_res, add_special_tokens=False, skip_special_tokens=True)
# print(f'output_text:{output_text}')
# print(f'speech_res:{speech_res}')
return (output_text, text_res, speech_res)
def generate_s2s_no_stream_think_with_repetition_penalty(
self,
wavs,
wavs_len,
):
self.llama_model.eval()
self.set_task_type("S2S")
self.do_add_speech_embed_head()
# =====================准备input embedding=====================
speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, None,
speech_embeds, speech_masks, None)
device = speech_embeds.device
# qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech, as well as input text, and respond appropriately.<|im_end|>\n<|im_start|>user\n"
# qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond with appropriate text and emotionally matching synthetic speech.<|im_end|>\n<|im_start|>user\n"
qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech. Before responding, first output your reasoning inside <think>...</think end>, analyzing the user’s words and vocal cues. Then generate a reply with appropriate text and emotionally matched synthetic speech.<|im_end|>\n<|im_start|>user\n"
# qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)
qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)
hyps = [4098]
token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
embeds = torch.cat(
[prompt_pattern1_embeds, speech_embeds, token_emb, prompt_pattern2_embeds],
dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16:
# utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
embeds = embeds.to(torch.float16)
# max_len = 350
top_k = 10
top_p = 0.9
temperature = 1.2
invalid_eos = 10000000
self.osum_chat_logit_processor1.init_match_found() # 非think不用匹配
llm_out = self.llama_model.generate(inputs_embeds=embeds,
max_new_tokens=self.max_length,
eos_token_id=invalid_eos,
cache_implementation="static",
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
logits_processor=self.s2s_repetition_penalty,
stopping_criteria=self.s2s_stop_criteria,
do_compile=True,
repetition_penalty=1.0,
)
text_eos_idx = (llm_out[0] == 151645).nonzero(as_tuple=True)[0][0].item()
text_res = llm_out[:, :text_eos_idx - 1]
speech_res = llm_out[:, text_eos_idx + 1:-1]
# print("llm_out", llm_out)
output_text = self.tokenizer.batch_decode(text_res, add_special_tokens=False, skip_special_tokens=True)
# print(f'output_text:{output_text}')
# print(f'speech_res:{speech_res}')
return (output_text, text_res, speech_res)
def _get_embedding_from_wav(self, wavs, wavs_len):
"""
return:
wav_embedding: (b, l, v)
wav_mask: (b, l), wav为有效值的位置为true
"""
encoder_out, encoder_mask = self.encoder(wavs, wavs_len)
speech_embeds, encoder_mask = self.down_sample_2(encoder_out, encoder_mask)
if self.speech_transformer is not None:
filled_wavs_len = encoder_mask.squeeze(1).sum(-1)
speech_embeds, encoder_mask = self.speech_transformer(speech_embeds, filled_wavs_len)
# if rank == 0:
# utils_file.logging_limit_print(
# f'out of link shape: {speech_embeds.shape},encoder的第一帧的前20个数字:\n {speech_embeds[0][0][:20]}')
# utils_file.logging_limit_print(
# 'get_embedding_from_wav(): speech_embeds.shape,by self.speech_transformer(speech_embeds, speech_lens):',
# speech_embeds.shape)
speech_embeds = self.speech_llama_proj(speech_embeds)
# if rank == 0:
# utils_file.logging_limit_print(
# f'out of speech_llama_proj shape: {speech_embeds.shape},encoder的第一帧的前20个数字:\n {speech_embeds[0][0][:20]}')
# utils_file.logging_limit_print(
# 'get_embedding_from_wav(): speech_embeds.shape,by self.speech_llama_proj(speech_embeds):',
# speech_embeds.shape)
return speech_embeds, encoder_mask.squeeze(1)
def _get_embedding_from_text(self, text):
"""
将字符串先量化,再转成词向量
Args:
text: str
Returns:
text_embeds: (1, L, D)
"""
text_id = self.tokenizer(
text,
return_tensors="pt",
add_special_tokens=False
).to(
self.embed_tokens.weight.device).input_ids
text_embeds = self.embed_tokens(text_id)
text_embeds_len = torch.tensor([text_embeds.size(1)], dtype=torch.long)
return text_embeds, text_embeds_len
def _add_bos_eos(self, bos, eos, inputs_embeds, attention_mask, target=None):
B = len(inputs_embeds)
bos_eos_target = torch.full([B, 1], self.IGNORE_ID).to(inputs_embeds.device) # B,1
bos_eos_mask = torch.full([B, 1], True).to(inputs_embeds.device) # B, 1
if bos is not None:
bos_embed = self.speech_token_emded(torch.full([B, 1],
bos).to(inputs_embeds.device)) # B, 1, D
inputs_embeds = torch.cat((bos_embed, inputs_embeds), 1) # B, (1+T), D
attention_mask = torch.cat((bos_eos_mask, attention_mask), 1) # B, (1+T)
if target is not None:
target = torch.cat((bos_eos_target, target), 1) # B, (1+T), D
if eos is not None:
eos_embed = self.speech_token_emded(torch.full([B, 1],
eos).to(inputs_embeds.device)) # B, 1, D
inputs_embeds = torch.cat((inputs_embeds, eos_embed), 1) # B, (1+T+1), D
attention_mask = torch.cat((attention_mask, bos_eos_mask), 1) # B, (1+T+1)
if target is not None:
target = torch.cat((target, bos_eos_target), 1) # B, (1+T+1), D
return inputs_embeds, attention_mask, target
def infer_sample_teach_force(
self,
wavs,
wavs_len,
prompt,
text,
speech_token,
):
labels_lengths = torch.tensor([len(text[0])], dtype=torch.int64, device=wavs.device)
labels = text[:, :]
labels_pad_mask = make_pad_mask(labels_lengths) # B, L
labels = labels.masked_fill(labels_pad_mask, 0)
speech_embeds = self.embed_tokens(labels) # B, L, D
speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to(
speech_embeds.device)
speech_masks = ~labels_pad_mask
speech_embeds, speech_masks, speech_target = self._add_bos_eos(0 +self.speech_token_num,
1 +self.speech_token_num,
speech_embeds, speech_masks, speech_target)
prompt = self.tokenizer([prompt], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_embeds = self.embed_tokens(prompt)
embeds = torch.cat([prompt_embeds, speech_embeds], dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16:
utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
embeds = embeds.to(torch.float16)
atts = atts.half()
device = wavs.device
inputs_embeds = embeds.to(device)
speech_token_list = speech_token[0].tolist()
speech_token_list_len = len(speech_token_list)
print(f'speech_token_list_len:{speech_token_list_len}')
max_len = 200
beam = 3
beam_size = beam
running_size = beam
output_token = []
hyps = [self.speech_token_num - 1]
scores = [1.0]
llm_out = self.llama_model(
inputs_embeds=embeds,
past_key_values=None,
output_hidden_states=True
)
batch_size = 1
top_k = 10
top_p = 0.9
temperature = 1.0
cache = llm_out.past_key_values
token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
repetition_penalty = 1.1
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
temperatures_tensor = None if not temperature else torch.FloatTensor(
[temperature] * batch_size).to(device)
for i in range(max_len):
llm_out = self.llama_model(
inputs_embeds=token_emb,
past_key_values=cache,
output_hidden_states=True
)
cache = llm_out.past_key_values
hidden_states = llm_out.hidden_states[-1]
token_logits = self.speech_head(hidden_states)
# probs = F.log_softmax(token_logits[:,-1], dim=-1)[0]
next_token_ids = self._sampler(
token_logits,
temperatures_tensor,
top_ps_tensor,
top_ks_tensor,
)
print(i, next_token_ids)
if next_token_ids == self.speech_token_num - 1:
print("break la!")
print("hyps:", hyps)
break
hyps.append(next_token_ids.item())
token_emb = self.speech_token_emded(torch.tensor(speech_token_list[i]).to(device)).unsqueeze(0)
res = []
for i in hyps[1:]:
# if i != self.speech_token_num-1:
res.append(i)
print(res)
return [res]
def _sampler(
self,
logits: torch.Tensor,
temperatures: Union[torch.Tensor, None],
top_ps: torch.Tensor,
top_ks: torch.Tensor,
) -> torch.Tensor:
"""
Sample from logits.
Args:
logits: (1,1,vocab_size)
temperatures:
top_ps:
top_ks:
Returns:
"""
print(f'logits:{logits.shape}')
assert logits.size(1) == 1
logits = logits.squeeze(1) # (batch_size, vocab_size)
if temperatures is None:
return torch.argmax(logits, dim=-1).squeeze(dim=-1)
# Apply temperature scaling.
logits.div_(temperatures.unsqueeze(dim=1))
# Calculate probabilities with softmax.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# Apply top-p, top-k.
probs_sum = torch.cumsum(probs_sort, dim=-1)
top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
probs_sort = torch.where(top_ps_mask, 0, probs_sort)
top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
probs_sort = torch.where(top_ks_mask, 0, probs_sort)
# Re-normalization.
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
probs = torch.gather(probs_sort,
dim=-1,
index=torch.argsort(probs_idx, dim=-1))
next_token_ids = torch.multinomial(probs, num_samples=1,
replacement=True).squeeze(dim=-1)
return next_token_ids
def infer_sample4speech2text_token_teacher_force(
self,
wavs,
wavs_len,
prompt,
speech_token=None,
answer_text=None,
):
self.llama_model.eval()
speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
speech_embeds, speech_masks, _ = self._add_bos_eos(0 +self.speech_token_num, 1+self.speech_token_num,
speech_embeds, speech_masks, None)
device = speech_embeds.device
prompt = self.tokenizer([prompt], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_embeds = self.embed_tokens(prompt)
text_token = self.tokenizer([answer_text], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
text_token_embeds = self.embed_tokens(text_token)
embeds = torch.cat([prompt_embeds, speech_embeds, text_token_embeds], dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16:
utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
embeds = embeds.to(torch.float16)
atts = atts.half()
inputs_embeds = embeds.to(speech_embeds.device)
max_len = 150
beam = 3
beam_size = beam
running_size = beam
output_token = []
hyps = [self.speech_token_num]
hyps_text = ""
scores = [1.0]
llm_out = self.llama_model(
inputs_embeds=embeds,
past_key_values=None,
output_hidden_states=True
)
# speech_token_list = speech_token[0]
# speech_token_list_len = len(speech_token_list)
if speech_token is not None:
print(f'speech_token_list_len:{len(speech_token[0])}')
print(f'speech_token:{speech_token[0]}')
batch_size = 1
top_k = 10
top_p = 0.9
temperature = 1.2
cache = llm_out.past_key_values
token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
repetition_penalty = 1.1
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
temperatures_tensor = None if not temperature else torch.FloatTensor(
[temperature] * batch_size).to(device)
is_speech_token = False
speech_eos_num = 0
for i in range(max_len):
llm_out = self.llama_model(
inputs_embeds=token_emb,
past_key_values=cache,
output_hidden_states=True
)
cache = llm_out.past_key_values
hidden_states = llm_out.hidden_states[-1]
token_logits = self.speech_head(hidden_states)
# probs = F.log_softmax(token_logits[:,-1], dim=-1)[0]
# if i ==2 or i == 80:
# torch.save(probs, f'probs_{i}.pt')
next_token_ids = self._sampler(
token_logits,
temperatures_tensor,
top_ps_tensor,
top_ks_tensor,
)
print(i, next_token_ids, f'is_speech_token:{is_speech_token}')
if next_token_ids == self.speech_token_num - 1:
print(f'遇到 4096')
break
hyps.append(next_token_ids.item())
# if 1+i >= len(speech_token[0]):
# break
# token_emb = self.speech_token_emded(torch.tensor([speech_token[0][i+1]]).to(device)).unsqueeze(0)
token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
res = []
for i in hyps[1:]:
# if i != self.speech_token_num-1:
res.append(i)
print(res)
return [answer_text + str(res[2:])]
def infer_sample4speech2text_token_teacher_force2(
self,
wavs,
wavs_len,
prompt,
speech_token=None,
answer_text=None,
):
self.llama_model.eval()
speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
speech_embeds, speech_masks, _ = self._add_bos_eos(0 +self.speech_token_num, 1+self.speech_token_num ,
speech_embeds, speech_masks, None)
device = speech_embeds.device
prompt = self.tokenizer([prompt], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_embeds = self.embed_tokens(prompt)
text_token = self.tokenizer([answer_text], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
# text_token_embeds = self.embed_tokens(text_token)
embeds = torch.cat([prompt_embeds, speech_embeds], dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16:
utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
embeds = embeds.to(torch.float16)
atts = atts.half()
inputs_embeds = embeds.to(speech_embeds.device)
max_len = 150
beam = 3
beam_size = beam
running_size = beam
output_token = []
hyps = [self.speech_token_num - 1]
hyps_text = ""
scores = [1.0]
llm_out = self.llama_model(
inputs_embeds=embeds,
past_key_values=None,
output_hidden_states=True
)
# speech_token_list = speech_token[0]
# speech_token_list_len = len(speech_token_list)
if speech_token is not None:
print(f'speech_token_list_len:{len(speech_token)}')
print(f'speech_token:{speech_token}')
batch_size = 1
top_k = 10
top_p = 0.9
temperature = 1.2
cache = llm_out.past_key_values
token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
repetition_penalty = 1.1
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
temperatures_tensor = None if not temperature else torch.FloatTensor(
[temperature] * batch_size).to(device)
is_speech_token = False
speech_eos_num = 0
token_num = len(speech_token)
for i in range(token_num):
llm_out = self.llama_model(
inputs_embeds=token_emb,
past_key_values=cache,
output_hidden_states=True
)
cache = llm_out.past_key_values
hidden_states = llm_out.hidden_states[-1]
token_logits = self.speech_head(hidden_states)
# probs = F.log_softmax(token_logits[:,-1], dim=-1)[0]
# if i ==2 or i == 80:
# torch.save(probs, f'probs_{i}.pt')
next_token_ids = self._sampler(
token_logits,
temperatures_tensor,
top_ps_tensor,
top_ks_tensor,
)
print(i, next_token_ids, f'is_speech_token:{is_speech_token}')
if next_token_ids == self.speech_token_num - 1:
print(f'遇到 4096')
break
hyps.append(next_token_ids.item())
# if 1+i >= len(speech_token[0]):
# break
# token_emb = self.speech_token_emded(torch.tensor([speech_token[0][i+1]]).to(device)).unsqueeze(0)
token_emb = self.embed_tokens(torch.tensor([speech_token[i]]).to(device)).unsqueeze(0)
res = []
for i in hyps[1:]:
# if i != self.speech_token_num-1:
res.append(i)
print(res)
return [hyps_text + str(res)]