# Based on code from: https://github.com/zhenye234/xcodec # Licensed under MIT License # Modifications by BosonAI import math import os import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Union, Sequence import numpy as np from transformers import AutoModel import torchaudio import json import librosa from huggingface_hub import snapshot_download from vector_quantize_pytorch import ResidualFSQ from .descriptaudiocodec.dac.model import dac as dac2 from .quantization.vq import ResidualVectorQuantizer from .semantic_module import Encoder, Decoder class EncodedResult: def __init__(self, audio_codes): self.audio_codes = audio_codes class HiggsAudioFeatureExtractor(nn.Module): def __init__(self, sampling_rate=16000): super().__init__() self.sampling_rate = sampling_rate def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"): # Convert from librosa to torch audio_signal = torch.tensor(raw_audio) audio_signal = audio_signal.unsqueeze(0) if len(audio_signal.shape) < 3: audio_signal = audio_signal.unsqueeze(0) return {"input_values": audio_signal} class HiggsAudioTokenizer(nn.Module): def __init__( self, n_filters: int = 32, D: int = 128, target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6], ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320 sample_rate: int = 16000, bins: int = 1024, n_q: int = 8, codebook_dim: int = None, normalize: bool = False, causal: bool = False, semantic_techer: str = "hubert_base_general", last_layer_semantic: bool = True, merge_mode: str = "concat", downsample_mode: str = "step_down", semantic_mode: str = "classic", vq_scale: int = 1, semantic_sample_rate: int = None, device: str = "cuda", ): super().__init__() self.hop_length = np.prod(ratios) self.semantic_techer = semantic_techer self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz self.target_bandwidths = target_bandwidths self.n_q = n_q self.sample_rate = sample_rate self.encoder = dac2.Encoder(64, ratios, D) self.decoder_2 = dac2.Decoder(D, 1024, ratios) self.last_layer_semantic = last_layer_semantic self.device = device if semantic_techer == "hubert_base": self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960") self.semantic_sample_rate = 16000 self.semantic_dim = 768 self.encoder_semantic_dim = 768 elif semantic_techer == "wavlm_base_plus": self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus") self.semantic_sample_rate = 16000 self.semantic_dim = 768 self.encoder_semantic_dim = 768 elif semantic_techer == "hubert_base_general": self.semantic_model = AutoModel.from_pretrained("ZhenYe234/hubert_base_general_audio") self.semantic_sample_rate = 16000 self.semantic_dim = 768 self.encoder_semantic_dim = 768 # Overwrite semantic model sr to ensure semantic_downsample_factor is an integer if semantic_sample_rate is not None: self.semantic_sample_rate = semantic_sample_rate self.semantic_model.eval() # make the semantic model parameters do not need gradient for param in self.semantic_model.parameters(): param.requires_grad = False self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320) self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale) self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim) self.decoder_semantic = Decoder( code_dim=self.encoder_semantic_dim, output_channels=self.semantic_dim, decode_channels=self.semantic_dim, ) # out_D=D+768 if isinstance(bins, int): # RVQ self.quantizer = ResidualVectorQuantizer( dimension=self.quantizer_dim, codebook_dim=codebook_dim, n_q=n_q, bins=bins, ) self.quantizer_type = "RVQ" else: # RFSQ self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q) self.quantizer_type = "RFSQ" self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim) self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim) self.fc_post2 = nn.Linear(self.quantizer_dim, D) self.downsample_mode = downsample_mode if downsample_mode == "avg": self.semantic_pooling = nn.AvgPool1d( kernel_size=self.semantic_downsample_factor, stride=self.semantic_downsample_factor, ) self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate) @property def tps(self): return self.frame_rate @property def sampling_rate(self): return self.sample_rate @property def num_codebooks(self): return self.n_q @property def codebook_size(self): return self.quantizer_dim def get_last_layer(self): return self.decoder.layers[-1].weight def calculate_rec_loss(self, rec, target): target = target / target.norm(dim=-1, keepdim=True) rec = rec / rec.norm(dim=-1, keepdim=True) rec_loss = (1 - (target * rec).sum(-1)).mean() return rec_loss @torch.no_grad() def get_regress_target(self, x): x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate) if ( self.semantic_techer == "hubert_base" or self.semantic_techer == "hubert_base_general" or self.semantic_techer == "wavlm_base_plus" ): x = x[:, 0, :] x = F.pad(x, (160, 160)) target = self.semantic_model(x, output_hidden_states=True).hidden_states target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2) # average for all layers target = target.mean(1) # target = target[9] # if self.hop_length > 320: # target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2) elif self.semantic_techer == "w2v_bert2": target = self.semantic_model(x) elif self.semantic_techer.startswith("whisper"): if self.last_layer_semantic: target = self.semantic_model(x, avg_layers=False) else: target = self.semantic_model(x, avg_layers=True) elif self.semantic_techer.startswith("mert_music"): if self.last_layer_semantic: target = self.semantic_model(x, avg_layers=False) else: target = self.semantic_model(x, avg_layers=True) elif self.semantic_techer.startswith("qwen_audio_omni"): target = self.semantic_model(x) if self.downsample_mode == "step_down": if self.semantic_downsample_factor > 1: target = target[:, :: self.semantic_downsample_factor, :] elif self.downsample_mode == "avg": target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2) return target def forward(self, x: torch.Tensor, bw: int): e_semantic_input = self.get_regress_target(x).detach() e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) e_acoustic = self.encoder(x) e = torch.cat([e_acoustic, e_semantic], dim=1) e = self.fc_prior(e.transpose(1, 2)) if self.quantizer_type == "RVQ": e = e.transpose(1, 2) quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) quantized = quantized.transpose(1, 2) else: quantized, codes = self.quantizer(e) commit_loss = torch.tensor(0.0) quantized_semantic = self.fc_post1(quantized).transpose(1, 2) quantized_acoustic = self.fc_post2(quantized).transpose(1, 2) o = self.decoder_2(quantized_acoustic) o_semantic = self.decoder_semantic(quantized_semantic) semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic) return o, commit_loss, semantic_recon_loss, None def encode( self, audio_path_or_wv, sr=None, loudness_normalize=False, loudness_threshold=-23.0, ): if isinstance(audio_path_or_wv, str): wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None) else: wv = audio_path_or_wv assert sr is not None if loudness_normalize: import pyloudnorm as pyln meter = pyln.Meter(sr) l = meter.integrated_loudness(wv) wv = pyln.normalize.loudness(wv, l, loudness_threshold) if sr != self.sampling_rate: wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate) if self.audio_tokenizer_feature_extractor is not None: inputs = self.audio_tokenizer_feature_extractor( raw_audio=wv, sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate, return_tensors="pt", ) input_values = inputs["input_values"].to(self.device) else: input_values = torch.from_numpy(wv).float().unsqueeze(0) with torch.no_grad(): encoder_outputs = self._xcodec_encode(input_values) vq_code = encoder_outputs.audio_codes[0] return vq_code def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor: bw = target_bw e_semantic_input = self.get_regress_target(x).detach() e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) e_acoustic = self.encoder(x) if e_acoustic.shape[2] != e_semantic.shape[2]: pad_size = 160 * self.semantic_downsample_factor e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0)) if e_acoustic.shape[2] != e_semantic.shape[2]: if e_acoustic.shape[2] > e_semantic.shape[2]: e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]] else: e_semantic = e_semantic[:, :, : e_acoustic.shape[2]] e = torch.cat([e_acoustic, e_semantic], dim=1) e = self.fc_prior(e.transpose(1, 2)) if self.quantizer_type == "RVQ": e = e.transpose(1, 2) quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) codes = codes.permute(1, 0, 2) else: quantized, codes = self.quantizer(e) codes = codes.permute(0, 2, 1) # return codes return EncodedResult(codes) def decode(self, vq_code: torch.Tensor) -> torch.Tensor: if self.quantizer_type == "RVQ": vq_code = vq_code.permute(1, 0, 2) quantized = self.quantizer.decode(vq_code) quantized = quantized.transpose(1, 2) else: vq_code = vq_code.permute(0, 2, 1) quantized = self.quantizer.get_output_from_indices(vq_code) quantized_acoustic = self.fc_post2(quantized).transpose(1, 2) o = self.decoder_2(quantized_acoustic) return o.cpu().numpy() def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"): is_local = os.path.exists(tokenizer_name_or_path) if not is_local: tokenizer_path = snapshot_download(tokenizer_name_or_path) else: tokenizer_path = tokenizer_name_or_path config_path = os.path.join(tokenizer_path, "config.json") model_path = os.path.join(tokenizer_path, "model.pth") config = json.load(open(config_path)) model = HiggsAudioTokenizer( **config, device=device, ) parameter_dict = torch.load(model_path, map_location=device) model.load_state_dict(parameter_dict, strict=False) model.to(device) model.eval() return model