cc_vad / examples /evaluation /step_1_run_evaluation.py
HoneyTian's picture
update
6efeebe
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
from pathlib import Path
import sys
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
import librosa
from gradio_client import Client
import numpy as np
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--test_set",
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\vad",
type=str
)
parser.add_argument(
"--output_file",
default=r"evaluation.jsonl",
type=str
)
parser.add_argument("--expected_sample_rate", default=8000, type=int)
args = parser.parse_args()
return args
def get_metrics(ground_truth, predictions, total_duration, step=0.01):
"""
基于时间点离散化的评估方法
:param ground_truth: 真实区间列表,格式 [[start1, end1], [start2, end2], ...]
:param predictions: 预测区间列表,格式同上
:param total_duration: 音频总时长(秒)
:param step: 时间离散化步长(默认10ms)
:return: 评估指标字典
"""
# 生成时间点数组
time_points = np.arange(0, total_duration, step)
# 生成标签数组
y_true = np.zeros_like(time_points, dtype=int)
y_pred = np.zeros_like(time_points, dtype=int)
# 标记真实语音区间
for start, end in ground_truth:
mask = (time_points >= start) & (time_points <= end)
y_true[mask] = 1
# 标记预测语音区间
for start, end in predictions:
mask = (time_points >= start) & (time_points <= end)
y_pred[mask] = 1
# 计算指标
result = {
"accuracy": accuracy_score(y_true, y_pred),
"precision": precision_score(y_true, y_pred, zero_division=0),
"recall": recall_score(y_true, y_pred, zero_division=0),
"f1": f1_score(y_true, y_pred, zero_division=0)
}
return result
def main():
args = get_args()
client = Client("http://127.0.0.1:7866/")
test_set = Path(args.test_set)
output_file = Path(args.output_file)
annotation_file = test_set / "vad.json"
with open(annotation_file.as_posix(), "r", encoding="utf-8") as f:
annotation = json.load(f)
total = 0
total_accuracy = 0
total_precision = 0
total_recall = 0
total_f1 = 0
total_duration = 0
progress_bar = tqdm(desc="evaluation")
with open(output_file.as_posix(), "w", encoding="utf-8") as f:
for row in annotation:
filename = row["filename"]
ground_truth_vad_segments = row["vad_segments"]
filename = test_set / filename
_, _, _, message = client.predict(
audio_file_t={
"path": filename.as_posix(),
"meta": {"_type": "gradio.FileData"}
},
audio_microphone_t=None,
start_ring_rate=0.5,
end_ring_rate=0.3,
ring_max_length=10,
min_silence_length=6,
max_speech_length=100000,
min_speech_length=15,
# engine="fsmn-vad-by-webrtcvad-nx2-dns3",
engine="silero-vad-by-webrtcvad-nx2-dns3",
api_name="/when_click_vad_button"
)
js = json.loads(message)
prediction_vad_segments = js["vad_segments"]
duration = js["duration"]
metrics = get_metrics(ground_truth_vad_segments, prediction_vad_segments, duration)
accuracy = metrics["accuracy"]
precision = metrics["precision"]
recall = metrics["recall"]
f1 = metrics["f1"]
row_ = {
"filename": filename.as_posix(),
"duration": duration,
"ground_truth": ground_truth_vad_segments,
"prediction": prediction_vad_segments,
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1,
}
row_ = json.dumps(row_, ensure_ascii=False)
f.write(f"{row_}\n")
total += 1
total_duration += duration
total_accuracy += accuracy * duration
total_precision += precision * duration
total_recall += recall * duration
total_f1 += f1 * duration
average_accuracy = total_accuracy / total_duration
average_precision = total_precision / total_duration
average_recall = total_recall / total_duration
average_f1 = total_f1 / total_duration
progress_bar.update(1)
progress_bar.set_postfix({
"total": total,
"accuracy": average_accuracy,
"precision": average_precision,
"recall": average_recall,
"f1": average_f1,
"total_duration": f"{round(total_duration / 60, 4)}min",
})
return
if __name__ == "__main__":
main()