#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import json import os from pathlib import Path import sys pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import librosa from gradio_client import Client import numpy as np from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score from tqdm import tqdm def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--test_set", default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\vad", type=str ) parser.add_argument( "--output_file", default=r"evaluation.jsonl", type=str ) parser.add_argument("--expected_sample_rate", default=8000, type=int) args = parser.parse_args() return args def get_metrics(ground_truth, predictions, total_duration, step=0.01): """ 基于时间点离散化的评估方法 :param ground_truth: 真实区间列表,格式 [[start1, end1], [start2, end2], ...] :param predictions: 预测区间列表,格式同上 :param total_duration: 音频总时长(秒) :param step: 时间离散化步长(默认10ms) :return: 评估指标字典 """ # 生成时间点数组 time_points = np.arange(0, total_duration, step) # 生成标签数组 y_true = np.zeros_like(time_points, dtype=int) y_pred = np.zeros_like(time_points, dtype=int) # 标记真实语音区间 for start, end in ground_truth: mask = (time_points >= start) & (time_points <= end) y_true[mask] = 1 # 标记预测语音区间 for start, end in predictions: mask = (time_points >= start) & (time_points <= end) y_pred[mask] = 1 # 计算指标 result = { "accuracy": accuracy_score(y_true, y_pred), "precision": precision_score(y_true, y_pred, zero_division=0), "recall": recall_score(y_true, y_pred, zero_division=0), "f1": f1_score(y_true, y_pred, zero_division=0) } return result def main(): args = get_args() client = Client("http://127.0.0.1:7866/") test_set = Path(args.test_set) output_file = Path(args.output_file) annotation_file = test_set / "vad.json" with open(annotation_file.as_posix(), "r", encoding="utf-8") as f: annotation = json.load(f) total = 0 total_accuracy = 0 total_precision = 0 total_recall = 0 total_f1 = 0 total_duration = 0 progress_bar = tqdm(desc="evaluation") with open(output_file.as_posix(), "w", encoding="utf-8") as f: for row in annotation: filename = row["filename"] ground_truth_vad_segments = row["vad_segments"] filename = test_set / filename _, _, _, message = client.predict( audio_file_t={ "path": filename.as_posix(), "meta": {"_type": "gradio.FileData"} }, audio_microphone_t=None, start_ring_rate=0.5, end_ring_rate=0.3, ring_max_length=10, min_silence_length=6, max_speech_length=100000, min_speech_length=15, # engine="fsmn-vad-by-webrtcvad-nx2-dns3", engine="silero-vad-by-webrtcvad-nx2-dns3", api_name="/when_click_vad_button" ) js = json.loads(message) prediction_vad_segments = js["vad_segments"] duration = js["duration"] metrics = get_metrics(ground_truth_vad_segments, prediction_vad_segments, duration) accuracy = metrics["accuracy"] precision = metrics["precision"] recall = metrics["recall"] f1 = metrics["f1"] row_ = { "filename": filename.as_posix(), "duration": duration, "ground_truth": ground_truth_vad_segments, "prediction": prediction_vad_segments, "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, } row_ = json.dumps(row_, ensure_ascii=False) f.write(f"{row_}\n") total += 1 total_duration += duration total_accuracy += accuracy * duration total_precision += precision * duration total_recall += recall * duration total_f1 += f1 * duration average_accuracy = total_accuracy / total_duration average_precision = total_precision / total_duration average_recall = total_recall / total_duration average_f1 = total_f1 / total_duration progress_bar.update(1) progress_bar.set_postfix({ "total": total, "accuracy": average_accuracy, "precision": average_precision, "recall": average_recall, "f1": average_f1, "total_duration": f"{round(total_duration / 60, 4)}min", }) return if __name__ == "__main__": main()