File size: 13,644 Bytes
032c113 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 |
import torch
from torch import nn
from tqdm import tqdm
import prettytable
import time
import os
import multiprocessing.pool as mpp
import multiprocessing as mp
from train import *
import argparse
from utils.config import Config
from tools.mask_convert import mask_save
import numpy as np # [PR] for histogram-based PR accumulation
import csv
# =========================== [PR] Utilities BEGIN ===========================
class PRHistogram:
# Memory-friendly PR accumulator. Call update(probs, mask) repeatedly inside
# your test loop, then call export_csv(path) after the loop.
# - probs: torch.Tensor in [0,1], shape [B,H,W], "change" probability
# - mask: torch.Tensor of 0/1 (or 0/255), shape [B,H,W]
def __init__(self, nbins: int = 1000):
import numpy as _np
self.nbins = int(nbins)
self.pos_hist = _np.zeros(self.nbins, dtype=_np.int64)
self.neg_hist = _np.zeros(self.nbins, dtype=_np.int64)
self.bin_edges = _np.linspace(0.0, 1.0, self.nbins + 1)
def update(self, probs, mask):
import numpy as _np
p = probs.detach().float().cpu().numpy().ravel()
g = (mask.detach().cpu().numpy().ravel() > 0).astype(_np.uint8)
pos_counts, _ = _np.histogram(p[g == 1], bins=self.bin_edges)
neg_counts, _ = _np.histogram(p[g == 0], bins=self.bin_edges)
self.pos_hist += pos_counts
self.neg_hist += neg_counts
def compute_curve(self):
import numpy as _np
# 累加得到从高阈值到低阈值的 TP/FP
pos_cum = _np.cumsum(self.pos_hist[::-1])
neg_cum = _np.cumsum(self.neg_hist[::-1])
TP = pos_cum
FP = neg_cum
FN = self.pos_hist.sum() - TP
TN = None # 曲线里用不到 TN
denom_prec = _np.maximum(TP + FP, 1)
denom_rec = _np.maximum(TP + FN, 1)
precision = TP / denom_prec
recall = TP / denom_rec
# F1 = 2PR/(P+R)
denom_f1 = _np.maximum(precision + recall, 1e-12)
f1 = 2.0 * precision * recall / denom_f1
# IoU = TP / (TP + FP + FN)
denom_iou = _np.maximum(TP + FP + FN, 1)
iou = TP / denom_iou
thresholds = self.bin_edges[::-1][1:] # 与上述累积方向一致的阈值序列
return thresholds, precision, recall, f1, iou, TP, FP, FN
def export_csv(self, save_path: str):
thresholds, precision, recall, f1, iou, TP, FP, FN = self.compute_curve()
import numpy as _np, os as _os
_os.makedirs(_os.path.dirname(save_path), exist_ok=True)
_np.savetxt(
save_path,
_np.column_stack([thresholds, precision, recall, f1, iou, TP, FP, FN]),
delimiter=",",
header="threshold,precision,recall,f1,iou,TP,FP,FN",
comments=""
)
return save_path
# Global PR object (create when needed)
_PR = None
def pr_init(nbins: int = 1000):
global _PR
if _PR is None:
_PR = PRHistogram(nbins=nbins)
return _PR
def pr_update_from_outputs(raw_predictions, mask, cfg):
# Try to derive probs ∈ [0,1] from various model outputs in this repo.
# This covers:
# - cfg.argmax=True: 2-channel logits -> softmax class-1 prob
# - single-channel logits -> sigmoid
# - net == 'maskcd' (list/tuple outputs)
# Modify here if your network has a special head.
import torch
global _PR
if _PR is None:
_PR = PRHistogram(nbins=1000)
if getattr(cfg, 'argmax', False):
logits = raw_predictions
if logits.dim() == 4 and logits.size(1) >= 2:
probs = torch.softmax(logits, dim=1)[:, 1, :, :]
else:
probs = torch.sigmoid(logits.squeeze(1))
else:
if getattr(cfg, 'net', '') == 'maskcd':
if isinstance(raw_predictions, (list, tuple)):
logits = raw_predictions[0]
else:
logits = raw_predictions
probs = torch.sigmoid(logits).squeeze(1)
else:
logits = raw_predictions
if logits.dim() == 4 and logits.size(1) == 1:
logits = logits.squeeze(1)
probs = torch.sigmoid(logits)
if mask.dim() == 4 and mask.size(1) == 1:
mask_ = mask.squeeze(1)
else:
mask_ = mask
_PR.update(probs, (mask_ > 0).to(probs.dtype))
def pr_export(base_dir: str, cfg):
# Export PR CSV to base_dir/pr_<net>.csv
import os
global _PR
if _PR is None:
return None
save_path = os.path.join(base_dir, f"pr_{getattr(cfg,'net','model')}.csv")
out = _PR.export_csv(save_path)
print(f"[PR] saved: {out}")
return out
# ============================ [PR] Utilities END ============================
# -------------------- [Per-Image] 逐图指标工具 --------------------
def _safe_div(a, b, eps=1e-12):
return a / max(b, eps)
def per_image_stats(pred_np: np.ndarray, gt_np: np.ndarray):
"""
pred_np, gt_np: 0/1 二值 numpy 数组, shape [H,W]
返回: dict 包含 TP/FP/TN/FN 与各类指标
"""
pred_bin = (pred_np > 0).astype(np.uint8)
gt_bin = (gt_np > 0).astype(np.uint8)
TP = int(((pred_bin == 1) & (gt_bin == 1)).sum())
FP = int(((pred_bin == 1) & (gt_bin == 0)).sum())
TN = int(((pred_bin == 0) & (gt_bin == 0)).sum())
FN = int(((pred_bin == 0) & (gt_bin == 1)).sum())
precision = _safe_div(TP, (TP + FP))
recall = _safe_div(TP, (TP + FN))
f1 = _safe_div(2 * precision * recall, (precision + recall))
iou = _safe_div(TP, (TP + FP + FN))
oa = _safe_div(TP + TN, (TP + TN + FP + FN))
return {
"TP": TP, "FP": FP, "TN": TN, "FN": FN,
"OA": oa, "Precision": precision, "Recall": recall, "F1": f1, "IoU": iou
}
# --------------------------------------------------------------------
def get_args():
parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
parser.add_argument("--ckpt", type=str, default=None)
parser.add_argument("--output_dir", type=str, default=None)
# 新增:仅生成表格模式(不导出可视化图片)
parser.add_argument("--tables-only", action="store_true",
help="仅生成表格与CSV(总体表、逐图CSV、逐图TXT、小计PR曲线CSV),不生成mask可视化图片")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
cfg = Config.fromfile(args.config)
ckpt = args.ckpt
if ckpt is None:
ckpt = cfg.test_ckpt_path
assert ckpt is not None
if args.output_dir:
base_dir = args.output_dir
else:
base_dir = os.path.dirname(ckpt)
# 原图像输出目录(仅在需要写图时使用)
masks_output_dir = os.path.join(base_dir, "mask_rgb")
# 表格输出目录(逐图表格 .txt),如果 tables-only 则单独放在 tables_only 下
tables_output_dir = os.path.join(base_dir, "tables_only" if args.tables_only else "mask_rgb")
os.makedirs(tables_output_dir, exist_ok=True)
model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg)
model = model.to('cuda')
model.eval()
metric_cfg_1 = cfg.metric_cfg1
metric_cfg_2 = cfg.metric_cfg2
test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda')
test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda')
test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda')
test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda')
test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda')
results = [] # 仅在生成图片时使用
per_image_rows = [] # [Per-Image] 收集逐图指标
with torch.no_grad():
test_loader = build_dataloader(cfg.dataset_config, mode='test')
# === 调用1: 初始化 ===
pr_init(nbins=1000)
for input in tqdm(test_loader):
raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3]
# === 调用2: 更新 ===
pr_update_from_outputs(raw_predictions, mask, cfg)
if cfg.net == 'SARASNet':
mask = Variable(resize_label(mask.data.cpu().numpy(), \
size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long()
param = 1 # This parameter is balance precision and recall to get higher F1-score
raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param
if cfg.argmax:
pred = raw_predictions.argmax(dim=1)
else:
if cfg.net == 'maskcd':
pred = raw_predictions[0]
pred = pred > 0.5
pred.squeeze_(1)
else:
pred = raw_predictions.squeeze(1)
pred = pred > 0.5
# ====== 累计整体验证指标 ======
test_oa(pred, mask)
test_iou(pred, mask)
test_prec(pred, mask)
test_f1(pred, mask)
test_recall(pred, mask)
# ====== [Per-Image] 逐图指标计算与收集 ======
for i in range(raw_predictions.shape[0]):
mask_real = mask[i].detach().cpu().numpy()
mask_pred = pred[i].detach().cpu().numpy()
mask_name = str(img_id[i])
# 逐图统计
stats = per_image_stats(mask_pred, mask_real)
per_image_rows.append({
"img_id": mask_name,
"TP": stats["TP"], "FP": stats["FP"], "TN": stats["TN"], "FN": stats["FN"],
"OA": stats["OA"], "Precision": stats["Precision"],
"Recall": stats["Recall"], "F1": stats["F1"], "IoU": stats["IoU"]
})
# 仅在需要生成可视化图片时才收集写图任务
if not args.tables_only:
results.append((mask_real, mask_pred, masks_output_dir, mask_name))
# ====== 打印总体指标 ======
metrics = [test_prec.compute(),
test_recall.compute(),
test_f1.compute(),
test_iou.compute()]
total_metrics = [test_oa.compute().cpu().numpy(),
np.mean([item.cpu() for item in metrics[0]]),
np.mean([item.cpu() for item in metrics[1]]),
np.mean([item.cpu() for item in metrics[2]]),
np.mean([item.cpu() for item in metrics[3]])]
result_table = prettytable.PrettyTable()
result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU']
for i in range(2):
item = [i, '--']
for j in range(len(metrics)):
item.append(np.round(metrics[j][i].cpu().numpy(), 4))
result_table.add_row(item)
total = [np.round(v, 4) for v in total_metrics]
total.insert(0, 'total')
result_table.add_row(total)
print(result_table)
file_name = os.path.join(base_dir, "test_res.txt")
f = open(file_name,"a")
current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time()))
f.write(current_time+'\n')
f.write(str(result_table)+'\n')
# ====== 根据模式选择是否写图 ======
if not args.tables_only:
if not os.path.exists(masks_output_dir):
os.makedirs(masks_output_dir)
print(masks_output_dir)
# 多进程写图
t0 = time.time()
mpp.Pool(processes=mp.cpu_count()).map(mask_save, results)
t1 = time.time()
img_write_time = t1 - t0
print('images writing spends: {} s'.format(img_write_time))
else:
print("[Mode] --tables-only: 跳过可视化图片的生成,仅导出表格/CSV。")
# ====== [Per-Image] 将逐图指标写成一个总 CSV ======
per_image_csv = os.path.join(base_dir, f"per_image_metrics_{getattr(cfg,'net','model')}.csv")
with open(per_image_csv, "w", newline="") as wf:
writer = csv.DictWriter(
wf,
fieldnames=["img_id","TP","FP","TN","FN","OA","Precision","Recall","F1","IoU"]
)
writer.writeheader()
for row in per_image_rows:
row_out = dict(row)
for k in ["OA","Precision","Recall","F1","IoU"]:
row_out[k] = float(np.round(row_out[k], 6))
writer.writerow(row_out)
print(f"[Per-Image] saved CSV: {per_image_csv}")
# ====== [Per-Image] 为每张图各自写一个小表(.txt) ======
for row in per_image_rows:
txt_path = os.path.join(tables_output_dir, f"{row['img_id']}_metrics.txt")
pt = prettytable.PrettyTable()
pt.field_names = ["Metric", "Value"]
# 先放混淆矩阵元素
pt.add_row(["TP", row["TP"]])
pt.add_row(["FP", row["FP"]])
pt.add_row(["TN", row["TN"]])
pt.add_row(["FN", row["FN"]])
# 再放比率类指标
pt.add_row(["OA", f"{row['OA']:.6f}"])
pt.add_row(["Precision",f"{row['Precision']:.6f}"])
pt.add_row(["Recall", f"{row['Recall']:.6f}"])
pt.add_row(["F1", f"{row['F1']:.6f}"])
pt.add_row(["IoU", f"{row['IoU']:.6f}"])
with open(txt_path, "w") as wf:
wf.write(str(pt))
print(f"[Per-Image] per-image tables saved to: {tables_output_dir}")
# ===== [PR] Export at program end =====
try:
pr_export(base_dir, cfg)
except Exception as e:
print(f"[PR] export skipped or failed: {e}")
|