import logging import os from typing import Dict, List, Optional, Union import torchaudio import torch from peft import LoraConfig, TaskType, get_peft_model from torch import nn from torch.nn import CrossEntropyLoss from torch.nn.utils.rnn import pad_sequence from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, StoppingCriteriaList from patches.cumstom_stop_criteria import InterruptStopper, S2SStopCriteria, MaxTokenStopper from patches.custom_speech_ngram_blocking import SpeechOnlyNGramBlockingLogitsProcessor, OSUM_chat_LogitsProcessor from wenet.osum_echat.utils4llmasr import make_streaming_mode_from_s2s, do_embedding_for_two_embeds from wenet.transformer.encoder import TransformerEncoder, TransformerEncoder2 from wenet.osum_echat.utils4llmasr import * from gxl_ai_utils.utils import utils_file from wenet.osum_echat.downsampler import get_downsampler, osum_echat2Conv1dSubsampling from wenet.transformer.swish import New_gelu4npu from wenet.utils.mask import make_pad_mask class LLMASR_Model(nn.Module): def __init__(self, encoder, encoder_output_dim, llm_path, lora=True, lora_alpha=32, lora_rank=8, lora_dropout=0.1, is_inference=False, downsample_rate=1, adapter_type='osum_echat2', speech_token_num=0, train_speech_out=False): """""" super().__init__() utils_file.logging_limit_print(f"instruct_version: LLMASR_Model init, is_inference={is_inference}, downsample_rate={downsample_rate}, adapter_type={adapter_type}, speech_token_num={speech_token_num}, train_speech_out={train_speech_out}") self.downsample_rate = downsample_rate self.encoder = encoder self.ln_speech = nn.LayerNorm(encoder_output_dim) # 连接层, 51.6M if adapter_type == 'osum_echat': self.speech_transformer = TransformerEncoder( input_size=encoder_output_dim, output_size=encoder_output_dim, attention_heads=4, linear_units=2560, num_blocks=4, dropout_rate=0.1, positional_dropout_rate=0.1, attention_dropout_rate=0.0, input_layer="linear", pos_enc_layer_type="abs_pos", normalize_before=True ) else: self.speech_transformer = None self.llama_model = AutoModelForCausalLM.from_pretrained( llm_path, # torch_dtype=torch.float32 if is_inference else torch.float16, torch_dtype=torch.bfloat16, trust_remote_code=True, output_hidden_states=True, ) self.s2s_stop_criteria = None self.max_token_criteria_list = None self.max_length = 4000 self.min_length = 1 self.num_beams = 4 self.do_sample = True self.top_p = 0.9 self.top_k = 5 self.repetition_penalty = 1.05 self.length_penalty = 1.0 self.temperature = 1.0 self.IGNORE_ID = -100 # lora self.lora = lora if lora: utils_file.logging_limit_print("OSUM-EChat: 使用lora了") # target_modules = ['w_pack', 'o_proj', 'gate_proj', 'down_proj'] target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'down_proj'] self.peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=is_inference, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, target_modules=target_modules, ) self.llama_model = get_peft_model(self.llama_model, self.peft_config) # tokenizer self.tokenizer = AutoTokenizer.from_pretrained( llm_path, use_fast=False, trust_remote_code=True) """ 设置分词器的pad_token和padding的方向。 """ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.tokenizer.padding_side = "right" self.eos_token_id = self.tokenizer.eos_token_id if hasattr(self.llama_model.config, 'hidden_size'): utils_file.logging_limit_print( f"self.llama_model.config.hidden_size: {self.llama_model.config.hidden_size}") if adapter_type == 'osum_echat2': self.down_sample_2 = osum_echat2Conv1dSubsampling(encoder_output_dim, self.llama_model.config.hidden_size) elif adapter_type == 'osum_echat': self.down_sample_2 = get_downsampler(downsample_rate, encoder_output_dim) self.speech_llama_proj = nn.Linear( encoder_output_dim, self.llama_model.config.hidden_size) else: raise NotImplementedError("self.llama_model.config.hidden_size not exist") self.embed_tokens = self.llama_model.model.model.embed_tokens if self.lora else self.llama_model.model.embed_tokens self.lm_head = self.llama_model.model.lm_head if self.lora else self.llama_model.lm_head self.llm_vocab_size = self.lm_head.weight.shape[0] self.speech_token_num = speech_token_num # init speech token module if speech_token_num > 0: utils_file.logging_info(f'OSUM-EChat: 进行语音token生成任务, speech_token_num: {speech_token_num}') self.speech_token_emded = torch.nn.Embedding(speech_token_num + 2, self.llama_model.config.hidden_size) self.speech_head = torch.nn.Linear(self.llama_model.config.hidden_size, speech_token_num) else: # 不做任何处理 self.speech_head = nn.Identity() self.speech_token_emded = nn.Identity() self.speech_model = nn.Identity() self.train_speech_out = train_speech_out utils_file.logging_info(f'OSUM-EChat: 是否进行语音输出训练:{self.train_speech_out}') self.loss_fct = CrossEntropyLoss(reduction='mean') self.unk_token_id = 7672 # &&对应的id self.add_embed_head = True self.init_custom_speech_repetition_penalty() self.init_custom_stop_criteria() def set_task_type(self, task_type: str): """设置任务类型,用于设置生成的初始类型 Args: task_type (str): 任务类型,从("ASR", "TTS", "S2S")选择 """ assert task_type in ("ASR", "TTS", "S2S") if task_type == "ASR": self.llama_model.text_phase = True elif task_type == "TTS": self.llama_model.text_phase = False elif task_type == "S2S": self.llama_model.text_phase = True def do_add_speech_embed_head(self): if self.add_embed_head: self.llama_model.speech_token_emded = self.speech_token_emded.to(torch.bfloat16) self.llama_model.speech_head = self.speech_head.to(torch.bfloat16) # self.llama_model.speech_token_emded = self.speech_token_emded.to(torch.bfloat16) # self.llama_model.speech_head = self.speech_head.to(torch.bfloat16) # 带lora的时候用 self.add_embed_head = False def init_custom_speech_repetition_penalty(self): """ """ self.s2s_repetition_penalty = LogitsProcessorList() # self.speech_repetition_penalty = SpeechOnlyRepetitionPenaltyLogitsProcessor(speech_token_num=4097, penalty=1.5) self.speech_repetition_penalty = SpeechOnlyNGramBlockingLogitsProcessor(speech_token_num=4097, repeat_times=5, special_token_repeat_times_dict={ 1446: 10}) self.osum_chat_logit_processor1 = OSUM_chat_LogitsProcessor([99119, 1808, 7863], [102185, 17714, 31252]) self.s2s_repetition_penalty.append(self.osum_chat_logit_processor1) self.s2s_repetition_penalty.append(self.speech_repetition_penalty) self.llama_model.speech_repetition_penalty = self.speech_repetition_penalty def init_custom_stop_criteria(self): """ 创建需要的stop criteria 1. 对于t2t任务,遇到text_eos停止 2. 对于t2s任务,遇到speech_eos停止 3. 对于s2s任务,遇到speech_eos停止 同时要取消原本的停止条件 if generation_config._eos_token_tensor is not None: 取消 generation_config._eos_token_tensor 的停止,尝试直接给一个大于vocb_size的eos_token """ self.interrupt = InterruptStopper() self.s2s_stop_criteria = StoppingCriteriaList() self.s2s_stop_criteria.append(S2SStopCriteria(text_eos_id=151645, speech_eos_id=self.speech_token_num - 1)) self.s2s_stop_criteria.append(MaxTokenStopper(2000)) self.s2s_stop_criteria.append(self.interrupt) def get_label_embedding(self, labels, labels_lengths, unk_id=7672): """""" labels_pad_mask = make_pad_mask(labels_lengths) # B, L unk_mask = (labels == unk_id) # B, L labels_pad_mask = labels_pad_mask | unk_mask # labels = labels.masked_fill(labels_pad_mask, 0) labels_embeds = self.embed_tokens(labels) labels_target = labels.masked_fill(labels_pad_mask, self.IGNORE_ID) # B, L labels_mask = ~labels_pad_mask return labels_embeds, labels_target, labels_mask def get_speech_token_label_embedding(self, speech_token_labels, speech_tokens_length): """""" speech_tokens_pad_mask = make_pad_mask(speech_tokens_length) # B, L speech_token_labels = speech_token_labels.masked_fill(speech_tokens_pad_mask, 0) speech_token_labels_embeds = self.speech_token_emded(speech_token_labels) # utils_file.logging_limit_print(f'进行speech_token_labels修改,修改前 speech_token_labels', # speech_token_labels.shape, speech_token_labels[0][-1], speech_token_labels[0][0]) speech_token_labels = speech_token_labels + self.llm_vocab_size # utils_file.logging_limit_print(f'进行speech_token_labels修改,修改后 speech_token_labels', # speech_token_labels.shape, speech_token_labels[0][-1], speech_token_labels[0][0]) speech_token_labels_target = speech_token_labels.masked_fill(speech_tokens_pad_mask, self.IGNORE_ID) # B, L speech_token_labels_mask = ~speech_tokens_pad_mask return speech_token_labels_embeds, speech_token_labels_target, speech_token_labels_mask def _get_embedding_for_history(self, history_batch, device): """ prompt_patern1,prompt,history, wav, prompt_patern2,txt,answer_wav, historcy_batch的内容是: [ big_embed, [ {wav: feat (L,D:80),->过encoder+link ,得到(L1, 2048) txt: labels (L,), ->labels_embeds = self.embed_tokens(labels) -> (L2, 2048), 带txt eos },->(L1+L2, 2048) {wav: feat (L,D), txt: labels (L,), },->(L3+L4, 2048) ]-> (L1+L2+L3+L4, 2048),len:L1+L2+L3+L4 [],->(1, 2048) ,len:0 [ {wav: feat (L,D), txt: labels (L,), }, ],-> (L1+L2, 2048),len:L1+L2 ] 将每一条的历史信息的embedding拼接起来,如果有空历史信息,则用0pad, 最后得到pad后的history_embedding(B, L, D), history_lens(B) Args: history_batch: device: Returns: history_embedding: B, L, D history_lens: B """ assistant_start ="<|im_end|>\n<|im_start|>assistant\n" assistant_start_id = self.tokenizer([assistant_start], return_tensors="pt" )['input_ids'].to(device) assistant_start_embedding = self.embed_tokens(assistant_start_id.squeeze(0)) assistant_end ="<|im_end|>\n" assistant_end_id = self.tokenizer([assistant_end], return_tensors="pt" )['input_ids'].to(device) assistant_end_embedding = self.embed_tokens(assistant_end_id.squeeze(0)) user_start = "<|im_start|>user\n" user_start_id = self.tokenizer([user_start], return_tensors="pt" )['input_ids'].to(device) user_start_embedding = self.embed_tokens(user_start_id.squeeze(0)) user_end ="<|im_end|>\n" user_end_id = self.tokenizer([user_end], return_tensors="pt" )['input_ids'].to(device) user_end_embedding = self.embed_tokens(user_end_id.squeeze(0)) batch_embeddings = [] history_lens = [] # 判断是否所有样本都没有历史 if all(len(history) == 0 for history in history_batch): return None, None for history in history_batch: history_embeds = [] for item in history: wav_feat = item['wav'].to(device) # shape: (L, D) wav_feat = wav_feat.unsqueeze(0).to(device) # shape: (1, L, D) wav_embed, wav_mask = self._get_embedding_from_wav(wav_feat, torch.tensor([wav_feat.size(1)], device=device, dtype=torch.long)) wav_embed = wav_embed.squeeze(0) # shape: (L, D) if len(history_embeds) != 0: history_embeds.append(user_start_embedding) # 第一个user start 不要 history_embeds.append(wav_embed) history_embeds.append(user_end_embedding) history_embeds.append(assistant_start_embedding) labels = item['txt'] # shape: (L,) labels = torch.tensor(labels, device=device, dtype=torch.long) embed = self.embed_tokens(labels) # (L2, D),一般 L2 = L history_embeds.append(embed) history_embeds.append(assistant_end_embedding) history_embeds.append(user_start_embedding) # 最后添加一个user start if history_embeds: # 拼接所有历史条目的 embedding: (sum(Li), D) full_embed = torch.cat(history_embeds, dim=0) history_lens.append(full_embed.size(0)) else: # 空历史 full_embed = torch.zeros((1, self.embed_tokens.embedding_dim), device=device) history_lens.append(0) batch_embeddings.append(full_embed) # padding 到 batch 中最大长度 padded_embeddings = pad_sequence(batch_embeddings, batch_first=True, padding_value=0.0) # (B, L, D) history_lens = torch.tensor(history_lens, device=device, dtype=torch.long) padded_embeddings = padded_embeddings.to(device) return padded_embeddings, history_lens def forward(self, batch, device, ): """""" output_type = batch['output_type'] # qwen_instruct_prompt_pattern_chat = "<|im_start|>system\nYou are OSUM-chat, a dialogue. You understand both the meaning and paralinguistic cues in speech, as well as input text, and respond appropriately.<|im_end|>\n<|im_start|>user\n" qwen_instruct_prompt_pattern_chat_s2s = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond with appropriate text and emotionally matching synthetic speech.<|im_end|>\n" qwen_instruct_prompt_pattern_chat_s2s_think = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech. Before responding, first output your reasoning inside ..., analyzing the user’s words and vocal cues. Then generate a reply with appropriate text and emotionally matched synthetic speech.<|im_end|>\n" qwen_instruct_prompt_pattern_chat_s2s_streaming = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You analyze speech (content + paralinguistic cues) and respond with interleaved text and emotionally-matched synthetic speech.<|im_end|>\n" qwen_instruct_prompt_pattern_chat_s2s_streaming_think = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You analyze speech (both content and paralinguistic cues). Before responding, output your reasoning in .... Then reply with interleaved text and emotionally matched synthetic speech.<|im_end|>\n" qwen_instruct_prompt_pattern_chat_s2t = "<|im_start|>system\nYou are OSUM-chat, a speech-to-text dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond exclusively with appropriate text.<|im_end|>\n" qwen_instruct_prompt_pattern__chat_t2t = "<|im_start|>system\nYou are OSUM-chat, a text-to-text dialogue assistant by ASLP Lab. You understand user input in text then respond exclusively with appropriate text.<|im_end|>\n" qwen_instruct_prompt_pattern_1_understand = "<|im_start|>system\nYou are OSUM-chat, an audio understanding assistant by ASLP Lab. You can transcribe speech accurately and analyze paralinguistic cues to provide precise text responses.<|im_end|>\n" qwen_instruct_prompt_pattern_1_tts = "<|im_start|>system\nYou are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural and fluent speech from text input.<|im_end|>\n" qwen_instruct_prompt_pattern_1_tts_streaming = "<|im_start|>system\nYou are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural speech from text input and output both audio and the original text in interleaved format.<|im_end|>\n" qwen_instruct_prompt_pattern_1_old = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n" qwen_instruct_prompt_pattern_1_s2t_thinking = "<|im_start|>system\nYou are OSUM-chat, a thinking-enabled speech-to-text dialogue assistant by ASLP Lab. You not only comprehend the semantic meaning and paralinguistic cues in speech but also engage in deliberate reasoning to process such information. Based on this thinking process, you then respond exclusively with appropriate text.<|im_end|>\n" # user_start = "<|im_start|>user\n" # 赋予不同的系统提示。 if output_type == "s2t_chat": system_prompt = qwen_instruct_prompt_pattern_chat_s2t elif output_type == "s2t_chat_fake": system_prompt = qwen_instruct_prompt_pattern_chat_s2s_think elif output_type == "text": system_prompt = qwen_instruct_prompt_pattern_1_understand elif output_type == "speech2text_token" or output_type == "speech2text_token_history": system_prompt = qwen_instruct_prompt_pattern_chat_s2s elif output_type == "text2token": system_prompt = qwen_instruct_prompt_pattern_1_tts elif output_type == "speech2text_token_streaming": system_prompt = qwen_instruct_prompt_pattern_chat_s2s_streaming elif output_type == "speech2text_token_think": system_prompt = qwen_instruct_prompt_pattern_chat_s2s_think elif output_type == "text2token_streaming": system_prompt = qwen_instruct_prompt_pattern_1_tts_streaming elif output_type == "text2text": system_prompt = qwen_instruct_prompt_pattern__chat_t2t elif output_type == "s2t_chat_think": system_prompt = qwen_instruct_prompt_pattern_1_s2t_thinking else: system_prompt = qwen_instruct_prompt_pattern_1_old # if output_type == "speech2text_token_history": # if output_type == "text2text" or output_type == "text": # qwen_instruct_prompt_pattern_1 = qwen_instruct_prompt_pattern_1_old # elif output_type == "speech2text_token" or output_type == "speech2text_token_streaming" or output_type == "text2text" or output_type == "s2t_chat": # qwen_instruct_prompt_pattern_1 = qwen_instruct_prompt_pattern_chat # elif output_type == "text2token": # qwen_instruct_prompt_pattern_1 = qwen_instruct_prompt_pattern_1_tts # else: # qwen_instruct_prompt_pattern_1 = qwen_instruct_prompt_pattern_1_old system_prompt = system_prompt + "<|im_start|>user\n" rank = int(os.environ.get('RANK', 0)) utils_file.logging_limit_print(f'xxx output_type {output_type}, rank {rank}') # if output_type == "s2t_chat": # output_type = "text" # assert output_type in ['text', 'speech2text_token', 'text2token'], f"output_type:{output_type} not support" # speech inputs if output_type == 'text' or output_type == 's2t_chat' or output_type == 's2t_chat_fake' or output_type== "s2t_chat_think" or output_type == 'speech2text_token' or output_type == "speech2text_token_streaming" or output_type == "speech2text_token_think" or output_type == "speech2text_token_history": wavs = batch['feats'].to(device) # utils_file.logging_limit_print(f'xxx wav shape {wavs.shape}') wavs_len = batch['feats_lengths'].to(device) B = wavs.shape[0] # utils_file.logging_limit_print(f"xxx {wavs_len}") speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len) # utils_file.logging_limit_print(f'xxx speech embeding shape {speech_embeds.shape}') # utils_file.logging_limit_print(f'xxx speech mask shape {speech_masks.shape}') # utils_file.logging_limit_print(f'xxx speech mask 0 {speech_masks[0]}') speech_target = torch.full(speech_masks.shape, self.IGNORE_ID).to( speech_embeds.device) # utils_file.logging_limit_print(f'xxx speech target shape {speech_target.shape}') # utils_file.logging_limit_print(f'xxx speech target 0 {speech_target[0]}') # add bos and eos speech_embeds, speech_masks, speech_target = self._add_bos_eos(0+self.speech_token_num, 1+self.speech_token_num, speech_embeds, speech_masks, speech_target) elif output_type == "text2token" or output_type == "text2token_streaming": labels = batch['target'].to(device) labels_lengths = batch['target_lengths'].to(device) -1 # 减1是因为要去掉eos B = labels.shape[0] # text 2 token ,拿到文本序列, max_len = max(labels_lengths) + 1 labels_pad_mask = make_pad_mask(labels_lengths, max_len=max_len) labels = labels.masked_fill(labels_pad_mask, 0) speech_embeds = self.embed_tokens(labels) # B, L, D speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to( speech_embeds.device) speech_masks = ~labels_pad_mask # add bos and eos # speech_embeds, speech_masks, speech_target = self._add_bos_eos(0+self.speech_token_num, # 1 + self.speech_token_num, # speech_embeds, speech_masks, speech_target) else: # text2text speech_embeds = None speech_masks = None speech_target = None # utils_file.logging_limit_print(f'xxx after add bos eos speech embeding shape {speech_embeds.shape}') # utils_file.logging_limit_print(f'xxx after add bos eos speech mask shape {speech_masks.shape}') # utils_file.logging_limit_print(f'xxx after add bos eos speech target shape {speech_target.shape}') # utils_file.logging_limit_print(f'xxx after add bos eos speech mask 0 {speech_masks[0]}') # utils_file.logging_limit_print(f'xxx after add bos eos speech target 0 {speech_target[0]}') # prompt if 'prompt' in batch: prompt = batch['prompt'].to(device) prompt_lengths = batch['prompt_lengths'].to(device) prompt_pad_mask = make_pad_mask(prompt_lengths) # B, L prompt = prompt.masked_fill(prompt_pad_mask, self.tokenizer.eos_token_id) prompt_embeds = self.embed_tokens(prompt) # B, L, D prompt_target = torch.full(prompt.shape, self.IGNORE_ID).to( device) # B, L prompt_mask = ~prompt_pad_mask # utils_file.logging_limit_print(f'xxx prompt embeding shape {prompt_embeds.shape}') # utils_file.logging_limit_print(f'xxx prompt mask shape {prompt_mask.shape}') # utils_file.logging_limit_print(f'xxx prompt target shape {prompt_target.shape}') else: prompt_embeds = None prompt_mask = None prompt_target = None inputs_embeds_list = [] attention_mask_list = [] target_list = [] prompt_pattern1 = self.tokenizer([system_prompt] * len(batch['target']), return_tensors="pt" )['input_ids'].to(device) prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1) prompt_pattern1_lens = torch.tensor([len(i) for i in prompt_pattern1]).to(device) prompt_pattern1_mask = ~make_pad_mask(prompt_pattern1_lens) prompt_pattern1_target = torch.full(prompt_pattern1.shape, self.IGNORE_ID).to( device) # B, L # user_start_id = self.tokenizer([user_start] * len(batch['target']), return_tensors="pt" # )['input_ids'].to(device) # user_start_embeds = self.embed_tokens(user_start_id) # user_start_lens = torch.tensor([len(i) for i in user_start_id]).to(device) # user_start_mask = ~make_pad_mask(user_start_lens) # user_start_target = torch.full(user_start_id.shape, self.IGNORE_ID).to( # device) # B, L qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n" prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(batch['target']), return_tensors="pt" )['input_ids'].to(device) prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2) prompt_pattern2_lens = torch.tensor([len(i) for i in prompt_pattern2]).to(device) prompt_pattern2_mask = ~make_pad_mask(prompt_pattern2_lens) prompt_pattern2_target = torch.full(prompt_pattern2.shape, self.IGNORE_ID).to( device) # B, L inputs_embeds_list.append(prompt_pattern1_embeds) attention_mask_list.append(prompt_pattern1_mask) target_list.append(prompt_pattern1_target) streaming_error = False if output_type == "speech2text_token_streaming": rank = int(os.environ.get('RANK', 0)) utils_file.logging_limit_print(f'开始处理speech2text_token streaming 任务') labels = batch['target'].to(device) labels_lengths = batch['target_lengths'].to(device) speech_token_labels = batch['speech_tokens'].to(device) speech_tokens_length = batch['speech_tokens_length'].to(device) labels_pad_mask = make_pad_mask(labels_lengths) # B, L labels = labels.masked_fill(labels_pad_mask, 0) speech_tokens_pad_mask = make_pad_mask(speech_tokens_length) # B, L speech_token_labels = speech_token_labels.masked_fill(speech_tokens_pad_mask, 0) speech_token_labels = speech_token_labels + self.llm_vocab_size if rank == 0: utils_file.logging_limit_print(f'labels.shape {labels.shape}') utils_file.logging_limit_print(f'labels_lengths.shape {labels_lengths.shape}') utils_file.logging_limit_print(f'labels[0] {labels[0]}') utils_file.logging_limit_print(f'------------------------') utils_file.logging_limit_print(f'speech_token_labels.shape {speech_token_labels.shape}') utils_file.logging_limit_print(f'speech_tokens_length.shape {speech_tokens_length.shape}') utils_file.logging_limit_print(f'speech_token_labels[0] {speech_token_labels[0]}') utils_file.logging_limit_print(f'==========================') streaming_concat_ids, streaming_concat_lens = make_streaming_mode_from_s2s(labels, labels_lengths, speech_token_labels, speech_tokens_length) if rank == 0: utils_file.logging_limit_print(f'streaming_concat_ids.shape {streaming_concat_ids.shape}') utils_file.logging_limit_print(f'streaming_concat_lens.shape {streaming_concat_lens.shape}') utils_file.logging_limit_print(f'streaming_concat_lens {streaming_concat_lens[0]}') utils_file.logging_limit_print(f'xxx streaming_concat_ids[0] : {streaming_concat_ids[0]}') utils_file.logging_limit_print(f'------------------------') streaming_concat_embeddings = do_embedding_for_two_embeds(streaming_concat_ids, self.llm_vocab_size, self.embed_tokens, self.speech_token_emded) streaming_concat_pad_mask = make_pad_mask(streaming_concat_lens) streaming_concat_target = streaming_concat_ids.masked_fill(streaming_concat_pad_mask, self.IGNORE_ID) streaming_concat_mask = ~streaming_concat_pad_mask if rank == 0: utils_file.logging_limit_print(f'streaming_concat_embeddings.shape {streaming_concat_embeddings.shape}') utils_file.logging_limit_print(f'streaming_concat_mask shape {streaming_concat_mask.shape}') utils_file.logging_limit_print(f'------------------------') # if prompt_embeds is not None: # 对于s2s 对话任务,不再使用user prompt 输入 # inputs_embeds_list.append(prompt_embeds) # attention_mask_list.append(prompt_mask) # target_list.append(prompt_target) # ===================history=================================== history_batch = batch.get('history', []) history_embedding, history_lens = self._get_embedding_for_history(history_batch, device) if history_embedding is not None: utils_file.logging_info(f'OSUM-EChat: 进行历史信息的embedding') history_pad_mask = make_pad_mask(history_lens) # B, L history_target = torch.full(history_pad_mask.shape, self.IGNORE_ID).to(device) # B, L history_mask = ~history_pad_mask inputs_embeds_list.append(history_embedding) attention_mask_list.append(history_mask) target_list.append(history_target) utils_file.logging_limit_print(f'xxx history embeding shape {history_embedding.shape}') utils_file.logging_limit_print(f'xxx history mask shape {history_mask.shape}') utils_file.logging_limit_print(f'xxx history target shape {history_target.shape}') else: utils_file.logging_limit_print(f'history is None') # ==========================history end =================== inputs_embeds_list.extend( [ speech_embeds, prompt_pattern2_embeds, streaming_concat_embeddings]) attention_mask_list.extend([speech_masks, prompt_pattern2_mask, streaming_concat_mask]) target_list.extend([speech_target, prompt_pattern2_target, streaming_concat_target]) elif output_type == "text2token_streaming": rank = int(os.environ.get('RANK', 0)) utils_file.logging_limit_print(f'开始tts streaming 任务') labels = batch['target'].to(device) labels_lengths = batch['target_lengths'].to(device) speech_token_labels = batch['speech_tokens'].to(device) speech_tokens_length = batch['speech_tokens_length'].to(device) labels_pad_mask = make_pad_mask(labels_lengths) # B, L labels = labels.masked_fill(labels_pad_mask, 0) speech_tokens_pad_mask = make_pad_mask(speech_tokens_length) # B, L speech_token_labels = speech_token_labels.masked_fill(speech_tokens_pad_mask, 0) speech_token_labels = speech_token_labels + self.llm_vocab_size streaming_concat_ids, streaming_concat_lens = make_streaming_mode_from_s2s(labels, labels_lengths, speech_token_labels, speech_tokens_length) streaming_concat_embeddings = do_embedding_for_two_embeds(streaming_concat_ids, self.llm_vocab_size, self.embed_tokens, self.speech_token_emded) streaming_concat_pad_mask = make_pad_mask(streaming_concat_lens) streaming_concat_target = streaming_concat_ids.masked_fill(streaming_concat_pad_mask, self.IGNORE_ID) streaming_concat_mask = ~streaming_concat_pad_mask # if prompt_embeds is not None: # 对于tts 对话任务,不再使用user prompt 输入 # inputs_embeds_list.append(prompt_embeds) # attention_mask_list.append(prompt_mask) # target_list.append(prompt_target) inputs_embeds_list.extend( [ speech_embeds, prompt_pattern2_embeds, streaming_concat_embeddings]) attention_mask_list.extend([speech_masks, prompt_pattern2_mask, streaming_concat_mask]) target_list.extend([speech_target, prompt_pattern2_target, streaming_concat_target]) elif output_type == 'speech2text_token' or output_type == "speech2text_token_think" or output_type == "speech2text_token_history": utils_file.logging_limit_print(f'xxx 开始处理speech2text_token任务') labels = batch['target'].to(device) labels_lengths = batch['target_lengths'].to(device) speech_token_labels = batch['speech_tokens'].to(device) speech_tokens_length = batch['speech_tokens_length'].to(device) labels_embeds, labels_target, labels_mask = self.get_label_embedding(labels, labels_lengths) speech_token_labels_embeds, speech_token_labels_target, speech_token_labels_mask = self.get_speech_token_label_embedding( speech_token_labels, speech_tokens_length) # if prompt_embeds is not None: # 对于s2s 对话任务,不再使用user prompt 输入 # inputs_embeds_list.append(prompt_embeds) # attention_mask_list.append(prompt_mask) # target_list.append(prompt_target) # ===================history=================================== history_batch = batch.get('history', []) history_embedding, history_lens = self._get_embedding_for_history(history_batch, device) if history_embedding is not None: utils_file.logging_info(f'OSUM-EChat: 进行历史信息的embedding') history_pad_mask = make_pad_mask(history_lens) # B, L history_target = torch.full(history_pad_mask.shape, self.IGNORE_ID).to(device) # B, L history_mask = ~history_pad_mask inputs_embeds_list.append(history_embedding) attention_mask_list.append(history_mask) target_list.append(history_target) utils_file.logging_limit_print(f'xxx history embeding shape {history_embedding.shape}') utils_file.logging_limit_print(f'xxx history mask shape {history_mask.shape}') utils_file.logging_limit_print(f'xxx history target shape {history_target.shape}') else: utils_file.logging_limit_print(f'history is None') # ==========================history end =================== inputs_embeds_list.extend( [ speech_embeds, prompt_pattern2_embeds, labels_embeds, speech_token_labels_embeds]) attention_mask_list.extend([speech_masks, prompt_pattern2_mask, labels_mask, speech_token_labels_mask]) target_list.extend([speech_target, prompt_pattern2_target, labels_target, speech_token_labels_target]) elif output_type == "text2token": speech_token_labels = batch['speech_tokens'].to(device) speech_tokens_length = batch['speech_tokens_length'].to(device) speech_token_labels_embeds, speech_token_labels_target, speech_token_labels_mask = self.get_speech_token_label_embedding( speech_token_labels, speech_tokens_length) # if prompt_embeds is not None: # 对于tts 对话任务,不再使用user prompt 输入 # inputs_embeds_list.append(prompt_embeds) # attention_mask_list.append(prompt_mask) # target_list.append(prompt_target) inputs_embeds_list.extend([ speech_embeds, prompt_pattern2_embeds, speech_token_labels_embeds]) attention_mask_list.extend([speech_masks, prompt_pattern2_mask, speech_token_labels_mask]) target_list.extend([speech_target, prompt_pattern2_target, speech_token_labels_target]) elif output_type == "text" or output_type == 's2t_chat' or output_type == "s2t_chat_fake" or output_type == "s2t_chat_think": labels = batch['target'].to(device) labels_lengths = batch['target_lengths'].to(device) labels_embeds, labels_target, labels_mask = self.get_label_embedding(labels, labels_lengths) if prompt_embeds is not None and output_type == 'text': # 对于s2t_chat 对话任务,不再使用user prompt 输入 inputs_embeds_list.append(prompt_embeds) attention_mask_list.append(prompt_mask) target_list.append(prompt_target) elif output_type != 's2t_chat' or output_type != "s2t_chat_fake" or output_type != "s2t_chat_think": utils_file.logging_limit_print( f'prompt is None,task: {batch["task"]}, prompt_embeds:{prompt_embeds}, prompt_mask:{prompt_mask}') inputs_embeds_list.extend([ speech_embeds, prompt_pattern2_embeds, labels_embeds]) attention_mask_list.extend([speech_masks, prompt_pattern2_mask, labels_mask]) target_list.extend([speech_target, prompt_pattern2_target, labels_target]) elif output_type == "text2text": labels = batch['target'].to(device) labels_lengths = batch['target_lengths'].to(device) labels_embeds, labels_target, labels_mask = self.get_label_embedding(labels, labels_lengths) if prompt_embeds is not None: inputs_embeds_list.append(prompt_embeds) attention_mask_list.append(prompt_mask) target_list.append(prompt_target) else: utils_file.logging_limit_print( f'prompt is None,task: {batch["task"]}, prompt_embeds:{prompt_embeds}, prompt_mask:{prompt_mask}') inputs_embeds_list.extend([ prompt_pattern2_embeds, labels_embeds]) attention_mask_list.extend([ prompt_pattern2_mask, labels_mask]) target_list.extend([ prompt_pattern2_target, labels_target]) else: raise NotImplementedError(f'output_type {output_type} not support') inputs_embeds = torch.cat(inputs_embeds_list, dim=1) # utils_file.logging_limit_print(f'xxx final inputs_embeds shape {inputs_embeds.shape}') attention_mask = torch.cat(attention_mask_list, dim=1) # utils_file.logging_limit_print(f'xxx final attention_mask shape {attention_mask.shape}') # utils_file.logging_limit_print(f'xxx final attention_mask 0 {attention_mask[0]}') target = torch.cat(target_list, dim=1) # utils_file.logging_limit_print(f'xxx final target shape {target.shape}') # utils_file.logging_limit_print(f'xxx final target 0 {target[0]}') # utils_file.logging_limit_print(f'OSUM-EChat output_type: {output_type}') position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) # utils_file.logging_limit_print(f'xxx final position_ids shape {position_ids.shape}') # utils_file.logging_limit_print(f'xxx final position_ids 0 {position_ids[0]}') if output_type == 'text' or output_type == 's2t_chat' or output_type == "s2t_chat_fake" or output_type == "s2t_chat_think" or output_type == "text2text": outputs = self.llama_model( inputs_embeds=inputs_embeds, labels=target, attention_mask=attention_mask, position_ids=position_ids.to(inputs_embeds.device) ) loss = outputs['loss'] return {"loss": loss,"output_type": output_type} else: utils_file.logging_limit_print(f'进行llama_model的 diy forward') outputs = self.llama_model( inputs_embeds=inputs_embeds, # labels=target, attention_mask=attention_mask, position_ids=position_ids.to(inputs_embeds.device) ) hidden_states = outputs['hidden_states'][-1] logits = self.lm_head(hidden_states) logits2 = self.speech_head(hidden_states) # speech_head combined_logits = torch.cat([logits, logits2], dim=-1) # combined_logits = self.new_lm_head(hidden_states) shift_logits = combined_logits[..., :-1, :].contiguous() shift_target = target[..., 1:].contiguous() # utils_file.logging_limit_print( # f'xxx shift_logits shape: {shift_logits.shape}, shift_target shape: {shift_target.shape}') # utils_file.logging_limit_print(f'xxx shift_target 0 {shift_target[0]}') shift_logits = shift_logits.view(-1, combined_logits.shape[-1]) # 注意这里维度的调整,根据logits2的维度相应改变 shift_target = shift_target.view(-1) shift_target = shift_target.to(shift_logits.device) loss = self.loss_fct(shift_logits, shift_target) loss.requires_grad_(True) return {"loss": loss,"output_type": output_type} def generate_s2s_streaming( self, wavs, wavs_len, prompt, ): self.llama_model.eval() speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len) speech_embeds, speech_masks, _ = self._add_bos_eos(0 +self.speech_token_num, 1+self.speech_token_num, speech_embeds, speech_masks, None) device = speech_embeds.device prompt = self.tokenizer([prompt], return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_embeds = self.embed_tokens(prompt) qwen_instruct_prompt_pattern_1_s_input_chat = "<|im_start|>system\nYou are OSUM-chat, a dialogue assistant created by . You understand both the meaning and paralinguistic cues in users' speech, and respond appropriately with text or voice.<|im_end|>\n<|im_start|>user\n" qwen_instruct_prompt_pattern_1_t2t_chat = "<|im_start|>system\nYou are OSUM-chat, a dialogue assistant created by . You understand user input in text and respond with accurate and helpful text replies.<|im_end|>\n<|im_start|>user\n" qwen_instruct_prompt_pattern_1_old = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" qwen_instruct_prompt_pattern_1 =qwen_instruct_prompt_pattern_1_old prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1], return_tensors="pt" )['input_ids'].to(device) prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1) prompt_pattern1_lens = torch.tensor([len(i) for i in prompt_pattern1]).to(device) prompt_pattern1_mask = ~make_pad_mask(prompt_pattern1_lens) prompt_pattern1_target = torch.full(prompt_pattern1.shape, self.IGNORE_ID).to( device) # B, L qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n" prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] , return_tensors="pt" )['input_ids'].to(device) start_id = prompt_pattern2[0][-1] new_prompt_pattern2 = prompt_pattern2[:,:-1] prompt_pattern2_embeds = self.embed_tokens(new_prompt_pattern2) prompt_pattern2_lens = torch.tensor([len(i) for i in new_prompt_pattern2],dtype=torch.long).to(device) prompt_pattern2_mask = ~make_pad_mask(prompt_pattern2_lens) prompt_pattern2_target = torch.full(new_prompt_pattern2.shape, self.IGNORE_ID).to( device) # B, L embeds = torch.cat([prompt_pattern1_embeds,prompt_embeds, speech_embeds, prompt_pattern2_embeds], dim=1) atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device) if self.embed_tokens.weight.dtype == torch.float16: # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16') embeds = embeds.to(torch.float16) max_len = 350 hyps = [start_id] print(f'start_id: {start_id}') llm_out = self.llama_model( inputs_embeds=embeds, past_key_values=None, output_hidden_states=True ) batch_size = 1 top_k = 10 top_p = 0.9 temperature = 1 cache = llm_out.past_key_values token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0) top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) temperatures_tensor = None if not temperature else torch.FloatTensor( [temperature] * batch_size).to(device) inferring_txt = True txt_finished = False speeech_finished = False hyps_text = "" speech_eos_num = 0 txt_list = [] token_list = [] i_num = 0 for i in range(max_len): if inferring_txt and not txt_finished: for i_txt in range(6): i_num += 1 if i_num> 300: break llm_out = self.llama_model( inputs_embeds=token_emb, past_key_values=cache, output_hidden_states=True ) cache = llm_out.past_key_values hidden_states = llm_out.hidden_states[-1] token_logits = self.lm_head(hidden_states) next_token_ids = self._sampler( token_logits, temperatures_tensor, top_ps_tensor, top_ks_tensor, ) # next_token_ids = torch.argmax(token_logits, dim=-1) print(i_num, next_token_ids, f'txt') hyps.append(next_token_ids.item()) token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0) if next_token_ids == self.eos_token_id: txt_finished = True hyps_text = self.tokenizer.decode(txt_list, skip_special_tokens=True, add_special_tokens=False) print("hyps_text:", hyps_text) print("text is over") break txt_list.append(next_token_ids.item()) hyps_text = self.tokenizer.decode(txt_list, skip_special_tokens=True, add_special_tokens=False) print("hyps_text:", hyps_text) inferring_txt = False elif not speeech_finished: for i_speech in range(18): i_num += 1 if i_num> 300: break llm_out = self.llama_model( inputs_embeds=token_emb, past_key_values=cache, output_hidden_states=True ) cache = llm_out.past_key_values hidden_states = llm_out.hidden_states[-1] token_logits = self.speech_head(hidden_states) next_token_ids = self._sampler( token_logits, temperatures_tensor, top_ps_tensor, top_ks_tensor, ) # next_token_ids = torch.argmax(token_logits, dim=-1) hyps.append(next_token_ids.item()) print(i_num, next_token_ids, f'speech') token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0) if next_token_ids == 4096: speech_eos_num += 1 print(f'遇到 4096') if speech_eos_num >= 2: print("speech is over") speeech_finished = True break token_list.append(next_token_ids.item()) inferring_txt = True if speeech_finished: break if i_num > 300: break return [hyps_text + "|" + str(token_list)] def generate( self, wavs, wavs_len, prompt, **kwargs ): self.llama_model.eval() self.set_task_type("ASR") self.do_add_speech_embed_head() speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len) speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, 1 + self.speech_token_num, speech_embeds, speech_masks, None) prompt = self.tokenizer([prompt], return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_embeds = self.embed_tokens(prompt) qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, an audio understanding. You can transcribe speech accurately and anaosum_echat2e paralinguistic cues to provide precise text responses.<|im_end|>\n<|im_start|>user\n" prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1) qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n" prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2) embeds = torch.cat([prompt_pattern1_embeds, prompt_embeds, speech_embeds, prompt_pattern2_embeds], dim=1) atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device) if self.embed_tokens.weight.dtype == torch.float16 or self.embed_tokens.weight.dtype == torch.bfloat16: embeds = embeds.to(torch.bfloat16) atts = atts.to(torch.bfloat16) outputs = self.llama_model.generate( inputs_embeds=embeds, max_new_tokens=self.max_length, cache_implementation="static", # num_beams=self.num_beams, do_sample=self.do_sample, min_length=self.min_length, top_p=self.top_p, top_k=self.top_k, repetition_penalty=self.repetition_penalty, length_penalty=self.length_penalty, temperature=self.temperature, # attention_mask=atts, eos_token_id=151645, pad_token_id=-100, stopping_criteria=self.max_token_criteria_list, do_compile=True, ) output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True) return output_text def generate4chat( self, wavs, wavs_len, prompt=" ", do_sample=True, top_k=2, top_p=1, temperature=0.4, **kwargs ): print(f'do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, temperature: {temperature}') self.llama_model.eval() self.set_task_type("ASR") self.do_add_speech_embed_head() # self.do_merge_embed_head() speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len) speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, 1 + self.speech_token_num, speech_embeds, speech_masks, None) prompt = self.tokenizer([prompt], return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_embeds = self.embed_tokens(prompt) # qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" # # sft qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a speech-to-text dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond exclusively with appropriate text.<|im_end|>\n<|im_start|>user\n" prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1) qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n" prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2) embeds = torch.cat([prompt_pattern1_embeds, speech_embeds, prompt_pattern2_embeds], dim=1) atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device) if self.embed_tokens.weight.dtype == torch.float16 or self.embed_tokens.weight.dtype == torch.bfloat16: # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16') # embeds = embeds.to(torch.float16) embeds = embeds.to(torch.bfloat16) atts = atts.to(torch.bfloat16) outputs = self.llama_model.generate( inputs_embeds=embeds, max_new_tokens=self.max_length, cache_implementation="static", # num_beams=1, do_sample=do_sample, min_length=self.min_length, top_p=top_p, top_k=top_k, repetition_penalty=self.repetition_penalty, length_penalty=1, temperature=temperature, # attention_mask=atts, eos_token_id=151645, pad_token_id=-100, do_compile=True, stopping_criteria=self.max_token_criteria_list, ) output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True) return output_text def generate4chat_think( self, wavs, wavs_len, prompt=" ", do_sample=True, top_k=2, top_p=1, temperature=0.4, **kwargs ): print(f'do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, temperature: {temperature}') self.llama_model.eval() self.set_task_type("ASR") self.do_add_speech_embed_head() # self.do_merge_embed_head() speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len) speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, 1 + self.speech_token_num, speech_embeds, speech_masks, None) prompt = self.tokenizer([prompt], return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_embeds = self.embed_tokens(prompt) # qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" # # sft qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a thinking-enabled speech-to-text dialogue assistant by ASLP Lab. You not only comprehend the semantic meaning and paralinguistic cues in speech but also engage in deliberate reasoning to process such information. Based on this thinking process, you then respond exclusively with appropriate text.<|im_end|>\n" prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1) qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n" prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2) embeds = torch.cat([prompt_pattern1_embeds, speech_embeds, prompt_pattern2_embeds], dim=1) atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device) if self.embed_tokens.weight.dtype == torch.float16 or self.embed_tokens.weight.dtype == torch.bfloat16: # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16') # embeds = embeds.to(torch.float16) embeds = embeds.to(torch.bfloat16) atts = atts.to(torch.bfloat16) outputs = self.llama_model.generate( inputs_embeds=embeds, max_new_tokens=self.max_length, cache_implementation="static", # num_beams=1, do_sample=do_sample, min_length=self.min_length, top_p=top_p, top_k=top_k, repetition_penalty=self.repetition_penalty, length_penalty=1, temperature=temperature, # attention_mask=atts, eos_token_id=151645, pad_token_id=-100, do_compile=True, stopping_criteria=self.max_token_criteria_list, ) output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True) return output_text def generate_tts( self, device, text, ): # =====================准备input embedding===================== self.llama_model.eval() # 得到模型所在的device # device = self.llama_model.device self.set_task_type("TTS") self.do_add_speech_embed_head() # labels_lengths = torch.tensor([len(text[0])], dtype=torch.int64, device=device) # labels = text[:,:] labels = self.tokenizer( [text], return_tensors="pt", add_special_tokens=False ).to( self.embed_tokens.weight.device).input_ids # (1, L) labels = labels.to(device) labels_lengths = torch.tensor([len(labels[0])], dtype=torch.int64, device=device) print(f'label_lengths:{labels_lengths}') print(f'labels:{labels}') labels_pad_mask = make_pad_mask(labels_lengths) # B, L labels = labels.masked_fill(labels_pad_mask, 0) speech_embeds = self.embed_tokens(labels) # B, L, D speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to( speech_embeds.device) speech_masks = ~labels_pad_mask # speech_embeds, speech_masks, speech_target = self._add_bos_eos(0 + self.speech_token_num, # 1 + self.speech_token_num, # speech_embeds, speech_masks, speech_target) qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural and fluent speech from text input.<|im_end|>\n<|im_start|>user\n" prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(speech_embeds), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1) qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n" prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(speech_embeds), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2) hyps = [self.speech_token_num - 1] speech_begin_token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0) embeds = torch.cat([prompt_pattern1_embeds, speech_embeds, prompt_pattern2_embeds, speech_begin_token_emb], dim=1).to(torch.bfloat16) # 指定top_k top_p temperature stop # max_len = 250 top_k = 15 # 5 top_p = 0.8 # 0.9 temperature = 1.2 # 1 print(f"tts eos id = {self.speech_token_num - 1}") llm_out = self.llama_model.generate(inputs_embeds=embeds, max_new_tokens=self.max_length, eos_token_id=self.speech_token_num - 1, cache_implementation="static", do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p, stopping_criteria=StoppingCriteriaList([MaxTokenStopper(2000)]), do_compile=True, repetition_penalty=1.0, ) return llm_out def generate_tts_streaming( self, device, prompt, text, ): self.llama_model.eval() # labels_lengths = torch.tensor([len(text[0])], dtype=torch.int64, device=device) # labels = text[:,:] labels = self.tokenizer( [text], return_tensors="pt", add_special_tokens=False ).to( self.embed_tokens.weight.device).input_ids # (1, L) labels = labels.to(device) labels_lengths = torch.tensor([len(labels[0])], dtype=torch.int64, device=device) print(f'label_lengths:{labels_lengths}') print(f'labels:{labels}') labels_pad_mask = make_pad_mask(labels_lengths) # B, L labels = labels.masked_fill(labels_pad_mask, 0) speech_embeds = self.embed_tokens(labels) # B, L, D speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to( speech_embeds.device) speech_masks = ~labels_pad_mask speech_embeds, speech_masks, speech_target = self._add_bos_eos(0 + self.speech_token_num, 1 + self.speech_token_num, speech_embeds, speech_masks, speech_target) qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(speech_embeds), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1) qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n" prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(speech_embeds), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2) prompt = self.tokenizer([prompt], return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_embeds = self.embed_tokens(prompt) # embeds = torch.cat([prompt_pattern1_embeds, prompt_embeds, speech_embeds, prompt_pattern2_embeds], dim=1) # ---------------- start_id = prompt_pattern2[0][-1] new_prompt_pattern2 = prompt_pattern2[:, :-1] prompt_pattern2_embeds = self.embed_tokens(new_prompt_pattern2) prompt_pattern2_lens = torch.tensor([len(i) for i in new_prompt_pattern2], dtype=torch.long).to(device) prompt_pattern2_mask = ~make_pad_mask(prompt_pattern2_lens) prompt_pattern2_target = torch.full(new_prompt_pattern2.shape, self.IGNORE_ID).to( device) # B, L embeds = torch.cat([prompt_pattern1_embeds, prompt_embeds, speech_embeds, prompt_pattern2_embeds], dim=1) atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device) if self.embed_tokens.weight.dtype == torch.float16: # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16') embeds = embeds.to(torch.float16) max_len = 350 hyps = [start_id] print(f'start_id: {start_id}') llm_out = self.llama_model( inputs_embeds=embeds, past_key_values=None, output_hidden_states=True ) batch_size = 1 top_k = 10 top_p = 0.9 temperature = 1 cache = llm_out.past_key_values token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0) top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) temperatures_tensor = None if not temperature else torch.FloatTensor( [temperature] * batch_size).to(device) inferring_txt = True txt_finished = False speeech_finished = False hyps_text = "" speech_eos_num = 0 txt_list = [] token_list = [] i_num = 0 for i in range(max_len): if inferring_txt and not txt_finished: for i_txt in range(6): i_num += 1 if i_num > 300: break llm_out = self.llama_model( inputs_embeds=token_emb, past_key_values=cache, output_hidden_states=True ) cache = llm_out.past_key_values hidden_states = llm_out.hidden_states[-1] token_logits = self.lm_head(hidden_states) next_token_ids = self._sampler( token_logits, temperatures_tensor, top_ps_tensor, top_ks_tensor, ) # next_token_ids = torch.argmax(token_logits, dim=-1) print(i_num, next_token_ids, f'txt') hyps.append(next_token_ids.item()) token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0) if next_token_ids == self.eos_token_id: txt_finished = True hyps_text = self.tokenizer.decode(txt_list, skip_special_tokens=True, add_special_tokens=False) print("hyps_text:", hyps_text) print("text is over") break txt_list.append(next_token_ids.item()) hyps_text = self.tokenizer.decode(txt_list, skip_special_tokens=True, add_special_tokens=False) print("hyps_text:", hyps_text) inferring_txt = False elif not speeech_finished: for i_speech in range(18): i_num += 1 if i_num > 300: break llm_out = self.llama_model( inputs_embeds=token_emb, past_key_values=cache, output_hidden_states=True ) cache = llm_out.past_key_values hidden_states = llm_out.hidden_states[-1] token_logits = self.speech_head(hidden_states) next_token_ids = self._sampler( token_logits, temperatures_tensor, top_ps_tensor, top_ks_tensor, ) # next_token_ids = torch.argmax(token_logits, dim=-1) hyps.append(next_token_ids.item()) print(i_num, next_token_ids, f'speech') token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0) if next_token_ids == 4096: speech_eos_num += 1 print(f'遇到 4096') if speech_eos_num >= 2: print("speech is over") speeech_finished = True break token_list.append(next_token_ids.item()) inferring_txt = True if i_num > 300: break return [hyps_text + "|" + str(token_list)] def generate_text2text( self, device, text, ): self.llama_model.eval() # labels_lengths = torch.tensor([len(text[0])], dtype=torch.int64, device=device) # labels = text[:,:] labels = self.tokenizer( [text], return_tensors="pt", add_special_tokens=False ).to( self.embed_tokens.weight.device).input_ids # (1, L) labels = labels.to(device) labels_lengths = torch.tensor([len(labels[0])], dtype=torch.int64, device=device) # print(f'label_lengths:{labels_lengths}') # print(f'labels:{labels}') labels_pad_mask = make_pad_mask(labels_lengths) # B, L labels = labels.masked_fill(labels_pad_mask, 0) speech_embeds = self.embed_tokens(labels) # B, L, D speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to( speech_embeds.device) speech_masks = ~labels_pad_mask # speech_embeds, speech_masks, speech_target = self._add_bos_eos(0 + self.speech_token_num, # 1 + self.speech_token_num, # speech_embeds, speech_masks, speech_target) # qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" # # sft qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a text-to-text dialogue assistant by ASLP Lab. You understand user input in text then respond exclusively with appropriate text.<|im_end|>\n<|im_start|>user\n" prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(speech_embeds), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1) qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n" prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(speech_embeds), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2) embeds = torch.cat([prompt_pattern1_embeds, speech_embeds, prompt_pattern2_embeds], dim=1) atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device) if self.embed_tokens.weight.dtype == torch.float16 or self.embed_tokens.weight.dtype == torch.bfloat16: # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16') # embeds = embeds.to(torch.float16) embeds = embeds.to(torch.bfloat16) atts = atts.to(torch.bfloat16) outputs = self.llama_model.generate( inputs_embeds=embeds, max_new_tokens=200, num_beams=1, do_sample=False, min_length=self.min_length, repetition_penalty=1.0, length_penalty=1.0, temperature=self.temperature, attention_mask=atts, eos_token_id=151645, pad_token_id=-100, do_compile=True, cache_implementation="static", ) output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True) # output_text = [item.replace('<|endoftext|>', '') for item in output_text] return output_text def generate_s2s_no_stream_with_repetition_penalty( self, wavs, wavs_len, ): self.llama_model.eval() self.set_task_type("S2S") self.do_add_speech_embed_head() # =====================准备input embedding===================== speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len) speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, None, speech_embeds, speech_masks, None) device = speech_embeds.device # qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech, as well as input text, and respond appropriately.<|im_end|>\n<|im_start|>user\n" qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond with appropriate text and emotionally matching synthetic speech.<|im_end|>\n<|im_start|>user\n" # qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1) qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n" prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2) hyps = [4098] token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0) embeds = torch.cat( [prompt_pattern1_embeds, speech_embeds, token_emb, prompt_pattern2_embeds], dim=1) atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device) if self.embed_tokens.weight.dtype == torch.float16: # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16') embeds = embeds.to(torch.float16) # max_len = 350 top_k = 10 top_p = 0.9 temperature = 1.2 invalid_eos = 10000000 self.osum_chat_logit_processor1.init_match_found() llm_out = self.llama_model.generate(inputs_embeds=embeds, max_new_tokens=self.max_length, eos_token_id=invalid_eos, cache_implementation="static", do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p, logits_processor=self.s2s_repetition_penalty, stopping_criteria=self.s2s_stop_criteria, do_compile=True, repetition_penalty=1.0, ) text_eos_idx = (llm_out[0] == 151645).nonzero(as_tuple=True)[0][0].item() text_res = llm_out[:, :text_eos_idx - 1] speech_res = llm_out[:, text_eos_idx + 1:-1] # print("llm_out", llm_out) output_text = self.tokenizer.batch_decode(text_res, add_special_tokens=False, skip_special_tokens=True) # print(f'output_text:{output_text}') # print(f'speech_res:{speech_res}') return (output_text, text_res, speech_res) def generate_s2s_no_stream_think_with_repetition_penalty( self, wavs, wavs_len, ): self.llama_model.eval() self.set_task_type("S2S") self.do_add_speech_embed_head() # =====================准备input embedding===================== speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len) speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, None, speech_embeds, speech_masks, None) device = speech_embeds.device # qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech, as well as input text, and respond appropriately.<|im_end|>\n<|im_start|>user\n" # qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond with appropriate text and emotionally matching synthetic speech.<|im_end|>\n<|im_start|>user\n" qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech. Before responding, first output your reasoning inside ..., analyzing the user’s words and vocal cues. Then generate a reply with appropriate text and emotionally matched synthetic speech.<|im_end|>\n<|im_start|>user\n" # qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1) qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n" prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2) hyps = [4098] token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0) embeds = torch.cat( [prompt_pattern1_embeds, speech_embeds, token_emb, prompt_pattern2_embeds], dim=1) atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device) if self.embed_tokens.weight.dtype == torch.float16: # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16') embeds = embeds.to(torch.float16) # max_len = 350 top_k = 10 top_p = 0.9 temperature = 1.2 invalid_eos = 10000000 self.osum_chat_logit_processor1.init_match_found() # 非think不用匹配 llm_out = self.llama_model.generate(inputs_embeds=embeds, max_new_tokens=self.max_length, eos_token_id=invalid_eos, cache_implementation="static", do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p, logits_processor=self.s2s_repetition_penalty, stopping_criteria=self.s2s_stop_criteria, do_compile=True, repetition_penalty=1.0, ) text_eos_idx = (llm_out[0] == 151645).nonzero(as_tuple=True)[0][0].item() text_res = llm_out[:, :text_eos_idx - 1] speech_res = llm_out[:, text_eos_idx + 1:-1] # print("llm_out", llm_out) output_text = self.tokenizer.batch_decode(text_res, add_special_tokens=False, skip_special_tokens=True) # print(f'output_text:{output_text}') # print(f'speech_res:{speech_res}') return (output_text, text_res, speech_res) def _get_embedding_from_wav(self, wavs, wavs_len): """ return: wav_embedding: (b, l, v) wav_mask: (b, l), wav为有效值的位置为true """ encoder_out, encoder_mask = self.encoder(wavs, wavs_len) speech_embeds, encoder_mask = self.down_sample_2(encoder_out, encoder_mask) if self.speech_transformer is not None: filled_wavs_len = encoder_mask.squeeze(1).sum(-1) speech_embeds, encoder_mask = self.speech_transformer(speech_embeds, filled_wavs_len) # if rank == 0: # utils_file.logging_limit_print( # f'out of link shape: {speech_embeds.shape},encoder的第一帧的前20个数字:\n {speech_embeds[0][0][:20]}') # utils_file.logging_limit_print( # 'get_embedding_from_wav(): speech_embeds.shape,by self.speech_transformer(speech_embeds, speech_lens):', # speech_embeds.shape) speech_embeds = self.speech_llama_proj(speech_embeds) # if rank == 0: # utils_file.logging_limit_print( # f'out of speech_llama_proj shape: {speech_embeds.shape},encoder的第一帧的前20个数字:\n {speech_embeds[0][0][:20]}') # utils_file.logging_limit_print( # 'get_embedding_from_wav(): speech_embeds.shape,by self.speech_llama_proj(speech_embeds):', # speech_embeds.shape) return speech_embeds, encoder_mask.squeeze(1) def _get_embedding_from_text(self, text): """ 将字符串先量化,再转成词向量 Args: text: str Returns: text_embeds: (1, L, D) """ text_id = self.tokenizer( text, return_tensors="pt", add_special_tokens=False ).to( self.embed_tokens.weight.device).input_ids text_embeds = self.embed_tokens(text_id) text_embeds_len = torch.tensor([text_embeds.size(1)], dtype=torch.long) return text_embeds, text_embeds_len def _add_bos_eos(self, bos, eos, inputs_embeds, attention_mask, target=None): B = len(inputs_embeds) bos_eos_target = torch.full([B, 1], self.IGNORE_ID).to(inputs_embeds.device) # B,1 bos_eos_mask = torch.full([B, 1], True).to(inputs_embeds.device) # B, 1 if bos is not None: bos_embed = self.speech_token_emded(torch.full([B, 1], bos).to(inputs_embeds.device)) # B, 1, D inputs_embeds = torch.cat((bos_embed, inputs_embeds), 1) # B, (1+T), D attention_mask = torch.cat((bos_eos_mask, attention_mask), 1) # B, (1+T) if target is not None: target = torch.cat((bos_eos_target, target), 1) # B, (1+T), D if eos is not None: eos_embed = self.speech_token_emded(torch.full([B, 1], eos).to(inputs_embeds.device)) # B, 1, D inputs_embeds = torch.cat((inputs_embeds, eos_embed), 1) # B, (1+T+1), D attention_mask = torch.cat((attention_mask, bos_eos_mask), 1) # B, (1+T+1) if target is not None: target = torch.cat((target, bos_eos_target), 1) # B, (1+T+1), D return inputs_embeds, attention_mask, target def infer_sample_teach_force( self, wavs, wavs_len, prompt, text, speech_token, ): labels_lengths = torch.tensor([len(text[0])], dtype=torch.int64, device=wavs.device) labels = text[:, :] labels_pad_mask = make_pad_mask(labels_lengths) # B, L labels = labels.masked_fill(labels_pad_mask, 0) speech_embeds = self.embed_tokens(labels) # B, L, D speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to( speech_embeds.device) speech_masks = ~labels_pad_mask speech_embeds, speech_masks, speech_target = self._add_bos_eos(0 +self.speech_token_num, 1 +self.speech_token_num, speech_embeds, speech_masks, speech_target) prompt = self.tokenizer([prompt], return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_embeds = self.embed_tokens(prompt) embeds = torch.cat([prompt_embeds, speech_embeds], dim=1) atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device) if self.embed_tokens.weight.dtype == torch.float16: utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16') embeds = embeds.to(torch.float16) atts = atts.half() device = wavs.device inputs_embeds = embeds.to(device) speech_token_list = speech_token[0].tolist() speech_token_list_len = len(speech_token_list) print(f'speech_token_list_len:{speech_token_list_len}') max_len = 200 beam = 3 beam_size = beam running_size = beam output_token = [] hyps = [self.speech_token_num - 1] scores = [1.0] llm_out = self.llama_model( inputs_embeds=embeds, past_key_values=None, output_hidden_states=True ) batch_size = 1 top_k = 10 top_p = 0.9 temperature = 1.0 cache = llm_out.past_key_values token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0) repetition_penalty = 1.1 top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) temperatures_tensor = None if not temperature else torch.FloatTensor( [temperature] * batch_size).to(device) for i in range(max_len): llm_out = self.llama_model( inputs_embeds=token_emb, past_key_values=cache, output_hidden_states=True ) cache = llm_out.past_key_values hidden_states = llm_out.hidden_states[-1] token_logits = self.speech_head(hidden_states) # probs = F.log_softmax(token_logits[:,-1], dim=-1)[0] next_token_ids = self._sampler( token_logits, temperatures_tensor, top_ps_tensor, top_ks_tensor, ) print(i, next_token_ids) if next_token_ids == self.speech_token_num - 1: print("break la!") print("hyps:", hyps) break hyps.append(next_token_ids.item()) token_emb = self.speech_token_emded(torch.tensor(speech_token_list[i]).to(device)).unsqueeze(0) res = [] for i in hyps[1:]: # if i != self.speech_token_num-1: res.append(i) print(res) return [res] def _sampler( self, logits: torch.Tensor, temperatures: Union[torch.Tensor, None], top_ps: torch.Tensor, top_ks: torch.Tensor, ) -> torch.Tensor: """ Sample from logits. Args: logits: (1,1,vocab_size) temperatures: top_ps: top_ks: Returns: """ print(f'logits:{logits.shape}') assert logits.size(1) == 1 logits = logits.squeeze(1) # (batch_size, vocab_size) if temperatures is None: return torch.argmax(logits, dim=-1).squeeze(dim=-1) # Apply temperature scaling. logits.div_(temperatures.unsqueeze(dim=1)) # Calculate probabilities with softmax. probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) # Apply top-p, top-k. probs_sum = torch.cumsum(probs_sort, dim=-1) top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1) probs_sort = torch.where(top_ps_mask, 0, probs_sort) top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1) top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1) probs_sort = torch.where(top_ks_mask, 0, probs_sort) # Re-normalization. probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) probs = torch.gather(probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1)) next_token_ids = torch.multinomial(probs, num_samples=1, replacement=True).squeeze(dim=-1) return next_token_ids def infer_sample4speech2text_token_teacher_force( self, wavs, wavs_len, prompt, speech_token=None, answer_text=None, ): self.llama_model.eval() speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len) speech_embeds, speech_masks, _ = self._add_bos_eos(0 +self.speech_token_num, 1+self.speech_token_num, speech_embeds, speech_masks, None) device = speech_embeds.device prompt = self.tokenizer([prompt], return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_embeds = self.embed_tokens(prompt) text_token = self.tokenizer([answer_text], return_tensors="pt" )['input_ids'].to(speech_embeds.device) text_token_embeds = self.embed_tokens(text_token) embeds = torch.cat([prompt_embeds, speech_embeds, text_token_embeds], dim=1) atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device) if self.embed_tokens.weight.dtype == torch.float16: utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16') embeds = embeds.to(torch.float16) atts = atts.half() inputs_embeds = embeds.to(speech_embeds.device) max_len = 150 beam = 3 beam_size = beam running_size = beam output_token = [] hyps = [self.speech_token_num] hyps_text = "" scores = [1.0] llm_out = self.llama_model( inputs_embeds=embeds, past_key_values=None, output_hidden_states=True ) # speech_token_list = speech_token[0] # speech_token_list_len = len(speech_token_list) if speech_token is not None: print(f'speech_token_list_len:{len(speech_token[0])}') print(f'speech_token:{speech_token[0]}') batch_size = 1 top_k = 10 top_p = 0.9 temperature = 1.2 cache = llm_out.past_key_values token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0) repetition_penalty = 1.1 top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) temperatures_tensor = None if not temperature else torch.FloatTensor( [temperature] * batch_size).to(device) is_speech_token = False speech_eos_num = 0 for i in range(max_len): llm_out = self.llama_model( inputs_embeds=token_emb, past_key_values=cache, output_hidden_states=True ) cache = llm_out.past_key_values hidden_states = llm_out.hidden_states[-1] token_logits = self.speech_head(hidden_states) # probs = F.log_softmax(token_logits[:,-1], dim=-1)[0] # if i ==2 or i == 80: # torch.save(probs, f'probs_{i}.pt') next_token_ids = self._sampler( token_logits, temperatures_tensor, top_ps_tensor, top_ks_tensor, ) print(i, next_token_ids, f'is_speech_token:{is_speech_token}') if next_token_ids == self.speech_token_num - 1: print(f'遇到 4096') break hyps.append(next_token_ids.item()) # if 1+i >= len(speech_token[0]): # break # token_emb = self.speech_token_emded(torch.tensor([speech_token[0][i+1]]).to(device)).unsqueeze(0) token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0) res = [] for i in hyps[1:]: # if i != self.speech_token_num-1: res.append(i) print(res) return [answer_text + str(res[2:])] def infer_sample4speech2text_token_teacher_force2( self, wavs, wavs_len, prompt, speech_token=None, answer_text=None, ): self.llama_model.eval() speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len) speech_embeds, speech_masks, _ = self._add_bos_eos(0 +self.speech_token_num, 1+self.speech_token_num , speech_embeds, speech_masks, None) device = speech_embeds.device prompt = self.tokenizer([prompt], return_tensors="pt" )['input_ids'].to(speech_embeds.device) prompt_embeds = self.embed_tokens(prompt) text_token = self.tokenizer([answer_text], return_tensors="pt" )['input_ids'].to(speech_embeds.device) # text_token_embeds = self.embed_tokens(text_token) embeds = torch.cat([prompt_embeds, speech_embeds], dim=1) atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device) if self.embed_tokens.weight.dtype == torch.float16: utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16') embeds = embeds.to(torch.float16) atts = atts.half() inputs_embeds = embeds.to(speech_embeds.device) max_len = 150 beam = 3 beam_size = beam running_size = beam output_token = [] hyps = [self.speech_token_num - 1] hyps_text = "" scores = [1.0] llm_out = self.llama_model( inputs_embeds=embeds, past_key_values=None, output_hidden_states=True ) # speech_token_list = speech_token[0] # speech_token_list_len = len(speech_token_list) if speech_token is not None: print(f'speech_token_list_len:{len(speech_token)}') print(f'speech_token:{speech_token}') batch_size = 1 top_k = 10 top_p = 0.9 temperature = 1.2 cache = llm_out.past_key_values token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0) repetition_penalty = 1.1 top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) temperatures_tensor = None if not temperature else torch.FloatTensor( [temperature] * batch_size).to(device) is_speech_token = False speech_eos_num = 0 token_num = len(speech_token) for i in range(token_num): llm_out = self.llama_model( inputs_embeds=token_emb, past_key_values=cache, output_hidden_states=True ) cache = llm_out.past_key_values hidden_states = llm_out.hidden_states[-1] token_logits = self.speech_head(hidden_states) # probs = F.log_softmax(token_logits[:,-1], dim=-1)[0] # if i ==2 or i == 80: # torch.save(probs, f'probs_{i}.pt') next_token_ids = self._sampler( token_logits, temperatures_tensor, top_ps_tensor, top_ks_tensor, ) print(i, next_token_ids, f'is_speech_token:{is_speech_token}') if next_token_ids == self.speech_token_num - 1: print(f'遇到 4096') break hyps.append(next_token_ids.item()) # if 1+i >= len(speech_token[0]): # break # token_emb = self.speech_token_emded(torch.tensor([speech_token[0][i+1]]).to(device)).unsqueeze(0) token_emb = self.embed_tokens(torch.tensor([speech_token[i]]).to(device)).unsqueeze(0) res = [] for i in hyps[1:]: # if i != self.speech_token_num-1: res.append(i) print(res) return [hyps_text + str(res)]