Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
from PIL.Image import Image as PILImage | |
from numpy import ndarray | |
from torch import Tensor | |
from wandb import Image as WandbImage | |
from PIL import Image | |
import numpy as np | |
import torch | |
from einops import rearrange | |
import wandb | |
IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
IMAGENET_STD = [0.229, 0.224, 0.225] | |
def wandb_mvimage_log(outputs: Dict[str, Tensor], max_num: int = 4, max_view: int = 8) -> List[WandbImage]: | |
"""Organize multi-view images in Dict `outputs` for wandb logging. | |
Only process values in Dict `outputs` that have keys containing the word "images", | |
which should be in the shape of (B, V, 3, H, W). | |
""" | |
formatted_images = [] | |
for k in outputs.keys(): | |
if "images" in k and outputs[k] is not None: # (B, V, 3, H, W) | |
assert outputs[k].ndim == 5 | |
num, view = outputs[k].shape[:2] | |
num, view = min(num, max_num), min(view, max_view) | |
mvimages = rearrange(outputs[k][:num, :view], "b v c h w -> c (b h) (v w)") | |
formatted_images.append( | |
wandb.Image( | |
tensor_to_image(mvimages.detach()), | |
caption=k | |
) | |
) | |
return formatted_images | |
def tensor_to_image(tensor: Tensor, return_pil: bool = False) -> Union[ndarray, PILImage]: | |
if tensor.ndim == 4: # (B, C, H, W) | |
tensor = rearrange(tensor, "b c h w -> c h (b w)") | |
assert tensor.ndim == 3 # (C, H, W) | |
assert tensor.shape[0] in [1, 3] # grayscale, RGB (not consider RGBA here) | |
if tensor.shape[0] == 1: | |
tensor = tensor.repeat(3, 1, 1) | |
image = (tensor.permute(1, 2, 0).cpu().float().numpy() * 255).astype(np.uint8) # (H, W, C) | |
if return_pil: | |
image = Image.fromarray(image) | |
return image | |
def load_image(image_path: str, rgba: bool = False, imagenet_norm: bool = False) -> Tensor: | |
image = Image.open(image_path) | |
tensor_image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255. # (C, H, W) in [0, 1] | |
if not rgba and tensor_image.shape[0] == 4: | |
mask = tensor_image[3:4] | |
tensor_image = tensor_image[:3] * mask + (1. - mask) # white background | |
if imagenet_norm: | |
mean = torch.tensor(IMAGENET_MEAN, dtype=tensor_image.dtype, device=tensor_image.device).view(3, 1, 1) | |
std = torch.tensor(IMAGENET_STD, dtype=tensor_image.dtype, device=tensor_image.device).view(3, 1, 1) | |
tensor_image = (tensor_image - mean) / std | |
return tensor_image # (C, H, W) | |