IndexTTS-2-Demo / indextts /s2mel /wav2vecbert_extract.py
kemuriririn's picture
init
fba9477
from transformers import SeamlessM4TFeatureExtractor
from transformers import Wav2Vec2BertModel
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
import os
import pickle
import math
import json
import safetensors
import json5
# from codec.kmeans.repcodec_model import RepCodec
from startts.examples.ftchar.models.codec.kmeans.repcodec_model import RepCodec
class JsonHParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = JsonHParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()
def _load_config(config_fn, lowercase=False):
"""Load configurations into a dictionary
Args:
config_fn (str): path to configuration file
lowercase (bool, optional): whether changing keys to lower case. Defaults to False.
Returns:
dict: dictionary that stores configurations
"""
with open(config_fn, "r") as f:
data = f.read()
config_ = json5.loads(data)
if "base_config" in config_:
# load configurations from new path
p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"])
p_config_ = _load_config(p_config_path)
config_ = override_config(p_config_, config_)
if lowercase:
# change keys in config_ to lower case
config_ = get_lowercase_keys_config(config_)
return config_
def load_config(config_fn, lowercase=False):
"""Load configurations into a dictionary
Args:
config_fn (str): path to configuration file
lowercase (bool, optional): _description_. Defaults to False.
Returns:
JsonHParams: an object that stores configurations
"""
config_ = _load_config(config_fn, lowercase=lowercase)
# create an JsonHParams object with configuration dict
cfg = JsonHParams(**config_)
return cfg
class Extract_wav2vectbert:
def __init__(self,device):
#semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
self.semantic_model = Wav2Vec2BertModel.from_pretrained("./MaskGCT_model/w2v_bert/")
self.semantic_model.eval()
self.semantic_model.to(device)
self.stat_mean_var = torch.load("./MaskGCT_model/wav2vec2bert_stats.pt")
self.semantic_mean = self.stat_mean_var["mean"]
self.semantic_std = torch.sqrt(self.stat_mean_var["var"])
self.semantic_mean = self.semantic_mean.to(device)
self.semantic_std = self.semantic_std.to(device)
self.processor = SeamlessM4TFeatureExtractor.from_pretrained(
"./MaskGCT_model/w2v_bert/")
self.device = device
cfg_maskgct = load_config('./MaskGCT_model/maskgct.json')
cfg = cfg_maskgct.model.semantic_codec
self.semantic_code_ckpt = r'./MaskGCT_model/semantic_codec/model.safetensors'
self.semantic_codec = RepCodec(cfg=cfg)
self.semantic_codec.eval()
self.semantic_codec.to(device)
safetensors.torch.load_model(self.semantic_codec, self.semantic_code_ckpt)
@torch.no_grad()
def extract_features(self, speech): # speech [b,T]
inputs = self.processor(speech, sampling_rate=16000, return_tensors="pt")
input_features = inputs["input_features"]
attention_mask = inputs["attention_mask"]
return input_features, attention_mask #[2, 620, 160] [2, 620]
@torch.no_grad()
def extract_semantic_code(self, input_features, attention_mask):
vq_emb = self.semantic_model( # Wav2Vec2BertModel
input_features=input_features,
attention_mask=attention_mask,
output_hidden_states=True,
)
feat = vq_emb.hidden_states[17] # (B, T, C)
feat = (feat - self.semantic_mean.to(feat)) / self.semantic_std.to(feat)
semantic_code, rec_feat = self.semantic_codec.quantize(feat) # (B, T)
return semantic_code, rec_feat
def feature_extract(self, prompt_speech):
input_features, attention_mask = self.extract_features(prompt_speech)
input_features = input_features.to(self.device)
attention_mask = attention_mask.to(self.device)
semantic_code, rec_feat = self.extract_semantic_code(input_features, attention_mask)
return semantic_code,rec_feat
if __name__=='__main__':
speech_path = 'test/magi1.wav'
speech = librosa.load(speech_path, sr=16000)[0]
speech = np.c_[speech,speech,speech].T #[2, 198559]
print(speech.shape)
Extract_feature = Extract_wav2vectbert('cuda:0')
semantic_code,rec_feat = Extract_feature.feature_extract(speech)
print(semantic_code.shape,rec_feat.shape)