import copy import torch import torch.nn as nn import torch.nn.functional as F from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer from omni_speech.constants import IGNORE_INDEX torch.autograd.set_detect_anomaly(True) try: import sys sys.path.append('/mnt/lzy/LLaMA-Omni/CosyVoice/') from cosyvoice.cli.cosyvoice import CosyVoice except: print('CosyVoice not found') import os if 'SPEECH_GEN_CONV_KERNEL' in os.environ: SPEECH_GEN_CONV_KERNEL = int(os.environ['SPEECH_GEN_CONV_KERNEL']) print(f'Using SPEECH_GEN_CONV_KERNEL={SPEECH_GEN_CONV_KERNEL}') else: SPEECH_GEN_CONV_KERNEL = -1 if 'DISTILL_EMBEDDING' in os.environ: DISTILL_EMBEDDING = True print(f'DISTILL_EMBEDDING is set.') else: DISTILL_EMBEDDING = False def lengths_to_padding_mask(lens): bsz, max_lens = lens.size(0), torch.max(lens).item() mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) return mask def _uniform_assignment(src_lens, tgt_lens): tgt_indices = torch.arange(torch.max(tgt_lens)).expand(len(tgt_lens), -1).to(tgt_lens.device) ratio = tgt_lens / src_lens index_t = (tgt_indices / ratio.view(-1, 1)).long() return index_t class SpeechGeneratorCTC(nn.Module): def __init__(self, config): super().__init__() n_layers, n_dims, n_heads, n_inter_dims = list(map(int, config.ctc_decoder_config[1:-1].split(","))) _config = copy.deepcopy(config) _config.hidden_size = n_dims _config.num_hidden_layers = n_layers _config.num_attention_heads = n_heads _config.num_key_value_heads = n_heads _config.intermediate_size = n_inter_dims _config._attn_implementation = "flash_attention_2" self.upsample_factor = config.ctc_upsample_factor self.input_proj = nn.Linear(config.hidden_size, n_dims) self.layers = nn.ModuleList( [LlamaDecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] ) self.unit_vocab_size = config.unit_vocab_size self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) def upsample(self, reps, tgt_units=None): src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) up_lens = src_lens * self.upsample_factor if tgt_units is not None: tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) up_lens = torch.max(up_lens, tgt_lens) reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) padding_mask = lengths_to_padding_mask(up_lens) mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill( padding_mask, 0 ) copied_reps = torch.gather( reps, 1, mapped_inputs.unsqueeze(-1).expand( *mapped_inputs.size(), reps.size(-1) ), ) copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) return copied_reps, ~padding_mask, position_ids def forward(self, tgt_reps, labels, tgt_units): tgt_label_reps = [] for tgt_rep, label in zip(tgt_reps, labels): tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) hidden_states = self.input_proj(hidden_states) for layer in self.layers: layer_outputs = layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, ) hidden_states = layer_outputs[0] ctc_logits = self.output_proj(hidden_states) ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) ctc_lens = attention_mask.long().sum(dim=-1) ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens) ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask) ctc_loss = F.ctc_loss( ctc_lprobs.transpose(0, 1), ctc_tgt_flat, ctc_lens, ctc_tgt_lens, reduction="sum", zero_infinity=True, blank=self.unit_vocab_size ) ctc_loss /= ctc_tgt_lens.sum().item() return ctc_loss def predict(self, tgt_reps): hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) hidden_states = self.input_proj(hidden_states) for layer in self.layers: layer_outputs = layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, ) hidden_states = layer_outputs[0] ctc_logits = self.output_proj(hidden_states) ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) return ctc_pred class SpeechGeneratorCTCQwen(nn.Module): def __init__(self, config): super().__init__() n_layers, n_dims, n_heads, n_inter_dims, n_kv_heads = list(map(int, config.ctc_decoder_config[1:-1].split(","))) _config = copy.deepcopy(config) _config.hidden_size = n_dims _config.num_hidden_layers = n_layers _config.num_attention_heads = n_heads _config.num_key_value_heads = n_kv_heads _config.intermediate_size = n_inter_dims _config._attn_implementation = "flash_attention_2" self.upsample_factor = config.ctc_upsample_factor self.input_proj = nn.Linear(config.hidden_size, n_dims) self.layers = nn.ModuleList( [Qwen2DecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] ) self.unit_vocab_size = config.unit_vocab_size self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) if SPEECH_GEN_CONV_KERNEL > 0: self.temporal_conv = nn.Conv1d(n_dims, n_dims, SPEECH_GEN_CONV_KERNEL, padding=0) self.learnable_pad_left = nn.Parameter(torch.zeros(SPEECH_GEN_CONV_KERNEL // 2, n_dims)) self.learnable_pad_right = nn.Parameter(torch.zeros(SPEECH_GEN_CONV_KERNEL // 2, n_dims)) # self.conv_layer_id = n_layers // 2 # Insert temporal conv layer in the middle of the decoder layers def upsample(self, reps, tgt_units=None): src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) up_lens = src_lens * self.upsample_factor if tgt_units is not None: tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) up_lens = torch.max(up_lens, tgt_lens) reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) padding_mask = lengths_to_padding_mask(up_lens) mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill( padding_mask, 0 ) copied_reps = torch.gather( reps, 1, mapped_inputs.unsqueeze(-1).expand( *mapped_inputs.size(), reps.size(-1) ), ) copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) return copied_reps, ~padding_mask, position_ids def forward(self, tgt_reps, labels, tgt_units): tgt_label_reps = [] for tgt_rep, label in zip(tgt_reps, labels): if SPEECH_GEN_CONV_KERNEL > 0: now_rep = tgt_rep[label != IGNORE_INDEX] now_rep = torch.cat([self.learnable_pad_left, now_rep, self.learnable_pad_right], dim=0) now_rep = self.input_proj(now_rep)[None] now_rep = self.temporal_conv(now_rep.transpose(1, 2)).transpose(1, 2)[0] tgt_label_reps.append(now_rep) else: tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) if SPEECH_GEN_CONV_KERNEL < 0: hidden_states = self.input_proj(hidden_states) for layer_id, layer in enumerate(self.layers): # if SPEECH_GEN_CONV_KERNEL: # if layer_id == self.conv_layer_id: # hidden_states = self.temporal_conv(hidden_states.transpose(1, 2)).transpose(1, 2) layer_outputs = layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, ) hidden_states = layer_outputs[0] ctc_logits = self.output_proj(hidden_states) ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) ctc_lens = attention_mask.long().sum(dim=-1) ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens) ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask) ctc_loss = F.ctc_loss( ctc_lprobs.transpose(0, 1), ctc_tgt_flat, ctc_lens, ctc_tgt_lens, reduction="sum", zero_infinity=True, blank=self.unit_vocab_size ) ctc_loss /= ctc_tgt_lens.sum().item() return ctc_loss def predict(self, tgt_reps): hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) hidden_states = self.input_proj(hidden_states) for layer in self.layers: layer_outputs = layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, ) hidden_states = layer_outputs[0] ctc_logits = self.output_proj(hidden_states) ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) return ctc_pred class SpeechGeneratorCEQwen(nn.Module): def __init__(self, config): super().__init__() n_layers, n_dims, n_heads, n_inter_dims, n_kv_heads = list(map(int, config.ctc_decoder_config[1:-1].split(","))) _config = copy.deepcopy(config) _config.hidden_size = n_dims _config.num_hidden_layers = n_layers _config.num_attention_heads = n_heads _config.num_key_value_heads = n_kv_heads _config.intermediate_size = n_inter_dims _config._attn_implementation = "flash_attention_2" self.upsample_factor = 1 self.input_proj = nn.Linear(config.hidden_size, n_dims) self.layers = nn.ModuleList( [Qwen2DecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] ) self.unit_vocab_size = config.unit_vocab_size self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) def upsample(self, reps, tgt_units=None): src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) up_lens = src_lens * self.upsample_factor if tgt_units is not None: tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) up_lens = torch.max(up_lens, tgt_lens) reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) padding_mask = lengths_to_padding_mask(up_lens) mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill( padding_mask, 0 ) copied_reps = torch.gather( reps, 1, mapped_inputs.unsqueeze(-1).expand( *mapped_inputs.size(), reps.size(-1) ), ) copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) return copied_reps, ~padding_mask, position_ids def forward(self, tgt_reps, labels, tgt_units): tgt_label_reps = [] for tgt_rep, label in zip(tgt_reps, labels): # if SPEECH_GEN_CONV_KERNEL > 0: # now_rep = tgt_rep[label != IGNORE_INDEX] # now_rep = torch.cat([self.learnable_pad_left, now_rep, self.learnable_pad_right], dim=0) # now_rep = self.input_proj(now_rep)[None] # now_rep = self.temporal_conv(now_rep.transpose(1, 2)).transpose(1, 2)[0] # tgt_label_reps.append(now_rep) # else: tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) # if SPEECH_GEN_CONV_KERNEL < 0: hidden_states = self.input_proj(hidden_states) for layer_id, layer in enumerate(self.layers): # if SPEECH_GEN_CONV_KERNEL: # if layer_id == self.conv_layer_id: # hidden_states = self.temporal_conv(hidden_states.transpose(1, 2)).transpose(1, 2) layer_outputs = layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, ) hidden_states = layer_outputs[0] shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_states.size(-1)) logits = self.output_proj(shift_hidden_states) shift_labels = tgt_units[..., 1:].contiguous().reshape(-1) assert shift_labels.size(0) == shift_hidden_states.size(0) loss_fct = nn.CrossEntropyLoss() logits = logits.float() loss = loss_fct(logits, shift_labels) # loss = (loss / 1.0).sum().item() # loss = loss.sum().item() return loss # def predict(self, tgt_reps): # hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) # hidden_states = self.input_proj(hidden_states) # for layer in self.layers: # layer_outputs = layer( # hidden_states, # attention_mask=attention_mask, # position_ids=position_ids, # ) # hidden_states = layer_outputs[0] # ctc_logits = self.output_proj(hidden_states) # ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) # ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) # return ctc_pred # class SpeechGeneratorCosyVoice(nn.Module): # def __init__(self, config): # super().__init__() # self.input_proj = nn.Sequential( # nn.Linear(config.hidden_size, 1024), # nn.GELU(), # nn.Linear(1024, 512) # ) # self.cosyvoice1 = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_onnx=False, fp16=False) # self.cosyvoice = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True) # self.llm = self.cosyvoice1.model.llm # if DISTILL_EMBEDDING: # self.criterion = nn.CosineEmbeddingLoss() # def forward(self, tgt_reps, labels, answer): # tgt_label_reps = [] # batch_speech_tokens = [] # embeddings = [] # target_embeddings = [] # if DISTILL_EMBEDDING: # for tgt_rep, label, ans in zip(tgt_reps, labels, answer): # # make all label id in [151644,151645,198] to IGNORE_INDEX # label[label == 151644] = IGNORE_INDEX # label[label == 151645] = IGNORE_INDEX # label[label == 198] = IGNORE_INDEX # tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) # normalized_text = self.cosyvoice1.frontend.text_normalize(ans, split=True) # tts_text_token_all = [] # for norm_text in normalized_text: # tts_text_token, tts_text_token_len = self.cosyvoice1.frontend._extract_text_token(norm_text) # tts_text_token_all.append(tts_text_token) # tts_text_token_all = torch.cat(tts_text_token_all, dim=0) # target_embedding = self.cosyvoice1.model.llm.text_embedding(tts_text_token) # target_embeddings.append(target_embedding) # import pdb;pdb.set_trace() # tgt_label_reps = torch.stack(tgt_label_reps) # target_embeddings = torch.stack(target_embeddings).squeeze(1) # hidden_states = self.input_proj(tgt_label_reps).reshape(-1, 512) # target_embeddings = target_embeddings.reshape(-1, 512) # loss = self.criterion(hidden_states, target_embeddings, torch.ones(hidden_states.size(0)).to(hidden_states.device)) # else: # for tgt_rep, label, ans in zip(tgt_reps, labels, answer): # # make all label id in [151644,151645,198] to IGNORE_INDEX # label[label == 151644] = IGNORE_INDEX # label[label == 151645] = IGNORE_INDEX # label[label == 198] = IGNORE_INDEX # tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) # speech_token = self.cosyvoice.inference_label(ans, '英文女', stream=False) # speech_tokens = [] # for i,j in enumerate(speech_token): # speech_tokens.append(j['tts_speech_token'].squeeze(0)) # speech_tokens.append(torch.tensor([0])) # speech_tokens = torch.cat(speech_tokens, dim=0) # if speech_tokens.size(0) > 1: # speech_tokens = speech_tokens[:-1] # batch_speech_tokens.append(speech_tokens) # embedding = self.cosyvoice.frontend.frontend_embedding('英文女') # embeddings.append(embedding['llm_embedding'].squeeze(0)) # tgt_label_reps = torch.stack(tgt_label_reps) # batch_speech_token = torch.stack(batch_speech_tokens) # embeddings = torch.stack(embeddings) # hidden_states = self.input_proj(tgt_label_reps) # batch = {'text_feature': hidden_states, 'text_token_len': torch.tensor([hidden_states.size(1)]).repeat(hidden_states.size(0)), # 'speech_token': batch_speech_token, 'speech_token_len': torch.tensor([batch_speech_token.size(1)]).repeat(hidden_states.size(0)), # 'embedding': embeddings} # output = self.llm.forward_ours(batch, 'cuda') # loss = output['loss'] # return loss class SpeechGeneratorCosyVoice(nn.Module): def __init__(self, config): super().__init__() self.cosyvoice = CosyVoice('CosyVoice/pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True) def forward(self, tgt_reps, labels, answer): tgt_label_reps = [] batch_speech_tokens = [] embeddings = [] target_embeddings = [] if DISTILL_EMBEDDING: for tgt_rep, label, ans in zip(tgt_reps, labels, answer): # make all label id in [151644,151645,198] to IGNORE_INDEX label[label == 151644] = IGNORE_INDEX label[label == 151645] = IGNORE_INDEX label[label == 198] = IGNORE_INDEX tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) normalized_text = self.cosyvoice1.frontend.text_normalize(ans, split=True) tts_text_token_all = [] for norm_text in normalized_text: tts_text_token, tts_text_token_len = self.cosyvoice1.frontend._extract_text_token(norm_text) tts_text_token_all.append(tts_text_token) tts_text_token_all = torch.cat(tts_text_token_all, dim=0) target_embedding = self.cosyvoice1.model.llm.text_embedding(tts_text_token) target_embeddings.append(target_embedding) import pdb;pdb.set_trace() tgt_label_reps = torch.stack(tgt_label_reps) target_embeddings = torch.stack(target_embeddings).squeeze(1) hidden_states = self.input_proj(tgt_label_reps).reshape(-1, 512) target_embeddings = target_embeddings.reshape(-1, 512) loss = self.criterion(hidden_states, target_embeddings, torch.ones(hidden_states.size(0)).to(hidden_states.device)) else: for tgt_rep, label, ans in zip(tgt_reps, labels, answer): # make all label id in [151644,151645,198] to IGNORE_INDEX label[label == 151644] = IGNORE_INDEX label[label == 151645] = IGNORE_INDEX label[label == 198] = IGNORE_INDEX tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) speech_token = self.cosyvoice.inference_label(ans, '英文女', stream=False) speech_tokens = [] for i,j in enumerate(speech_token): speech_tokens.append(j['tts_speech_token'].squeeze(0)) speech_tokens.append(torch.tensor([0])) speech_tokens = torch.cat(speech_tokens, dim=0) if speech_tokens.size(0) > 1: speech_tokens = speech_tokens[:-1] batch_speech_tokens.append(speech_tokens) embedding = self.cosyvoice.frontend.frontend_embedding('英文女') embeddings.append(embedding['llm_embedding'].squeeze(0)) tgt_label_reps = torch.stack(tgt_label_reps) batch_speech_token = torch.stack(batch_speech_tokens) embeddings = torch.stack(embeddings) hidden_states = self.input_proj(tgt_label_reps) batch = {'text_feature': hidden_states, 'text_token_len': torch.tensor([hidden_states.size(1)]).repeat(hidden_states.size(0)), 'speech_token': batch_speech_token, 'speech_token_len': torch.tensor([batch_speech_token.size(1)]).repeat(hidden_states.size(0)), 'embedding': embeddings} output = self.llm.forward_ours(batch, 'cuda') loss = output['loss'] return loss