Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,518 Bytes
476e0f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from typing import *
from PIL.Image import Image as PILImage
from numpy import ndarray
from torch import Tensor
from wandb import Image as WandbImage
from PIL import Image
import numpy as np
import torch
from einops import rearrange
import wandb
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
def wandb_mvimage_log(outputs: Dict[str, Tensor], max_num: int = 4, max_view: int = 8) -> List[WandbImage]:
"""Organize multi-view images in Dict `outputs` for wandb logging.
Only process values in Dict `outputs` that have keys containing the word "images",
which should be in the shape of (B, V, 3, H, W).
"""
formatted_images = []
for k in outputs.keys():
if "images" in k and outputs[k] is not None: # (B, V, 3, H, W)
assert outputs[k].ndim == 5
num, view = outputs[k].shape[:2]
num, view = min(num, max_num), min(view, max_view)
mvimages = rearrange(outputs[k][:num, :view], "b v c h w -> c (b h) (v w)")
formatted_images.append(
wandb.Image(
tensor_to_image(mvimages.detach()),
caption=k
)
)
return formatted_images
def tensor_to_image(tensor: Tensor, return_pil: bool = False) -> Union[ndarray, PILImage]:
if tensor.ndim == 4: # (B, C, H, W)
tensor = rearrange(tensor, "b c h w -> c h (b w)")
assert tensor.ndim == 3 # (C, H, W)
assert tensor.shape[0] in [1, 3] # grayscale, RGB (not consider RGBA here)
if tensor.shape[0] == 1:
tensor = tensor.repeat(3, 1, 1)
image = (tensor.permute(1, 2, 0).cpu().float().numpy() * 255).astype(np.uint8) # (H, W, C)
if return_pil:
image = Image.fromarray(image)
return image
def load_image(image_path: str, rgba: bool = False, imagenet_norm: bool = False) -> Tensor:
image = Image.open(image_path)
tensor_image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255. # (C, H, W) in [0, 1]
if not rgba and tensor_image.shape[0] == 4:
mask = tensor_image[3:4]
tensor_image = tensor_image[:3] * mask + (1. - mask) # white background
if imagenet_norm:
mean = torch.tensor(IMAGENET_MEAN, dtype=tensor_image.dtype, device=tensor_image.device).view(3, 1, 1)
std = torch.tensor(IMAGENET_STD, dtype=tensor_image.dtype, device=tensor_image.device).view(3, 1, 1)
tensor_image = (tensor_image - mean) / std
return tensor_image # (C, H, W)
|