EndoSAM / endoSAM /utils.py
Chris Xiao
init model
2df812d
'''
Author: Chris Xiao yl.xiao@mail.utoronto.ca
Date: 2023-09-16 19:47:31
LastEditors: Chris Xiao yl.xiao@mail.utoronto.ca
LastEditTime: 2023-12-15 13:27:37
FilePath: /EndoSAM/endoSAM/utils.py
Description: EndoSAM utilities functions
I Love IU
Copyright (c) 2023 by Chris Xiao yl.xiao@mail.utoronto.ca, All Rights Reserved.
'''
import os
import numpy as np
import shutil
import logging
from torch.nn import functional as F
import torch
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
from copy import deepcopy
import matplotlib.pyplot as plt
from typing import Tuple
import matplotlib
def plot_progress(logger, save_dir, train_loss, val_loss, name):
"""
Should probably by improved
:return:
"""
assert len(train_loss) != 0
train_loss = np.array(train_loss)
try:
font = {'weight': 'normal',
'size': 18}
matplotlib.rc('font', **font)
fig = plt.figure(figsize=(30, 24))
ax = fig.add_subplot(111)
ax.plot(train_loss[:,0], train_loss[:,1], color='b', ls='-', label="loss_tr")
if len(val_loss) != 0:
val_loss = np.array(val_loss)
ax.plot(val_loss[:, 0], val_loss[:, 1], color='r', ls='-', label="loss_val")
ax.set_xlabel("epoch")
ax.set_ylabel("loss")
ax.legend()
ax.set_title(name)
fig.savefig(os.path.join(save_dir, name + ".png"))
plt.cla()
plt.close(fig)
except:
logger.info(f"failed to plot {name} training progress")
def save_checkpoint(adapter_model, optimizer, epoch, best_val_loss, train_losses, val_losses, save_dir):
torch.save({
'epoch': epoch,
'best_val_loss': best_val_loss,
'train_losses': train_losses,
'val_losses': val_losses,
'weights': adapter_model.state_dict(),
'optimizer': optimizer.state_dict(),
}, save_dir)
def one_hot_embedding_3d(labels, dim=1, class_num=21):
'''
:param real_labels: B 1 H W
:param class_num: N
:return: B N H W
'''
one_hot_labels = labels.clone()
data_dim = list(one_hot_labels.shape)
if data_dim[dim] != 1:
raise AssertionError("labels should have a channel with length equal to one.")
data_dim[dim] = class_num
o = torch.zeros(size=data_dim, dtype=one_hot_labels.dtype, device=one_hot_labels.device)
return o.scatter_(dim, one_hot_labels, 1).contiguous().float()
def setup_logger(logger_name, log_file, level=logging.INFO):
log_setup = logging.getLogger(logger_name)
formatter = logging.Formatter('%(asctime)s %(message)s', datefmt="%Y-%m-%d %H:%M:%S")
log_setup.setLevel(level)
log_setup.propagate = False
if not log_setup.handlers:
fileHandler = logging.FileHandler(log_file, mode='w')
fileHandler.setFormatter(formatter)
streamHandler = logging.StreamHandler()
streamHandler.setFormatter(formatter)
log_setup.addHandler(fileHandler)
log_setup.addHandler(streamHandler)
return log_setup
def make_if_dont_exist(folder_path, overwrite=False):
if os.path.exists(folder_path):
if not overwrite:
print(f'{folder_path} exists, no overwrite here.')
else:
print(f"{folder_path} overwritten")
shutil.rmtree(folder_path, ignore_errors = True)
os.makedirs(folder_path)
else:
os.makedirs(folder_path)
print(f"{folder_path} created!")
# taken from sam.postprocess_masks of https://github.com/facebookresearch/segment-anything
def postprocess_masks(masks, input_size, original_size):
"""
Remove padding and upscale masks to the original image size.
Arguments:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
"""
masks = F.interpolate(
masks,
(1024, 1024),
mode="bilinear",
align_corners=False,
)
masks = masks[..., : input_size[0], : input_size[1]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
return masks
def preprocess(x: torch.Tensor, img_size: int) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
pixel_mean=[123.675, 116.28, 103.53]
pixel_std=[58.395, 57.12, 57.375]
pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)
pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)
x = (x - pixel_mean) / pixel_std
# Pad
h, w = x.shape[-2:]
padh = img_size - h
padw = img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
class ResizeLongestSide:
"""
Resizes images to longest side 'target_length', as well as provides
methods for resizing coordinates and boxes. Provides methods for
transforming both numpy array and batched torch tensors.
"""
def __init__(self, target_length: int) -> None:
self.target_length = target_length
def apply_image(self, image: np.ndarray) -> np.ndarray:
"""
Expects a numpy array with shape HxWxC in uint8 format.
"""
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
return np.array(resize(to_pil_image(image), target_size))
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
"""
Expects a numpy array of length 2 in the final dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], self.target_length
)
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
"""
Expects a numpy array shape Bx4. Requires the original image size
in (H, W) format.
"""
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
"""
Expects batched images with shape BxCxHxW and float format. This
transformation may not exactly match apply_image. apply_image is
the transformation expected by the model.
"""
# Expects an image in BCHW format. May not exactly match apply_image.
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
return F.interpolate(
image, target_size, mode="bilinear", align_corners=False, antialias=True
)
def apply_coords_torch(
self, coords: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
"""
Expects a torch tensor with length 2 in the last dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], self.target_length
)
coords = deepcopy(coords).to(torch.float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes_torch(
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
"""
Expects a torch tensor with shape Bx4. Requires the original image
size in (H, W) format.
"""
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
"""
Compute the output size given input size and target long side length.
"""
scale = long_side_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)