File size: 5,729 Bytes
2df812d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
'''
Author: Chris Xiao yl.xiao@mail.utoronto.ca
Date: 2023-09-30 16:14:13
LastEditors: Chris Xiao yl.xiao@mail.utoronto.ca
LastEditTime: 2023-12-17 01:50:37
FilePath: /EndoSAM/endoSAM/test.py
Description: fine-tune inference script
I Love IU
Copyright (c) 2023 by Chris Xiao yl.xiao@mail.utoronto.ca, All Rights Reserved.
'''
import argparse
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
import os
from dataset import EndoVisDataset
from utils import make_if_dont_exist, one_hot_embedding_3d
import torch
from model import EndoSAMAdapter
import numpy as np
from segment_anything.build_sam import sam_model_registry
from loss import jaccard
import cv2
import json
import wget
COMMON_MODEL_LINKS={
'default': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth'
}
def parse_command():
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', default=None, type=str, help='path to config file')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_command()
cfg_path = args.cfg
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if cfg_path is not None:
if os.path.exists(cfg_path):
cfg = OmegaConf.load(cfg_path)
else:
raise FileNotFoundError(f'config file {cfg_path} not found')
else:
raise ValueError('config file not specified')
if 'sam_model_dir' not in OmegaConf.to_container(cfg)['model'].keys() or OmegaConf.is_missing(cfg.model, 'sam_model_dir') or not os.path.exists(cfg.model.sam_model_dir):
print("Didn't find SAM Checkpoint. Downloading from Facebook AI...")
parent_dir = '/'.join(os.getcwd().split('/')[:-1])
model_dir = os.path.join(parent_dir, 'sam_ckpts')
make_if_dont_exist(model_dir, overwrite=True)
checkpoint = os.path.join(model_dir, cfg.model.sam_model_type+'.pth')
wget.download(COMMON_MODEL_LINKS[cfg.model.sam_model_type], checkpoint)
OmegaConf.update(cfg, 'model.sam_model_dir', checkpoint)
OmegaConf.save(cfg, cfg_path)
exp = cfg.experiment_name
root_dir = cfg.dataset.dataset_dir
img_format = cfg.dataset.img_format
ann_format = cfg.dataset.ann_format
model_path = cfg.model_folder
model_exp_path = os.path.join(model_path, exp)
test_path = cfg.test_folder
test_exp_path = os.path.join(test_path, exp)
test_exp_mask_path = os.path.join(test_exp_path,'mask')
test_exp_overlay_path = os.path.join(test_exp_path, 'overlay')
make_if_dont_exist(test_exp_path)
make_if_dont_exist(test_exp_mask_path)
make_if_dont_exist(test_exp_overlay_path)
test_dataset = EndoVisDataset(root_dir, ann_format=ann_format, img_format=img_format, mode='test', encoder_size=cfg.model.encoder_size)
test_loader = DataLoader(test_dataset, batch_size=cfg.test_bs, shuffle=False, num_workers=cfg.num_workers)
sam_mask_encoder, sam_prompt_encoder, sam_mask_decoder = sam_model_registry[cfg.model.sam_model_type](checkpoint=cfg.model.sam_model_dir,customized=cfg.model.sam_model_customized)
model = EndoSAMAdapter(device, cfg.model.class_num, sam_mask_encoder, sam_prompt_encoder, sam_mask_decoder, num_token=cfg.num_token).to(device)
weights = torch.load(os.path.join(model_exp_path,'model.pth'), map_location=device)['endosam_state_dict']
model.load_state_dict(weights)
model.eval()
iou_dict = {}
ious = []
with torch.no_grad():
for img, ann, name, img_bgr in test_loader:
cv2.destroyAllWindows()
img = img.to(device)
ann = ann.to(device).unsqueeze(1).long()
ann = one_hot_embedding_3d(ann, class_num=cfg.model.class_num)
pred, pred_quality = model(img)
mask_iou = np.nan
if torch.unique(pred).size()[0] > 1:
iou = jaccard(ann, pred)
mask_iou = iou.item()
iou_dict[name[0]] = mask_iou
ious.append(mask_iou)
pred = torch.argmax(pred, dim=1)
numpy_pred = pred.squeeze(0).detach().cpu().numpy()
numpy_pred[numpy_pred != 0] = 255
img_bgr = img_bgr.squeeze(0).detach().cpu().numpy()
# 将预测结果转换为三通道图像
overlay = np.zeros_like(img_bgr)
red_color = (0, 0, 255) # 红色
overlay[:,:,2][numpy_pred == 255] = 255
# 将红色区域叠加在原图上
alpha = 0.5 # 半透明度
result = cv2.addWeighted(img_bgr, 1 - alpha, overlay, alpha, 0)
cv2.imshow('Result', result)
# 等待键盘输入(最多等待1秒)
key = cv2.waitKey(1000) # 超时时间为1000毫秒(1秒)
# 判断是否有键盘输入
if key == ord('q'): # 如果用户按下 'q' 键
cv2.destroyAllWindows() # 关闭窗口
else:
# 继续执行其他操作
pass
cv2.imwrite(os.path.join(test_exp_mask_path, f'{name[0]}.png'), numpy_pred.astype(np.uint8))
cv2.imwrite(os.path.join(test_exp_overlay_path, f'{name[0]}.png'), result)
with open(os.path.join(test_exp_path, 'mask_ious.json'), 'w') as f:
json.dump(iou_dict, f, indent=4, sort_keys=False)
f.close()
avg_iou = np.mean(ious, axis=0)
print(f'average intersection over union of mask: {avg_iou}') |