cc_vad / examples /evaluation /step_3_show_vad.py
HoneyTian's picture
update
6efeebe
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import sys
import tempfile
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import wavfile
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--eval_file",
default=r"evaluation.jsonl",
type=str
)
args = parser.parse_args()
return args
def show_image(signal: np.ndarray,
ground_truth_probs: np.ndarray,
prediction_probs: np.ndarray,
sample_rate: int = 8000,
):
duration = np.arange(0, len(signal)) / sample_rate
plt.figure(figsize=(12, 5))
plt.subplot(2, 1, 1) # 2行1列,第1个位置
plt.plot(duration, signal, color="b")
plt.plot(duration, ground_truth_probs, color="gray")
plt.title("ground_truth")
plt.subplot(2, 1, 2) # 2行1列,第2个位置
plt.plot(duration, signal, color="b")
plt.plot(duration, prediction_probs, color="gray")
plt.title("prediction")
# plt.tight_layout()
plt.subplots_adjust(hspace=0.5) # 调整上下间距
plt.show()
def main():
args = get_args()
with open(args.eval_file, "r", encoding="utf-8") as f:
for row in f:
row = json.loads(row)
filename = row["filename"]
duration = row["duration"]
ground_truth = row["ground_truth"]
prediction = row["prediction"]
accuracy = row["accuracy"]
precision = row["precision"]
recall = row["recall"]
f1 = row["f1"]
sample_rate, signal = wavfile.read(
filename=filename,
)
signal = np.array(signal / (1 << 15), dtype=np.float32)
signal_length = len(signal)
ground_truth_probs = np.zeros(shape=(signal_length,), dtype=np.float32)
for begin, end in ground_truth:
begin = int(begin * sample_rate)
end = int(end * sample_rate)
ground_truth_probs[begin:end] = 1
prediction_probs = np.zeros(shape=(signal_length,), dtype=np.float32)
for begin, end in prediction:
begin = int(begin * sample_rate)
end = int(end * sample_rate)
prediction_probs[begin:end] = 1
# p = encoder_num_layers * (encoder_kernel_size - 1) // 2 * hop_size * sample_rate
p = 3 * (3 - 1) // 2 * 80
p = int(p)
print(f"p: {p}")
prediction_probs = np.concat(
[
prediction_probs[p:], prediction_probs[-p:]
],
axis=-1
)
show_image(signal,
ground_truth_probs, prediction_probs,
sample_rate=sample_rate,
)
return
if __name__ == "__main__":
main()