Spaces:
Runtime error
Runtime error
| import random | |
| from typing import Dict, Iterator, List, Tuple, Union | |
| from fairseq import utils | |
| import numpy as np | |
| import torch | |
| import math | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from fairseq.data import Dictionary | |
| from src.slam_llm.models.vallex.transformers import ( | |
| LayerNorm, | |
| TransformerEncoder, | |
| TransformerEncoderLayer, | |
| ) | |
| from src.slam_llm.models.vallex.vallex_config import VallexConfig | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers import AutoConfig, AutoModel, AutoModelForImageClassification | |
| from dataclasses import dataclass | |
| class ModelOutput: | |
| logits: torch.Tensor | |
| loss: torch.Tensor | |
| acc: torch.Tensor | |
| def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True, scale=1, prob_mask=None): | |
| if target.dim() == lprobs.dim() - 1: | |
| target = target.unsqueeze(-1) | |
| if prob_mask is not None: | |
| lprobs = lprobs.masked_fill(prob_mask, 0.0) | |
| n_class = (1-prob_mask.float()).sum() | |
| else: | |
| n_class = lprobs.size(-1) | |
| nll_loss = -lprobs.gather(dim=-1, index=target) | |
| # nll_loss = nll_loss * scale | |
| smooth_loss = -lprobs.sum(dim=-1, keepdim=True) * scale | |
| if ignore_index is not None: | |
| pad_mask = target.eq(ignore_index) | |
| nll_loss.masked_fill_(pad_mask, 0.0) | |
| smooth_loss.masked_fill_(pad_mask, 0.0) | |
| pad_mask_float = (1 - pad_mask.to(torch.float)).sum() | |
| else: | |
| nll_loss = nll_loss.squeeze(-1) | |
| smooth_loss = smooth_loss.squeeze(-1) | |
| if reduce: | |
| nll_loss = nll_loss.sum() | |
| smooth_loss = smooth_loss.sum() | |
| eps_i = epsilon / (n_class - 1) | |
| loss = (1.0 - epsilon - eps_i) * nll_loss + \ | |
| eps_i * smooth_loss | |
| return loss / pad_mask_float, nll_loss / pad_mask_float | |
| class SinusoidalPositionalEmbedding(nn.Module): | |
| def __init__(self, embedding_dim, padding_idx, init_size=1024): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.padding_idx = padding_idx if padding_idx is not None else 0 | |
| self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
| init_size, embedding_dim, padding_idx | |
| ) | |
| self.onnx_trace = False | |
| self.register_buffer("_float_tensor", torch.FloatTensor(1)) | |
| self.max_positions = int(1e5) | |
| def prepare_for_onnx_export_(self): | |
| self.onnx_trace = True | |
| def get_embedding( | |
| num_embeddings: int, embedding_dim: int, padding_idx = None | |
| ): | |
| half_dim = embedding_dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) | |
| emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( | |
| 1 | |
| ) * emb.unsqueeze(0) | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( | |
| num_embeddings, -1 | |
| ) | |
| if embedding_dim % 2 == 1: | |
| # zero pad | |
| emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) | |
| if padding_idx is not None: | |
| emb[padding_idx, :] = 0 | |
| return emb | |
| def forward( | |
| self, | |
| input, | |
| incremental_state = None, | |
| timestep = None, | |
| positions = None, | |
| ): | |
| bspair = torch.onnx.operators.shape_as_tensor(input) | |
| bsz, seq_len = bspair[0], bspair[1] | |
| max_pos = self.padding_idx + 1 + seq_len | |
| if self.weights is None or max_pos > self.weights.size(0): | |
| # recompute/expand embeddings if needed | |
| self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
| max_pos, self.embedding_dim, self.padding_idx | |
| ) | |
| self.weights = self.weights.to(self._float_tensor) | |
| if incremental_state is not None: | |
| # positions is the same for every token when decoding a single step | |
| pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len | |
| if self.onnx_trace: | |
| return ( | |
| self.weights.index_select(index=self.padding_idx + pos, dim=0) | |
| .unsqueeze(1) | |
| .repeat(bsz, 1, 1) | |
| ) | |
| return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) | |
| positions = utils.make_positions( | |
| input, self.padding_idx, onnx_trace=self.onnx_trace | |
| ) | |
| if self.onnx_trace: | |
| flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) | |
| embedding_shape = torch.cat( | |
| (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) | |
| ) | |
| embeddings = torch.onnx.operators.reshape_from_tensor_shape( | |
| flat_embeddings, embedding_shape | |
| ) | |
| return embeddings | |
| return ( | |
| self.weights.index_select(0, positions.view(-1)) | |
| .view(bsz, seq_len, -1) | |
| .detach() | |
| ) | |
| class Transpose(nn.Identity): | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return input.transpose(1, 2) | |
| class VALLF(PreTrainedModel): | |
| config_class = VallexConfig | |
| def __init__( | |
| self, | |
| config: VallexConfig | |
| ): | |
| super().__init__(config) | |
| self.ar_at_dict = Dictionary.load(self.config.ar_at_dict) | |
| self.ar_st_dict = Dictionary.load(self.config.ar_st_dict) | |
| self.nar_at_dict = Dictionary.load(self.config.nar_at_dict) | |
| self.nar_st_dict = Dictionary.load(self.config.nar_st_dict) | |
| self.ar_at_dict.tts_flag = self.ar_at_dict.add_symbol("<TTS>") | |
| self.ar_st_dict.asr_flag = self.ar_st_dict.add_symbol("<ASR>") | |
| self.ar_st_dict.mt_flag = self.ar_st_dict.add_symbol("<MT>") | |
| self.padding_idx = self.ar_at_dict.pad() | |
| self.config = config | |
| d_model = self.config.n_dim | |
| nar_scale_factor = self.config.nar_scale_factor | |
| prepend_bos = self.config.prepend_bos | |
| norm_first = self.config.norm_first | |
| num_layers = self.config.n_layer | |
| self.NUM_AUDIO_TOKENS = self.ar_at_dict.eos() | |
| nar_d_model = int(d_model * nar_scale_factor) | |
| self.ar_text_embedding = nn.Embedding(len(self.ar_st_dict), d_model, self.ar_st_dict.pad()) # W_x | |
| if config.only_ar: | |
| pass | |
| else: | |
| self.nar_text_embedding = nn.Embedding(len(self.nar_st_dict), d_model, self.nar_st_dict.pad()) | |
| # ID self.NUM_AUDIO_TOKENS -> PAD | |
| # ID self.NUM_AUDIO_TOKENS + 1 -> BOS | |
| self.ar_audio_prepend_bos = prepend_bos | |
| self.ar_audio_embedding = EncodecDecoderLstm( | |
| dictionary=self.ar_at_dict, emb_dim=d_model | |
| ) | |
| self.ar_text_prenet = nn.Identity() | |
| self.ar_audio_prenet = nn.Identity() | |
| self.ar_text_position = SinusoidalPositionalEmbedding( | |
| d_model, | |
| padding_idx=self.ar_at_dict.pad(), | |
| init_size=1024+self.ar_at_dict.pad()+1 | |
| ) | |
| self.ar_audio_position = SinusoidalPositionalEmbedding( | |
| d_model, | |
| padding_idx=self.ar_at_dict.pad(), | |
| init_size=1024+self.ar_at_dict.pad()+1 | |
| ) | |
| self.ar_decoder = TransformerEncoder( | |
| TransformerEncoderLayer( | |
| d_model, | |
| self.config.n_head, | |
| dim_feedforward=d_model * 4, | |
| dropout=0.1, | |
| batch_first=True, | |
| norm_first=norm_first, | |
| ), | |
| num_layers=num_layers, | |
| norm=LayerNorm(d_model) if norm_first else None, | |
| ) | |
| self.ar_predict_layer = nn.Linear( | |
| d_model, len(self.ar_at_dict), bias=False | |
| ) | |
| self.rng = random.Random(0) | |
| self.num_heads = self.config.n_head | |
| self.prefix_mode = self.config.prefix_mode | |
| self.num_quantizers = self.config.num_quantizers | |
| assert self.num_quantizers >= 1 | |
| if config.only_ar: | |
| pass | |
| else: | |
| if self.num_quantizers > 1: | |
| self.nar_audio_embeddings = NATEncodecDecoderLstm( | |
| codecs=[0, 1, 2, 3, 4, 5, 6, 7], dictionary=self.nar_at_dict, emb_dim=d_model | |
| ) # W_a | |
| self.nar_text_prenet = nn.Identity() | |
| self.nar_audio_prenet = nn.Identity() | |
| self.nar_text_position = SinusoidalPositionalEmbedding( | |
| d_model, | |
| padding_idx=self.nar_at_dict.pad(), | |
| init_size=1024+self.nar_at_dict.pad()+1 | |
| ) | |
| self.nar_audio_position = SinusoidalPositionalEmbedding( | |
| d_model, | |
| padding_idx=self.nar_at_dict.pad(), | |
| init_size=1024+self.nar_at_dict.pad()+1 | |
| ) | |
| self.nar_decoder = TransformerEncoder( | |
| TransformerEncoderLayer( | |
| nar_d_model, | |
| int(self.num_heads * nar_scale_factor), | |
| dim_feedforward=nar_d_model * 4, | |
| dropout=0.1, | |
| batch_first=True, | |
| norm_first=norm_first, | |
| adaptive_layer_norm=True, | |
| ), | |
| num_layers=int(num_layers * nar_scale_factor), | |
| norm=nn.LayerNorm(nar_d_model) | |
| if norm_first | |
| else None, | |
| ) | |
| self.nar_predict_layers = nn.ModuleList( | |
| [ | |
| nn.Linear(nar_d_model, len(self.nar_at_dict), bias=False) | |
| for i in range(self.num_quantizers) | |
| ] | |
| ) | |
| self.nar_stage_embeddings = None | |
| def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: | |
| assert stage > 0 | |
| if stage == 1: | |
| for name, param in self.named_parameters(): | |
| if name.startswith("ar_"): | |
| print(f" AR parameter: {name}") | |
| yield param | |
| if stage == 2: | |
| for name, param in self.named_parameters(): | |
| if name.startswith("nar_"): | |
| print(f"NAR parameter: {name}") | |
| yield param | |
| def stage_named_parameters( | |
| self, stage: int = 1 | |
| ) -> Iterator[Tuple[str, nn.Parameter]]: | |
| assert stage > 0 | |
| if stage == 1: | |
| for pair in self.named_parameters(): | |
| if pair[0].startswith("ar_"): | |
| yield pair | |
| if stage == 2: | |
| for pair in self.named_parameters(): | |
| if pair[0].startswith("nar_"): | |
| yield pair | |
| def pad_y_eos(self, y, y_mask_int, eos_id): | |
| targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( | |
| y_mask_int, (0, 1), value=1 | |
| ) | |
| # inputs, targets | |
| if self.ar_audio_prepend_bos: | |
| return ( | |
| F.pad(targets[:, :-1], (1, 0), value=self.NUM_AUDIO_TOKENS + 1), | |
| targets, | |
| ) | |
| return targets[:, :-1], targets[:, 1:] | |
| class VALLE(VALLF): | |
| config_class = VallexConfig | |
| def __init__( | |
| self, | |
| config: VallexConfig, | |
| **kwargs, | |
| ): | |
| super(VALLE, self).__init__( | |
| config, | |
| **kwargs, | |
| ) | |
| print(config) | |
| self.config = config | |
| d_model = self.config.n_dim | |
| self.eps = config.eps | |
| self.language_ID = { | |
| 'en': 0, | |
| 'zh': 1, | |
| } | |
| self.ar_language_embedding = nn.Embedding(3, d_model, padding_idx=2) | |
| self.nar_language_embedding = nn.Embedding(3, d_model, padding_idx=2) | |
| self.embed_scale = 32.0 | |
| def forward( | |
| self, | |
| zh, | |
| en | |
| ): | |
| """ | |
| "zh": { | |
| "st_tokens": zh_st, | |
| "at_tokens_wbos": zh_prev_at, | |
| "at_tokens_tgt": zh_tgt_at, | |
| "self_atten_mask": zh_self_atten_mask, | |
| "padding_mask": zh_padding_mask, | |
| "langid": zh_id.long() | |
| }, | |
| "en": { | |
| "st_tokens": en_st, | |
| "at_tokens_wbos": en_prev_at, | |
| "at_tokens_tgt": en_tgt_at, | |
| "self_atten_mask": en_self_atten_mask, | |
| "padding_mask": en_padding_mask, | |
| "langid": en_id.long() | |
| } | |
| """ | |
| flag = (np.random.randint(low=0, high=1000) % 2 == 0) # zh or en | |
| if flag: | |
| data = zh | |
| else: | |
| data = en | |
| st_tokens = data["st_tokens"] | |
| at_tokens_wbos = data["at_tokens_wbos"] | |
| at_tokens_tgt = data["at_tokens_tgt"] | |
| self_atten_mask = data["self_atten_mask"] | |
| padding_mask = data["padding_mask"] | |
| langid = data["langid"] | |
| st_len = st_tokens.size(1) | |
| st_emb = self.embed_scale * self.ar_text_embedding(st_tokens) | |
| src_lang_emb = self.embed_scale * self.ar_language_embedding(langid) | |
| st_emb += src_lang_emb | |
| st_pos = self.ar_text_position(st_tokens) | |
| st_emb += st_pos | |
| at_emb, _ = self.ar_audio_embedding(at_tokens_wbos, None) | |
| at_emb = self.embed_scale * at_emb | |
| tgt_lang_emb = self.embed_scale * self.ar_language_embedding(langid) | |
| at_emb += tgt_lang_emb | |
| at_pos = self.ar_audio_position(at_tokens_wbos) | |
| at_emb += at_pos | |
| x = torch.concat([st_emb, at_emb], dim=1) | |
| x = self.ar_decoder( | |
| x, | |
| mask=self_atten_mask, | |
| src_key_padding_mask=padding_mask | |
| ) | |
| x = self.ar_predict_layer(x) | |
| x = x[:, st_len:, :] | |
| loss, nll_loss, lprob, right_rate = self.calculate_loss( | |
| x, at_tokens_tgt | |
| ) | |
| return ModelOutput(logits=lprob, loss=loss, acc=right_rate), right_rate | |
| def calculate_loss(self, encoder_out, target, reduce=True, scale=1.0, prob_mask=None, acc=True): | |
| lprob = self.get_normalized_probs(encoder_out, log_probs=True) | |
| with torch.no_grad(): | |
| mask = target.ne(self.padding_idx) | |
| n_correct = torch.sum( | |
| lprob.argmax(-1).masked_select(mask).eq(target.masked_select(mask)) | |
| ) | |
| total = torch.sum(mask) | |
| right_rate = n_correct * 100.0 / total | |
| lprob, target = lprob.view(-1, lprob.size(-1)), target.view(-1) | |
| loss, nll_loss = label_smoothed_nll_loss( | |
| lprob, | |
| target, | |
| self.eps, | |
| ignore_index=self.padding_idx, | |
| reduce=reduce, | |
| scale=scale, | |
| prob_mask=prob_mask | |
| ) | |
| return loss, nll_loss, lprob, right_rate | |
| def get_normalized_probs(self, encoder_out, log_probs, sample=None): | |
| if torch.is_tensor(encoder_out): | |
| logits = encoder_out.float() | |
| if log_probs: | |
| return F.log_softmax(logits, dim=-1) | |
| else: | |
| return F.softmax(logits, dim=-1) | |
| def inference_24L( | |
| self, | |
| x: torch.Tensor, | |
| x_lens: torch.Tensor, | |
| y: torch.Tensor, | |
| enroll_x_lens: torch.Tensor, | |
| top_k: int = -100, | |
| temperature: float = 1.0, | |
| prompt_language: str = None, | |
| text_language: str = None, | |
| best_of: int = 1, | |
| length_penalty: float = 1.0, | |
| return_worst: bool = False, | |
| at_eos: int = -1 | |
| ) -> torch.Tensor: | |
| assert x.ndim == 2, x.shape | |
| assert x_lens.ndim == 1, x_lens.shape | |
| assert y.ndim == 3, y.shape | |
| assert y.shape[0] == 1, y.shape | |
| assert torch.all(x_lens > 0) | |
| self.NUM_AUDIO_TOKENS = at_eos | |
| text = x | |
| x = self.embed_scale * self.ar_text_embedding(text) | |
| prompt_language_id = prompt_language.to(x.device) | |
| text_language_id = text_language.to(x.device) | |
| src_lang_emb = self.embed_scale * self.ar_language_embedding(prompt_language_id) | |
| tgt_lang_emb = self.embed_scale * self.ar_language_embedding(text_language_id) | |
| x[:, :enroll_x_lens, :] += src_lang_emb | |
| x[:, enroll_x_lens:, :] += tgt_lang_emb | |
| x = self.ar_text_prenet(x) | |
| x_pos = self.ar_text_position(text) | |
| text_len = x_lens.max() | |
| prompts = y | |
| prefix_len = y.shape[1] | |
| # AR Decoder | |
| # TODO: Managing decoder steps avoid repetitive computation | |
| y = prompts[..., 0] | |
| if self.ar_audio_prepend_bos: | |
| y = F.pad(y, (1, 0), value=self.ar_at_dict.tts_flag) | |
| x_len = x_lens.max() | |
| x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) | |
| kv_cache = None | |
| use_kv_caching = True | |
| sum_logprobs = torch.zeros(best_of, device=y.device) # implement batch decoding here | |
| x = x.repeat(best_of, 1, 1) | |
| y = y.repeat(best_of, 1) | |
| lstm_h = None | |
| first_ar = True | |
| while True: | |
| if first_ar: | |
| y_emb, lstm_h = self.ar_audio_embedding(y, lstm_h) | |
| y_emb = y_emb * self.embed_scale | |
| y_emb = self.ar_audio_prenet(y_emb) | |
| y_pos = self.ar_audio_position(y) | |
| y_emb[:, :prefix_len] = y_emb[:, :prefix_len] + src_lang_emb | |
| y_emb[:, prefix_len:] = y_emb[:, prefix_len:] + tgt_lang_emb | |
| xy_pos_token = torch.concat([x_pos+x, y_pos+y_emb], dim=1) | |
| first_ar = False | |
| else: | |
| y_emb_cur, lstm_h = self.ar_audio_embedding(y[:, -1:], lstm_h) | |
| y_emb_cur = y_emb_cur * self.embed_scale | |
| y_emb_cur = self.ar_audio_prenet(y_emb_cur) | |
| y_pos_cur = self.ar_audio_position(y)[:, -1:] | |
| y_emb_cur = y_emb_cur + src_lang_emb | |
| y_emb_cur = y_emb_cur + tgt_lang_emb | |
| xy_pos_token = torch.concat([xy_pos_token, y_pos_cur+y_emb_cur], dim=1) | |
| # print(xy_pos_token.size()) | |
| y_len = y.shape[1] | |
| x_attn_mask_pad = F.pad( | |
| x_attn_mask, | |
| (0, y_len), | |
| value=True, | |
| ) | |
| y_attn_mask = F.pad( | |
| torch.triu( | |
| torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1 | |
| ), | |
| (x_len, 0), | |
| value=False, | |
| ) | |
| xy_attn_mask = torch.concat( | |
| [x_attn_mask_pad, y_attn_mask], dim=0 | |
| ).to(y.device) | |
| if use_kv_caching and kv_cache is not None: | |
| xy_pos = xy_pos_token[:, [-1]] | |
| xy_attn_mask = xy_attn_mask[:, [-1]] | |
| else: | |
| xy_pos = xy_pos_token | |
| xy_dec, kv_cache = self.ar_decoder.infer( | |
| xy_pos, | |
| mask=xy_attn_mask, | |
| past_kv=kv_cache, | |
| use_cache=use_kv_caching, | |
| ) | |
| logits = self.ar_predict_layer(xy_dec[:, -1]) | |
| samples, current_logprobs = topk_sampling( | |
| logits, top_k=top_k, top_p=1, temperature=temperature | |
| ) | |
| sum_logprobs += current_logprobs * (y[:, -1] != self.NUM_AUDIO_TOKENS) | |
| samples[y[:, -1] == self.NUM_AUDIO_TOKENS] = self.NUM_AUDIO_TOKENS | |
| completed = (samples[:, -1] == self.NUM_AUDIO_TOKENS).all() | |
| if ( | |
| completed | |
| or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 32 | |
| ): | |
| if prompts.shape[1] == y.shape[1]: | |
| raise SyntaxError( | |
| "well trained model shouldn't reach here." | |
| ) | |
| lengths = torch.sum(y != self.NUM_AUDIO_TOKENS, dim=1) | |
| avg_logprobs = sum_logprobs / lengths ** length_penalty | |
| # choose the best beam according to sum_logprobs | |
| best_beam = y[torch.argmax(avg_logprobs), :] | |
| worst_beam = y[torch.argmin(avg_logprobs), :] | |
| # strip all eos tokens | |
| best_beam = best_beam[best_beam != self.NUM_AUDIO_TOKENS] | |
| worst_beam = worst_beam[worst_beam != self.NUM_AUDIO_TOKENS] | |
| if return_worst: | |
| y = worst_beam.unsqueeze(0) | |
| else: | |
| y = best_beam.unsqueeze(0) | |
| print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]") | |
| break | |
| y = torch.concat([y, samples], dim=1) | |
| codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] | |
| if self.num_quantizers == 1: | |
| return torch.stack(codes, dim=-1) | |
| if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes | |
| enrolled_len = enroll_x_lens.max().item() | |
| # SOS + Synthesis Text + EOS | |
| text = torch.concat( | |
| [ | |
| text[:, :1], | |
| text[:, enrolled_len - 1 :], | |
| ], | |
| dim=1, | |
| ) | |
| text_len = text_len - (enrolled_len - 2) | |
| assert text.shape[0] == 1 | |
| x = self.embed_scale * self.nar_text_embedding(text) | |
| # Add language embedding | |
| prompt_language_id = prompt_language.to(x.device) | |
| text_language_id = text_language.to(x.device) | |
| src_lang_emb = self.embed_scale * self.nar_language_embedding(prompt_language_id) | |
| tgt_lang_emb = self.embed_scale * self.nar_language_embedding(text_language_id) | |
| x[:, :enroll_x_lens, :] += src_lang_emb | |
| x[:, enroll_x_lens:, :] += tgt_lang_emb | |
| x = self.nar_text_prenet(x) | |
| x_pos = self.nar_text_position(text) | |
| if self.prefix_mode == 0: | |
| for i, predict_layer in enumerate( | |
| self.nar_predict_layers | |
| ): | |
| y_pos = self.nar_audio_prenet(y_emb) | |
| y_pos = self.nar_audio_position(y_pos) | |
| xy_pos = torch.concat([x, y_pos], dim=1) | |
| xy_dec, _ = self.nar_decoder( | |
| (xy_pos, self.nar_stage_embeddings[i].weight) | |
| ) | |
| logits = predict_layer(xy_dec[:, text_len + prefix_len :]) | |
| samples = torch.argmax(logits, dim=-1) | |
| codes.append(samples) | |
| if i < self.num_quantizers - 2: | |
| y_emb[:, :prefix_len] += self.embed_scale * self.nar_audio_embeddings( | |
| prompts[..., i + 1] | |
| )[0] | |
| y_emb[:, prefix_len:] += self.embed_scale * self.nar_audio_embeddings(samples)[0] | |
| else: | |
| y_pos = self.nar_audio_position(y[:, int(self.ar_audio_prepend_bos):]) | |
| ref_at_emb = self.embed_scale * self.nar_audio_embeddings(prompts)[0] + src_lang_emb | |
| est_at = y[:, prefix_len+int(self.ar_audio_prepend_bos):].unsqueeze(-1) | |
| # | |
| for i in range(1, 8): | |
| y_emb, _ = self.nar_audio_embeddings(est_at) | |
| y_emb = self.embed_scale * y_emb + tgt_lang_emb | |
| y_emb = torch.concat([ref_at_emb, y_emb], dim=1) | |
| xy_pos = torch.concat([x+x_pos, y_emb+y_pos], dim=1) | |
| xy_dec = self.nar_decoder( | |
| xy_pos | |
| ) | |
| logits = self.nar_predict_layers[i-1](xy_dec[:, text_len + prefix_len :]) | |
| # print(logits.size(), xy_pos.size(), xy_dec.size()) | |
| samples = torch.argmax(logits, dim=-1) | |
| est_at = torch.concat([est_at, samples.unsqueeze(-1)], dim=-1) | |
| codes.append(samples) | |
| assert len(codes) == self.num_quantizers | |
| return torch.stack(codes, dim=-1) | |
| def top_k_top_p_filtering( | |
| logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 | |
| ): | |
| if top_k > 0: | |
| top_k = min( | |
| max(top_k, min_tokens_to_keep), logits.size(-1) | |
| ) # Safety check | |
| # Remove all tokens with a probability less than the last token of the top-k | |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
| logits[indices_to_remove] = filter_value | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum( | |
| F.softmax(sorted_logits, dim=-1), dim=-1 | |
| ) | |
| # Remove tokens with cumulative probability above the threshold (token with 0 are kept) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| if min_tokens_to_keep > 1: | |
| # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | |
| sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 | |
| # Shift the indices to the right to keep also the first token above the threshold | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |
| ..., :-1 | |
| ].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| # scatter sorted tensors to original indexing | |
| indices_to_remove = sorted_indices_to_remove.scatter( | |
| 1, sorted_indices, sorted_indices_to_remove | |
| ) | |
| logits[indices_to_remove] = filter_value | |
| return logits | |
| def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): | |
| if temperature != 1.0: | |
| logits = logits / temperature | |
| # Top-p/top-k filtering | |
| logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) | |
| # Sample | |
| token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) | |
| logprobs = F.log_softmax(logits.float(), dim=-1) | |
| current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)] | |
| return token, current_logprobs | |
| class SLSTM(nn.Module): | |
| def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True, bidirectional=False): | |
| super().__init__() | |
| self.skip = skip | |
| self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=bidirectional) | |
| if bidirectional: | |
| self.out_fc = nn.Linear(dimension*2, dimension) | |
| else: | |
| self.out_fc = None | |
| def forward(self, x, hidden=None): | |
| x = x.permute(2, 0, 1) | |
| y, hidden = self.lstm(x, hidden) | |
| if self.out_fc is not None: | |
| y = self.out_fc(y) | |
| if self.skip: | |
| y = y + x | |
| y = y.permute(1, 2, 0) | |
| return y, hidden | |
| class EncodecDecoderLstm(nn.Module): | |
| def __init__(self, dictionary, emb_dim, | |
| out_dim=None, | |
| num_layers=3, lstm_skip=True, lstm_bidire=False, | |
| activation_param={'alpha': 1.0}, **kwargs): | |
| super().__init__() | |
| # Identity() | |
| if out_dim is None: | |
| out_dim = emb_dim | |
| self.slstm = SLSTM(dimension=out_dim, num_layers=num_layers, skip=lstm_skip, bidirectional=lstm_bidire) | |
| self.elu = nn.ELU(**activation_param) | |
| self.embedding_dim = emb_dim | |
| self.padding_idx = dictionary.pad() | |
| self.emb = nn.Embedding(len(dictionary), emb_dim, dictionary.pad_index) | |
| def forward(self, x, hidden=None): | |
| """ | |
| Args: | |
| x (_type_): B,T,D | |
| """ | |
| # print(x.size()) | |
| quantized_out = self.emb(x) | |
| out, hidden = self.slstm(quantized_out.permute(0,2,1), hidden) | |
| out = self.elu(out) | |
| return out.permute(0,2,1), hidden | |
| class NATEncodecDecoderLstm(nn.Module): | |
| def __init__(self, codecs, dictionary, emb_dim, | |
| out_dim=None, | |
| num_layers=3, lstm_skip=True, lstm_bidire=False, | |
| activation_param={'alpha': 1.0}, **kwargs): | |
| super().__init__() | |
| # Identity() | |
| if out_dim is None: | |
| out_dim = emb_dim | |
| self.slstm = SLSTM(dimension=out_dim, num_layers=num_layers, skip=lstm_skip, bidirectional=lstm_bidire) | |
| self.elu = nn.ELU(**activation_param) | |
| self.codecs = codecs | |
| self.embedding_dim = emb_dim | |
| self.padding_idx = dictionary.pad() | |
| self.emb_list = nn.ModuleList( | |
| [nn.Embedding(len(dictionary), emb_dim, dictionary.pad_index) for i in range(len(self.codecs))] | |
| ) | |
| def forward(self, x, hidden=None): | |
| """ | |
| Args: | |
| x (_type_): B,T,D | |
| """ | |
| if len(x.size()) == 2: | |
| x = x.unsqueeze(-1) | |
| if x.size(2) != len(self.codecs) and x.size(1) == len(self.codecs): | |
| x = x.permute(0, 2, 1) | |
| quantized_out = 0 | |
| for i in range(x.size(2)): | |
| quantized = self.emb_list[i](x[: , :, i]) | |
| quantized_out = quantized_out + quantized | |
| # quantized_out = quantized_out / len(self.codecs) | |
| out, hidden = self.slstm(quantized_out.permute(0,2,1), hidden) | |
| out = self.elu(out) | |
| return out.permute(0,2,1), hidden | |
| AutoModel.register(VallexConfig, VALLE) |