Spaces:
Running
on
Zero
Running
on
Zero
from transformers import SeamlessM4TFeatureExtractor | |
from transformers import Wav2Vec2BertModel | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import librosa | |
import os | |
import pickle | |
import math | |
import json | |
import safetensors | |
import json5 | |
# from codec.kmeans.repcodec_model import RepCodec | |
from startts.examples.ftchar.models.codec.kmeans.repcodec_model import RepCodec | |
class JsonHParams: | |
def __init__(self, **kwargs): | |
for k, v in kwargs.items(): | |
if type(v) == dict: | |
v = JsonHParams(**v) | |
self[k] = v | |
def keys(self): | |
return self.__dict__.keys() | |
def items(self): | |
return self.__dict__.items() | |
def values(self): | |
return self.__dict__.values() | |
def __len__(self): | |
return len(self.__dict__) | |
def __getitem__(self, key): | |
return getattr(self, key) | |
def __setitem__(self, key, value): | |
return setattr(self, key, value) | |
def __contains__(self, key): | |
return key in self.__dict__ | |
def __repr__(self): | |
return self.__dict__.__repr__() | |
def _load_config(config_fn, lowercase=False): | |
"""Load configurations into a dictionary | |
Args: | |
config_fn (str): path to configuration file | |
lowercase (bool, optional): whether changing keys to lower case. Defaults to False. | |
Returns: | |
dict: dictionary that stores configurations | |
""" | |
with open(config_fn, "r") as f: | |
data = f.read() | |
config_ = json5.loads(data) | |
if "base_config" in config_: | |
# load configurations from new path | |
p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"]) | |
p_config_ = _load_config(p_config_path) | |
config_ = override_config(p_config_, config_) | |
if lowercase: | |
# change keys in config_ to lower case | |
config_ = get_lowercase_keys_config(config_) | |
return config_ | |
def load_config(config_fn, lowercase=False): | |
"""Load configurations into a dictionary | |
Args: | |
config_fn (str): path to configuration file | |
lowercase (bool, optional): _description_. Defaults to False. | |
Returns: | |
JsonHParams: an object that stores configurations | |
""" | |
config_ = _load_config(config_fn, lowercase=lowercase) | |
# create an JsonHParams object with configuration dict | |
cfg = JsonHParams(**config_) | |
return cfg | |
class Extract_wav2vectbert: | |
def __init__(self,device): | |
#semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0") | |
self.semantic_model = Wav2Vec2BertModel.from_pretrained("./MaskGCT_model/w2v_bert/") | |
self.semantic_model.eval() | |
self.semantic_model.to(device) | |
self.stat_mean_var = torch.load("./MaskGCT_model/wav2vec2bert_stats.pt") | |
self.semantic_mean = self.stat_mean_var["mean"] | |
self.semantic_std = torch.sqrt(self.stat_mean_var["var"]) | |
self.semantic_mean = self.semantic_mean.to(device) | |
self.semantic_std = self.semantic_std.to(device) | |
self.processor = SeamlessM4TFeatureExtractor.from_pretrained( | |
"./MaskGCT_model/w2v_bert/") | |
self.device = device | |
cfg_maskgct = load_config('./MaskGCT_model/maskgct.json') | |
cfg = cfg_maskgct.model.semantic_codec | |
self.semantic_code_ckpt = r'./MaskGCT_model/semantic_codec/model.safetensors' | |
self.semantic_codec = RepCodec(cfg=cfg) | |
self.semantic_codec.eval() | |
self.semantic_codec.to(device) | |
safetensors.torch.load_model(self.semantic_codec, self.semantic_code_ckpt) | |
def extract_features(self, speech): # speech [b,T] | |
inputs = self.processor(speech, sampling_rate=16000, return_tensors="pt") | |
input_features = inputs["input_features"] | |
attention_mask = inputs["attention_mask"] | |
return input_features, attention_mask #[2, 620, 160] [2, 620] | |
def extract_semantic_code(self, input_features, attention_mask): | |
vq_emb = self.semantic_model( # Wav2Vec2BertModel | |
input_features=input_features, | |
attention_mask=attention_mask, | |
output_hidden_states=True, | |
) | |
feat = vq_emb.hidden_states[17] # (B, T, C) | |
feat = (feat - self.semantic_mean.to(feat)) / self.semantic_std.to(feat) | |
semantic_code, rec_feat = self.semantic_codec.quantize(feat) # (B, T) | |
return semantic_code, rec_feat | |
def feature_extract(self, prompt_speech): | |
input_features, attention_mask = self.extract_features(prompt_speech) | |
input_features = input_features.to(self.device) | |
attention_mask = attention_mask.to(self.device) | |
semantic_code, rec_feat = self.extract_semantic_code(input_features, attention_mask) | |
return semantic_code,rec_feat | |
if __name__=='__main__': | |
speech_path = 'test/magi1.wav' | |
speech = librosa.load(speech_path, sr=16000)[0] | |
speech = np.c_[speech,speech,speech].T #[2, 198559] | |
print(speech.shape) | |
Extract_feature = Extract_wav2vectbert('cuda:0') | |
semantic_code,rec_feat = Extract_feature.feature_extract(speech) | |
print(semantic_code.shape,rec_feat.shape) | |