# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, check out LICENSE.md """Utils for the few shot vid2vid model.""" import random import numpy as np import torch import torch.nn.functional as F def resample(image, flow): r"""Resamples an image using the provided flow. Args: image (NxCxHxW tensor) : Image to resample. flow (Nx2xHxW tensor) : Optical flow to resample the image. Returns: output (NxCxHxW tensor) : Resampled image. """ assert flow.shape[1] == 2 b, c, h, w = image.size() grid = get_grid(b, (h, w)) flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) final_grid = (grid + flow).permute(0, 2, 3, 1) try: output = F.grid_sample(image, final_grid, mode='bilinear', padding_mode='border', align_corners=True) except Exception: output = F.grid_sample(image, final_grid, mode='bilinear', padding_mode='border') return output def get_grid(batchsize, size, minval=-1.0, maxval=1.0): r"""Get a grid ranging [-1, 1] of 2D/3D coordinates. Args: batchsize (int) : Batch size. size (tuple) : (height, width) or (depth, height, width). minval (float) : minimum value in returned grid. maxval (float) : maximum value in returned grid. Returns: t_grid (4D tensor) : Grid of coordinates. """ if len(size) == 2: rows, cols = size elif len(size) == 3: deps, rows, cols = size else: raise ValueError('Dimension can only be 2 or 3.') x = torch.linspace(minval, maxval, cols) x = x.view(1, 1, 1, cols) x = x.expand(batchsize, 1, rows, cols) y = torch.linspace(minval, maxval, rows) y = y.view(1, 1, rows, 1) y = y.expand(batchsize, 1, rows, cols) t_grid = torch.cat([x, y], dim=1) if len(size) == 3: z = torch.linspace(minval, maxval, deps) z = z.view(1, 1, deps, 1, 1) z = z.expand(batchsize, 1, deps, rows, cols) t_grid = t_grid.unsqueeze(2).expand(batchsize, 2, deps, rows, cols) t_grid = torch.cat([t_grid, z], dim=1) t_grid.requires_grad = False return t_grid.to('cuda') def pick_image(images, idx): r"""Pick the image among images according to idx. Args: images (B x N x C x H x W tensor or list of tensors) : N images. idx (B tensor) : indices to select. Returns: image (B x C x H x W) : Selected images. """ if type(images) == list: return [pick_image(r, idx) for r in images] if idx is None: return images[:, 0] elif type(idx) == int: return images[:, idx] idx = idx.long().view(-1, 1, 1, 1, 1) image = images.gather(1, idx.expand_as(images)[:, 0:1])[:, 0] return image def crop_face_from_data(cfg, is_inference, data): r"""Crop the face regions in input data and resize to the target size. This is for training face datasets. Args: cfg (obj): Data configuration. is_inference (bool): Is doing inference or not. data (dict): Input data. Returns: data (dict): Cropped data. """ label = data['label'] if 'label' in data else None image = data['images'] landmarks = data['landmarks-dlib68_xy'] ref_labels = data['few_shot_label'] if 'few_shot_label' in data else None ref_images = data['few_shot_images'] ref_landmarks = data['few_shot_landmarks-dlib68_xy'] img_size = image.shape[-2:] h, w = cfg.output_h_w.split(',') h, w = int(h), int(w) # When doing inference, need to sync common attributes like crop coodinates # between different workers, so all workers crop the same region. if 'common_attr' in data and 'crop_coords' in data['common_attr']: # Has been computed before, reusing the previous one. crop_coords, ref_crop_coords = data['common_attr']['crop_coords'] else: # Is the first frame, need to compute the bbox. ref_crop_coords, scale = get_face_bbox_for_data( ref_landmarks[0], img_size, None, is_inference) crop_coords, _ = get_face_bbox_for_data( landmarks[0], img_size, scale, is_inference) # Crop the images according to the bbox and resize them to target size. label, image = crop_and_resize([label, image], crop_coords, (h, w)) ref_labels, ref_images = crop_and_resize([ref_labels, ref_images], ref_crop_coords, (h, w)) data['images'], data['few_shot_images'] = image, ref_images if label is not None: data['label'], data['few_shot_label'] = label, ref_labels if is_inference: if 'common_attr' not in data: data['common_attr'] = dict() data['common_attr']['crop_coords'] = crop_coords, ref_crop_coords return data def get_face_bbox_for_data(keypoints, orig_img_size, scale, is_inference): r"""Get the bbox coordinates for face region. Args: keypoints (Nx2 tensor): Facial landmarks. orig_img_size (int tuple): Height and width of the input image size. scale (float): When training, randomly scale the crop size for augmentation. is_inference (bool): Is doing inference or not. Returns: crop_coords (list of int): bbox for face region. scale (float): Also returns scale to ensure reference and target frames are croppped using the same scale. """ min_y, max_y = int(keypoints[:, 1].min()), int(keypoints[:, 1].max()) min_x, max_x = int(keypoints[:, 0].min()), int(keypoints[:, 0].max()) x_cen, y_cen = (min_x + max_x) // 2, (min_y + max_y) // 2 H, W = orig_img_size w = h = (max_x - min_x) if not is_inference: # During training, randomly jitter the cropping position by offset # amount for augmentation. offset_max = 0.2 offset = [np.random.uniform(-offset_max, offset_max), np.random.uniform(-offset_max, offset_max)] # Also augment the crop size. if scale is None: scale_max = 0.2 scale = [np.random.uniform(1 - scale_max, 1 + scale_max), np.random.uniform(1 - scale_max, 1 + scale_max)] w *= scale[0] h *= scale[1] x_cen += int(offset[0] * w) y_cen += int(offset[1] * h) # Get the cropping coordinates. x_cen = max(w, min(W - w, x_cen)) y_cen = max(h * 1.25, min(H - h * 0.75, y_cen)) min_x = x_cen - w min_y = y_cen - h * 1.25 max_x = min_x + w * 2 max_y = min_y + h * 2 crop_coords = [min_y, max_y, min_x, max_x] return [int(x) for x in crop_coords], scale def crop_person_from_data(cfg, is_inference, data): r"""Crop the person regions in data and resize to the target size. This is for training full body datasets. Args: cfg (obj): Data configuration. is_inference (bool): Is doing inference or not. data (dict): Input data. Returns: data (dict): Cropped data. """ label = data['label'] image = data['images'] use_few_shot = 'few_shot_label' in data if use_few_shot: ref_labels = data['few_shot_label'] ref_images = data['few_shot_images'] img_size = image.shape[-2:] output_h, output_w = cfg.output_h_w.split(',') output_h, output_w = int(output_h), int(output_w) output_aspect_ratio = output_w / output_h if 'human_instance_maps' in data: # Remove other people in the DensePose map except for the current # target. label = remove_other_ppl(label, data['human_instance_maps']) if use_few_shot: ref_labels = remove_other_ppl(ref_labels, data['few_shot_human_instance_maps']) # Randomly jitter the crop position by offset amount for augmentation. offset = ref_offset = None if not is_inference: offset = np.random.randn(2) * 0.05 offset = np.minimum(1, np.maximum(-1, offset)) ref_offset = np.random.randn(2) * 0.02 ref_offset = np.minimum(1, np.maximum(-1, ref_offset)) # Randomly scale the crop size for augmentation. # Final cropped size = person height * scale. scale = ref_scale = 1.5 if not is_inference: scale = min(2, max(1, scale + np.random.randn() * 0.05)) ref_scale = min(2, max(1, ref_scale + np.random.randn() * 0.02)) # When doing inference, need to sync common attributes like crop coodinates # between different workers, so all workers crop the same region. if 'common_attr' in data: # Has been computed before, reusing the previous one. crop_coords, ref_crop_coords = data['common_attr']['crop_coords'] else: # Is the first frame, need to compute the bbox. crop_coords = get_person_bbox_for_data(label, img_size, scale, output_aspect_ratio, offset) if use_few_shot: ref_crop_coords = get_person_bbox_for_data( ref_labels, img_size, ref_scale, output_aspect_ratio, ref_offset) else: ref_crop_coords = None # Crop the images according to the bbox and resize them to target size. label = crop_and_resize(label, crop_coords, (output_h, output_w), 'nearest') image = crop_and_resize(image, crop_coords, (output_h, output_w)) if use_few_shot: ref_labels = crop_and_resize(ref_labels, ref_crop_coords, (output_h, output_w), 'nearest') ref_images = crop_and_resize(ref_images, ref_crop_coords, (output_h, output_w)) data['label'], data['images'] = label, image if use_few_shot: data['few_shot_label'], data['few_shot_images'] = ref_labels, ref_images if 'human_instance_maps' in data: del data['human_instance_maps'] if 'few_shot_human_instance_maps' in data: del data['few_shot_human_instance_maps'] if is_inference: data['common_attr'] = dict() data['common_attr']['crop_coords'] = crop_coords, ref_crop_coords return data def get_person_bbox_for_data(pose_map, orig_img_size, scale=1.5, crop_aspect_ratio=1, offset=None): r"""Get the bbox (pixel coordinates) to crop for person body region. Args: pose_map (NxCxHxW tensor): Input pose map. orig_img_size (int tuple): Height and width of the input image size. scale (float): When training, randomly scale the crop size for augmentation. crop_aspect_ratio (float): Output aspect ratio, offset (list of float): Offset for crop position. Returns: crop_coords (list of int): bbox for body region. """ H, W = orig_img_size assert pose_map.dim() == 4 nonzero_indices = (pose_map[:, :3] > 0).nonzero(as_tuple=False) if nonzero_indices.size(0) == 0: bw = int(H * crop_aspect_ratio // 2) return [0, H, W // 2 - bw, W // 2 + bw] y_indices, x_indices = nonzero_indices[:, 2], nonzero_indices[:, 3] y_min, y_max = y_indices.min().item(), y_indices.max().item() x_min, x_max = x_indices.min().item(), x_indices.max().item() y_cen = int(y_min + y_max) // 2 x_cen = int(x_min + x_max) // 2 y_len = y_max - y_min x_len = x_max - x_min # bh, bw: half of height / width of final cropped size. bh = int(min(H, max(H // 2, y_len * scale))) // 2 bh = max(bh, int(x_len * scale / crop_aspect_ratio) // 2) bw = int(bh * crop_aspect_ratio) # Randomly offset the cropped position for augmentation. if offset is not None: x_cen += int(offset[0] * bw) y_cen += int(offset[1] * bh) x_cen = max(bw, min(W - bw, x_cen)) y_cen = max(bh, min(H - bh, y_cen)) return [(y_cen - bh), (y_cen + bh), (x_cen - bw), (x_cen + bw)] def crop_and_resize(img, coords, size=None, method='bilinear'): r"""Crop the image using the given coordinates and resize to target size. Args: img (tensor or list of tensors): Input image. coords (list of int): Pixel coordinates to crop. size (list of int): Output size. method (str): Interpolation method. Returns: img (tensor or list of tensors): Output image. """ if isinstance(img, list): return [crop_and_resize(x, coords, size, method) for x in img] if img is None: return None min_y, max_y, min_x, max_x = coords img = img[:, :, min_y:max_y, min_x:max_x] if size is not None: if method == 'nearest': img = F.interpolate(img, size=size, mode=method) else: img = F.interpolate(img, size=size, mode=method, align_corners=False) return img def remove_other_ppl(labels, densemasks): r"""Remove other people in the label map except for the current target by looking at the id in the densemask map. Args: labels (NxCxHxW tensor): Input labels. densemasks (Nx1xHxW tensor): Densemask maps. Returns: labels (NxCxHxW tensor): Output labels. """ densemasks = densemasks[:, 0:1] * 255 for idx in range(labels.shape[0]): label, densemask = labels[idx], densemasks[idx] # Get OpenPose and find the person id in Densemask that has the most # overlap with the person in OpenPose result. openpose = label[3:] valid = (openpose[0] > 0) | (openpose[1] > 0) | (openpose[2] > 0) dp_valid = densemask[valid.unsqueeze(0)] if dp_valid.shape[0]: ind = np.bincount(dp_valid).argmax() # Remove all other people that have different indices. label = label * (densemask == ind).float() labels[idx] = label return labels def select_object(data, obj_indices=None): r"""Select the object/person in the dict according to the object index. Currently it's used to select the target person in OpenPose dict. Args: data (dict): Input data. obj_indices (list of int): Indices for the objects to select. Returns: data (dict): Output data. """ op_keys = ['poses-openpose', 'captions-clip'] for op_key in op_keys: if op_key in data: for i in range(len(data[op_key])): # data[op_key] is a list of dicts for different frames. # people = data[op_key][i]['people'] people = data[op_key][i] # "people" is a list of people dicts found by OpenPose. We will # use the obj_index to get the target person from the list, and # write it back to the dict. # data[op_key][i]['people'] = [people[obj_indices[i]]] if obj_indices is not None: data[op_key][i] = people[obj_indices[i]] else: if op_key == 'poses-openpose': data[op_key][i] = people[0] else: idx = random.randint(0, len(people) - 1) data[op_key][i] = people[idx] return data def concat_frames(prev, now, n_frames): r"""Concat previous and current frames and only keep the latest $(n_frames). If concatenated frames are longer than $(n_frames), drop the oldest one. Args: prev (NxTxCxHxW tensor): Tensor for previous frames. now (NxCxHxW tensor): Tensor for current frame. n_frames (int): Max number of frames to store. Returns: result (NxTxCxHxW tensor): Updated tensor. """ now = now.unsqueeze(1) if prev is None: return now if prev.shape[1] == n_frames: prev = prev[:, 1:] return torch.cat([prev, now], dim=1) def combine_fg_mask(fg_mask, ref_fg_mask, has_fg): r"""Get the union of target and reference foreground masks. Args: fg_mask (tensor): Foreground mask for target image. ref_fg_mask (tensor): Foreground mask for reference image. has_fg (bool): Whether the image can be classified into fg/bg. Returns: output (tensor or int): Combined foreground mask. """ return ((fg_mask > 0) | (ref_fg_mask > 0)).float() if has_fg else 1 def get_fg_mask(densepose_map, has_fg): r"""Obtain the foreground mask for pose sequences, which only includes the human. This is done by looking at the body part map from DensePose. Args: densepose_map (NxCxHxW tensor): DensePose map. has_fg (bool): Whether data has foreground or not. Returns: mask (Nx1xHxW tensor): fg mask. """ if type(densepose_map) == list: return [get_fg_mask(label, has_fg) for label in densepose_map] if not has_fg or densepose_map is None: return 1 if densepose_map.dim() == 5: densepose_map = densepose_map[:, 0] # Get the body part map from DensePose. mask = densepose_map[:, 2:3] # Make the mask slightly larger. mask = torch.nn.MaxPool2d(15, padding=7, stride=1)(mask) mask = (mask > -1).float() return mask def get_part_mask(densepose_map): r"""Obtain mask of different body parts of humans. This is done by looking at the body part map from DensePose. Args: densepose_map (NxCxHxW tensor): DensePose map. Returns: mask (NxKxHxW tensor): Body part mask, where K is the number of parts. """ # Groups of body parts. Each group contains IDs of body part labels in # DensePose. The 9 groups here are: background, torso, hands, feet, # upper legs, lower legs, upper arms, lower arms, head. part_groups = [[0], [1, 2], [3, 4], [5, 6], [7, 9, 8, 10], [11, 13, 12, 14], [15, 17, 16, 18], [19, 21, 20, 22], [23, 24]] n_parts = len(part_groups) need_reshape = densepose_map.dim() == 4 if need_reshape: bo, t, h, w = densepose_map.size() densepose_map = densepose_map.view(-1, h, w) b, h, w = densepose_map.size() part_map = (densepose_map / 2 + 0.5) * 24 assert (part_map >= 0).all() and (part_map < 25).all() mask = torch.cuda.ByteTensor(b, n_parts, h, w).fill_(0) for i in range(n_parts): for j in part_groups[i]: # Account for numerical errors. mask[:, i] = mask[:, i] | ( (part_map > j - 0.1) & (part_map < j + 0.1)).byte() if need_reshape: mask = mask.view(bo, t, -1, h, w) return mask.float() def get_face_mask(densepose_map): r"""Obtain mask of faces. Args: densepose_map (3D or 4D tensor): DensePose map. Returns: mask (3D or 4D tensor): Face mask. """ need_reshape = densepose_map.dim() == 4 if need_reshape: bo, t, h, w = densepose_map.size() densepose_map = densepose_map.view(-1, h, w) b, h, w = densepose_map.size() part_map = (densepose_map / 2 + 0.5) * 24 assert (part_map >= 0).all() and (part_map < 25).all() if densepose_map.is_cuda: mask = torch.cuda.ByteTensor(b, h, w).fill_(0) else: mask = torch.ByteTensor(b, h, w).fill_(0) for j in [23, 24]: mask = mask | ((part_map > j - 0.1) & (part_map < j + 0.1)).byte() if need_reshape: mask = mask.view(bo, t, h, w) return mask.float() def extract_valid_pose_labels(pose_map, pose_type, remove_face_labels, do_remove=True): r"""Remove some labels (e.g. face regions) in the pose map if necessary. Args: pose_map (3D, 4D or 5D tensor): Input pose map. pose_type (str): 'both' or 'open'. remove_face_labels (bool): Whether to remove labels for the face region. do_remove (bool): Do remove face labels. Returns: pose_map (3D, 4D or 5D tensor): Output pose map. """ if pose_map is None: return pose_map if type(pose_map) == list: return [extract_valid_pose_labels(p, pose_type, remove_face_labels, do_remove) for p in pose_map] orig_dim = pose_map.dim() assert (orig_dim >= 3 and orig_dim <= 5) if orig_dim == 3: pose_map = pose_map.unsqueeze(0).unsqueeze(0) elif orig_dim == 4: pose_map = pose_map.unsqueeze(0) if pose_type == 'open': # If input is only openpose, remove densepose part. pose_map = pose_map[:, :, 3:] elif remove_face_labels and do_remove: # Remove face part for densepose input. densepose, openpose = pose_map[:, :, :3], pose_map[:, :, 3:] face_mask = get_face_mask(pose_map[:, :, 2]).unsqueeze(2) pose_map = torch.cat([densepose * (1 - face_mask) - face_mask, openpose], dim=2) if orig_dim == 3: pose_map = pose_map[0, 0] elif orig_dim == 4: pose_map = pose_map[0] return pose_map def normalize_faces(keypoints, ref_keypoints, dist_scale_x=None, dist_scale_y=None): r"""Normalize face keypoints w.r.t. the reference face keypoints. Args: keypoints (Kx2 numpy array): target facial keypoints. ref_keypoints (Kx2 numpy array): reference facial keypoints. Returns: keypoints (Kx2 numpy array): normalized facial keypoints. """ if keypoints.shape[0] == 68: central_keypoints = [8] add_upper_face = False part_list = [[0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10], [7, 9, 8], [17, 26], [18, 25], [19, 24], [20, 23], [21, 22], [27], [28], [29], [30], [31, 35], [32, 34], [33], [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46], [48, 54], [49, 53], [50, 52], [51], [55, 59], [56, 58], [57], [60, 64], [61, 63], [62], [65, 67], [66] ] if add_upper_face: part_list += [[68, 82], [69, 81], [70, 80], [71, 79], [72, 78], [73, 77], [74, 76, 75]] elif keypoints.shape[0] == 126: central_keypoints = [16] part_list = [[i] for i in range(126)] else: raise ValueError('Input keypoints type not supported.') face_cen = np.mean(keypoints[central_keypoints, :], axis=0) ref_face_cen = np.mean(ref_keypoints[central_keypoints, :], axis=0) def get_mean_dists(pts, face_cen): r"""Get the mean xy distances of keypoints wrt face center.""" mean_dists_x, mean_dists_y = [], [] pts_cen = np.mean(pts, axis=0) for p, pt in enumerate(pts): mean_dists_x.append(np.linalg.norm(pt - pts_cen)) mean_dists_y.append(np.linalg.norm(pts_cen - face_cen)) mean_dist_x = sum(mean_dists_x) / len(mean_dists_x) + 1e-3 mean_dist_y = sum(mean_dists_y) / len(mean_dists_y) + 1e-3 return mean_dist_x, mean_dist_y if dist_scale_x is None: dist_scale_x, dist_scale_y = [None] * len(part_list), \ [None] * len(part_list) for i, pts_idx in enumerate(part_list): pts = keypoints[pts_idx] if dist_scale_x[i] is None: ref_pts = ref_keypoints[pts_idx] mean_dist_x, mean_dist_y = get_mean_dists(pts, face_cen) ref_dist_x, ref_dist_y = get_mean_dists(ref_pts, ref_face_cen) dist_scale_x[i] = ref_dist_x / mean_dist_x dist_scale_y[i] = ref_dist_y / mean_dist_y pts_cen = np.mean(pts, axis=0) pts = (pts - pts_cen) * dist_scale_x[i] + \ (pts_cen - face_cen) * dist_scale_y[i] + face_cen keypoints[pts_idx] = pts return keypoints, [dist_scale_x, dist_scale_y] def crop_face_from_output(data_cfg, image, input_label, crop_smaller=0): r"""Crop out the face region of the image (and resize if necessary to feed into generator/discriminator). Args: data_cfg (obj): Data configuration. image (NxC1xHxW tensor or list of tensors): Image to crop. input_label (NxC2xHxW tensor): Input label map. crop_smaller (int): Number of pixels to crop slightly smaller region. Returns: output (NxC1xHxW tensor or list of tensors): Cropped image. """ if type(image) == list: return [crop_face_from_output(data_cfg, im, input_label, crop_smaller) for im in image] output = None face_size = image.shape[-2] // 32 * 8 for i in range(input_label.size(0)): ys, ye, xs, xe = get_face_bbox_for_output(data_cfg, input_label[i:i + 1], crop_smaller=crop_smaller) output_i = F.interpolate(image[i:i + 1, -3:, ys:ye, xs:xe], size=(face_size, face_size), mode='bilinear', align_corners=True) # output_i = image[i:i + 1, -3:, ys:ye, xs:xe] output = torch.cat([output, output_i]) if i != 0 else output_i return output def get_face_bbox_for_output(data_cfg, pose, crop_smaller=0): r"""Get pixel coordinates of the face bounding box. Args: data_cfg (obj): Data configuration. pose (NxCxHxW tensor): Pose label map. crop_smaller (int): Number of pixels to crop slightly smaller region. Returns: output (list of int): Face bbox. """ if pose.dim() == 3: pose = pose.unsqueeze(0) elif pose.dim() == 5: pose = pose[-1, -1:] _, _, h, w = pose.size() use_openpose = 'pose_maps-densepose' not in data_cfg.input_labels if use_openpose: # Use openpose face keypoints to identify face region. for input_type in data_cfg.input_types: if 'poses-openpose' in input_type: num_ch = input_type['poses-openpose'].num_channels if num_ch > 3: face = (pose[:, -1] > 0).nonzero(as_tuple=False) else: raise ValueError('Not implemented yet.') else: # Use densepose labels. face = (pose[:, 2] > 0.9).nonzero(as_tuple=False) ylen = xlen = h // 32 * 8 if face.size(0): y, x = face[:, 1], face[:, 2] ys, ye = y.min().item(), y.max().item() xs, xe = x.min().item(), x.max().item() if use_openpose: xc, yc = (xs + xe) // 2, (ys * 3 + ye * 2) // 5 ylen = int((xe - xs) * 2.5) else: xc, yc = (xs + xe) // 2, (ys + ye) // 2 ylen = int((ye - ys) * 1.25) ylen = xlen = min(w, max(32, ylen)) yc = max(ylen // 2, min(h - 1 - ylen // 2, yc)) xc = max(xlen // 2, min(w - 1 - xlen // 2, xc)) else: yc = h // 4 xc = w // 2 ys, ye = yc - ylen // 2, yc + ylen // 2 xs, xe = xc - xlen // 2, xc + xlen // 2 if crop_smaller != 0: # Crop slightly smaller region inside face. ys += crop_smaller xs += crop_smaller ye -= crop_smaller xe -= crop_smaller return [ys, ye, xs, xe] def crop_hand_from_output(data_cfg, image, input_label): r"""Crop out the hand region of the image. Args: data_cfg (obj): Data configuration. image (NxC1xHxW tensor or list of tensors): Image to crop. input_label (NxC2xHxW tensor): Input label map. Returns: output (NxC1xHxW tensor or list of tensors): Cropped image. """ if type(image) == list: return [crop_hand_from_output(data_cfg, im, input_label) for im in image] output = None for i in range(input_label.size(0)): coords = get_hand_bbox_for_output(data_cfg, input_label[i:i + 1]) if coords: for coord in coords: ys, ye, xs, xe = coord output_i = image[i:i + 1, -3:, ys:ye, xs:xe] output = torch.cat([output, output_i]) \ if output is not None else output_i return output def get_hand_bbox_for_output(data_cfg, pose): r"""Get coordinates of the hand bounding box. Args: data_cfg (obj): Data configuration. pose (NxCxHxW tensor): Pose label map. Returns: output (list of int): Hand bbox. """ if pose.dim() == 3: pose = pose.unsqueeze(0) elif pose.dim() == 5: pose = pose[-1, -1:] _, _, h, w = pose.size() ylen = xlen = h // 64 * 8 coords = [] colors = [[0.95, 0.5, 0.95], [0.95, 0.95, 0.5]] for i, color in enumerate(colors): if pose.shape[1] > 6: # Using one-hot encoding for openpose. idx = -3 if i == 0 else -2 hand = (pose[:, idx] == 1).nonzero(as_tuple=False) else: raise ValueError('Not implemented yet.') if hand.size(0): y, x = hand[:, 1], hand[:, 2] ys, ye, xs, xe = y.min().item(), y.max().item(), \ x.min().item(), x.max().item() xc, yc = (xs + xe) // 2, (ys + ye) // 2 yc = max(ylen // 2, min(h - 1 - ylen // 2, yc)) xc = max(xlen // 2, min(w - 1 - xlen // 2, xc)) ys, ye, xs, xe = yc - ylen // 2, yc + ylen // 2, \ xc - xlen // 2, xc + xlen // 2 coords.append([ys, ye, xs, xe]) return coords def pre_process_densepose(pose_cfg, pose_map, is_infer=False): r"""Pre-process the DensePose part of input label map. Args: pose_cfg (obj): Pose data configuration. pose_map (NxCxHxW tensor): Pose label map. is_infer (bool): Is doing inference. Returns: pose_map (NxCxHxW tensor): Processed pose label map. """ part_map = pose_map[:, :, 2] * 255 # should be within [0-24] assert (part_map >= 0).all() and (part_map < 25).all() # Randomly drop some body part during training. if not is_infer: random_drop_prob = getattr(pose_cfg, 'random_drop_prob', 0) else: random_drop_prob = 0 if random_drop_prob > 0: densepose_map = pose_map[:, :, :3] for part_id in range(1, 25): if (random.random() < random_drop_prob): part_mask = abs(part_map - part_id) < 0.1 densepose_map[part_mask.unsqueeze(2).expand_as( densepose_map)] = 0 pose_map[:, :, :3] = densepose_map # Renormalize the DensePose channel from [0, 24] to [0, 255]. pose_map[:, :, 2] = pose_map[:, :, 2] * (255 / 24) # Normalize from [0, 1] to [-1, 1]. pose_map = pose_map * 2 - 1 return pose_map def random_roll(tensors): r"""Randomly roll the input tensors along x and y dimensions. Also randomly flip the tensors. Args: tensors (list of 4D tensors): Input tensors. Returns: output (list of 4D tensors): Rolled tensors. """ h, w = tensors[0].shape[2:] ny = np.random.choice([np.random.randint(h//16), h-np.random.randint(h//16)]) nx = np.random.choice([np.random.randint(w//16), w-np.random.randint(w//16)]) flip = np.random.rand() > 0.5 return [roll(t, ny, nx, flip) for t in tensors] def roll(t, ny, nx, flip=False): r"""Roll and flip the tensor by specified amounts. Args: t (4D tensor): Input tensor. ny (int): Amount to roll along y dimension. nx (int): Amount to roll along x dimension. flip (bool): Whether to flip input. Returns: t (4D tensor): Output tensor. """ t = torch.cat([t[:, :, -ny:], t[:, :, :-ny]], dim=2) t = torch.cat([t[:, :, :, -nx:], t[:, :, :, :-nx]], dim=3) if flip: t = torch.flip(t, dims=[3]) return t def detach(output): r"""Detach tensors in the dict. Args: output (dict): Output dict. Returns: output (dict): Detached output dict. """ if type(output) == dict: new_dict = dict() for k, v in output.items(): new_dict[k] = detach(v) return new_dict elif type(output) == torch.Tensor: return output.detach() return output