File size: 3,046 Bytes
6efeebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import sys
import tempfile

pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))

import matplotlib.pyplot as plt
import numpy as np
from scipy.io import wavfile
from tqdm import tqdm


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--eval_file",
        default=r"evaluation.jsonl",
        type=str
    )
    args = parser.parse_args()
    return args


def show_image(signal: np.ndarray,
               ground_truth_probs: np.ndarray,
               prediction_probs: np.ndarray,
               sample_rate: int = 8000,
               ):
    duration = np.arange(0, len(signal)) / sample_rate
    plt.figure(figsize=(12, 5))

    plt.subplot(2, 1, 1)  # 2行1列,第1个位置
    plt.plot(duration, signal, color="b")
    plt.plot(duration, ground_truth_probs, color="gray")
    plt.title("ground_truth")

    plt.subplot(2, 1, 2)  # 2行1列,第2个位置
    plt.plot(duration, signal, color="b")
    plt.plot(duration, prediction_probs, color="gray")
    plt.title("prediction")

    # plt.tight_layout()
    plt.subplots_adjust(hspace=0.5)  # 调整上下间距

    plt.show()


def main():
    args = get_args()

    with open(args.eval_file, "r", encoding="utf-8") as f:
        for row in f:
            row = json.loads(row)
            filename = row["filename"]
            duration = row["duration"]
            ground_truth = row["ground_truth"]
            prediction = row["prediction"]

            accuracy = row["accuracy"]
            precision = row["precision"]
            recall = row["recall"]
            f1 = row["f1"]

            sample_rate, signal = wavfile.read(
                filename=filename,
            )
            signal = np.array(signal / (1 << 15), dtype=np.float32)
            signal_length = len(signal)
            ground_truth_probs = np.zeros(shape=(signal_length,), dtype=np.float32)
            for begin, end in ground_truth:
                begin = int(begin * sample_rate)
                end = int(end * sample_rate)
                ground_truth_probs[begin:end] = 1
            prediction_probs = np.zeros(shape=(signal_length,), dtype=np.float32)
            for begin, end in prediction:
                begin = int(begin * sample_rate)
                end = int(end * sample_rate)
                prediction_probs[begin:end] = 1

            # p = encoder_num_layers * (encoder_kernel_size - 1) // 2 * hop_size * sample_rate
            p = 3 * (3 - 1) // 2 * 80
            p = int(p)
            print(f"p: {p}")
            prediction_probs = np.concat(
                [
                    prediction_probs[p:], prediction_probs[-p:]
                ],
                axis=-1
            )

            show_image(signal,
                       ground_truth_probs, prediction_probs,
                       sample_rate=sample_rate,
                       )
    return


if __name__ == "__main__":
    main()