import torch import numpy as np import utils.basic import utils.py from sklearn.decomposition import PCA from matplotlib import cm import matplotlib.pyplot as plt import cv2 import torch.nn.functional as F import torchvision EPS = 1e-6 from skimage.color import ( rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb, rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb) def _convert(input_, type_): return { 'float': input_.float(), 'double': input_.double(), }.get(type_, input_) def _generic_transform_sk_3d(transform, in_type='', out_type=''): def apply_transform_individual(input_): device = input_.device input_ = input_.cpu() input_ = _convert(input_, in_type) input_ = input_.permute(1, 2, 0).detach().numpy() transformed = transform(input_) output = torch.from_numpy(transformed).float().permute(2, 0, 1) output = _convert(output, out_type) return output.to(device) def apply_transform(input_): to_stack = [] for image in input_: to_stack.append(apply_transform_individual(image)) return torch.stack(to_stack) return apply_transform hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb) def preprocess_color_tf(x): import tensorflow as tf return tf.cast(x,tf.float32) * 1./255 - 0.5 def preprocess_color(x): if isinstance(x, np.ndarray): return x.astype(np.float32) * 1./255 - 0.5 else: return x.float() * 1./255 - 0.5 def pca_embed(emb, keep, valid=None): ## emb -- [S,H/2,W/2,C] ## keep is the number of principal components to keep ## Helper function for reduce_emb. emb = emb + EPS #emb is B x C x H x W emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C if valid: valid = valid.cpu().detach().numpy().reshape((H*W)) emb_reduced = list() B, H, W, C = np.shape(emb) for img in emb: if np.isnan(img).any(): emb_reduced.append(np.zeros([H, W, keep])) continue pixels_kd = np.reshape(img, (H*W, C)) if valid: pixels_kd_pca = pixels_kd[valid] else: pixels_kd_pca = pixels_kd P = PCA(keep) P.fit(pixels_kd_pca) if valid: pixels3d = P.transform(pixels_kd)*valid else: pixels3d = P.transform(pixels_kd) out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32) if np.isnan(out_img).any(): emb_reduced.append(np.zeros([H, W, keep])) continue emb_reduced.append(out_img) emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32) return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2) def pca_embed_together(emb, keep): ## emb -- [S,H/2,W/2,C] ## keep is the number of principal components to keep ## Helper function for reduce_emb. emb = emb + EPS #emb is B x C x H x W emb = emb.permute(0, 2, 3, 1).cpu().detach().float().numpy() #this is B x H x W x C B, H, W, C = np.shape(emb) if np.isnan(emb).any(): return torch.zeros(B, keep, H, W) pixelskd = np.reshape(emb, (B*H*W, C)) P = PCA(keep) P.fit(pixelskd) pixels3d = P.transform(pixelskd) out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32) if np.isnan(out_img).any(): return torch.zeros(B, keep, H, W) return torch.from_numpy(out_img).permute(0, 3, 1, 2) def reduce_emb(emb, valid=None, inbound=None, together=False): ## emb -- [S,C,H/2,W/2], inbound -- [S,1,H/2,W/2] ## Reduce number of chans to 3 with PCA. For vis. # S,H,W,C = emb.shape.as_list() S, C, H, W = list(emb.size()) keep = 4 if together: reduced_emb = pca_embed_together(emb, keep) else: reduced_emb = pca_embed(emb, keep, valid) #not im reduced_emb = reduced_emb[:,1:] reduced_emb = utils.basic.normalize(reduced_emb) - 0.5 if inbound is not None: emb_inbound = emb*inbound else: emb_inbound = None return reduced_emb, emb_inbound def get_feat_pca(feat, valid=None): B, C, D, W = list(feat.size()) # feat is B x C x D x W. If 3D input, average it through Height dimension before passing into this function. pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True) # pca is B x 3 x W x D return pca def gif_and_tile(ims, just_gif=False): S = len(ims) # each im is B x H x W x C # i want a gif in the left, and the tiled frames on the right # for the gif tool, this means making a B x S x H x W tensor # where the leftmost part is sequential and the rest is tiled gif = torch.stack(ims, dim=1) if just_gif: return gif til = torch.cat(ims, dim=2) til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1) im = torch.cat([gif, til], dim=3) return im def back2color(i, blacken_zeros=False): if blacken_zeros: const = torch.tensor([-0.5]) i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i) return back2color(i) else: return ((i+0.5)*255).type(torch.ByteTensor) def xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=False): # xy is B x N x 2, containing float x and y coordinates of N things # grid_xs and grid_ys are B x N x Y x X B, N, Y, X = list(grid_xs.shape) mu_x = xy[:,:,0].clone() mu_y = xy[:,:,1].clone() x_valid = (mu_x>-0.5) & (mu_x-0.5) & (mu_y 0.5).float() return prior def seq2color(im, norm=True, colormap='coolwarm'): B, S, H, W = list(im.shape) # S is sequential # prep a mask of the valid pixels, so we can blacken the invalids later mask = torch.max(im, dim=1, keepdim=True)[0] # turn the S dim into an explicit sequence coeffs = np.linspace(1.0, float(S), S).astype(np.float32)/float(S) # # increase the spacing from the center # coeffs[:int(S/2)] -= 2.0 # coeffs[int(S/2)+1:] += 2.0 coeffs = torch.from_numpy(coeffs).float().cuda() coeffs = coeffs.reshape(1, S, 1, 1).repeat(B, 1, H, W) # scale each channel by the right coeff im = im * coeffs # now im is in [1/S, 1], except for the invalid parts which are 0 # keep the highest valid coeff at each pixel im = torch.max(im, dim=1, keepdim=True)[0] out = [] for b in range(B): im_ = im[b] # move channels out to last dim_ im_ = im_.detach().cpu().numpy() im_ = np.squeeze(im_) # im_ is H x W if colormap=='coolwarm': im_ = cm.coolwarm(im_)[:, :, :3] elif colormap=='PiYG': im_ = cm.PiYG(im_)[:, :, :3] elif colormap=='winter': im_ = cm.winter(im_)[:, :, :3] elif colormap=='spring': im_ = cm.spring(im_)[:, :, :3] elif colormap=='onediff': im_ = np.reshape(im_, (-1)) im0_ = cm.spring(im_)[:, :3] im1_ = cm.winter(im_)[:, :3] im1_[im_==1/float(S)] = im0_[im_==1/float(S)] im_ = np.reshape(im1_, (H, W, 3)) else: assert(False) # invalid colormap # move channels into dim 0 im_ = np.transpose(im_, [2, 0, 1]) im_ = torch.from_numpy(im_).float().cuda() out.append(im_) out = torch.stack(out, dim=0) # blacken the invalid pixels, instead of using the 0-color out = out*mask # out = out*255.0 # put it in [-0.5, 0.5] out = out - 0.5 return out def colorize(d): # this is actually just grayscale right now if d.ndim==2: d = d.unsqueeze(dim=0) else: assert(d.ndim==3) # color_map = cm.get_cmap('plasma') color_map = cm.get_cmap('inferno') # S1, D = traj.shape # print('d1', d.shape) C,H,W = d.shape assert(C==1) d = d.reshape(-1) d = d.detach().cpu().numpy() # print('d2', d.shape) color = np.array(color_map(d)) * 255 # rgba # print('color1', color.shape) color = np.reshape(color[:,:3], [H*W, 3]) # print('color2', color.shape) color = torch.from_numpy(color).permute(1,0).reshape(3,H,W) # # gather # cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray') # if cmap=='RdBu' or cmap=='RdYlGn': # colors = cm(np.arange(256))[:, :3] # else: # colors = cm.colors # colors = np.array(colors).astype(np.float32) # colors = np.reshape(colors, [-1, 3]) # colors = tf.constant(colors, dtype=tf.float32) # value = tf.gather(colors, indices) # colorize(value, normalize=True, vmin=None, vmax=None, cmap=None, vals=255) # copy to the three chans # d = d.repeat(3, 1, 1) return color def oned2inferno(d, norm=True, do_colorize=False): # convert a 1chan input to a 3chan image output # if it's just B x H x W, add a C dim if d.ndim==3: d = d.unsqueeze(dim=1) # d should be B x C x H x W, where C=1 B, C, H, W = list(d.shape) assert(C==1) if norm: d = utils.basic.normalize(d) if do_colorize: rgb = torch.zeros(B, 3, H, W) for b in list(range(B)): rgb[b] = colorize(d[b]) else: rgb = d.repeat(1, 3, 1, 1)*255.0 # rgb = (255.0*rgb).type(torch.ByteTensor) rgb = rgb.type(torch.ByteTensor) # rgb = tf.cast(255.0*rgb, tf.uint8) # rgb = tf.reshape(rgb, [-1, hyp.H, hyp.W, 3]) # rgb = tf.expand_dims(rgb, axis=0) return rgb def oned2gray(d, norm=True): # convert a 1chan input to a 3chan image output # if it's just B x H x W, add a C dim if d.ndim==3: d = d.unsqueeze(dim=1) # d should be B x C x H x W, where C=1 B, C, H, W = list(d.shape) assert(C==1) if norm: d = utils.basic.normalize(d) rgb = d.repeat(1,3,1,1) rgb = (255.0*rgb).type(torch.ByteTensor) return rgb def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20, shadow=True): rgb = vis.detach().cpu().numpy()[0] rgb = np.transpose(rgb, [1, 2, 0]) # put channels last rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) color = (255, 255, 255) # print('putting frame id', frame_id) frame_str = utils.basic.strnum(frame_id) text_color_bg = (0,0,0) font = cv2.FONT_HERSHEY_SIMPLEX text_size, _ = cv2.getTextSize(frame_str, font, scale, 1) text_w, text_h = text_size if shadow: cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1) cv2.putText( rgb, frame_str, (left, top), # from left, from top font, scale, # font scale (float) color, 1) # font thickness (int) rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB) vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) return vis def draw_frame_str_on_vis(vis, frame_str, scale=0.5, left=5, top=40, shadow=True): rgb = vis.detach().cpu().numpy()[0] rgb = np.transpose(rgb, [1, 2, 0]) # put channels last rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) color = (255, 255, 255) text_color_bg = (0,0,0) font = cv2.FONT_HERSHEY_SIMPLEX text_size, _ = cv2.getTextSize(frame_str, font, scale, 1) text_w, text_h = text_size if shadow: cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1) cv2.putText( rgb, frame_str, (left, top), # from left, from top font, scale, # font scale (float) color, 1) # font thickness (int) rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB) vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) return vis COLORMAP_FILE = "./utils/bremm.png" class ColorMap2d: def __init__(self, filename=None): self._colormap_file = filename or COLORMAP_FILE self._img = plt.imread(self._colormap_file) self._height = self._img.shape[0] self._width = self._img.shape[1] def __call__(self, X): assert len(X.shape) == 2 output = np.zeros((X.shape[0], 3)) for i in range(X.shape[0]): x, y = X[i, :] xp = int((self._width-1) * x) yp = int((self._height-1) * y) xp = np.clip(xp, 0, self._width-1) yp = np.clip(yp, 0, self._height-1) output[i, :] = self._img[yp, xp] return output def get_n_colors(N, sequential=False): label_colors = [] for ii in range(N): if sequential: rgb = cm.winter(ii/(N-1)) rgb = (np.array(rgb) * 255).astype(np.uint8)[:3] else: # rgb = np.zeros(3) # while np.sum(rgb) < 128: # ensure min brightness # rgb = np.random.randint(0,256,3) rgb = cm.gist_rainbow(ii/(N-1)) rgb = (np.array(rgb) * 255).astype(np.uint8)[:3] label_colors.append(rgb) return label_colors class Summ_writer(object): def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False): self.writer = writer self.global_step = global_step self.log_freq = log_freq self.scalar_freq = scalar_freq self.fps = fps self.just_gif = just_gif self.maxwidth = 10000 self.save_this = (self.global_step % self.log_freq == 0) self.scalar_freq = max(scalar_freq,1) self.save_scalar = (self.global_step % self.scalar_freq == 0) if self.save_this: self.save_scalar = True def summ_gif(self, name, tensor, blacken_zeros=False): # tensor should be in B x S x C x H x W assert tensor.dtype in {torch.uint8,torch.float32} shape = list(tensor.shape) if tensor.dtype == torch.float32: tensor = back2color(tensor, blacken_zeros=blacken_zeros) video_to_write = tensor[0:1] S = video_to_write.shape[1] if S==1: # video_to_write is 1 x 1 x C x H x W self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step) else: self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step) return video_to_write def draw_boxlist2d_on_image(self, rgb, boxlist, scores=None, tids=None, linewidth=1): B, C, H, W = list(rgb.shape) assert(C==3) B2, N, D = list(boxlist.shape) assert(B2==B) assert(D==4) # ymin, xmin, ymax, xmax rgb = back2color(rgb) if scores is None: scores = torch.ones(B2, N).float() if tids is None: # tids = torch.arange(N).reshape(1,N).repeat(B2,1).long() tids = torch.zeros(B2, N).long() out = self.draw_boxlist2d_on_image_py( rgb[0].cpu().detach().numpy(), boxlist[0].cpu().detach().numpy(), scores[0].cpu().detach().numpy(), tids[0].cpu().detach().numpy(), linewidth=linewidth) out = torch.from_numpy(out).type(torch.ByteTensor).permute(2, 0, 1) out = torch.unsqueeze(out, dim=0) out = preprocess_color(out) out = torch.reshape(out, [1, C, H, W]) return out def draw_boxlist2d_on_image_py(self, rgb, boxlist, scorelist, tidlist, linewidth=1): # all inputs are numpy tensors # rgb is H x W x 3 # boxlist is N x 4 # scorelist is N # tidlist is N rgb = np.transpose(rgb, [1, 2, 0]) # put channels last # rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) rgb = rgb.astype(np.uint8).copy() H, W, C = rgb.shape assert(C==3) N, D = boxlist.shape assert(D==4) M = scorelist.shape[0] assert(M==N) O = tidlist.shape[0] assert(M==O) # color_map = cm.get_cmap('Accent') # color_map = cm.get_cmap('Set3') color_map = cm.get_cmap('tab20') color_map = color_map.colors # print('color_map', color_map) # draw for (box, score, tid) in zip(boxlist, scorelist, tidlist): # box is 4 if not np.isclose(score, 0.0, atol=1e-3): # ymin, xmin, ymax, xmax = box xmin, ymin, xmax, ymax = box color = color_map[tid] color = np.array(color)*255.0 color = color.round() if not np.isclose(score, 1.0, atol=1e-3): cv2.putText(rgb, # '%d (%.2f)' % (tidlist[ind], scorelist[ind]), '%.2f' % (score), (int(xmin), int(ymin)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, # font size color), xmin = int(np.clip(xmin,0,W-1)) ymin = int(np.clip(ymin,0,H-1)) xmax = int(np.clip(xmax,0,W-1)) ymax = int(np.clip(ymax,0,H-1)) # print('xmin, xmax, ymin, ymax', xmin, xmax, ymin, ymax) cv2.line(rgb, (xmin, ymin), (xmin, ymax), color, linewidth, cv2.LINE_4) cv2.line(rgb, (xmin, ymin), (xmax, ymin), color, linewidth, cv2.LINE_4) cv2.line(rgb, (xmax, ymin), (xmax, ymax), color, linewidth, cv2.LINE_4) cv2.line(rgb, (xmax, ymax), (xmin, ymax), color, linewidth, cv2.LINE_4) # rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB) return rgb def summ_boxlist2d(self, name, rgb, boxlist, scores=None, tids=None, frame_id=None, frame_str=None, only_return=False, linewidth=1): B, C, H, W = list(rgb.shape) boxlist_vis = self.draw_boxlist2d_on_image(rgb, boxlist, scores=scores, tids=tids, linewidth=linewidth) return self.summ_rgb(name, boxlist_vis, frame_id=frame_id, frame_str=frame_str, only_return=only_return) def summ_rgbs(self, name, ims, frame_ids=None, frame_strs=None, blacken_zeros=False, only_return=False): if self.save_this: ims = gif_and_tile(ims, just_gif=self.just_gif) vis = ims assert vis.dtype in {torch.uint8,torch.float32} if vis.dtype == torch.float32: vis = back2color(vis, blacken_zeros) B, S, C, H, W = list(vis.shape) if frame_ids is not None: assert(len(frame_ids)==S) for s in range(S): vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s]) if frame_strs is not None: assert(len(frame_strs)==S) for s in range(S): vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s]) if int(W) > self.maxwidth: vis = vis[:,:,:,:self.maxwidth] if only_return: return vis else: return self.summ_gif(name, vis, blacken_zeros) def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, frame_str=None, only_return=False, halfres=False, shadow=True): if self.save_this: assert ims.dtype in {torch.uint8,torch.float32} if ims.dtype == torch.float32: ims = back2color(ims, blacken_zeros) #ims is B x C x H x W vis = ims[0:1] # just the first one B, C, H, W = list(vis.shape) if halfres: vis = F.interpolate(vis, scale_factor=0.5) if frame_id is not None: vis = draw_frame_id_on_vis(vis, frame_id, shadow=shadow) if frame_str is not None: vis = draw_frame_str_on_vis(vis, frame_str, shadow=shadow) if int(W) > self.maxwidth: vis = vis[:,:,:,:self.maxwidth] if only_return: return vis else: return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros) def flow2color(self, flow, clip=0.0): B, C, H, W = list(flow.size()) assert(C==2) flow = flow[0:1].detach() if False: flow = flow[0].detach().cpu().permute(1,2,0).numpy() # H,W,2 if clip > 0: clip_flow = clip else: clip_flow = None im = utils.py.flow_to_image(flow, clip_flow=clip_flow, convert_to_bgr=True) # im = utils.py.flow_to_image(flow, convert_to_bgr=True) im = torch.from_numpy(im).permute(2,0,1).unsqueeze(0).byte() # 1,3,H,W im = torch.flip(im, dims=[1]).clone() # BGR # # i prefer black bkg # white_pixels = (im == 255).all(dim=1, keepdim=True) # im[white_pixels.expand(-1, 3, -1, -1)] = 0 return im # flow_abs = torch.abs(flow) # flow_mean = flow_abs.mean(dim=[1,2,3]) # flow_std = flow_abs.std(dim=[1,2,3]) if clip==0: clip = torch.max(torch.abs(flow)).item() # if clip: flow = torch.clamp(flow, -clip, clip)/clip # else: # # # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2) # # flow_max = flow_mean + flow_std*2 + 1e-10 # # for b in range(B): # # flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1) # flow_max = torch.max(flow_abs[b]) # for b in range(B): # flow[b] = flow[b].clamp(-flow_max.item(), flow_max.item()) / flow_max[b].clamp(min=1) radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) #B x 1 x H x W radius_clipped = torch.clamp(radius, 0.0, 1.0) angle = torch.atan2(-flow[:, 1:2], -flow[:, 0:1]) / np.pi # B x 1 x H x W hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0) # hue = torch.mod(angle / (2 * np.pi) + 1.0, 1.0) saturation = torch.ones_like(hue) * 0.75 value = radius_clipped hsv = torch.cat([hue, saturation, value], dim=1) #B x 3 x H x W #flow = tf.image.hsv_to_rgb(hsv) flow = hsv_to_rgb(hsv) flow = (flow*255.0).type(torch.ByteTensor) # flow = torch.flip(flow, dims=[1]).clone() # BGR return flow def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None, frame_str=None, shadow=True): # flow is B x C x D x W if self.save_this: return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id, frame_str=frame_str, shadow=shadow) else: return None def summ_oneds(self, name, ims, frame_ids=None, frame_strs=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False): if self.save_this: if bev: B, C, H, _, W = list(ims[0].shape) if reduce_max: ims = [torch.max(im, dim=3)[0] for im in ims] else: ims = [torch.mean(im, dim=3) for im in ims] elif fro: B, C, _, H, W = list(ims[0].shape) if reduce_max: ims = [torch.max(im, dim=2)[0] for im in ims] else: ims = [torch.mean(im, dim=2) for im in ims] if len(ims) != 1: # sequence im = gif_and_tile(ims, just_gif=self.just_gif) else: im = torch.stack(ims, dim=1) # single frame B, S, C, H, W = list(im.shape) if logvis and max_val: max_val = np.log(max_val) im = torch.log(torch.clamp(im, 0)+1.0) im = torch.clamp(im, 0, max_val) im = im/max_val norm = False elif max_val: im = torch.clamp(im, 0, max_val) im = im/max_val norm = False if norm: # normalize before oned2inferno, # so that the ranges are similar within B across S im = utils.basic.normalize(im) im = im.view(B*S, C, H, W) vis = oned2inferno(im, norm=norm, do_colorize=do_colorize) vis = vis.view(B, S, 3, H, W) if frame_ids is not None: assert(len(frame_ids)==S) for s in range(S): vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s]) if frame_strs is not None: assert(len(frame_strs)==S) for s in range(S): vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s]) if W > self.maxwidth: vis = vis[...,:self.maxwidth] if only_return: return vis else: self.summ_gif(name, vis) def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, frame_str=None, only_return=False, shadow=True): if self.save_this: if bev: B, C, H, _, W = list(im.shape) if max_along_y: im = torch.max(im, dim=3)[0] else: im = torch.mean(im, dim=3) elif fro: B, C, _, H, W = list(im.shape) if max_along_y: im = torch.max(im, dim=2)[0] else: im = torch.mean(im, dim=2) else: B, C, H, W = list(im.shape) im = im[0:1] # just the first one assert(C==1) if logvis and max_val: max_val = np.log(max_val) im = torch.log(im) im = torch.clamp(im, 0, max_val) im = im/max_val norm = False elif max_val: im = torch.clamp(im, 0, max_val)/max_val norm = False vis = oned2inferno(im, norm=norm) if W > self.maxwidth: vis = vis[...,:self.maxwidth] return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, frame_str=frame_str, only_return=only_return, shadow=shadow) def summ_4chan(self, name, im, norm=True, frame_id=None, frame_str=None, only_return=False): if self.save_this: B, C, H, W = list(im.shape) im = im[0:1] # just the first one assert(C==4) # d = utils.basic.normalize(d) im0 = im[:,0:1] im1 = im[:,1:2] im2 = im[:,2:3] im3 = im[:,3:4] im0 = utils.basic.normalize(im0).round() im1 = utils.basic.normalize(im1).round() im2 = utils.basic.normalize(im2).round() im3 = utils.basic.normalize(im3).round() # kp_vis = sw.summ_rgbs('tff/2_kp_s%d' % s, kp.unbind(1), only_return=True) # kp_any = (torch.max(kp_vis, dim=2, keepdims=True)[0]).repeat(1, 1, 3, 1, 1) # kp_vis[kp_any==0] = fcp_vis[kp_any==0] # vis0 = oned2inferno(im0, norm=False) # vis1 = oned2inferno(im1, norm=False) # vis2 = oned2inferno(im2, norm=False) # vis3 = oned2inferno(im3, norm=False) # vis0 = self.summ_seg('', im0[:,0:1]*1, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20') # vis1 = self.summ_seg('', im1[:,0:1]*2, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20') # vis2 = self.summ_seg('', im2[:,0:1]*3, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20') # vis3 = self.summ_seg('', im3[:,0:1]*4, only_return=True, frame_id=frame_id, frame_str=frame_str, colormap='tab20') vis0 = self.summ_seg('', im0[:,0]*1, only_return=True, colormap='tab20') vis1 = self.summ_seg('', im1[:,0]*2, only_return=True, colormap='tab20') vis2 = self.summ_seg('', im2[:,0]*3, only_return=True, colormap='tab20') vis3 = self.summ_seg('', im3[:,0]*4, only_return=True, colormap='tab20') # vis_any = (torch.max(vis2, dim=2, keepdims=True)[0]).repeat(1, 1, 3, 1, 1) # vis3[vis_any==0] = fcp_vis[kp_any==0] vis0_any = (torch.max(vis0, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1) vis1_any = (torch.max(vis1, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1) vis2_any = (torch.max(vis2, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1) vis3_any = (torch.max(vis3, dim=1, keepdims=True)[0]).repeat(1, 3, 1, 1) vis0[vis0_any==0] = vis1[vis0_any==0] vis0[vis1_any==0] = vis2[vis1_any==0] vis0[vis2_any==0] = vis3[vis2_any==0] print('vis0', vis0.shape, vis0.device) vis0 = vis0.cpu() # vis = oned2inferno(im, norm=norm) # if W > self.maxwidth: # vis = vis[...,:self.maxwidth] return self.summ_rgb(name, vis0, blacken_zeros=False, frame_id=frame_id, frame_str=frame_str, only_return=only_return) def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None, frame_strs=None): if self.save_this: if valids is not None: valids = torch.stack(valids, dim=1) feats = torch.stack(feats, dim=1) # feats leads with B x S x C if feats.ndim==6: # feats is B x S x C x D x H x W if fro: reduce_dim = 3 else: reduce_dim = 4 if valids is None: feats = torch.mean(feats, dim=reduce_dim) else: valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1) feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim) B, S, C, D, W = list(feats.size()) if not pca: # feats leads with B x S x C feats = torch.mean(torch.abs(feats), dim=2, keepdims=True) # feats leads with B x S x 1 feats = torch.unbind(feats, dim=1) return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs) else: __p = lambda x: utils.basic.pack_seqdim(x, B) __u = lambda x: utils.basic.unpack_seqdim(x, B) feats_ = __p(feats) if valids is None: feats_pca_ = get_feat_pca(feats_) else: valids_ = __p(valids) feats_pca_ = get_feat_pca(feats_, valids) feats_pca = __u(feats_pca_) return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs) def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None, frame_str=None): if self.save_this: if feat.ndim==5: # B x C x D x H x W if bev: reduce_axis = 3 elif fro: reduce_axis = 2 else: # default to bev reduce_axis = 3 if valid is None: feat = torch.mean(feat, dim=reduce_axis) else: valid = valid.repeat(1, feat.size()[1], 1, 1, 1) feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis) B, C, D, W = list(feat.shape) if not pca: feat = torch.mean(torch.abs(feat), dim=1, keepdims=True) # feat is B x 1 x D x W return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id, frame_str=frame_str) else: feat_pca = get_feat_pca(feat, valid) return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id, frame_str=frame_str) def summ_scalar(self, name, value): if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()): value = value.detach().cpu().numpy() if not np.isnan(value): if (self.log_freq == 1): self.writer.add_scalar(name, value, global_step=self.global_step) elif self.save_this or self.save_scalar: self.writer.add_scalar(name, value, global_step=self.global_step) def summ_seg(self, name, seg, only_return=False, frame_id=None, frame_str=None, colormap='tab20', label_colors=None): if not self.save_this: return B,H,W = seg.shape if label_colors is None: custom_label_colors = False # label_colors = get_n_colors(int(torch.max(seg).item()), sequential=True) label_colors = cm.get_cmap(colormap).colors label_colors = [[int(i*255) for i in l] for l in label_colors] else: custom_label_colors = True # label_colors = matplotlib.cm.get_cmap(colormap).colors # label_colors = [[int(i*255) for i in l] for l in label_colors] # print('label_colors', label_colors) # label_colors = [ # (0, 0, 0), # None # (70, 70, 70), # Buildings # (190, 153, 153), # Fences # (72, 0, 90), # Other # (220, 20, 60), # Pedestrians # (153, 153, 153), # Poles # (157, 234, 50), # RoadLines # (128, 64, 128), # Roads # (244, 35, 232), # Sidewalks # (107, 142, 35), # Vegetation # (0, 0, 255), # Vehicles # (102, 102, 156), # Walls # (220, 220, 0) # TrafficSigns # ] r = torch.zeros_like(seg,dtype=torch.uint8) g = torch.zeros_like(seg,dtype=torch.uint8) b = torch.zeros_like(seg,dtype=torch.uint8) for label in range(0,len(label_colors)): if (not custom_label_colors):# and (N > 20): label_ = label % 20 else: label_ = label idx = (seg == label) r[idx] = label_colors[label_][0] g[idx] = label_colors[label_][1] b[idx] = label_colors[label_][2] rgb = torch.stack([r,g,b],axis=1) return self.summ_rgb(name,rgb,only_return=only_return, frame_id=frame_id, frame_str=frame_str) def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap='coolwarm', vals=None, linewidth=1, max_show=1024): # trajs is B, S, N, 2 # rgbs is B, S, C, H, W B, S, C, H, W = rgbs.shape B, S2, N, D = trajs.shape assert(S==S2) rgbs = rgbs[0] # S, C, H, W trajs = trajs[0] # S, N, 2 if valids is None: valids = torch.ones_like(trajs[:,:,0]) # S, N else: valids = valids[0] if visibs is None: visibs = torch.ones_like(trajs[:,:,0]) # S, N else: visibs = visibs[0] if vals is not None: vals = vals[0] # N # print('vals', vals.shape) if N > max_show: inds = np.random.choice(N, max_show) trajs = trajs[:,inds] valids = valids[:,inds] visibs = visibs[:,inds] if vals is not None: vals = vals[inds] N = trajs.shape[1] trajs = trajs.clamp(-16, W+16) rgbs_color = [] for rgb in rgbs: rgb = back2color(rgb).detach().cpu().numpy() rgb = np.transpose(rgb, [1, 2, 0]) # put channels last rgbs_color.append(rgb) # each element 3 x H x W for i in range(min(N, max_show)): if cmap=='onediff' and i==0: cmap_ = 'spring' elif cmap=='onediff': cmap_ = 'winter' else: cmap_ = cmap traj = trajs[:,i].long().detach().cpu().numpy() # S, 2 valid = valids[:,i].long().detach().cpu().numpy() # S # print('traj', traj.shape) # print('valid', valid.shape) if vals is not None: # val = vals[:,i].float().detach().cpu().numpy() # [] val = vals[i].float().detach().cpu().numpy() # [] # print('val', val.shape) else: val = None for t in range(S): if valid[t]: rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj[:t+1], S=S, show_dots=show_dots, cmap=cmap_, val=val, linewidth=linewidth) for i in range(min(N, max_show)): if cmap=='onediff' and i==0: cmap_ = 'spring' elif cmap=='onediff': cmap_ = 'winter' else: cmap_ = cmap traj = trajs[:,i] # S,2 vis = visibs[:,i].round() # S valid = valids[:,i] # S rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) rgbs = [] for rgb in rgbs_color: rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) rgbs.append(preprocess_color(rgb)) return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs) def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap=None, linewidth=1, max_show=1024): # trajs is B, S, N, 2 # rgbs is B, S, C, H, W B, S, C, H, W = rgbs.shape B, S2, N, D = trajs.shape assert(S==S2) rgbs = rgbs[0] # S, C, H, W trajs = trajs[0] # S, N, 2 visibles = visibles[0] # S, N if valids is None: valids = torch.ones_like(trajs[:,:,0]) # S, N else: valids = valids[0] rgbs_color = [] for rgb in rgbs: rgb = back2color(rgb).detach().cpu().numpy() rgb = np.transpose(rgb, [1, 2, 0]) # put channels last rgbs_color.append(rgb) # each element 3 x H x W trajs = trajs.long().detach().cpu().numpy() # S, N, 2 visibles = visibles.float().detach().cpu().numpy() # S, N valids = valids.long().detach().cpu().numpy() # S, N for i in range(min(N, max_show)): if cmap=='onediff' and i==0: cmap_ = 'spring' elif cmap=='onediff': cmap_ = 'winter' else: cmap_ = cmap traj = trajs[:,i] # S,2 vis = visibles[:,i] # S valid = valids[:,i] # S rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) for i in range(min(N, max_show)): if cmap=='onediff' and i==0: cmap_ = 'spring' elif cmap=='onediff': cmap_ = 'winter' else: cmap_ = cmap traj = trajs[:,i] # S,2 vis = visibles[:,i] # S valid = valids[:,i] # S rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth) rgbs = [] for rgb in rgbs_color: rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) rgbs.append(preprocess_color(rgb)) return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs) def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=True, show_lines=True, frame_id=None, frame_str=None, only_return=False, cmap='coolwarm', linewidth=1, max_show=1024): # trajs is B, S, N, 2 # rgb is B, C, H, W B, C, H, W = rgb.shape B, S, N, D = trajs.shape rgb = rgb[0] # S, C, H, W trajs = trajs[0] # S, N, 2 if valids is None: valids = torch.ones_like(trajs[:,:,0]) else: valids = valids[0] rgb_color = back2color(rgb).detach().cpu().numpy() rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last # using maxdist will dampen the colors for short motions # norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) # N # maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy() maxdist = None trajs = trajs.long().detach().cpu().numpy() # S, N, 2 valids = valids.long().detach().cpu().numpy() # S, N if N > max_show: inds = np.random.choice(N, max_show) trajs = trajs[:,inds] valids = valids[:,inds] N = trajs.shape[1] for i in range(min(N, max_show)): if cmap=='onediff' and i==0: cmap_ = 'spring' elif cmap=='onediff': cmap_ = 'winter' else: cmap_ = cmap traj = trajs[:,i] # S, 2 valid = valids[:,i] # S if valid[0]==1: traj = traj[valid>0] rgb_color = self.draw_traj_on_image_py( rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth) rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0) rgb = preprocess_color(rgb_color) return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str) def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None): # all inputs are numpy tensors # rgb is 3 x H x W # traj is S x 2 H, W, C = rgb.shape assert(C==3) rgb = rgb.astype(np.uint8).copy() S1, D = traj.shape assert(D==2) color_map = cm.get_cmap(cmap) S1, D = traj.shape for s in range(S1): if val is not None: color = np.array(color_map(val)[:3]) * 255 # rgb else: if maxdist is not None: val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1) color = np.array(color_map(val)[:3]) * 255 # rgb else: color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb if show_lines and s<(S1-1): cv2.line(rgb, (int(traj[s,0]), int(traj[s,1])), (int(traj[s+1,0]), int(traj[s+1,1])), color, linewidth, cv2.LINE_AA) if show_dots: cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1) # if maxdist is not None: # val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1) # color = np.array(color_map(val)[:3]) * 255 # rgb # else: # # draw the endpoint of traj, using the next color (which may be the last color) # color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb # # emphasize endpoint # cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1) return rgb def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None): # all inputs are numpy tensors # rgbs is a list of H,W,3 # traj is S,2 H, W, C = rgbs[0].shape assert(C==3) rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs] S1, D = traj.shape assert(D==2) x = int(np.clip(traj[0,0], 0, W-1)) y = int(np.clip(traj[0,1], 0, H-1)) color = rgbs[0][y,x] color = (int(color[0]),int(color[1]),int(color[2])) for s in range(S): # bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1) cv2.polylines(rgbs[s], [traj[:s+1]], False, color, linewidth, cv2.LINE_AA) return rgbs def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None): # all inputs are numpy tensors # rgbs is a list of 3,H,W # xy is N,2 H, W, C = rgb.shape assert(C==3) rgb = rgb.astype(np.uint8).copy() N, D = xy.shape assert(D==2) xy = xy.astype(np.float32) xy[:,0] = np.clip(xy[:,0], 0, W-1) xy[:,1] = np.clip(xy[:,1], 0, H-1) xy = xy.astype(np.int32) if colors is None: colors = get_n_colors(N) for n in range(N): color = colors[n] # print('color', color) # color = (color[0]*255).astype(np.uint8) color = (int(color[0]),int(color[1]),int(color[2])) # x = int(np.clip(xy[0,0], 0, W-1)) # y = int(np.clip(xy[0,1], 0, H-1)) # color_ = rgbs[0][y,x] # color_ = (int(color_[0]),int(color_[1]),int(color_[2])) # color_ = (int(color_[0]),int(color_[1]),int(color_[2])) cv2.circle(rgb, (int(xy[n,0]), int(xy[n,1])), linewidth, color, 3) # vis_color = int(np.squeeze(vis[s])*255) # vis_color = (vis_color,vis_color,vis_color) # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1) return rgb def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None): # all inputs are numpy tensors # rgbs is a list of 3,H,W # traj is S,2 H, W, C = rgbs[0].shape assert(C==3) rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs] S1, D = traj.shape assert(D==2) if cmap is None: bremm = ColorMap2d() traj_ = traj[0:1].astype(np.float32) traj_[:,0] /= float(W) traj_[:,1] /= float(H) color = bremm(traj_) # print('color', color) color = (color[0]*255).astype(np.uint8) color = (int(color[0]),int(color[1]),int(color[2])) for s in range(S): if cmap is not None: color_map = cm.get_cmap(cmap) # color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb # color = color.astype(np.uint8) # color = (color[0], color[1], color[2]) # print('color', color) # import ipdb; ipdb.set_trace() cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+2, color, -1) vis_color = int(np.squeeze(vis[s])*255) vis_color = (vis_color,vis_color,vis_color) cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1) return rgbs def summ_traj_as_crops(self, name, trajs_e, rgbs, frame_id=None, frame_str=None, only_return=False, show_circ=True, trajs_g=None, valids_g=None, is_g=False, anchor_ind=None, ara=None, anchors=None, frame_ids=None, frame_strs=None): B, S, N, D = trajs_e.shape assert(N==1) assert(D==2) rgbs = back2color(rgbs).detach().cpu().byte().numpy() rgbs_vis = [] n = 0 pad_amount = 128 trajs_e_py = trajs_e[0].detach().cpu().numpy() # trajs_e_py = np.clip(trajs_e_py, min=pad_amount/2, max=pad_amoun trajs_e_py = trajs_e_py + pad_amount if trajs_g is not None: trajs_g_py = trajs_g[0].detach().cpu().numpy() trajs_g_py = trajs_g_py + pad_amount if valids_g is not None: valids_g_py = valids_g[0].detach().cpu().numpy() else: valids_g_py = np.ones_like(trajs_g_py[:,:,:,0]) for s in range(S): rgb = rgbs[0,s] # print('orig rgb', rgb.shape) rgb = np.transpose(rgb,(1,2,0)) # H, W, 3 rgb = np.pad(rgb, ((pad_amount,pad_amount),(pad_amount,pad_amount),(0,0))) # print('pad rgb', rgb.shape) H, W, C = rgb.shape if trajs_g is not None: xy_g = trajs_g_py[s,n] xy_g[0] = np.clip(xy_g[0], pad_amount, W-pad_amount) xy_g[1] = np.clip(xy_g[1], pad_amount, H-pad_amount) if valids_g_py[s,n] > 0: rgb = self.draw_circs_on_image_py(rgb, xy_g.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3) xy_e = trajs_e_py[s,n] xy_e[0] = np.clip(xy_e[0], pad_amount, W-pad_amount) xy_e[1] = np.clip(xy_e[1], pad_amount, H-pad_amount) if show_circ: # if (anchors is not None) and (s==anchor_ind): # rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,0,0)], linewidth=8, radius=12) if (anchors is not None) and (s in anchors): rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,255,255)], linewidth=4, radius=8) if is_g: rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3) else: rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,0,255)], linewidth=2, radius=3) xmin = int(xy_e[0])-pad_amount//2 xmax = int(xy_e[0])+pad_amount//2 ymin = int(xy_e[1])-pad_amount//2 ymax = int(xy_e[1])+pad_amount//2 rgb_ = rgb[ymin:ymax, xmin:xmax] H_, W_ = rgb_.shape[:2] # if np.any(rgb_.shape==0): # input() if H_==0 or W_==0: import ipdb; ipdb.set_trace() if (ara is not None) and (s in ara): # green border rgb_[0,:,0] = 0 rgb_[0,:,1] = 255 rgb_[0,:,2] = 0 rgb_[-1,:,0] = 0 rgb_[-1,:,1] = 255 rgb_[-1,:,2] = 0 rgb_[:,0,0] = 0 rgb_[:,0,1] = 255 rgb_[:,0,2] = 0 rgb_[:,-1,0] = 0 rgb_[:,-1,1] = 255 rgb_[:,-1,2] = 0 if (anchor_ind is not None) and (s==anchor_ind): # inner green border pad = 4 rgb_[:pad,:,0] = 0 rgb_[:pad,:,1] = 255 rgb_[:pad,:,2] = 0 rgb_[-pad:,:,0] = 0 rgb_[-pad:,:,1] = 255 rgb_[-pad:,:,2] = 0 rgb_[:,:pad,0] = 0 rgb_[:,:pad,1] = 255 rgb_[:,:pad,2] = 0 rgb_[:,-pad:,0] = 0 rgb_[:,-pad:,1] = 255 rgb_[:,-pad:,2] = 0 rgb_ = rgb_.transpose(2,0,1) rgb_ = torch.from_numpy(rgb_) if frame_ids is not None: # if s==anchor_ind: # frame_ids[s] = frame_ids[s] + ' rgb_ = draw_frame_id_on_vis(rgb_.unsqueeze(0), frame_ids[s]).squeeze(0) if s==anchor_ind: rgb_ = draw_frame_str_on_vis(rgb_.unsqueeze(0), '(A)').squeeze(0) if frame_strs is not None: rgb_ = draw_frame_str_on_vis(rgb_.unsqueeze(0), frame_strs[s]).squeeze(0) rgbs_vis.append(rgb_) # nrow = int(np.sqrt(S)*(16.0/9)/2.0) nrow = int(np.sqrt(S)*1.5) grid_img = torchvision.utils.make_grid(torch.stack(rgbs_vis, dim=0), nrow=nrow).unsqueeze(0) # print('grid_img', grid_img.shape) return self.summ_rgb(name, grid_img.byte(), frame_id=frame_id, frame_str=frame_str, only_return=only_return) def summ_pts_on_rgb(self, name, trajs, rgb, visibs=None, valids=None, frame_id=None, frame_str=None, only_return=False, show_dots=True, colors=None, cmap='coolwarm', linewidth=1, max_show=1024, already_sorted=False): # trajs is B, S, N, 2 # rgbs is B, S, C, H, W B, C, H, W = rgb.shape B, S, N, D = trajs.shape rgb = rgb[0] # C, H, W trajs = trajs[0] # S, N, 2 if valids is None: valids = torch.ones_like(trajs[:,:,0]) # S, N else: valids = valids[0] if visibs is None: visibs = torch.ones_like(trajs[:,:,0]) # S, N else: visibs = visibs[0] trajs = trajs.clamp(-16, W+16) if N > max_show: inds = np.random.choice(N, max_show) trajs = trajs[:,inds] valids = valids[:,inds] visibs = visibs[:,inds] N = trajs.shape[1] if not already_sorted: inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0)) trajs = trajs[:,inds] valids = valids[:,inds] visibs = visibs[:,inds] rgb = back2color(rgb).detach().cpu().numpy() rgb = np.transpose(rgb, [1, 2, 0]) # put channels last trajs = trajs.long().detach().cpu().numpy() # S, N, 2 valids = valids.long().detach().cpu().numpy() # S, N visibs = visibs.long().detach().cpu().numpy() # S, N rgb = rgb.astype(np.uint8).copy() for i in range(min(N, max_show)): if cmap=='onediff' and i==0: cmap_ = 'spring' elif cmap=='onediff': cmap_ = 'winter' else: cmap_ = cmap traj = trajs[:,i] # S,2 valid = valids[:,i] # S visib = visibs[:,i] # S if colors is None: ii = i/(1e-4+N-1.0) color_map = cm.get_cmap(cmap) color = np.array(color_map(ii)[:3]) * 255 # rgb else: color = np.array(colors[i]).astype(np.int64) color = (int(color[0]),int(color[1]),int(color[2])) for s in range(S): if valid[s]: if visib[s]: thickness = -1 else: thickness = 2 cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, thickness) rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0) rgb = preprocess_color(rgb) return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str) def summ_pts_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', colors=None, linewidth=1, max_show=1024, frame_strs=None): # trajs is B, S, N, 2 # rgbs is B, S, C, H, W B, S, C, H, W = rgbs.shape B, S2, N, D = trajs.shape assert(S==S2) rgbs = rgbs[0] # S, C, H, W trajs = trajs[0] # S, N, 2 if valids is None: valids = torch.ones_like(trajs[:,:,0]) # S, N else: valids = valids[0] if visibs is None: visibs = torch.ones_like(trajs[:,:,0]) # S, N else: visibs = visibs[0] if N > max_show: inds = np.random.choice(N, max_show) trajs = trajs[:,inds] valids = valids[:,inds] visibs = visibs[:,inds] N = trajs.shape[1] inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0)) trajs = trajs[:,inds] valids = valids[:,inds] visibs = visibs[:,inds] rgbs_color = [] for rgb in rgbs: rgb = back2color(rgb).detach().cpu().numpy() rgb = np.transpose(rgb, [1, 2, 0]) # put channels last rgbs_color.append(rgb) # each element 3 x H x W trajs = trajs.long().detach().cpu().numpy() # S, N, 2 valids = valids.long().detach().cpu().numpy() # S, N visibs = visibs.long().detach().cpu().numpy() # S, N rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color] for i in range(min(N, max_show)): traj = trajs[:,i] # S,2 valid = valids[:,i] # S visib = visibs[:,i] # S if colors is None: ii = i/(1e-4+N-1.0) color_map = cm.get_cmap(cmap) color = np.array(color_map(ii)[:3]) * 255 # rgb else: color = np.array(colors[i]).astype(np.int64) color = (int(color[0]),int(color[1]),int(color[2])) for s in range(S): if valid[s]: if visib[s]: thickness = -1 else: thickness = 2 cv2.circle(rgbs_color[s], (int(traj[s,0]), int(traj[s,1])), int(linewidth), color, thickness) rgbs = [] for rgb in rgbs_color: rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) rgbs.append(preprocess_color(rgb)) return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs) def erode2d(im, times=1): device = im.device weights2d = torch.ones(1, 1, 3, 3, dtype=im.dtype, device=device) for time in range(times): im = 1.0 - F.conv2d(1.0 - im, weights2d, padding=1).clamp(0, 1) return im def dilate2d(im, times=1): device = im.device assert(times>0) dilation_kernel_size = times*2 + 1 padding_size = dilation_kernel_size // 2 dilation_kernel = torch.ones((1, 1, dilation_kernel_size, dilation_kernel_size), device=device) im = F.conv2d(im, dilation_kernel, padding=padding_size, groups=im.shape[1]).clamp(0,1) return im