File size: 3,220 Bytes
95c3696
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import time
from threading import Lock

import numpy as np
import torch
import torchaudio
from funasr import AutoModel
from funasr.models.seaco_paraformer.model import SeacoParaformer

# Monkey patching to disable hotwords
SeacoParaformer.generate_hotwords_list = lambda self, *args, **kwargs: None


def load_model(*, device="cuda"):
    zh_model = AutoModel(
        model="paraformer-zh",
        device=device,
        disable_pbar=True,
    )
    en_model = AutoModel(
        model="paraformer-en",
        device=device,
        disable_pbar=True,
    )

    return zh_model, en_model


@torch.no_grad()
def batch_asr_internal(model, audios, sr):
    resampled_audios = []
    for audio in audios:
        # 将 NumPy 数组转换为 PyTorch 张量
        if isinstance(audio, np.ndarray):
            audio = torch.from_numpy(audio).float()

        # 确保音频是一维的
        if audio.dim() > 1:
            audio = audio.squeeze()

        audio = torchaudio.functional.resample(audio, sr, 16000)
        assert audio.dim() == 1
        resampled_audios.append(audio)

    res = model.generate(input=resampled_audios, batch_size=len(resampled_audios))

    results = []
    for r, audio in zip(res, audios):
        text = r["text"]
        duration = len(audio) / sr * 1000
        huge_gap = False

        if "timestamp" in r and len(r["timestamp"]) > 2:
            for timestamp_a, timestamp_b in zip(
                r["timestamp"][:-1], r["timestamp"][1:]
            ):
                # If there is a gap of more than 5 seconds, we consider it as a huge gap
                if timestamp_b[0] - timestamp_a[1] > 5000:
                    huge_gap = True
                    break

            # Doesn't make sense to have a huge gap at the end
            if duration - r["timestamp"][-1][1] > 3000:
                huge_gap = True

        results.append(
            {
                "text": text,
                "duration": duration,
                "huge_gap": huge_gap,
            }
        )

    return results


global_lock = Lock()


def batch_asr(model, audios, sr):
    return batch_asr_internal(model, audios, sr)


def is_chinese(text):
    return True


def calculate_wer(text1, text2):
    words1 = text1.split()
    words2 = text2.split()

    # 计算编辑距离
    m, n = len(words1), len(words2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if words1[i - 1] == words2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1

    # 计算WER
    edits = dp[m][n]
    wer = edits / len(words1)

    return wer


if __name__ == "__main__":
    zh_model, en_model = load_model()
    audios = [
        torchaudio.load("lengyue.wav")[0][0],
        torchaudio.load("lengyue.wav")[0][0, : 44100 * 5],
    ]
    print(batch_asr(zh_model, audios, 44100))

    start_time = time.time()
    for _ in range(10):
        batch_asr(zh_model, audios, 44100)
    print("Time taken:", time.time() - start_time)