import torch from torch import nn from tqdm import tqdm import prettytable import time import os import multiprocessing.pool as mpp import multiprocessing as mp from train import * import argparse from utils.config import Config from tools.mask_convert import mask_save import numpy as np # [PR] for histogram-based PR accumulation import csv # =========================== [PR] Utilities BEGIN =========================== class PRHistogram: # Memory-friendly PR accumulator. Call update(probs, mask) repeatedly inside # your test loop, then call export_csv(path) after the loop. # - probs: torch.Tensor in [0,1], shape [B,H,W], "change" probability # - mask: torch.Tensor of 0/1 (or 0/255), shape [B,H,W] def __init__(self, nbins: int = 1000): import numpy as _np self.nbins = int(nbins) self.pos_hist = _np.zeros(self.nbins, dtype=_np.int64) self.neg_hist = _np.zeros(self.nbins, dtype=_np.int64) self.bin_edges = _np.linspace(0.0, 1.0, self.nbins + 1) def update(self, probs, mask): import numpy as _np p = probs.detach().float().cpu().numpy().ravel() g = (mask.detach().cpu().numpy().ravel() > 0).astype(_np.uint8) pos_counts, _ = _np.histogram(p[g == 1], bins=self.bin_edges) neg_counts, _ = _np.histogram(p[g == 0], bins=self.bin_edges) self.pos_hist += pos_counts self.neg_hist += neg_counts def compute_curve(self): import numpy as _np # 累加得到从高阈值到低阈值的 TP/FP pos_cum = _np.cumsum(self.pos_hist[::-1]) neg_cum = _np.cumsum(self.neg_hist[::-1]) TP = pos_cum FP = neg_cum FN = self.pos_hist.sum() - TP TN = None # 曲线里用不到 TN denom_prec = _np.maximum(TP + FP, 1) denom_rec = _np.maximum(TP + FN, 1) precision = TP / denom_prec recall = TP / denom_rec # F1 = 2PR/(P+R) denom_f1 = _np.maximum(precision + recall, 1e-12) f1 = 2.0 * precision * recall / denom_f1 # IoU = TP / (TP + FP + FN) denom_iou = _np.maximum(TP + FP + FN, 1) iou = TP / denom_iou thresholds = self.bin_edges[::-1][1:] # 与上述累积方向一致的阈值序列 return thresholds, precision, recall, f1, iou, TP, FP, FN def export_csv(self, save_path: str): thresholds, precision, recall, f1, iou, TP, FP, FN = self.compute_curve() import numpy as _np, os as _os _os.makedirs(_os.path.dirname(save_path), exist_ok=True) _np.savetxt( save_path, _np.column_stack([thresholds, precision, recall, f1, iou, TP, FP, FN]), delimiter=",", header="threshold,precision,recall,f1,iou,TP,FP,FN", comments="" ) return save_path # Global PR object (create when needed) _PR = None def pr_init(nbins: int = 1000): global _PR if _PR is None: _PR = PRHistogram(nbins=nbins) return _PR def pr_update_from_outputs(raw_predictions, mask, cfg): # Try to derive probs ∈ [0,1] from various model outputs in this repo. # This covers: # - cfg.argmax=True: 2-channel logits -> softmax class-1 prob # - single-channel logits -> sigmoid # - net == 'maskcd' (list/tuple outputs) # Modify here if your network has a special head. import torch global _PR if _PR is None: _PR = PRHistogram(nbins=1000) if getattr(cfg, 'argmax', False): logits = raw_predictions if logits.dim() == 4 and logits.size(1) >= 2: probs = torch.softmax(logits, dim=1)[:, 1, :, :] else: probs = torch.sigmoid(logits.squeeze(1)) else: if getattr(cfg, 'net', '') == 'maskcd': if isinstance(raw_predictions, (list, tuple)): logits = raw_predictions[0] else: logits = raw_predictions probs = torch.sigmoid(logits).squeeze(1) else: logits = raw_predictions if logits.dim() == 4 and logits.size(1) == 1: logits = logits.squeeze(1) probs = torch.sigmoid(logits) if mask.dim() == 4 and mask.size(1) == 1: mask_ = mask.squeeze(1) else: mask_ = mask _PR.update(probs, (mask_ > 0).to(probs.dtype)) def pr_export(base_dir: str, cfg): # Export PR CSV to base_dir/pr_.csv import os global _PR if _PR is None: return None save_path = os.path.join(base_dir, f"pr_{getattr(cfg,'net','model')}.csv") out = _PR.export_csv(save_path) print(f"[PR] saved: {out}") return out # ============================ [PR] Utilities END ============================ # -------------------- [Per-Image] 逐图指标工具 -------------------- def _safe_div(a, b, eps=1e-12): return a / max(b, eps) def per_image_stats(pred_np: np.ndarray, gt_np: np.ndarray): """ pred_np, gt_np: 0/1 二值 numpy 数组, shape [H,W] 返回: dict 包含 TP/FP/TN/FN 与各类指标 """ pred_bin = (pred_np > 0).astype(np.uint8) gt_bin = (gt_np > 0).astype(np.uint8) TP = int(((pred_bin == 1) & (gt_bin == 1)).sum()) FP = int(((pred_bin == 1) & (gt_bin == 0)).sum()) TN = int(((pred_bin == 0) & (gt_bin == 0)).sum()) FN = int(((pred_bin == 0) & (gt_bin == 1)).sum()) precision = _safe_div(TP, (TP + FP)) recall = _safe_div(TP, (TP + FN)) f1 = _safe_div(2 * precision * recall, (precision + recall)) iou = _safe_div(TP, (TP + FP + FN)) oa = _safe_div(TP + TN, (TP + TN + FP + FN)) return { "TP": TP, "FP": FP, "TN": TN, "FN": FN, "OA": oa, "Precision": precision, "Recall": recall, "F1": f1, "IoU": iou } # -------------------------------------------------------------------- def get_args(): parser = argparse.ArgumentParser('description=Change detection of remote sensing images') parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py") parser.add_argument("--ckpt", type=str, default=None) parser.add_argument("--output_dir", type=str, default=None) # 新增:仅生成表格模式(不导出可视化图片) parser.add_argument("--tables-only", action="store_true", help="仅生成表格与CSV(总体表、逐图CSV、逐图TXT、小计PR曲线CSV),不生成mask可视化图片") return parser.parse_args() if __name__ == "__main__": args = get_args() cfg = Config.fromfile(args.config) ckpt = args.ckpt if ckpt is None: ckpt = cfg.test_ckpt_path assert ckpt is not None if args.output_dir: base_dir = args.output_dir else: base_dir = os.path.dirname(ckpt) # 原图像输出目录(仅在需要写图时使用) masks_output_dir = os.path.join(base_dir, "mask_rgb") # 表格输出目录(逐图表格 .txt),如果 tables-only 则单独放在 tables_only 下 tables_output_dir = os.path.join(base_dir, "tables_only" if args.tables_only else "mask_rgb") os.makedirs(tables_output_dir, exist_ok=True) model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg) model = model.to('cuda') model.eval() metric_cfg_1 = cfg.metric_cfg1 metric_cfg_2 = cfg.metric_cfg2 test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda') test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda') test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda') test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda') test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda') results = [] # 仅在生成图片时使用 per_image_rows = [] # [Per-Image] 收集逐图指标 with torch.no_grad(): test_loader = build_dataloader(cfg.dataset_config, mode='test') # === 调用1: 初始化 === pr_init(nbins=1000) for input in tqdm(test_loader): raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3] # === 调用2: 更新 === pr_update_from_outputs(raw_predictions, mask, cfg) if cfg.net == 'SARASNet': mask = Variable(resize_label(mask.data.cpu().numpy(), \ size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long() param = 1 # This parameter is balance precision and recall to get higher F1-score raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param if cfg.argmax: pred = raw_predictions.argmax(dim=1) else: if cfg.net == 'maskcd': pred = raw_predictions[0] pred = pred > 0.5 pred.squeeze_(1) else: pred = raw_predictions.squeeze(1) pred = pred > 0.5 # ====== 累计整体验证指标 ====== test_oa(pred, mask) test_iou(pred, mask) test_prec(pred, mask) test_f1(pred, mask) test_recall(pred, mask) # ====== [Per-Image] 逐图指标计算与收集 ====== for i in range(raw_predictions.shape[0]): mask_real = mask[i].detach().cpu().numpy() mask_pred = pred[i].detach().cpu().numpy() mask_name = str(img_id[i]) # 逐图统计 stats = per_image_stats(mask_pred, mask_real) per_image_rows.append({ "img_id": mask_name, "TP": stats["TP"], "FP": stats["FP"], "TN": stats["TN"], "FN": stats["FN"], "OA": stats["OA"], "Precision": stats["Precision"], "Recall": stats["Recall"], "F1": stats["F1"], "IoU": stats["IoU"] }) # 仅在需要生成可视化图片时才收集写图任务 if not args.tables_only: results.append((mask_real, mask_pred, masks_output_dir, mask_name)) # ====== 打印总体指标 ====== metrics = [test_prec.compute(), test_recall.compute(), test_f1.compute(), test_iou.compute()] total_metrics = [test_oa.compute().cpu().numpy(), np.mean([item.cpu() for item in metrics[0]]), np.mean([item.cpu() for item in metrics[1]]), np.mean([item.cpu() for item in metrics[2]]), np.mean([item.cpu() for item in metrics[3]])] result_table = prettytable.PrettyTable() result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU'] for i in range(2): item = [i, '--'] for j in range(len(metrics)): item.append(np.round(metrics[j][i].cpu().numpy(), 4)) result_table.add_row(item) total = [np.round(v, 4) for v in total_metrics] total.insert(0, 'total') result_table.add_row(total) print(result_table) file_name = os.path.join(base_dir, "test_res.txt") f = open(file_name,"a") current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time())) f.write(current_time+'\n') f.write(str(result_table)+'\n') # ====== 根据模式选择是否写图 ====== if not args.tables_only: if not os.path.exists(masks_output_dir): os.makedirs(masks_output_dir) print(masks_output_dir) # 多进程写图 t0 = time.time() mpp.Pool(processes=mp.cpu_count()).map(mask_save, results) t1 = time.time() img_write_time = t1 - t0 print('images writing spends: {} s'.format(img_write_time)) else: print("[Mode] --tables-only: 跳过可视化图片的生成,仅导出表格/CSV。") # ====== [Per-Image] 将逐图指标写成一个总 CSV ====== per_image_csv = os.path.join(base_dir, f"per_image_metrics_{getattr(cfg,'net','model')}.csv") with open(per_image_csv, "w", newline="") as wf: writer = csv.DictWriter( wf, fieldnames=["img_id","TP","FP","TN","FN","OA","Precision","Recall","F1","IoU"] ) writer.writeheader() for row in per_image_rows: row_out = dict(row) for k in ["OA","Precision","Recall","F1","IoU"]: row_out[k] = float(np.round(row_out[k], 6)) writer.writerow(row_out) print(f"[Per-Image] saved CSV: {per_image_csv}") # ====== [Per-Image] 为每张图各自写一个小表(.txt) ====== for row in per_image_rows: txt_path = os.path.join(tables_output_dir, f"{row['img_id']}_metrics.txt") pt = prettytable.PrettyTable() pt.field_names = ["Metric", "Value"] # 先放混淆矩阵元素 pt.add_row(["TP", row["TP"]]) pt.add_row(["FP", row["FP"]]) pt.add_row(["TN", row["TN"]]) pt.add_row(["FN", row["FN"]]) # 再放比率类指标 pt.add_row(["OA", f"{row['OA']:.6f}"]) pt.add_row(["Precision",f"{row['Precision']:.6f}"]) pt.add_row(["Recall", f"{row['Recall']:.6f}"]) pt.add_row(["F1", f"{row['F1']:.6f}"]) pt.add_row(["IoU", f"{row['IoU']:.6f}"]) with open(txt_path, "w") as wf: wf.write(str(pt)) print(f"[Per-Image] per-image tables saved to: {tables_output_dir}") # ===== [PR] Export at program end ===== try: pr_export(base_dir, cfg) except Exception as e: print(f"[PR] export skipped or failed: {e}")