import ast import os import re import pandas as pd from sklearn.metrics import precision_recall_fscore_support from segmentation import segment_batchalign from segmentation import segment_SaT from segmentation import segment_SaT_cunit_3l from segmentation import segment_SaT_cunit_12l from segmentation import segment_SaT_cunit_3l_r32a64 from segmentation import segment_SaT_cunit_3l_r64a128 from segmentation import segment_SaT_cunit_3l_no_shuffle from tqdm import tqdm def clean_text(text): return re.sub(r"[^\w\s]", "", text.lower()).strip() def eval_segmentation(dataset_path, segmentation_model, model_name="unknown", chunk_num=10): os.makedirs("benchmark_result/segmentation", exist_ok=True) df = pd.read_csv(dataset_path) results = [] for i in tqdm(range(0, len(df), chunk_num), desc="Evaluating chunks"): chunk = df.iloc[i:i + chunk_num] if len(chunk) < chunk_num: continue word_sequence = [] gt_label_sequence = [] for row in chunk["cleaned_transcription"]: if pd.isna(row): continue cleaned = clean_text(row) words = cleaned.split() if not words: continue word_sequence.extend(words) gt_label_sequence.extend([0] * (len(words) - 1) + [1]) input_text = " ".join(word_sequence) predicted_labels = segmentation_model(input_text) if len(predicted_labels) != len(gt_label_sequence): print(f"Label length mismatch at chunk {i}. Skipping...") continue results.append({ "word_sequence": input_text, "gt_label_sequence": " ".join(map(str, gt_label_sequence)), "predict_label_sequence": " ".join(map(str, predicted_labels)) }) result_df = pd.DataFrame(results) result_df.to_csv(f"benchmark_result/segmentation/{model_name}_results.csv", index=False) all_gt = [] all_pred = [] for row in results: all_gt.extend(map(int, row["gt_label_sequence"].split())) all_pred.extend(map(int, row["predict_label_sequence"].split())) tp = sum((g == 1 and p == 1) for g, p in zip(all_gt, all_pred)) fp = sum((g == 0 and p == 1) for g, p in zip(all_gt, all_pred)) fn = sum((g == 1 and p == 0) for g, p in zip(all_gt, all_pred)) precision, recall, f1, _ = precision_recall_fscore_support(all_gt, all_pred, average='binary', zero_division=0) print(f"{model_name} - TP: {tp}, FP: {fp}, FN: {fn}") print(f"{model_name} - Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}") return precision, recall, f1 if __name__ == "__main__": dataset_path = "./data/enni_salt_for_segmentation/test.csv" # print("Evaluating BatchAlign segmentation model...") # batchalign_precision, batchalign_recall, batchalign_f1 = eval_segmentation( # dataset_path, segment_batchalign, "batchalign" # ) print("\nEvaluating SaT segmentation model...") sat_precision, sat_recall, sat_f1 = eval_segmentation( dataset_path, segment_SaT, "SaT" ) print("\nEvaluating SaT_cunit_3l segmentation model...") sat_cunit_3l_precision, sat_cunit_3l_recall, sat_cunit_3l_f1 = eval_segmentation( dataset_path, segment_SaT_cunit_3l, "SaT_cunit_3l" ) print("\nEvaluating SaT_cunit_12l segmentation model...") sat_cunit_12l_precision, sat_cunit_12l_recall, sat_cunit_12l_f1 = eval_segmentation( dataset_path, segment_SaT_cunit_12l, "SaT_cunit_12l" ) print("\nEvaluating SaT_cunit_3l_r32a64 segmentation model...") sat_cunit_3l_r32a64_precision, sat_cunit_3l_r32a64_recall, sat_cunit_3l_r32a64_f1 = eval_segmentation( dataset_path, segment_SaT_cunit_3l_r32a64, "SaT_cunit_3l_r32a64" ) print("\nEvaluating SaT_cunit_3l_r64a128 segmentation model...") sat_cunit_3l_r64a128_precision, sat_cunit_3l_r64a128_recall, sat_cunit_3l_r64a128_f1 = eval_segmentation( dataset_path, segment_SaT_cunit_3l_r64a128, "SaT_cunit_3l_r64a128" ) print("\nEvaluating SaT_cunit_3l_no_shuffle segmentation model...") sat_cunit_3l_no_shuffle_precision, sat_cunit_3l_no_shuffle_recall, sat_cunit_3l_no_shuffle_f1 = eval_segmentation( dataset_path, segment_SaT_cunit_3l_no_shuffle, "SaT_cunit_3l_no_shuffle" ) print("\n" + "="*80) print("COMPARISON RESULTS:") print("="*80) # print(f"BatchAlign - Precision: {batchalign_precision:.3f}, Recall: {batchalign_recall:.3f}, F1: {batchalign_f1:.3f}") print(f"SaT - Precision: {sat_precision:.3f}, Recall: {sat_recall:.3f}, F1: {sat_f1:.3f}") print(f"SaT_cunit_3l - Precision: {sat_cunit_3l_precision:.3f}, Recall: {sat_cunit_3l_recall:.3f}, F1: {sat_cunit_3l_f1:.3f}") print(f"SaT_cunit_12l - Precision: {sat_cunit_12l_precision:.3f}, Recall: {sat_cunit_12l_recall:.3f}, F1: {sat_cunit_12l_f1:.3f}") print(f"SaT_cunit_3l_r32a64 - Precision: {sat_cunit_3l_r32a64_precision:.3f}, Recall: {sat_cunit_3l_r32a64_recall:.3f}, F1: {sat_cunit_3l_r32a64_f1:.3f}") print(f"SaT_cunit_3l_r64a128 - Precision: {sat_cunit_3l_r64a128_precision:.3f}, Recall: {sat_cunit_3l_r64a128_recall:.3f}, F1: {sat_cunit_3l_r64a128_f1:.3f}") print(f"SaT_cunit_3l_no_shuffle - Precision: {sat_cunit_3l_no_shuffle_precision:.3f}, Recall: {sat_cunit_3l_no_shuffle_recall:.3f}, F1: {sat_cunit_3l_no_shuffle_f1:.3f}")