|
''' |
|
Author: Chris Xiao yl.xiao@mail.utoronto.ca |
|
Date: 2023-09-16 19:47:31 |
|
LastEditors: Chris Xiao yl.xiao@mail.utoronto.ca |
|
LastEditTime: 2023-12-15 13:27:37 |
|
FilePath: /EndoSAM/endoSAM/utils.py |
|
Description: EndoSAM utilities functions |
|
I Love IU |
|
Copyright (c) 2023 by Chris Xiao yl.xiao@mail.utoronto.ca, All Rights Reserved. |
|
''' |
|
import os |
|
import numpy as np |
|
import shutil |
|
import logging |
|
from torch.nn import functional as F |
|
import torch |
|
from torchvision.transforms.functional import resize, to_pil_image |
|
from copy import deepcopy |
|
import matplotlib.pyplot as plt |
|
from typing import Tuple |
|
import matplotlib |
|
|
|
|
|
def plot_progress(logger, save_dir, train_loss, val_loss, name): |
|
""" |
|
Should probably by improved |
|
:return: |
|
""" |
|
assert len(train_loss) != 0 |
|
train_loss = np.array(train_loss) |
|
try: |
|
font = {'weight': 'normal', |
|
'size': 18} |
|
|
|
matplotlib.rc('font', **font) |
|
|
|
fig = plt.figure(figsize=(30, 24)) |
|
ax = fig.add_subplot(111) |
|
ax.plot(train_loss[:,0], train_loss[:,1], color='b', ls='-', label="loss_tr") |
|
if len(val_loss) != 0: |
|
val_loss = np.array(val_loss) |
|
ax.plot(val_loss[:, 0], val_loss[:, 1], color='r', ls='-', label="loss_val") |
|
|
|
ax.set_xlabel("epoch") |
|
ax.set_ylabel("loss") |
|
ax.legend() |
|
ax.set_title(name) |
|
fig.savefig(os.path.join(save_dir, name + ".png")) |
|
plt.cla() |
|
plt.close(fig) |
|
except: |
|
logger.info(f"failed to plot {name} training progress") |
|
|
|
|
|
def save_checkpoint(adapter_model, optimizer, epoch, best_val_loss, train_losses, val_losses, save_dir): |
|
torch.save({ |
|
'epoch': epoch, |
|
'best_val_loss': best_val_loss, |
|
'train_losses': train_losses, |
|
'val_losses': val_losses, |
|
'weights': adapter_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
}, save_dir) |
|
|
|
|
|
def one_hot_embedding_3d(labels, dim=1, class_num=21): |
|
''' |
|
:param real_labels: B 1 H W |
|
:param class_num: N |
|
:return: B N H W |
|
''' |
|
one_hot_labels = labels.clone() |
|
data_dim = list(one_hot_labels.shape) |
|
if data_dim[dim] != 1: |
|
raise AssertionError("labels should have a channel with length equal to one.") |
|
data_dim[dim] = class_num |
|
o = torch.zeros(size=data_dim, dtype=one_hot_labels.dtype, device=one_hot_labels.device) |
|
return o.scatter_(dim, one_hot_labels, 1).contiguous().float() |
|
|
|
|
|
def setup_logger(logger_name, log_file, level=logging.INFO): |
|
log_setup = logging.getLogger(logger_name) |
|
formatter = logging.Formatter('%(asctime)s %(message)s', datefmt="%Y-%m-%d %H:%M:%S") |
|
log_setup.setLevel(level) |
|
log_setup.propagate = False |
|
if not log_setup.handlers: |
|
fileHandler = logging.FileHandler(log_file, mode='w') |
|
fileHandler.setFormatter(formatter) |
|
streamHandler = logging.StreamHandler() |
|
streamHandler.setFormatter(formatter) |
|
log_setup.addHandler(fileHandler) |
|
log_setup.addHandler(streamHandler) |
|
|
|
return log_setup |
|
|
|
|
|
def make_if_dont_exist(folder_path, overwrite=False): |
|
if os.path.exists(folder_path): |
|
if not overwrite: |
|
print(f'{folder_path} exists, no overwrite here.') |
|
else: |
|
print(f"{folder_path} overwritten") |
|
shutil.rmtree(folder_path, ignore_errors = True) |
|
os.makedirs(folder_path) |
|
else: |
|
os.makedirs(folder_path) |
|
print(f"{folder_path} created!") |
|
|
|
|
|
|
|
def postprocess_masks(masks, input_size, original_size): |
|
""" |
|
Remove padding and upscale masks to the original image size. |
|
|
|
Arguments: |
|
masks (torch.Tensor): Batched masks from the mask_decoder, |
|
in BxCxHxW format. |
|
input_size (tuple(int, int)): The size of the image input to the |
|
model, in (H, W) format. Used to remove padding. |
|
original_size (tuple(int, int)): The original size of the image |
|
before resizing for input to the model, in (H, W) format. |
|
|
|
Returns: |
|
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) |
|
is given by original_size. |
|
""" |
|
masks = F.interpolate( |
|
masks, |
|
(1024, 1024), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
masks = masks[..., : input_size[0], : input_size[1]] |
|
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) |
|
return masks |
|
|
|
|
|
def preprocess(x: torch.Tensor, img_size: int) -> torch.Tensor: |
|
"""Normalize pixel values and pad to a square input.""" |
|
|
|
pixel_mean=[123.675, 116.28, 103.53] |
|
pixel_std=[58.395, 57.12, 57.375] |
|
pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1) |
|
pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1) |
|
x = (x - pixel_mean) / pixel_std |
|
|
|
|
|
h, w = x.shape[-2:] |
|
padh = img_size - h |
|
padw = img_size - w |
|
x = F.pad(x, (0, padw, 0, padh)) |
|
return x |
|
|
|
|
|
class ResizeLongestSide: |
|
""" |
|
Resizes images to longest side 'target_length', as well as provides |
|
methods for resizing coordinates and boxes. Provides methods for |
|
transforming both numpy array and batched torch tensors. |
|
""" |
|
|
|
def __init__(self, target_length: int) -> None: |
|
self.target_length = target_length |
|
|
|
def apply_image(self, image: np.ndarray) -> np.ndarray: |
|
""" |
|
Expects a numpy array with shape HxWxC in uint8 format. |
|
""" |
|
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) |
|
return np.array(resize(to_pil_image(image), target_size)) |
|
|
|
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: |
|
""" |
|
Expects a numpy array of length 2 in the final dimension. Requires the |
|
original image size in (H, W) format. |
|
""" |
|
old_h, old_w = original_size |
|
new_h, new_w = self.get_preprocess_shape( |
|
original_size[0], original_size[1], self.target_length |
|
) |
|
coords = deepcopy(coords).astype(float) |
|
coords[..., 0] = coords[..., 0] * (new_w / old_w) |
|
coords[..., 1] = coords[..., 1] * (new_h / old_h) |
|
return coords |
|
|
|
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: |
|
""" |
|
Expects a numpy array shape Bx4. Requires the original image size |
|
in (H, W) format. |
|
""" |
|
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) |
|
return boxes.reshape(-1, 4) |
|
|
|
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Expects batched images with shape BxCxHxW and float format. This |
|
transformation may not exactly match apply_image. apply_image is |
|
the transformation expected by the model. |
|
""" |
|
|
|
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) |
|
return F.interpolate( |
|
image, target_size, mode="bilinear", align_corners=False, antialias=True |
|
) |
|
|
|
def apply_coords_torch( |
|
self, coords: torch.Tensor, original_size: Tuple[int, ...] |
|
) -> torch.Tensor: |
|
""" |
|
Expects a torch tensor with length 2 in the last dimension. Requires the |
|
original image size in (H, W) format. |
|
""" |
|
old_h, old_w = original_size |
|
new_h, new_w = self.get_preprocess_shape( |
|
original_size[0], original_size[1], self.target_length |
|
) |
|
coords = deepcopy(coords).to(torch.float) |
|
coords[..., 0] = coords[..., 0] * (new_w / old_w) |
|
coords[..., 1] = coords[..., 1] * (new_h / old_h) |
|
return coords |
|
|
|
def apply_boxes_torch( |
|
self, boxes: torch.Tensor, original_size: Tuple[int, ...] |
|
) -> torch.Tensor: |
|
""" |
|
Expects a torch tensor with shape Bx4. Requires the original image |
|
size in (H, W) format. |
|
""" |
|
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) |
|
return boxes.reshape(-1, 4) |
|
|
|
@staticmethod |
|
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: |
|
""" |
|
Compute the output size given input size and target long side length. |
|
""" |
|
scale = long_side_length * 1.0 / max(oldh, oldw) |
|
newh, neww = oldh * scale, oldw * scale |
|
neww = int(neww + 0.5) |
|
newh = int(newh + 0.5) |
|
return (newh, neww) |
|
|