|
|
|
|
|
import argparse |
|
import json |
|
import os |
|
from pathlib import Path |
|
import sys |
|
|
|
pwd = os.path.abspath(os.path.dirname(__file__)) |
|
sys.path.append(os.path.join(pwd, "../../")) |
|
|
|
import librosa |
|
from gradio_client import Client |
|
import numpy as np |
|
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score |
|
from tqdm import tqdm |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--test_set", |
|
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\vad", |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--output_file", |
|
default=r"evaluation.jsonl", |
|
type=str |
|
) |
|
parser.add_argument("--expected_sample_rate", default=8000, type=int) |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def get_metrics(ground_truth, predictions, total_duration, step=0.01): |
|
""" |
|
基于时间点离散化的评估方法 |
|
:param ground_truth: 真实区间列表,格式 [[start1, end1], [start2, end2], ...] |
|
:param predictions: 预测区间列表,格式同上 |
|
:param total_duration: 音频总时长(秒) |
|
:param step: 时间离散化步长(默认10ms) |
|
:return: 评估指标字典 |
|
""" |
|
|
|
time_points = np.arange(0, total_duration, step) |
|
|
|
|
|
y_true = np.zeros_like(time_points, dtype=int) |
|
y_pred = np.zeros_like(time_points, dtype=int) |
|
|
|
|
|
for start, end in ground_truth: |
|
mask = (time_points >= start) & (time_points <= end) |
|
y_true[mask] = 1 |
|
|
|
|
|
for start, end in predictions: |
|
mask = (time_points >= start) & (time_points <= end) |
|
y_pred[mask] = 1 |
|
|
|
|
|
result = { |
|
"accuracy": accuracy_score(y_true, y_pred), |
|
"precision": precision_score(y_true, y_pred, zero_division=0), |
|
"recall": recall_score(y_true, y_pred, zero_division=0), |
|
"f1": f1_score(y_true, y_pred, zero_division=0) |
|
} |
|
return result |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
client = Client("http://127.0.0.1:7866/") |
|
|
|
test_set = Path(args.test_set) |
|
output_file = Path(args.output_file) |
|
|
|
annotation_file = test_set / "vad.json" |
|
|
|
with open(annotation_file.as_posix(), "r", encoding="utf-8") as f: |
|
annotation = json.load(f) |
|
|
|
total = 0 |
|
total_accuracy = 0 |
|
total_precision = 0 |
|
total_recall = 0 |
|
total_f1 = 0 |
|
total_duration = 0 |
|
progress_bar = tqdm(desc="evaluation") |
|
with open(output_file.as_posix(), "w", encoding="utf-8") as f: |
|
for row in annotation: |
|
filename = row["filename"] |
|
ground_truth_vad_segments = row["vad_segments"] |
|
|
|
filename = test_set / filename |
|
|
|
_, _, _, message = client.predict( |
|
audio_file_t={ |
|
"path": filename.as_posix(), |
|
"meta": {"_type": "gradio.FileData"} |
|
}, |
|
audio_microphone_t=None, |
|
start_ring_rate=0.5, |
|
end_ring_rate=0.3, |
|
ring_max_length=10, |
|
min_silence_length=6, |
|
max_speech_length=100000, |
|
min_speech_length=15, |
|
|
|
engine="silero-vad-by-webrtcvad-nx2-dns3", |
|
api_name="/when_click_vad_button" |
|
) |
|
js = json.loads(message) |
|
prediction_vad_segments = js["vad_segments"] |
|
duration = js["duration"] |
|
|
|
metrics = get_metrics(ground_truth_vad_segments, prediction_vad_segments, duration) |
|
accuracy = metrics["accuracy"] |
|
precision = metrics["precision"] |
|
recall = metrics["recall"] |
|
f1 = metrics["f1"] |
|
|
|
row_ = { |
|
"filename": filename.as_posix(), |
|
"duration": duration, |
|
"ground_truth": ground_truth_vad_segments, |
|
"prediction": prediction_vad_segments, |
|
|
|
"accuracy": accuracy, |
|
"precision": precision, |
|
"recall": recall, |
|
"f1": f1, |
|
} |
|
row_ = json.dumps(row_, ensure_ascii=False) |
|
f.write(f"{row_}\n") |
|
|
|
total += 1 |
|
total_duration += duration |
|
total_accuracy += accuracy * duration |
|
total_precision += precision * duration |
|
total_recall += recall * duration |
|
total_f1 += f1 * duration |
|
|
|
average_accuracy = total_accuracy / total_duration |
|
average_precision = total_precision / total_duration |
|
average_recall = total_recall / total_duration |
|
average_f1 = total_f1 / total_duration |
|
|
|
progress_bar.update(1) |
|
progress_bar.set_postfix({ |
|
"total": total, |
|
"accuracy": average_accuracy, |
|
"precision": average_precision, |
|
"recall": average_recall, |
|
"f1": average_f1, |
|
"total_duration": f"{round(total_duration / 60, 4)}min", |
|
}) |
|
|
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|