alltracker_demo / utils /improc.py
aharley's picture
added basics
6d95ea1
import torch
import numpy as np
import utils.basic
import utils.py
from sklearn.decomposition import PCA
from matplotlib import cm
import matplotlib.pyplot as plt
import cv2
import torch.nn.functional as F
import torchvision
EPS = 1e-6
from skimage.color import (
rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb,
rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb)
def _convert(input_, type_):
return {
'float': input_.float(),
'double': input_.double(),
}.get(type_, input_)
def _generic_transform_sk_3d(transform, in_type='', out_type=''):
def apply_transform_individual(input_):
device = input_.device
input_ = input_.cpu()
input_ = _convert(input_, in_type)
input_ = input_.permute(1, 2, 0).detach().numpy()
transformed = transform(input_)
output = torch.from_numpy(transformed).float().permute(2, 0, 1)
output = _convert(output, out_type)
return output.to(device)
def apply_transform(input_):
to_stack = []
for image in input_:
to_stack.append(apply_transform_individual(image))
return torch.stack(to_stack)
return apply_transform
hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb)
def preprocess_color_tf(x):
import tensorflow as tf
return tf.cast(x,tf.float32) * 1./255 - 0.5
def preprocess_color(x):
if isinstance(x, np.ndarray):
return x.astype(np.float32) * 1./255 - 0.5
else:
return x.float() * 1./255 - 0.5
def pca_embed(emb, keep, valid=None):
## emb -- [S,H/2,W/2,C]
## keep is the number of principal components to keep
## Helper function for reduce_emb.
emb = emb + EPS
#emb is B x C x H x W
emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C
if valid:
valid = valid.cpu().detach().numpy().reshape((H*W))
emb_reduced = list()
B, H, W, C = np.shape(emb)
for img in emb:
if np.isnan(img).any():
emb_reduced.append(np.zeros([H, W, keep]))
continue
pixels_kd = np.reshape(img, (H*W, C))
if valid:
pixels_kd_pca = pixels_kd[valid]
else:
pixels_kd_pca = pixels_kd
P = PCA(keep)
P.fit(pixels_kd_pca)
if valid:
pixels3d = P.transform(pixels_kd)*valid
else:
pixels3d = P.transform(pixels_kd)
out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32)
if np.isnan(out_img).any():
emb_reduced.append(np.zeros([H, W, keep]))
continue
emb_reduced.append(out_img)
emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32)
return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2)
def pca_embed_together(emb, keep):
## emb -- [S,H/2,W/2,C]
## keep is the number of principal components to keep
## Helper function for reduce_emb.
emb = emb + EPS
#emb is B x C x H x W
emb = emb.permute(0, 2, 3, 1).cpu().detach().float().numpy() #this is B x H x W x C
B, H, W, C = np.shape(emb)
if np.isnan(emb).any():
return torch.zeros(B, keep, H, W)
pixelskd = np.reshape(emb, (B*H*W, C))
P = PCA(keep)
P.fit(pixelskd)
pixels3d = P.transform(pixelskd)
out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32)
if np.isnan(out_img).any():
return torch.zeros(B, keep, H, W)
return torch.from_numpy(out_img).permute(0, 3, 1, 2)
def reduce_emb(emb, valid=None, inbound=None, together=False):
## emb -- [S,C,H/2,W/2], inbound -- [S,1,H/2,W/2]
## Reduce number of chans to 3 with PCA. For vis.
# S,H,W,C = emb.shape.as_list()
S, C, H, W = list(emb.size())
keep = 4
if together:
reduced_emb = pca_embed_together(emb, keep)
else:
reduced_emb = pca_embed(emb, keep, valid) #not im
reduced_emb = reduced_emb[:,1:]
reduced_emb = utils.basic.normalize(reduced_emb) - 0.5
if inbound is not None:
emb_inbound = emb*inbound
else:
emb_inbound = None
return reduced_emb, emb_inbound
def get_feat_pca(feat, valid=None):
B, C, D, W = list(feat.size())
# feat is B x C x D x W. If 3D input, average it through Height dimension before passing into this function.
pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True)
# pca is B x 3 x W x D
return pca
def gif_and_tile(ims, just_gif=False):
S = len(ims)
# each im is B x H x W x C
# i want a gif in the left, and the tiled frames on the right
# for the gif tool, this means making a B x S x H x W tensor
# where the leftmost part is sequential and the rest is tiled
gif = torch.stack(ims, dim=1)
if just_gif:
return gif
til = torch.cat(ims, dim=2)
til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1)
im = torch.cat([gif, til], dim=3)
return im
def back2color(i, blacken_zeros=False):
if blacken_zeros:
const = torch.tensor([-0.5])
i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i)
return back2color(i)
else:
return ((i+0.5)*255).type(torch.ByteTensor)
def xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=False):
# xy is B x N x 2, containing float x and y coordinates of N things
# grid_xs and grid_ys are B x N x Y x X
B, N, Y, X = list(grid_xs.shape)
mu_x = xy[:,:,0].clone()
mu_y = xy[:,:,1].clone()
x_valid = (mu_x>-0.5) & (mu_x<float(X+0.5))
y_valid = (mu_y>-0.5) & (mu_y<float(Y+0.5))
not_valid = ~(x_valid & y_valid)
mu_x[not_valid] = -10000
mu_y[not_valid] = -10000
mu_x = mu_x.reshape(B, N, 1, 1).repeat(1, 1, Y, X)
mu_y = mu_y.reshape(B, N, 1, 1).repeat(1, 1, Y, X)
sigma_sq = sigma*sigma
# sigma_sq = (sigma*sigma).reshape(B, N, 1, 1)
sq_diff_x = (grid_xs - mu_x)**2
sq_diff_y = (grid_ys - mu_y)**2
term1 = 1./2.*np.pi*sigma_sq
term2 = torch.exp(-(sq_diff_x+sq_diff_y)/(2.*sigma_sq))
gauss = term1*term2
if norm:
# normalize so each gaussian peaks at 1
gauss_ = gauss.reshape(B*N, Y, X)
gauss_ = utils.basic.normalize(gauss_)
gauss = gauss_.reshape(B, N, Y, X)
return gauss
def xy2heatmaps(xy, Y, X, sigma=30.0, norm=True):
# xy is B x N x 2
B, N, D = list(xy.shape)
assert(D==2)
device = xy.device
grid_y, grid_x = utils.basic.meshgrid2d(B, Y, X, device=device)
# grid_x and grid_y are B x Y x X
grid_xs = grid_x.unsqueeze(1).repeat(1, N, 1, 1)
grid_ys = grid_y.unsqueeze(1).repeat(1, N, 1, 1)
heat = xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=norm)
return heat
def draw_circles_at_xy(xy, Y, X, sigma=12.5, round=False):
B, N, D = list(xy.shape)
assert(D==2)
prior = xy2heatmaps(xy, Y, X, sigma=sigma)
# prior is B x N x Y x X
if round:
prior = (prior > 0.5).float()
return prior
def seq2color(im, norm=True, colormap='coolwarm'):
B, S, H, W = list(im.shape)
# S is sequential
# prep a mask of the valid pixels, so we can blacken the invalids later
mask = torch.max(im, dim=1, keepdim=True)[0]
# turn the S dim into an explicit sequence
coeffs = np.linspace(1.0, float(S), S).astype(np.float32)/float(S)
# # increase the spacing from the center
# coeffs[:int(S/2)] -= 2.0
# coeffs[int(S/2)+1:] += 2.0
coeffs = torch.from_numpy(coeffs).float().cuda()
coeffs = coeffs.reshape(1, S, 1, 1).repeat(B, 1, H, W)
# scale each channel by the right coeff
im = im * coeffs
# now im is in [1/S, 1], except for the invalid parts which are 0
# keep the highest valid coeff at each pixel
im = torch.max(im, dim=1, keepdim=True)[0]
out = []
for b in range(B):
im_ = im[b]
# move channels out to last dim_
im_ = im_.detach().cpu().numpy()
im_ = np.squeeze(im_)
# im_ is H x W
if colormap=='coolwarm':
im_ = cm.coolwarm(im_)[:, :, :3]
elif colormap=='PiYG':
im_ = cm.PiYG(im_)[:, :, :3]
elif colormap=='winter':
im_ = cm.winter(im_)[:, :, :3]
elif colormap=='spring':
im_ = cm.spring(im_)[:, :, :3]
elif colormap=='onediff':
im_ = np.reshape(im_, (-1))
im0_ = cm.spring(im_)[:, :3]
im1_ = cm.winter(im_)[:, :3]
im1_[im_==1/float(S)] = im0_[im_==1/float(S)]
im_ = np.reshape(im1_, (H, W, 3))
else:
assert(False) # invalid colormap
# move channels into dim 0
im_ = np.transpose(im_, [2, 0, 1])
im_ = torch.from_numpy(im_).float().cuda()
out.append(im_)
out = torch.stack(out, dim=0)
# blacken the invalid pixels, instead of using the 0-color
out = out*mask
# out = out*255.0
# put it in [-0.5, 0.5]
out = out - 0.5
return out
def colorize(d):
# this is actually just grayscale right now
if d.ndim==2:
d = d.unsqueeze(dim=0)
else:
assert(d.ndim==3)
# color_map = cm.get_cmap('plasma')
color_map = cm.get_cmap('inferno')
# S1, D = traj.shape
# print('d1', d.shape)
C,H,W = d.shape
assert(C==1)
d = d.reshape(-1)
d = d.detach().cpu().numpy()
# print('d2', d.shape)
color = np.array(color_map(d)) * 255 # rgba
# print('color1', color.shape)
color = np.reshape(color[:,:3], [H*W, 3])
# print('color2', color.shape)
color = torch.from_numpy(color).permute(1,0).reshape(3,H,W)
# # gather
# cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray')
# if cmap=='RdBu' or cmap=='RdYlGn':
# colors = cm(np.arange(256))[:, :3]
# else:
# colors = cm.colors
# colors = np.array(colors).astype(np.float32)
# colors = np.reshape(colors, [-1, 3])
# colors = tf.constant(colors, dtype=tf.float32)
# value = tf.gather(colors, indices)
# colorize(value, normalize=True, vmin=None, vmax=None, cmap=None, vals=255)
# copy to the three chans
# d = d.repeat(3, 1, 1)
return color
def oned2inferno(d, norm=True, do_colorize=False):
# convert a 1chan input to a 3chan image output
# if it's just B x H x W, add a C dim
if d.ndim==3:
d = d.unsqueeze(dim=1)
# d should be B x C x H x W, where C=1
B, C, H, W = list(d.shape)
assert(C==1)
if norm:
d = utils.basic.normalize(d)
if do_colorize:
rgb = torch.zeros(B, 3, H, W)
for b in list(range(B)):
rgb[b] = colorize(d[b])
else:
rgb = d.repeat(1, 3, 1, 1)*255.0
# rgb = (255.0*rgb).type(torch.ByteTensor)
rgb = rgb.type(torch.ByteTensor)
# rgb = tf.cast(255.0*rgb, tf.uint8)
# rgb = tf.reshape(rgb, [-1, hyp.H, hyp.W, 3])
# rgb = tf.expand_dims(rgb, axis=0)
return rgb
def oned2gray(d, norm=True):
# convert a 1chan input to a 3chan image output
# if it's just B x H x W, add a C dim
if d.ndim==3:
d = d.unsqueeze(dim=1)
# d should be B x C x H x W, where C=1
B, C, H, W = list(d.shape)
assert(C==1)
if norm:
d = utils.basic.normalize(d)
rgb = d.repeat(1,3,1,1)
rgb = (255.0*rgb).type(torch.ByteTensor)
return rgb
def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20, shadow=True):
rgb = vis.detach().cpu().numpy()[0]
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
color = (255, 255, 255)
# print('putting frame id', frame_id)
frame_str = utils.basic.strnum(frame_id)
text_color_bg = (0,0,0)
font = cv2.FONT_HERSHEY_SIMPLEX
text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
text_w, text_h = text_size
if shadow:
cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
cv2.putText(
rgb,
frame_str,
(left, top), # from left, from top
font,
scale, # font scale (float)
color,
1) # font thickness (int)
rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
return vis
def draw_frame_str_on_vis(vis, frame_str, scale=0.5, left=5, top=40, shadow=True):
rgb = vis.detach().cpu().numpy()[0]
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
color = (255, 255, 255)
text_color_bg = (0,0,0)
font = cv2.FONT_HERSHEY_SIMPLEX
text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
text_w, text_h = text_size
if shadow:
cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
cv2.putText(
rgb,
frame_str,
(left, top), # from left, from top
font,
scale, # font scale (float)
color,
1) # font thickness (int)
rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
return vis
COLORMAP_FILE = "./utils/bremm.png"
class ColorMap2d:
def __init__(self, filename=None):
self._colormap_file = filename or COLORMAP_FILE
self._img = plt.imread(self._colormap_file)
self._height = self._img.shape[0]
self._width = self._img.shape[1]
def __call__(self, X):
assert len(X.shape) == 2
output = np.zeros((X.shape[0], 3))
for i in range(X.shape[0]):
x, y = X[i, :]
xp = int((self._width-1) * x)
yp = int((self._height-1) * y)
xp = np.clip(xp, 0, self._width-1)
yp = np.clip(yp, 0, self._height-1)
output[i, :] = self._img[yp, xp]
return output
def get_n_colors(N, sequential=False):
label_colors = []
for ii in range(N):
if sequential:
rgb = cm.winter(ii/(N-1))
rgb = (np.array(rgb) * 255).astype(np.uint8)[:3]
else:
# rgb = np.zeros(3)
# while np.sum(rgb) < 128: # ensure min brightness
# rgb = np.random.randint(0,256,3)
rgb = cm.gist_rainbow(ii/(N-1))
rgb = (np.array(rgb) * 255).astype(np.uint8)[:3]
label_colors.append(rgb)
return label_colors
class Summ_writer(object):
def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False):
self.writer = writer
self.global_step = global_step
self.log_freq = log_freq
self.scalar_freq = scalar_freq
self.fps = fps
self.just_gif = just_gif
self.maxwidth = 10000
self.save_this = (self.global_step % self.log_freq == 0)
self.scalar_freq = max(scalar_freq,1)
self.save_scalar = (self.global_step % self.scalar_freq == 0)
if self.save_this:
self.save_scalar = True
def summ_gif(self, name, tensor, blacken_zeros=False):
# tensor should be in B x S x C x H x W
assert tensor.dtype in {torch.uint8,torch.float32}
shape = list(tensor.shape)
if tensor.dtype == torch.float32:
tensor = back2color(tensor, blacken_zeros=blacken_zeros)
video_to_write = tensor[0:1]
S = video_to_write.shape[1]
if S==1:
# video_to_write is 1 x 1 x C x H x W
self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step)
else:
self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step)
return video_to_write
def draw_boxlist2d_on_image(self, rgb, boxlist, scores=None, tids=None, linewidth=1):
B, C, H, W = list(rgb.shape)
assert(C==3)
B2, N, D = list(boxlist.shape)
assert(B2==B)
assert(D==4) # ymin, xmin, ymax, xmax
rgb = back2color(rgb)
if scores is None:
scores = torch.ones(B2, N).float()
if tids is None:
# tids = torch.arange(N).reshape(1,N).repeat(B2,1).long()
tids = torch.zeros(B2, N).long()
out = self.draw_boxlist2d_on_image_py(
rgb[0].cpu().detach().numpy(),
boxlist[0].cpu().detach().numpy(),
scores[0].cpu().detach().numpy(),
tids[0].cpu().detach().numpy(),
linewidth=linewidth)
out = torch.from_numpy(out).type(torch.ByteTensor).permute(2, 0, 1)
out = torch.unsqueeze(out, dim=0)
out = preprocess_color(out)
out = torch.reshape(out, [1, C, H, W])
return out
def draw_boxlist2d_on_image_py(self, rgb, boxlist, scorelist, tidlist, linewidth=1):
# all inputs are numpy tensors
# rgb is H x W x 3
# boxlist is N x 4
# scorelist is N
# tidlist is N
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
# rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
rgb = rgb.astype(np.uint8).copy()
H, W, C = rgb.shape
assert(C==3)
N, D = boxlist.shape
assert(D==4)
M = scorelist.shape[0]
assert(M==N)
O = tidlist.shape[0]
assert(M==O)
# color_map = cm.get_cmap('Accent')
# color_map = cm.get_cmap('Set3')
color_map = cm.get_cmap('tab20')
color_map = color_map.colors
# print('color_map', color_map)
# draw
for (box, score, tid) in zip(boxlist, scorelist, tidlist):
# box is 4
if not np.isclose(score, 0.0, atol=1e-3):
# ymin, xmin, ymax, xmax = box
xmin, ymin, xmax, ymax = box
color = color_map[tid]
color = np.array(color)*255.0
color = color.round()
if not np.isclose(score, 1.0, atol=1e-3):
cv2.putText(rgb,
# '%d (%.2f)' % (tidlist[ind], scorelist[ind]),
'%.2f' % (score),
(int(xmin), int(ymin)),
cv2.FONT_HERSHEY_SIMPLEX,
0.5, # font size
color),
xmin = int(np.clip(xmin,0,W-1))
ymin = int(np.clip(ymin,0,H-1))
xmax = int(np.clip(xmax,0,W-1))
ymax = int(np.clip(ymax,0,H-1))
# print('xmin, xmax, ymin, ymax', xmin, xmax, ymin, ymax)
cv2.line(rgb, (xmin, ymin), (xmin, ymax), color, linewidth, cv2.LINE_4)
cv2.line(rgb, (xmin, ymin), (xmax, ymin), color, linewidth, cv2.LINE_4)
cv2.line(rgb, (xmax, ymin), (xmax, ymax), color, linewidth, cv2.LINE_4)
cv2.line(rgb, (xmax, ymax), (xmin, ymax), color, linewidth, cv2.LINE_4)
# rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
return rgb
def summ_boxlist2d(self, name, rgb, boxlist, scores=None, tids=None, frame_id=None, frame_str=None, only_return=False, linewidth=1):
B, C, H, W = list(rgb.shape)
boxlist_vis = self.draw_boxlist2d_on_image(rgb, boxlist, scores=scores, tids=tids, linewidth=linewidth)
return self.summ_rgb(name, boxlist_vis, frame_id=frame_id, frame_str=frame_str, only_return=only_return)
def summ_rgbs(self, name, ims, frame_ids=None, frame_strs=None, blacken_zeros=False, only_return=False):
if self.save_this:
ims = gif_and_tile(ims, just_gif=self.just_gif)
vis = ims
assert vis.dtype in {torch.uint8,torch.float32}
if vis.dtype == torch.float32:
vis = back2color(vis, blacken_zeros)
B, S, C, H, W = list(vis.shape)
if frame_ids is not None:
assert(len(frame_ids)==S)
for s in range(S):
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
if frame_strs is not None:
assert(len(frame_strs)==S)
for s in range(S):
vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
if int(W) > self.maxwidth:
vis = vis[:,:,:,:self.maxwidth]
if only_return:
return vis
else:
return self.summ_gif(name, vis, blacken_zeros)
def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, frame_str=None, only_return=False, halfres=False, shadow=True):
if self.save_this:
assert ims.dtype in {torch.uint8,torch.float32}
if ims.dtype == torch.float32:
ims = back2color(ims, blacken_zeros)
#ims is B x C x H x W
vis = ims[0:1] # just the first one
B, C, H, W = list(vis.shape)
if halfres:
vis = F.interpolate(vis, scale_factor=0.5)
if frame_id is not None:
vis = draw_frame_id_on_vis(vis, frame_id, shadow=shadow)
if frame_str is not None:
vis = draw_frame_str_on_vis(vis, frame_str, shadow=shadow)
if int(W) > self.maxwidth:
vis = vis[:,:,:,:self.maxwidth]
if only_return:
return vis
else:
return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros)
def flow2color(self, flow, clip=0.0):
B, C, H, W = list(flow.size())
assert(C==2)
flow = flow[0:1].detach()
if False:
flow = flow[0].detach().cpu().permute(1,2,0).numpy() # H,W,2
if clip > 0:
clip_flow = clip
else:
clip_flow = None
im = utils.py.flow_to_image(flow, clip_flow=clip_flow, convert_to_bgr=True)
# im = utils.py.flow_to_image(flow, convert_to_bgr=True)
im = torch.from_numpy(im).permute(2,0,1).unsqueeze(0).byte() # 1,3,H,W
im = torch.flip(im, dims=[1]).clone() # BGR
# # i prefer black bkg
# white_pixels = (im == 255).all(dim=1, keepdim=True)
# im[white_pixels.expand(-1, 3, -1, -1)] = 0
return im
# flow_abs = torch.abs(flow)
# flow_mean = flow_abs.mean(dim=[1,2,3])
# flow_std = flow_abs.std(dim=[1,2,3])
if clip==0:
clip = torch.max(torch.abs(flow)).item()
# if clip:
flow = torch.clamp(flow, -clip, clip)/clip
# else:
# # # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2)
# # flow_max = flow_mean + flow_std*2 + 1e-10
# # for b in range(B):
# # flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1)
# flow_max = torch.max(flow_abs[b])
# for b in range(B):
# flow[b] = flow[b].clamp(-flow_max.item(), flow_max.item()) / flow_max[b].clamp(min=1)
radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) #B x 1 x H x W
radius_clipped = torch.clamp(radius, 0.0, 1.0)
angle = torch.atan2(-flow[:, 1:2], -flow[:, 0:1]) / np.pi # B x 1 x H x W
hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)
# hue = torch.mod(angle / (2 * np.pi) + 1.0, 1.0)
saturation = torch.ones_like(hue) * 0.75
value = radius_clipped
hsv = torch.cat([hue, saturation, value], dim=1) #B x 3 x H x W
#flow = tf.image.hsv_to_rgb(hsv)
flow = hsv_to_rgb(hsv)
flow = (flow*255.0).type(torch.ByteTensor)
# flow = torch.flip(flow, dims=[1]).clone() # BGR
return flow
def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None, frame_str=None, shadow=True):
# flow is B x C x D x W
if self.save_this:
return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id, frame_str=frame_str, shadow=shadow)
else:
return None
def summ_oneds(self, name, ims, frame_ids=None, frame_strs=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False):
if self.save_this:
if bev:
B, C, H, _, W = list(ims[0].shape)
if reduce_max:
ims = [torch.max(im, dim=3)[0] for im in ims]
else:
ims = [torch.mean(im, dim=3) for im in ims]
elif fro:
B, C, _, H, W = list(ims[0].shape)
if reduce_max:
ims = [torch.max(im, dim=2)[0] for im in ims]
else:
ims = [torch.mean(im, dim=2) for im in ims]
if len(ims) != 1: # sequence
im = gif_and_tile(ims, just_gif=self.just_gif)
else:
im = torch.stack(ims, dim=1) # single frame
B, S, C, H, W = list(im.shape)
if logvis and max_val:
max_val = np.log(max_val)
im = torch.log(torch.clamp(im, 0)+1.0)
im = torch.clamp(im, 0, max_val)
im = im/max_val
norm = False
elif max_val:
im = torch.clamp(im, 0, max_val)
im = im/max_val
norm = False
if norm:
# normalize before oned2inferno,
# so that the ranges are similar within B across S
im = utils.basic.normalize(im)
im = im.view(B*S, C, H, W)
vis = oned2inferno(im, norm=norm, do_colorize=do_colorize)
vis = vis.view(B, S, 3, H, W)
if frame_ids is not None:
assert(len(frame_ids)==S)
for s in range(S):
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
if frame_strs is not None:
assert(len(frame_strs)==S)
for s in range(S):
vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
if W > self.maxwidth:
vis = vis[...,:self.maxwidth]
if only_return:
return vis
else:
self.summ_gif(name, vis)
def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, frame_str=None, only_return=False, shadow=True):
if self.save_this:
if bev:
B, C, H, _, W = list(im.shape)
if max_along_y:
im = torch.max(im, dim=3)[0]
else:
im = torch.mean(im, dim=3)
elif fro:
B, C, _, H, W = list(im.shape)
if max_along_y:
im = torch.max(im, dim=2)[0]
else:
im = torch.mean(im, dim=2)
else:
B, C, H, W = list(im.shape)
im = im[0:1] # just the first one
assert(C==1)
if logvis and max_val:
max_val = np.log(max_val)
im = torch.log(im)
im = torch.clamp(im, 0, max_val)
im = im/max_val
norm = False
elif max_val:
im = torch.clamp(im, 0, max_val)/max_val
norm = False
vis = oned2inferno(im, norm=norm)
if W > self.maxwidth:
vis = vis[...,:self.maxwidth]
return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, frame_str=frame_str, only_return=only_return, shadow=shadow)
def summ_4chan(self, name, im, norm=True, frame_id=None, frame_str=None, only_return=False):
if self.save_this:
B, C, H, W = list(im.shape)
im = im[0:1] # just the first one
assert(C==4)
# d = utils.basic.normalize(d)
im0 = im[:,0:1]
im1 = im[:,1:2]
im2 = im[:,2:3]
im3 = im[:,3:4]
im0 = utils.basic.normalize(im0).round()
im1 = utils.basic.normalize(im1).round()
im2 = utils.basic.normalize(im2).round()
im3 = utils.basic.normalize(im3).round()
# kp_vis = sw.summ_rgbs('tff/2_kp_s%d' % s, kp.unbind(1), only_return=True)
# kp_any = (torch.max(kp_vis, dim=2, keepdims=True)[0]).repeat(1, 1, 3, 1, 1)
# kp_vis[kp_any==0] = fcp_vis[kp_any==0]
# vis0 = oned2inferno(im0, norm=False)
# vis1 = oned2inferno(im1, norm=False)
# vis2 = oned2inferno(im2, norm=False)
# vis3 = oned2inferno(im3, norm=False)
# vis0 = self.summ_seg('', im0[:,0:1]*1, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20')
# vis1 = self.summ_seg('', im1[:,0:1]*2, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20')
# vis2 = self.summ_seg('', im2[:,0:1]*3, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20')
# vis3 = self.summ_seg('', im3[:,0:1]*4, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20')
vis0 = self.summ_seg('', im0[:,0]*1, only_return=True, colormap='tab20')
vis1 = self.summ_seg('', im1[:,0]*2, only_return=True, colormap='tab20')
vis2 = self.summ_seg('', im2[:,0]*3, only_return=True, colormap='tab20')
vis3 = self.summ_seg('', im3[:,0]*4, only_return=True, colormap='tab20')
# vis_any = (torch.max(vis2, dim=2, keepdims=True)[0]).repeat(1, 1, 3, 1, 1)
# vis3[vis_any==0] = fcp_vis[kp_any==0]
vis0_any = (torch.max(vis0, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1)
vis1_any = (torch.max(vis1, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1)
vis2_any = (torch.max(vis2, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1)
vis3_any = (torch.max(vis3, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1)
vis0[vis0_any==0] = vis1[vis0_any==0]
vis0[vis1_any==0] = vis2[vis1_any==0]
vis0[vis2_any==0] = vis3[vis2_any==0]
print('vis0', vis0.shape, vis0.device)
vis0 = vis0.cpu()
# vis = oned2inferno(im, norm=norm)
# if W > self.maxwidth:
# vis = vis[...,:self.maxwidth]
return self.summ_rgb(name, vis0, blacken_zeros=False, frame_id=frame_id, frame_str=frame_str, only_return=only_return)
def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None, frame_strs=None):
if self.save_this:
if valids is not None:
valids = torch.stack(valids, dim=1)
feats = torch.stack(feats, dim=1)
# feats leads with B x S x C
if feats.ndim==6:
# feats is B x S x C x D x H x W
if fro:
reduce_dim = 3
else:
reduce_dim = 4
if valids is None:
feats = torch.mean(feats, dim=reduce_dim)
else:
valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1)
feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim)
B, S, C, D, W = list(feats.size())
if not pca:
# feats leads with B x S x C
feats = torch.mean(torch.abs(feats), dim=2, keepdims=True)
# feats leads with B x S x 1
feats = torch.unbind(feats, dim=1)
return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
else:
__p = lambda x: utils.basic.pack_seqdim(x, B)
__u = lambda x: utils.basic.unpack_seqdim(x, B)
feats_ = __p(feats)
if valids is None:
feats_pca_ = get_feat_pca(feats_)
else:
valids_ = __p(valids)
feats_pca_ = get_feat_pca(feats_, valids)
feats_pca = __u(feats_pca_)
return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None, frame_str=None):
if self.save_this:
if feat.ndim==5: # B x C x D x H x W
if bev:
reduce_axis = 3
elif fro:
reduce_axis = 2
else:
# default to bev
reduce_axis = 3
if valid is None:
feat = torch.mean(feat, dim=reduce_axis)
else:
valid = valid.repeat(1, feat.size()[1], 1, 1, 1)
feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis)
B, C, D, W = list(feat.shape)
if not pca:
feat = torch.mean(torch.abs(feat), dim=1, keepdims=True)
# feat is B x 1 x D x W
return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
else:
feat_pca = get_feat_pca(feat, valid)
return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
def summ_scalar(self, name, value):
if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()):
value = value.detach().cpu().numpy()
if not np.isnan(value):
if (self.log_freq == 1):
self.writer.add_scalar(name, value, global_step=self.global_step)
elif self.save_this or self.save_scalar:
self.writer.add_scalar(name, value, global_step=self.global_step)
def summ_seg(self, name, seg, only_return=False, frame_id=None, frame_str=None, colormap='tab20', label_colors=None):
if not self.save_this:
return
B,H,W = seg.shape
if label_colors is None:
custom_label_colors = False
# label_colors = get_n_colors(int(torch.max(seg).item()), sequential=True)
label_colors = cm.get_cmap(colormap).colors
label_colors = [[int(i*255) for i in l] for l in label_colors]
else:
custom_label_colors = True
# label_colors = matplotlib.cm.get_cmap(colormap).colors
# label_colors = [[int(i*255) for i in l] for l in label_colors]
# print('label_colors', label_colors)
# label_colors = [
# (0, 0, 0), # None
# (70, 70, 70), # Buildings
# (190, 153, 153), # Fences
# (72, 0, 90), # Other
# (220, 20, 60), # Pedestrians
# (153, 153, 153), # Poles
# (157, 234, 50), # RoadLines
# (128, 64, 128), # Roads
# (244, 35, 232), # Sidewalks
# (107, 142, 35), # Vegetation
# (0, 0, 255), # Vehicles
# (102, 102, 156), # Walls
# (220, 220, 0) # TrafficSigns
# ]
r = torch.zeros_like(seg,dtype=torch.uint8)
g = torch.zeros_like(seg,dtype=torch.uint8)
b = torch.zeros_like(seg,dtype=torch.uint8)
for label in range(0,len(label_colors)):
if (not custom_label_colors):# and (N > 20):
label_ = label % 20
else:
label_ = label
idx = (seg == label)
r[idx] = label_colors[label_][0]
g[idx] = label_colors[label_][1]
b[idx] = label_colors[label_][2]
rgb = torch.stack([r,g,b],axis=1)
return self.summ_rgb(name,rgb,only_return=only_return, frame_id=frame_id, frame_str=frame_str)
def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap='coolwarm', vals=None, linewidth=1, max_show=1024):
# trajs is B, S, N, 2
# rgbs is B, S, C, H, W
B, S, C, H, W = rgbs.shape
B, S2, N, D = trajs.shape
assert(S==S2)
rgbs = rgbs[0] # S, C, H, W
trajs = trajs[0] # S, N, 2
if valids is None:
valids = torch.ones_like(trajs[:,:,0]) # S, N
else:
valids = valids[0]
if visibs is None:
visibs = torch.ones_like(trajs[:,:,0]) # S, N
else:
visibs = visibs[0]
if vals is not None:
vals = vals[0] # N
# print('vals', vals.shape)
if N > max_show:
inds = np.random.choice(N, max_show)
trajs = trajs[:,inds]
valids = valids[:,inds]
visibs = visibs[:,inds]
if vals is not None:
vals = vals[inds]
N = trajs.shape[1]
trajs = trajs.clamp(-16, W+16)
rgbs_color = []
for rgb in rgbs:
rgb = back2color(rgb).detach().cpu().numpy()
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
rgbs_color.append(rgb) # each element 3 x H x W
for i in range(min(N, max_show)):
if cmap=='onediff' and i==0:
cmap_ = 'spring'
elif cmap=='onediff':
cmap_ = 'winter'
else:
cmap_ = cmap
traj = trajs[:,i].long().detach().cpu().numpy() # S, 2
valid = valids[:,i].long().detach().cpu().numpy() # S
# print('traj', traj.shape)
# print('valid', valid.shape)
if vals is not None:
# val = vals[:,i].float().detach().cpu().numpy() # []
val = vals[i].float().detach().cpu().numpy() # []
# print('val', val.shape)
else:
val = None
for t in range(S):
if valid[t]:
rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj[:t+1], S=S, show_dots=show_dots, cmap=cmap_, val=val, linewidth=linewidth)
for i in range(min(N, max_show)):
if cmap=='onediff' and i==0:
cmap_ = 'spring'
elif cmap=='onediff':
cmap_ = 'winter'
else:
cmap_ = cmap
traj = trajs[:,i] # S,2
vis = visibs[:,i].round() # S
valid = valids[:,i] # S
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
rgbs = []
for rgb in rgbs_color:
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
rgbs.append(preprocess_color(rgb))
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap=None, linewidth=1, max_show=1024):
# trajs is B, S, N, 2
# rgbs is B, S, C, H, W
B, S, C, H, W = rgbs.shape
B, S2, N, D = trajs.shape
assert(S==S2)
rgbs = rgbs[0] # S, C, H, W
trajs = trajs[0] # S, N, 2
visibles = visibles[0] # S, N
if valids is None:
valids = torch.ones_like(trajs[:,:,0]) # S, N
else:
valids = valids[0]
rgbs_color = []
for rgb in rgbs:
rgb = back2color(rgb).detach().cpu().numpy()
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
rgbs_color.append(rgb) # each element 3 x H x W
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
visibles = visibles.float().detach().cpu().numpy() # S, N
valids = valids.long().detach().cpu().numpy() # S, N
for i in range(min(N, max_show)):
if cmap=='onediff' and i==0:
cmap_ = 'spring'
elif cmap=='onediff':
cmap_ = 'winter'
else:
cmap_ = cmap
traj = trajs[:,i] # S,2
vis = visibles[:,i] # S
valid = valids[:,i] # S
rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
for i in range(min(N, max_show)):
if cmap=='onediff' and i==0:
cmap_ = 'spring'
elif cmap=='onediff':
cmap_ = 'winter'
else:
cmap_ = cmap
traj = trajs[:,i] # S,2
vis = visibles[:,i] # S
valid = valids[:,i] # S
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth)
rgbs = []
for rgb in rgbs_color:
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
rgbs.append(preprocess_color(rgb))
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=True, show_lines=True, frame_id=None, frame_str=None, only_return=False, cmap='coolwarm', linewidth=1, max_show=1024):
# trajs is B, S, N, 2
# rgb is B, C, H, W
B, C, H, W = rgb.shape
B, S, N, D = trajs.shape
rgb = rgb[0] # S, C, H, W
trajs = trajs[0] # S, N, 2
if valids is None:
valids = torch.ones_like(trajs[:,:,0])
else:
valids = valids[0]
rgb_color = back2color(rgb).detach().cpu().numpy()
rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last
# using maxdist will dampen the colors for short motions
# norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) # N
# maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy()
maxdist = None
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
valids = valids.long().detach().cpu().numpy() # S, N
if N > max_show:
inds = np.random.choice(N, max_show)
trajs = trajs[:,inds]
valids = valids[:,inds]
N = trajs.shape[1]
for i in range(min(N, max_show)):
if cmap=='onediff' and i==0:
cmap_ = 'spring'
elif cmap=='onediff':
cmap_ = 'winter'
else:
cmap_ = cmap
traj = trajs[:,i] # S, 2
valid = valids[:,i] # S
if valid[0]==1:
traj = traj[valid>0]
rgb_color = self.draw_traj_on_image_py(
rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth)
rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0)
rgb = preprocess_color(rgb_color)
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None):
# all inputs are numpy tensors
# rgb is 3 x H x W
# traj is S x 2
H, W, C = rgb.shape
assert(C==3)
rgb = rgb.astype(np.uint8).copy()
S1, D = traj.shape
assert(D==2)
color_map = cm.get_cmap(cmap)
S1, D = traj.shape
for s in range(S1):
if val is not None:
color = np.array(color_map(val)[:3]) * 255 # rgb
else:
if maxdist is not None:
val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1)
color = np.array(color_map(val)[:3]) * 255 # rgb
else:
color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
if show_lines and s<(S1-1):
cv2.line(rgb,
(int(traj[s,0]), int(traj[s,1])),
(int(traj[s+1,0]), int(traj[s+1,1])),
color,
linewidth,
cv2.LINE_AA)
if show_dots:
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1)
# if maxdist is not None:
# val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1)
# color = np.array(color_map(val)[:3]) * 255 # rgb
# else:
# # draw the endpoint of traj, using the next color (which may be the last color)
# color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb
# # emphasize endpoint
# cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1)
return rgb
def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None):
# all inputs are numpy tensors
# rgbs is a list of H,W,3
# traj is S,2
H, W, C = rgbs[0].shape
assert(C==3)
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
S1, D = traj.shape
assert(D==2)
x = int(np.clip(traj[0,0], 0, W-1))
y = int(np.clip(traj[0,1], 0, H-1))
color = rgbs[0][y,x]
color = (int(color[0]),int(color[1]),int(color[2]))
for s in range(S):
# bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb
# cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1)
cv2.polylines(rgbs[s],
[traj[:s+1]],
False,
color,
linewidth,
cv2.LINE_AA)
return rgbs
def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None):
# all inputs are numpy tensors
# rgbs is a list of 3,H,W
# xy is N,2
H, W, C = rgb.shape
assert(C==3)
rgb = rgb.astype(np.uint8).copy()
N, D = xy.shape
assert(D==2)
xy = xy.astype(np.float32)
xy[:,0] = np.clip(xy[:,0], 0, W-1)
xy[:,1] = np.clip(xy[:,1], 0, H-1)
xy = xy.astype(np.int32)
if colors is None:
colors = get_n_colors(N)
for n in range(N):
color = colors[n]
# print('color', color)
# color = (color[0]*255).astype(np.uint8)
color = (int(color[0]),int(color[1]),int(color[2]))
# x = int(np.clip(xy[0,0], 0, W-1))
# y = int(np.clip(xy[0,1], 0, H-1))
# color_ = rgbs[0][y,x]
# color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
# color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
cv2.circle(rgb, (int(xy[n,0]), int(xy[n,1])), linewidth, color, 3)
# vis_color = int(np.squeeze(vis[s])*255)
# vis_color = (vis_color,vis_color,vis_color)
# cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1)
return rgb
def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None):
# all inputs are numpy tensors
# rgbs is a list of 3,H,W
# traj is S,2
H, W, C = rgbs[0].shape
assert(C==3)
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
S1, D = traj.shape
assert(D==2)
if cmap is None:
bremm = ColorMap2d()
traj_ = traj[0:1].astype(np.float32)
traj_[:,0] /= float(W)
traj_[:,1] /= float(H)
color = bremm(traj_)
# print('color', color)
color = (color[0]*255).astype(np.uint8)
color = (int(color[0]),int(color[1]),int(color[2]))
for s in range(S):
if cmap is not None:
color_map = cm.get_cmap(cmap)
# color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb
color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
# color = color.astype(np.uint8)
# color = (color[0], color[1], color[2])
# print('color', color)
# import ipdb; ipdb.set_trace()
cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+2, color, -1)
vis_color = int(np.squeeze(vis[s])*255)
vis_color = (vis_color,vis_color,vis_color)
cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1)
return rgbs
def summ_traj_as_crops(self, name, trajs_e, rgbs, frame_id=None, frame_str=None, only_return=False, show_circ=True, trajs_g=None, valids_g=None, is_g=False, anchor_ind=None, ara=None, anchors=None, frame_ids=None, frame_strs=None):
B, S, N, D = trajs_e.shape
assert(N==1)
assert(D==2)
rgbs = back2color(rgbs).detach().cpu().byte().numpy()
rgbs_vis = []
n = 0
pad_amount = 128
trajs_e_py = trajs_e[0].detach().cpu().numpy()
# trajs_e_py = np.clip(trajs_e_py, min=pad_amount/2, max=pad_amoun
trajs_e_py = trajs_e_py + pad_amount
if trajs_g is not None:
trajs_g_py = trajs_g[0].detach().cpu().numpy()
trajs_g_py = trajs_g_py + pad_amount
if valids_g is not None:
valids_g_py = valids_g[0].detach().cpu().numpy()
else:
valids_g_py = np.ones_like(trajs_g_py[:,:,:,0])
for s in range(S):
rgb = rgbs[0,s]
# print('orig rgb', rgb.shape)
rgb = np.transpose(rgb,(1,2,0)) # H, W, 3
rgb = np.pad(rgb, ((pad_amount,pad_amount),(pad_amount,pad_amount),(0,0)))
# print('pad rgb', rgb.shape)
H, W, C = rgb.shape
if trajs_g is not None:
xy_g = trajs_g_py[s,n]
xy_g[0] = np.clip(xy_g[0], pad_amount, W-pad_amount)
xy_g[1] = np.clip(xy_g[1], pad_amount, H-pad_amount)
if valids_g_py[s,n] > 0:
rgb = self.draw_circs_on_image_py(rgb, xy_g.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3)
xy_e = trajs_e_py[s,n]
xy_e[0] = np.clip(xy_e[0], pad_amount, W-pad_amount)
xy_e[1] = np.clip(xy_e[1], pad_amount, H-pad_amount)
if show_circ:
# if (anchors is not None) and (s==anchor_ind):
# rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,0,0)], linewidth=8, radius=12)
if (anchors is not None) and (s in anchors):
rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,255,255)], linewidth=4, radius=8)
if is_g:
rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3)
else:
rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,0,255)], linewidth=2, radius=3)
xmin = int(xy_e[0])-pad_amount//2
xmax = int(xy_e[0])+pad_amount//2
ymin = int(xy_e[1])-pad_amount//2
ymax = int(xy_e[1])+pad_amount//2
rgb_ = rgb[ymin:ymax, xmin:xmax]
H_, W_ = rgb_.shape[:2]
# if np.any(rgb_.shape==0):
# input()
if H_==0 or W_==0:
import ipdb; ipdb.set_trace()
if (ara is not None) and (s in ara):
# green border
rgb_[0,:,0] = 0
rgb_[0,:,1] = 255
rgb_[0,:,2] = 0
rgb_[-1,:,0] = 0
rgb_[-1,:,1] = 255
rgb_[-1,:,2] = 0
rgb_[:,0,0] = 0
rgb_[:,0,1] = 255
rgb_[:,0,2] = 0
rgb_[:,-1,0] = 0
rgb_[:,-1,1] = 255
rgb_[:,-1,2] = 0
if (anchor_ind is not None) and (s==anchor_ind):
# inner green border
pad = 4
rgb_[:pad,:,0] = 0
rgb_[:pad,:,1] = 255
rgb_[:pad,:,2] = 0
rgb_[-pad:,:,0] = 0
rgb_[-pad:,:,1] = 255
rgb_[-pad:,:,2] = 0
rgb_[:,:pad,0] = 0
rgb_[:,:pad,1] = 255
rgb_[:,:pad,2] = 0
rgb_[:,-pad:,0] = 0
rgb_[:,-pad:,1] = 255
rgb_[:,-pad:,2] = 0
rgb_ = rgb_.transpose(2,0,1)
rgb_ = torch.from_numpy(rgb_)
if frame_ids is not None:
# if s==anchor_ind:
# frame_ids[s] = frame_ids[s] + '
rgb_ = draw_frame_id_on_vis(rgb_.unsqueeze(0), frame_ids[s]).squeeze(0)
if s==anchor_ind:
rgb_ = draw_frame_str_on_vis(rgb_.unsqueeze(0), '(A)').squeeze(0)
if frame_strs is not None:
rgb_ = draw_frame_str_on_vis(rgb_.unsqueeze(0), frame_strs[s]).squeeze(0)
rgbs_vis.append(rgb_)
# nrow = int(np.sqrt(S)*(16.0/9)/2.0)
nrow = int(np.sqrt(S)*1.5)
grid_img = torchvision.utils.make_grid(torch.stack(rgbs_vis, dim=0), nrow=nrow).unsqueeze(0)
# print('grid_img', grid_img.shape)
return self.summ_rgb(name, grid_img.byte(), frame_id=frame_id, frame_str=frame_str, only_return=only_return)
def summ_pts_on_rgb(self, name, trajs, rgb, visibs=None, valids=None, frame_id=None, frame_str=None, only_return=False, show_dots=True, colors=None, cmap='coolwarm', linewidth=1, max_show=1024, already_sorted=False):
# trajs is B, S, N, 2
# rgbs is B, S, C, H, W
B, C, H, W = rgb.shape
B, S, N, D = trajs.shape
rgb = rgb[0] # C, H, W
trajs = trajs[0] # S, N, 2
if valids is None:
valids = torch.ones_like(trajs[:,:,0]) # S, N
else:
valids = valids[0]
if visibs is None:
visibs = torch.ones_like(trajs[:,:,0]) # S, N
else:
visibs = visibs[0]
trajs = trajs.clamp(-16, W+16)
if N > max_show:
inds = np.random.choice(N, max_show)
trajs = trajs[:,inds]
valids = valids[:,inds]
visibs = visibs[:,inds]
N = trajs.shape[1]
if not already_sorted:
inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
trajs = trajs[:,inds]
valids = valids[:,inds]
visibs = visibs[:,inds]
rgb = back2color(rgb).detach().cpu().numpy()
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
valids = valids.long().detach().cpu().numpy() # S, N
visibs = visibs.long().detach().cpu().numpy() # S, N
rgb = rgb.astype(np.uint8).copy()
for i in range(min(N, max_show)):
if cmap=='onediff' and i==0:
cmap_ = 'spring'
elif cmap=='onediff':
cmap_ = 'winter'
else:
cmap_ = cmap
traj = trajs[:,i] # S,2
valid = valids[:,i] # S
visib = visibs[:,i] # S
if colors is None:
ii = i/(1e-4+N-1.0)
color_map = cm.get_cmap(cmap)
color = np.array(color_map(ii)[:3]) * 255 # rgb
else:
color = np.array(colors[i]).astype(np.int64)
color = (int(color[0]),int(color[1]),int(color[2]))
for s in range(S):
if valid[s]:
if visib[s]:
thickness = -1
else:
thickness = 2
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, thickness)
rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0)
rgb = preprocess_color(rgb)
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
def summ_pts_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', colors=None, linewidth=1, max_show=1024, frame_strs=None):
# trajs is B, S, N, 2
# rgbs is B, S, C, H, W
B, S, C, H, W = rgbs.shape
B, S2, N, D = trajs.shape
assert(S==S2)
rgbs = rgbs[0] # S, C, H, W
trajs = trajs[0] # S, N, 2
if valids is None:
valids = torch.ones_like(trajs[:,:,0]) # S, N
else:
valids = valids[0]
if visibs is None:
visibs = torch.ones_like(trajs[:,:,0]) # S, N
else:
visibs = visibs[0]
if N > max_show:
inds = np.random.choice(N, max_show)
trajs = trajs[:,inds]
valids = valids[:,inds]
visibs = visibs[:,inds]
N = trajs.shape[1]
inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
trajs = trajs[:,inds]
valids = valids[:,inds]
visibs = visibs[:,inds]
rgbs_color = []
for rgb in rgbs:
rgb = back2color(rgb).detach().cpu().numpy()
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
rgbs_color.append(rgb) # each element 3 x H x W
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
valids = valids.long().detach().cpu().numpy() # S, N
visibs = visibs.long().detach().cpu().numpy() # S, N
rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color]
for i in range(min(N, max_show)):
traj = trajs[:,i] # S,2
valid = valids[:,i] # S
visib = visibs[:,i] # S
if colors is None:
ii = i/(1e-4+N-1.0)
color_map = cm.get_cmap(cmap)
color = np.array(color_map(ii)[:3]) * 255 # rgb
else:
color = np.array(colors[i]).astype(np.int64)
color = (int(color[0]),int(color[1]),int(color[2]))
for s in range(S):
if valid[s]:
if visib[s]:
thickness = -1
else:
thickness = 2
cv2.circle(rgbs_color[s], (int(traj[s,0]), int(traj[s,1])), int(linewidth), color, thickness)
rgbs = []
for rgb in rgbs_color:
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
rgbs.append(preprocess_color(rgb))
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
def erode2d(im, times=1):
device = im.device
weights2d = torch.ones(1, 1, 3, 3, dtype=im.dtype, device=device)
for time in range(times):
im = 1.0 - F.conv2d(1.0 - im, weights2d, padding=1).clamp(0, 1)
return im
def dilate2d(im, times=1):
device = im.device
assert(times>0)
dilation_kernel_size = times*2 + 1
padding_size = dilation_kernel_size // 2
dilation_kernel = torch.ones((1, 1, dilation_kernel_size, dilation_kernel_size), device=device)
im = F.conv2d(im, dilation_kernel, padding=padding_size, groups=im.shape[1]).clamp(0,1)
return im