lj1995 commited on
Commit
6c168a1
·
1 Parent(s): e27e3fe

Delete feature_extractor

Browse files
feature_extractor/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from . import cnhubert, whisper_enc
2
-
3
- content_module_map = {
4
- 'cnhubert': cnhubert,
5
- 'whisper': whisper_enc
6
- }
 
 
 
 
 
 
 
feature_extractor/cnhubert.py DELETED
@@ -1,109 +0,0 @@
1
- import time
2
-
3
- import librosa
4
- import torch
5
- import torch.nn.functional as F
6
- import soundfile as sf
7
- import os
8
- from transformers import logging as tf_logging
9
- tf_logging.set_verbosity_error()
10
-
11
- import logging
12
- logging.getLogger("numba").setLevel(logging.WARNING)
13
-
14
- from transformers import (
15
- Wav2Vec2FeatureExtractor,
16
- HubertModel,
17
- )
18
-
19
- import utils
20
- import torch.nn as nn
21
-
22
- cnhubert_base_path = None
23
-
24
-
25
- class CNHubert(nn.Module):
26
- def __init__(self):
27
- super().__init__()
28
- if os.path.exists(cnhubert_base_path):...
29
- else:raise FileNotFoundError(cnhubert_base_path)
30
- self.model = HubertModel.from_pretrained(cnhubert_base_path, local_files_only=True)
31
- self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
32
- cnhubert_base_path, local_files_only=True
33
- )
34
-
35
- def forward(self, x):
36
- input_values = self.feature_extractor(
37
- x, return_tensors="pt", sampling_rate=16000
38
- ).input_values.to(x.device)
39
- feats = self.model(input_values)["last_hidden_state"]
40
- return feats
41
-
42
-
43
- # class CNHubertLarge(nn.Module):
44
- # def __init__(self):
45
- # super().__init__()
46
- # self.model = HubertModel.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
47
- # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
48
- # def forward(self, x):
49
- # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
50
- # feats = self.model(input_values)["last_hidden_state"]
51
- # return feats
52
- #
53
- # class CVec(nn.Module):
54
- # def __init__(self):
55
- # super().__init__()
56
- # self.model = HubertModel.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
57
- # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
58
- # def forward(self, x):
59
- # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
60
- # feats = self.model(input_values)["last_hidden_state"]
61
- # return feats
62
- #
63
- # class cnw2v2base(nn.Module):
64
- # def __init__(self):
65
- # super().__init__()
66
- # self.model = Wav2Vec2Model.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
67
- # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
68
- # def forward(self, x):
69
- # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
70
- # feats = self.model(input_values)["last_hidden_state"]
71
- # return feats
72
-
73
-
74
- def get_model():
75
- model = CNHubert()
76
- model.eval()
77
- return model
78
-
79
-
80
- # def get_large_model():
81
- # model = CNHubertLarge()
82
- # model.eval()
83
- # return model
84
- #
85
- # def get_model_cvec():
86
- # model = CVec()
87
- # model.eval()
88
- # return model
89
- #
90
- # def get_model_cnw2v2base():
91
- # model = cnw2v2base()
92
- # model.eval()
93
- # return model
94
-
95
-
96
- def get_content(hmodel, wav_16k_tensor):
97
- with torch.no_grad():
98
- feats = hmodel(wav_16k_tensor)
99
- return feats.transpose(1, 2)
100
-
101
-
102
- if __name__ == "__main__":
103
- model = get_model()
104
- src_path = "/Users/Shared/原音频2.wav"
105
- wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000)
106
- model = model
107
- wav_16k_tensor = wav_16k_tensor
108
- feats = get_content(model, wav_16k_tensor)
109
- print(feats.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
feature_extractor/whisper_enc.py DELETED
@@ -1,25 +0,0 @@
1
- import torch
2
-
3
-
4
- def get_model():
5
- import whisper
6
-
7
- model = whisper.load_model("small", device="cpu")
8
-
9
- return model.encoder
10
-
11
-
12
- def get_content(model=None, wav_16k_tensor=None):
13
- from whisper import log_mel_spectrogram, pad_or_trim
14
-
15
- dev = next(model.parameters()).device
16
- mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000]
17
- # if torch.cuda.is_available():
18
- # mel = mel.to(torch.float16)
19
- feature_len = mel.shape[-1] // 2
20
- assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
21
- with torch.no_grad():
22
- feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
23
- :1, :feature_len, :
24
- ].transpose(1, 2)
25
- return feature