File size: 2,623 Bytes
032c113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import cv2
import os
import glob
import numpy as np

def concat_image_heatmap(img1_path, img2_path, label_path, mask_path, heatmap_path, output_path):
    img1 = cv2.imread(img1_path)
    img2 = cv2.imread(img2_path)
    mask = cv2.imread(mask_path)
    heatmap = cv2.imread(heatmap_path)
    label = cv2.imread(label_path) if label_path and os.path.exists(label_path) else None

    if img1 is None or img2 is None or mask is None or heatmap is None:
        print(f"❌ Missing image: {img1_path}, {img2_path}, {mask_path}, {heatmap_path}")
        return

    h, w = img1.shape[:2]
    img2 = cv2.resize(img2, (w, h))
    mask = cv2.resize(mask, (w, h))
    heatmap = cv2.resize(heatmap, (w, h))
    label = cv2.resize(label, (w, h)) if label is not None else np.zeros_like(img1)

    top_row = np.concatenate([img1, img2, label], axis=1)
    bottom_row = np.concatenate([mask, heatmap], axis=1)

    # 补齐对齐
    max_width = max(top_row.shape[1], bottom_row.shape[1])
    if top_row.shape[1] < max_width:
        pad = max_width - top_row.shape[1]
        top_row = cv2.copyMakeBorder(top_row, 0, 0, 0, pad, cv2.BORDER_CONSTANT, value=0)
    if bottom_row.shape[1] < max_width:
        pad = max_width - bottom_row.shape[1]
        bottom_row = cv2.copyMakeBorder(bottom_row, 0, 0, 0, pad, cv2.BORDER_CONSTANT, value=0)

    full_image = np.concatenate([top_row, bottom_row], axis=0)
    cv2.imwrite(output_path, full_image)
    print(f"✅ Saved: {output_path}")

def batch_process(img1_dir, img2_dir, label_dir, mask_dir, heatmap_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    img1_paths = glob.glob(os.path.join(img1_dir, "*.png"))

    for img1_path in img1_paths:
        filename = os.path.basename(img1_path)
        img2_path = os.path.join(img2_dir, filename)
        label_path = os.path.join(label_dir, filename) if label_dir else None
        mask_path = os.path.join(mask_dir, filename)
        heatmap_path = os.path.join(heatmap_dir, filename)
        output_path = os.path.join(output_dir, filename.replace(".png", "_full.png"))

        concat_image_heatmap(img1_path, img2_path, label_path, mask_path, heatmap_path, output_path)

# 设置路径
img1_dir = "data/WHU_CD/test/image1"
img2_dir = "data/WHU_CD/test/image2"
label_dir = "data/WHU_CD/test/label"  # 可设为 None
mask_dir = "mask_connect_test_dir/mask_rgb"
heatmap_dir = "mask_connect_test_dir/grad_cam/model.net.decoderhead.LHBlock2"
output_dir = "mask_heatmap_concat_dir"

batch_process(img1_dir, img2_dir, label_dir, mask_dir, heatmap_dir, output_dir)