InPeerReview's picture
Upload 9 files
032c113 verified
import torch
from torch import nn
from tqdm import tqdm
import prettytable
import time
import os
import multiprocessing.pool as mpp
import multiprocessing as mp
from train import *
import argparse
from utils.config import Config
from tools.mask_convert import mask_save
import numpy as np # [PR] for histogram-based PR accumulation
import csv
# =========================== [PR] Utilities BEGIN ===========================
class PRHistogram:
# Memory-friendly PR accumulator. Call update(probs, mask) repeatedly inside
# your test loop, then call export_csv(path) after the loop.
# - probs: torch.Tensor in [0,1], shape [B,H,W], "change" probability
# - mask: torch.Tensor of 0/1 (or 0/255), shape [B,H,W]
def __init__(self, nbins: int = 1000):
import numpy as _np
self.nbins = int(nbins)
self.pos_hist = _np.zeros(self.nbins, dtype=_np.int64)
self.neg_hist = _np.zeros(self.nbins, dtype=_np.int64)
self.bin_edges = _np.linspace(0.0, 1.0, self.nbins + 1)
def update(self, probs, mask):
import numpy as _np
p = probs.detach().float().cpu().numpy().ravel()
g = (mask.detach().cpu().numpy().ravel() > 0).astype(_np.uint8)
pos_counts, _ = _np.histogram(p[g == 1], bins=self.bin_edges)
neg_counts, _ = _np.histogram(p[g == 0], bins=self.bin_edges)
self.pos_hist += pos_counts
self.neg_hist += neg_counts
def compute_curve(self):
import numpy as _np
# 累加得到从高阈值到低阈值的 TP/FP
pos_cum = _np.cumsum(self.pos_hist[::-1])
neg_cum = _np.cumsum(self.neg_hist[::-1])
TP = pos_cum
FP = neg_cum
FN = self.pos_hist.sum() - TP
TN = None # 曲线里用不到 TN
denom_prec = _np.maximum(TP + FP, 1)
denom_rec = _np.maximum(TP + FN, 1)
precision = TP / denom_prec
recall = TP / denom_rec
# F1 = 2PR/(P+R)
denom_f1 = _np.maximum(precision + recall, 1e-12)
f1 = 2.0 * precision * recall / denom_f1
# IoU = TP / (TP + FP + FN)
denom_iou = _np.maximum(TP + FP + FN, 1)
iou = TP / denom_iou
thresholds = self.bin_edges[::-1][1:] # 与上述累积方向一致的阈值序列
return thresholds, precision, recall, f1, iou, TP, FP, FN
def export_csv(self, save_path: str):
thresholds, precision, recall, f1, iou, TP, FP, FN = self.compute_curve()
import numpy as _np, os as _os
_os.makedirs(_os.path.dirname(save_path), exist_ok=True)
_np.savetxt(
save_path,
_np.column_stack([thresholds, precision, recall, f1, iou, TP, FP, FN]),
delimiter=",",
header="threshold,precision,recall,f1,iou,TP,FP,FN",
comments=""
)
return save_path
# Global PR object (create when needed)
_PR = None
def pr_init(nbins: int = 1000):
global _PR
if _PR is None:
_PR = PRHistogram(nbins=nbins)
return _PR
def pr_update_from_outputs(raw_predictions, mask, cfg):
# Try to derive probs ∈ [0,1] from various model outputs in this repo.
# This covers:
# - cfg.argmax=True: 2-channel logits -> softmax class-1 prob
# - single-channel logits -> sigmoid
# - net == 'maskcd' (list/tuple outputs)
# Modify here if your network has a special head.
import torch
global _PR
if _PR is None:
_PR = PRHistogram(nbins=1000)
if getattr(cfg, 'argmax', False):
logits = raw_predictions
if logits.dim() == 4 and logits.size(1) >= 2:
probs = torch.softmax(logits, dim=1)[:, 1, :, :]
else:
probs = torch.sigmoid(logits.squeeze(1))
else:
if getattr(cfg, 'net', '') == 'maskcd':
if isinstance(raw_predictions, (list, tuple)):
logits = raw_predictions[0]
else:
logits = raw_predictions
probs = torch.sigmoid(logits).squeeze(1)
else:
logits = raw_predictions
if logits.dim() == 4 and logits.size(1) == 1:
logits = logits.squeeze(1)
probs = torch.sigmoid(logits)
if mask.dim() == 4 and mask.size(1) == 1:
mask_ = mask.squeeze(1)
else:
mask_ = mask
_PR.update(probs, (mask_ > 0).to(probs.dtype))
def pr_export(base_dir: str, cfg):
# Export PR CSV to base_dir/pr_<net>.csv
import os
global _PR
if _PR is None:
return None
save_path = os.path.join(base_dir, f"pr_{getattr(cfg,'net','model')}.csv")
out = _PR.export_csv(save_path)
print(f"[PR] saved: {out}")
return out
# ============================ [PR] Utilities END ============================
# -------------------- [Per-Image] 逐图指标工具 --------------------
def _safe_div(a, b, eps=1e-12):
return a / max(b, eps)
def per_image_stats(pred_np: np.ndarray, gt_np: np.ndarray):
"""
pred_np, gt_np: 0/1 二值 numpy 数组, shape [H,W]
返回: dict 包含 TP/FP/TN/FN 与各类指标
"""
pred_bin = (pred_np > 0).astype(np.uint8)
gt_bin = (gt_np > 0).astype(np.uint8)
TP = int(((pred_bin == 1) & (gt_bin == 1)).sum())
FP = int(((pred_bin == 1) & (gt_bin == 0)).sum())
TN = int(((pred_bin == 0) & (gt_bin == 0)).sum())
FN = int(((pred_bin == 0) & (gt_bin == 1)).sum())
precision = _safe_div(TP, (TP + FP))
recall = _safe_div(TP, (TP + FN))
f1 = _safe_div(2 * precision * recall, (precision + recall))
iou = _safe_div(TP, (TP + FP + FN))
oa = _safe_div(TP + TN, (TP + TN + FP + FN))
return {
"TP": TP, "FP": FP, "TN": TN, "FN": FN,
"OA": oa, "Precision": precision, "Recall": recall, "F1": f1, "IoU": iou
}
# --------------------------------------------------------------------
def get_args():
parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
parser.add_argument("--ckpt", type=str, default=None)
parser.add_argument("--output_dir", type=str, default=None)
# 新增:仅生成表格模式(不导出可视化图片)
parser.add_argument("--tables-only", action="store_true",
help="仅生成表格与CSV(总体表、逐图CSV、逐图TXT、小计PR曲线CSV),不生成mask可视化图片")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
cfg = Config.fromfile(args.config)
ckpt = args.ckpt
if ckpt is None:
ckpt = cfg.test_ckpt_path
assert ckpt is not None
if args.output_dir:
base_dir = args.output_dir
else:
base_dir = os.path.dirname(ckpt)
# 原图像输出目录(仅在需要写图时使用)
masks_output_dir = os.path.join(base_dir, "mask_rgb")
# 表格输出目录(逐图表格 .txt),如果 tables-only 则单独放在 tables_only 下
tables_output_dir = os.path.join(base_dir, "tables_only" if args.tables_only else "mask_rgb")
os.makedirs(tables_output_dir, exist_ok=True)
model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg)
model = model.to('cuda')
model.eval()
metric_cfg_1 = cfg.metric_cfg1
metric_cfg_2 = cfg.metric_cfg2
test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda')
test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda')
test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda')
test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda')
test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda')
results = [] # 仅在生成图片时使用
per_image_rows = [] # [Per-Image] 收集逐图指标
with torch.no_grad():
test_loader = build_dataloader(cfg.dataset_config, mode='test')
# === 调用1: 初始化 ===
pr_init(nbins=1000)
for input in tqdm(test_loader):
raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3]
# === 调用2: 更新 ===
pr_update_from_outputs(raw_predictions, mask, cfg)
if cfg.net == 'SARASNet':
mask = Variable(resize_label(mask.data.cpu().numpy(), \
size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long()
param = 1 # This parameter is balance precision and recall to get higher F1-score
raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param
if cfg.argmax:
pred = raw_predictions.argmax(dim=1)
else:
if cfg.net == 'maskcd':
pred = raw_predictions[0]
pred = pred > 0.5
pred.squeeze_(1)
else:
pred = raw_predictions.squeeze(1)
pred = pred > 0.5
# ====== 累计整体验证指标 ======
test_oa(pred, mask)
test_iou(pred, mask)
test_prec(pred, mask)
test_f1(pred, mask)
test_recall(pred, mask)
# ====== [Per-Image] 逐图指标计算与收集 ======
for i in range(raw_predictions.shape[0]):
mask_real = mask[i].detach().cpu().numpy()
mask_pred = pred[i].detach().cpu().numpy()
mask_name = str(img_id[i])
# 逐图统计
stats = per_image_stats(mask_pred, mask_real)
per_image_rows.append({
"img_id": mask_name,
"TP": stats["TP"], "FP": stats["FP"], "TN": stats["TN"], "FN": stats["FN"],
"OA": stats["OA"], "Precision": stats["Precision"],
"Recall": stats["Recall"], "F1": stats["F1"], "IoU": stats["IoU"]
})
# 仅在需要生成可视化图片时才收集写图任务
if not args.tables_only:
results.append((mask_real, mask_pred, masks_output_dir, mask_name))
# ====== 打印总体指标 ======
metrics = [test_prec.compute(),
test_recall.compute(),
test_f1.compute(),
test_iou.compute()]
total_metrics = [test_oa.compute().cpu().numpy(),
np.mean([item.cpu() for item in metrics[0]]),
np.mean([item.cpu() for item in metrics[1]]),
np.mean([item.cpu() for item in metrics[2]]),
np.mean([item.cpu() for item in metrics[3]])]
result_table = prettytable.PrettyTable()
result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU']
for i in range(2):
item = [i, '--']
for j in range(len(metrics)):
item.append(np.round(metrics[j][i].cpu().numpy(), 4))
result_table.add_row(item)
total = [np.round(v, 4) for v in total_metrics]
total.insert(0, 'total')
result_table.add_row(total)
print(result_table)
file_name = os.path.join(base_dir, "test_res.txt")
f = open(file_name,"a")
current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time()))
f.write(current_time+'\n')
f.write(str(result_table)+'\n')
# ====== 根据模式选择是否写图 ======
if not args.tables_only:
if not os.path.exists(masks_output_dir):
os.makedirs(masks_output_dir)
print(masks_output_dir)
# 多进程写图
t0 = time.time()
mpp.Pool(processes=mp.cpu_count()).map(mask_save, results)
t1 = time.time()
img_write_time = t1 - t0
print('images writing spends: {} s'.format(img_write_time))
else:
print("[Mode] --tables-only: 跳过可视化图片的生成,仅导出表格/CSV。")
# ====== [Per-Image] 将逐图指标写成一个总 CSV ======
per_image_csv = os.path.join(base_dir, f"per_image_metrics_{getattr(cfg,'net','model')}.csv")
with open(per_image_csv, "w", newline="") as wf:
writer = csv.DictWriter(
wf,
fieldnames=["img_id","TP","FP","TN","FN","OA","Precision","Recall","F1","IoU"]
)
writer.writeheader()
for row in per_image_rows:
row_out = dict(row)
for k in ["OA","Precision","Recall","F1","IoU"]:
row_out[k] = float(np.round(row_out[k], 6))
writer.writerow(row_out)
print(f"[Per-Image] saved CSV: {per_image_csv}")
# ====== [Per-Image] 为每张图各自写一个小表(.txt) ======
for row in per_image_rows:
txt_path = os.path.join(tables_output_dir, f"{row['img_id']}_metrics.txt")
pt = prettytable.PrettyTable()
pt.field_names = ["Metric", "Value"]
# 先放混淆矩阵元素
pt.add_row(["TP", row["TP"]])
pt.add_row(["FP", row["FP"]])
pt.add_row(["TN", row["TN"]])
pt.add_row(["FN", row["FN"]])
# 再放比率类指标
pt.add_row(["OA", f"{row['OA']:.6f}"])
pt.add_row(["Precision",f"{row['Precision']:.6f}"])
pt.add_row(["Recall", f"{row['Recall']:.6f}"])
pt.add_row(["F1", f"{row['F1']:.6f}"])
pt.add_row(["IoU", f"{row['IoU']:.6f}"])
with open(txt_path, "w") as wf:
wf.write(str(pt))
print(f"[Per-Image] per-image tables saved to: {tables_output_dir}")
# ===== [PR] Export at program end =====
try:
pr_export(base_dir, cfg)
except Exception as e:
print(f"[PR] export skipped or failed: {e}")