File size: 5,223 Bytes
fba9477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from transformers import SeamlessM4TFeatureExtractor
from transformers import Wav2Vec2BertModel
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
import os
import pickle
import math
import json
import safetensors
import json5
# from codec.kmeans.repcodec_model import RepCodec
from startts.examples.ftchar.models.codec.kmeans.repcodec_model import RepCodec

class JsonHParams:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            if type(v) == dict:
                v = JsonHParams(**v)
            self[k] = v

    def keys(self):
        return self.__dict__.keys()

    def items(self):
        return self.__dict__.items()

    def values(self):
        return self.__dict__.values()

    def __len__(self):
        return len(self.__dict__)

    def __getitem__(self, key):
        return getattr(self, key)

    def __setitem__(self, key, value):
        return setattr(self, key, value)

    def __contains__(self, key):
        return key in self.__dict__

    def __repr__(self):
        return self.__dict__.__repr__()


def _load_config(config_fn, lowercase=False):
    """Load configurations into a dictionary

    Args:
        config_fn (str): path to configuration file
        lowercase (bool, optional): whether changing keys to lower case. Defaults to False.

    Returns:
        dict: dictionary that stores configurations
    """
    with open(config_fn, "r") as f:
        data = f.read()
    config_ = json5.loads(data)
    if "base_config" in config_:
        # load configurations from new path
        p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"])
        p_config_ = _load_config(p_config_path)
        config_ = override_config(p_config_, config_)
    if lowercase:
        # change keys in config_ to lower case
        config_ = get_lowercase_keys_config(config_)
    return config_


def load_config(config_fn, lowercase=False):
    """Load configurations into a dictionary

    Args:
        config_fn (str): path to configuration file
        lowercase (bool, optional): _description_. Defaults to False.

    Returns:
        JsonHParams: an object that stores configurations
    """
    config_ = _load_config(config_fn, lowercase=lowercase)
    # create an JsonHParams object with configuration dict
    cfg = JsonHParams(**config_)
    return cfg

class Extract_wav2vectbert:
    def __init__(self,device):
    #semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
        self.semantic_model = Wav2Vec2BertModel.from_pretrained("./MaskGCT_model/w2v_bert/")
        self.semantic_model.eval()
        self.semantic_model.to(device)
        self.stat_mean_var = torch.load("./MaskGCT_model/wav2vec2bert_stats.pt")
        self.semantic_mean = self.stat_mean_var["mean"]
        self.semantic_std = torch.sqrt(self.stat_mean_var["var"])
        self.semantic_mean = self.semantic_mean.to(device)
        self.semantic_std = self.semantic_std.to(device)
        self.processor = SeamlessM4TFeatureExtractor.from_pretrained(
                "./MaskGCT_model/w2v_bert/")
        self.device = device
        
        cfg_maskgct = load_config('./MaskGCT_model/maskgct.json')
        cfg = cfg_maskgct.model.semantic_codec
        self.semantic_code_ckpt = r'./MaskGCT_model/semantic_codec/model.safetensors'
        self.semantic_codec = RepCodec(cfg=cfg)
        self.semantic_codec.eval()
        self.semantic_codec.to(device)
        safetensors.torch.load_model(self.semantic_codec, self.semantic_code_ckpt)

    @torch.no_grad()
    def extract_features(self, speech): # speech [b,T]
        inputs = self.processor(speech, sampling_rate=16000, return_tensors="pt")
        input_features = inputs["input_features"]
        attention_mask = inputs["attention_mask"]
        return input_features, attention_mask #[2, 620, 160] [2, 620]

    @torch.no_grad()
    def extract_semantic_code(self, input_features, attention_mask):
        vq_emb = self.semantic_model(           # Wav2Vec2BertModel
            input_features=input_features,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )
        feat = vq_emb.hidden_states[17]  # (B, T, C)
        feat = (feat - self.semantic_mean.to(feat)) / self.semantic_std.to(feat)

        semantic_code, rec_feat = self.semantic_codec.quantize(feat)  # (B, T)
        return semantic_code, rec_feat

    def feature_extract(self, prompt_speech):
        
        input_features, attention_mask = self.extract_features(prompt_speech)
        input_features = input_features.to(self.device)
        attention_mask = attention_mask.to(self.device)
        semantic_code, rec_feat = self.extract_semantic_code(input_features, attention_mask)
        return semantic_code,rec_feat
            
if __name__=='__main__':
    speech_path = 'test/magi1.wav'
    speech = librosa.load(speech_path, sr=16000)[0]
    speech = np.c_[speech,speech,speech].T #[2, 198559] 
    print(speech.shape)
            
    Extract_feature = Extract_wav2vectbert('cuda:0')
    semantic_code,rec_feat = Extract_feature.feature_extract(speech)
    print(semantic_code.shape,rec_feat.shape)