Spaces:
Running
on
Zero
Running
on
Zero
# Based on code from: https://github.com/zhenye234/xcodec | |
# Licensed under MIT License | |
# Modifications by BosonAI | |
import math | |
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Optional, Union, Sequence | |
import numpy as np | |
from transformers import AutoModel | |
import torchaudio | |
import json | |
import librosa | |
from huggingface_hub import snapshot_download | |
from vector_quantize_pytorch import ResidualFSQ | |
from .descriptaudiocodec.dac.model import dac as dac2 | |
from .quantization.vq import ResidualVectorQuantizer | |
from .semantic_module import Encoder, Decoder | |
class EncodedResult: | |
def __init__(self, audio_codes): | |
self.audio_codes = audio_codes | |
class HiggsAudioFeatureExtractor(nn.Module): | |
def __init__(self, sampling_rate=16000): | |
super().__init__() | |
self.sampling_rate = sampling_rate | |
def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"): | |
# Convert from librosa to torch | |
audio_signal = torch.tensor(raw_audio) | |
audio_signal = audio_signal.unsqueeze(0) | |
if len(audio_signal.shape) < 3: | |
audio_signal = audio_signal.unsqueeze(0) | |
return {"input_values": audio_signal} | |
class HiggsAudioTokenizer(nn.Module): | |
def __init__( | |
self, | |
n_filters: int = 32, | |
D: int = 128, | |
target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6], | |
ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320 | |
sample_rate: int = 16000, | |
bins: int = 1024, | |
n_q: int = 8, | |
codebook_dim: int = None, | |
normalize: bool = False, | |
causal: bool = False, | |
semantic_techer: str = "hubert_base_general", | |
last_layer_semantic: bool = True, | |
merge_mode: str = "concat", | |
downsample_mode: str = "step_down", | |
semantic_mode: str = "classic", | |
vq_scale: int = 1, | |
semantic_sample_rate: int = None, | |
device: str = "cuda", | |
): | |
super().__init__() | |
self.hop_length = np.prod(ratios) | |
self.semantic_techer = semantic_techer | |
self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz | |
self.target_bandwidths = target_bandwidths | |
self.n_q = n_q | |
self.sample_rate = sample_rate | |
self.encoder = dac2.Encoder(64, ratios, D) | |
self.decoder_2 = dac2.Decoder(D, 1024, ratios) | |
self.last_layer_semantic = last_layer_semantic | |
self.device = device | |
if semantic_techer == "hubert_base": | |
self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960") | |
self.semantic_sample_rate = 16000 | |
self.semantic_dim = 768 | |
self.encoder_semantic_dim = 768 | |
elif semantic_techer == "wavlm_base_plus": | |
self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus") | |
self.semantic_sample_rate = 16000 | |
self.semantic_dim = 768 | |
self.encoder_semantic_dim = 768 | |
elif semantic_techer == "hubert_base_general": | |
self.semantic_model = AutoModel.from_pretrained("ZhenYe234/hubert_base_general_audio") | |
self.semantic_sample_rate = 16000 | |
self.semantic_dim = 768 | |
self.encoder_semantic_dim = 768 | |
# Overwrite semantic model sr to ensure semantic_downsample_factor is an integer | |
if semantic_sample_rate is not None: | |
self.semantic_sample_rate = semantic_sample_rate | |
self.semantic_model.eval() | |
# make the semantic model parameters do not need gradient | |
for param in self.semantic_model.parameters(): | |
param.requires_grad = False | |
self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320) | |
self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale) | |
self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim) | |
self.decoder_semantic = Decoder( | |
code_dim=self.encoder_semantic_dim, | |
output_channels=self.semantic_dim, | |
decode_channels=self.semantic_dim, | |
) | |
# out_D=D+768 | |
if isinstance(bins, int): # RVQ | |
self.quantizer = ResidualVectorQuantizer( | |
dimension=self.quantizer_dim, | |
codebook_dim=codebook_dim, | |
n_q=n_q, | |
bins=bins, | |
) | |
self.quantizer_type = "RVQ" | |
else: # RFSQ | |
self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q) | |
self.quantizer_type = "RFSQ" | |
self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim) | |
self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim) | |
self.fc_post2 = nn.Linear(self.quantizer_dim, D) | |
self.downsample_mode = downsample_mode | |
if downsample_mode == "avg": | |
self.semantic_pooling = nn.AvgPool1d( | |
kernel_size=self.semantic_downsample_factor, | |
stride=self.semantic_downsample_factor, | |
) | |
self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate) | |
def tps(self): | |
return self.frame_rate | |
def sampling_rate(self): | |
return self.sample_rate | |
def num_codebooks(self): | |
return self.n_q | |
def codebook_size(self): | |
return self.quantizer_dim | |
def get_last_layer(self): | |
return self.decoder.layers[-1].weight | |
def calculate_rec_loss(self, rec, target): | |
target = target / target.norm(dim=-1, keepdim=True) | |
rec = rec / rec.norm(dim=-1, keepdim=True) | |
rec_loss = (1 - (target * rec).sum(-1)).mean() | |
return rec_loss | |
def get_regress_target(self, x): | |
x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate) | |
if ( | |
self.semantic_techer == "hubert_base" | |
or self.semantic_techer == "hubert_base_general" | |
or self.semantic_techer == "wavlm_base_plus" | |
): | |
x = x[:, 0, :] | |
x = F.pad(x, (160, 160)) | |
target = self.semantic_model(x, output_hidden_states=True).hidden_states | |
target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2) | |
# average for all layers | |
target = target.mean(1) | |
# target = target[9] | |
# if self.hop_length > 320: | |
# target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2) | |
elif self.semantic_techer == "w2v_bert2": | |
target = self.semantic_model(x) | |
elif self.semantic_techer.startswith("whisper"): | |
if self.last_layer_semantic: | |
target = self.semantic_model(x, avg_layers=False) | |
else: | |
target = self.semantic_model(x, avg_layers=True) | |
elif self.semantic_techer.startswith("mert_music"): | |
if self.last_layer_semantic: | |
target = self.semantic_model(x, avg_layers=False) | |
else: | |
target = self.semantic_model(x, avg_layers=True) | |
elif self.semantic_techer.startswith("qwen_audio_omni"): | |
target = self.semantic_model(x) | |
if self.downsample_mode == "step_down": | |
if self.semantic_downsample_factor > 1: | |
target = target[:, :: self.semantic_downsample_factor, :] | |
elif self.downsample_mode == "avg": | |
target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2) | |
return target | |
def forward(self, x: torch.Tensor, bw: int): | |
e_semantic_input = self.get_regress_target(x).detach() | |
e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) | |
e_acoustic = self.encoder(x) | |
e = torch.cat([e_acoustic, e_semantic], dim=1) | |
e = self.fc_prior(e.transpose(1, 2)) | |
if self.quantizer_type == "RVQ": | |
e = e.transpose(1, 2) | |
quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) | |
quantized = quantized.transpose(1, 2) | |
else: | |
quantized, codes = self.quantizer(e) | |
commit_loss = torch.tensor(0.0) | |
quantized_semantic = self.fc_post1(quantized).transpose(1, 2) | |
quantized_acoustic = self.fc_post2(quantized).transpose(1, 2) | |
o = self.decoder_2(quantized_acoustic) | |
o_semantic = self.decoder_semantic(quantized_semantic) | |
semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic) | |
return o, commit_loss, semantic_recon_loss, None | |
def encode( | |
self, | |
audio_path_or_wv, | |
sr=None, | |
loudness_normalize=False, | |
loudness_threshold=-23.0, | |
): | |
if isinstance(audio_path_or_wv, str): | |
wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None) | |
else: | |
wv = audio_path_or_wv | |
assert sr is not None | |
if loudness_normalize: | |
import pyloudnorm as pyln | |
meter = pyln.Meter(sr) | |
l = meter.integrated_loudness(wv) | |
wv = pyln.normalize.loudness(wv, l, loudness_threshold) | |
if sr != self.sampling_rate: | |
wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate) | |
if self.audio_tokenizer_feature_extractor is not None: | |
inputs = self.audio_tokenizer_feature_extractor( | |
raw_audio=wv, | |
sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate, | |
return_tensors="pt", | |
) | |
input_values = inputs["input_values"].to(self.device) | |
else: | |
input_values = torch.from_numpy(wv).float().unsqueeze(0) | |
with torch.no_grad(): | |
encoder_outputs = self._xcodec_encode(input_values) | |
vq_code = encoder_outputs.audio_codes[0] | |
return vq_code | |
def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor: | |
bw = target_bw | |
e_semantic_input = self.get_regress_target(x).detach() | |
e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) | |
e_acoustic = self.encoder(x) | |
if e_acoustic.shape[2] != e_semantic.shape[2]: | |
pad_size = 160 * self.semantic_downsample_factor | |
e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0)) | |
if e_acoustic.shape[2] != e_semantic.shape[2]: | |
if e_acoustic.shape[2] > e_semantic.shape[2]: | |
e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]] | |
else: | |
e_semantic = e_semantic[:, :, : e_acoustic.shape[2]] | |
e = torch.cat([e_acoustic, e_semantic], dim=1) | |
e = self.fc_prior(e.transpose(1, 2)) | |
if self.quantizer_type == "RVQ": | |
e = e.transpose(1, 2) | |
quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) | |
codes = codes.permute(1, 0, 2) | |
else: | |
quantized, codes = self.quantizer(e) | |
codes = codes.permute(0, 2, 1) | |
# return codes | |
return EncodedResult(codes) | |
def decode(self, vq_code: torch.Tensor) -> torch.Tensor: | |
if self.quantizer_type == "RVQ": | |
vq_code = vq_code.permute(1, 0, 2) | |
quantized = self.quantizer.decode(vq_code) | |
quantized = quantized.transpose(1, 2) | |
else: | |
vq_code = vq_code.permute(0, 2, 1) | |
quantized = self.quantizer.get_output_from_indices(vq_code) | |
quantized_acoustic = self.fc_post2(quantized).transpose(1, 2) | |
o = self.decoder_2(quantized_acoustic) | |
return o.cpu().numpy() | |
def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"): | |
is_local = os.path.exists(tokenizer_name_or_path) | |
if not is_local: | |
tokenizer_path = snapshot_download(tokenizer_name_or_path) | |
else: | |
tokenizer_path = tokenizer_name_or_path | |
config_path = os.path.join(tokenizer_path, "config.json") | |
model_path = os.path.join(tokenizer_path, "model.pth") | |
config = json.load(open(config_path)) | |
model = HiggsAudioTokenizer( | |
**config, | |
device=device, | |
) | |
parameter_dict = torch.load(model_path, map_location=device) | |
model.load_state_dict(parameter_dict, strict=False) | |
model.to(device) | |
model.eval() | |
return model | |