File size: 3,185 Bytes
05005db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright 2024 Yiwei Guo
#  Licensed under Apache 2.0

"""Extract VQ indexes using wav2vec2.0 model (from fairseq)"""

import torch
import logging
from kaldiio import WriteHelper
import os
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining
import argparse
import numpy as np
from pathlib import Path
import soundfile as sf
from tqdm import tqdm

logging.basicConfig(level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')

class Extractor:
    def __init__(self, checkpoint="pretrained/wav2vec2-large-lv60/", device="cuda"):
        self.device = device
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint) 
        model = Wav2Vec2ForPreTraining.from_pretrained(checkpoint) 
        model.to(self.device)
        model.half()
        model.eval()
        self.model = model
        self.feature_extractor = feature_extractor
        logging.info(self.model)
        for p in self.model.parameters():
            p.requires_grad_(False)
    
    def extract(self, wav: np.ndarray, sample_rate: int) -> torch.Tensor:
        with torch.no_grad():
            wav = torch.from_numpy(wav).float()

            input_values = self.feature_extractor(wav, return_tensors="pt", sampling_rate=sample_rate).input_values
            input_values = input_values.half().to(self.device)
            outputs = self.model.wav2vec2(input_values)
            extract_features = self.model.dropout_features(outputs[1]) 
            hidden_states = extract_features
            batch_size, sequence_length, hidden_size = hidden_states.shape
            hidden_states = self.model.quantizer.weight_proj(hidden_states)
            hidden_states = hidden_states.view(batch_size * sequence_length * self.model.quantizer.num_groups, -1)
            codevector_idx = hidden_states.argmax(dim=-1)
            idxs = codevector_idx.view(batch_size, sequence_length, self.model.quantizer.num_groups)
        return idxs[0].cpu()  # [L, Groups]

    def get_codebook(self) -> np.ndarray:
        quantizer = self.model.quantizer
        codebook = quantizer.codevectors  # (1, 640, 384)
        codebook = codebook.view(quantizer.num_groups, quantizer.num_vars, -1)  # (2, 320, 384)
        return codebook.cpu().numpy()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--wav-scp', type=str)
    parser.add_argument("--out-dir", type=str)
    parser.add_argument('--model', default="pretrained/wav2vec2-large-lv60/", type=str)
    args = parser.parse_args()
    
    extractor = Extractor(checkpoint=args.model, device="cuda" if torch.cuda.is_available() else "cpu")

    out_dir=Path(args.out_dir).absolute()
    with open(args.wav_scp, 'r') as f, torch.no_grad(), WriteHelper(f"ark,scp:{out_dir}/feats.ark,{out_dir}/feats.scp") as writer:
        for line in tqdm(f.readlines()):
            uttid, wav_path = line.strip().split(maxsplit=1)
            logging.info("Extracting " + uttid)
            audio, sample_rate = sf.read(wav_path)
            idxs = extractor.extract(audio, sample_rate=sample_rate)
            idxs = idxs.astype(float)
            writer(uttid, idxs)