import torch import numpy as np import math import torch.nn.functional as F import utils.basic from typing import Tuple, Union def standardize_test_data(rgbs, trajs, visibs, valids, S_cap=600, only_first=False, seq_len=None): trajs = trajs.astype(np.float32) # S,N,2 visibs = visibs.astype(np.float32) # S,N valids = valids.astype(np.float32) # S,N visval_ok = np.sum(valids*visibs, axis=0) > 1 trajs = trajs[:,visval_ok] visibs = visibs[:,visval_ok] valids = valids[:,visval_ok] # fill in missing data N = trajs.shape[1] for ni in range(N): trajs[:,ni] = utils.misc.data_replace_with_nearest(trajs[:,ni], valids[:,ni]) # use cap or seq_len if seq_len is not None: S = min(len(rgbs), seq_len) else: S = len(rgbs) S = min(S, S_cap) if only_first: # we'll find the best frame to start on best_count = 0 best_ind = 0 for si in range(0,len(rgbs)-64): # try this slice visibs_ = visibs[si:min(si+S,len(rgbs)+1)] # S,N valids_ = valids[si:min(si+S,len(rgbs)+1)] # S,N visval_ok0 = (visibs_[0]*valids_[0]) > 0 # N visval_okA = np.sum(visibs_*valids_, axis=0) > 1 # N all_ok = visval_ok0 & visval_okA # print('- slicing %d to %d; sum(ok) %d' % (si, min(si+S,len(rgbs)+1), np.sum(all_ok))) if np.sum(all_ok) > best_count: best_count = np.sum(all_ok) best_ind = si si = best_ind rgbs = rgbs[si:si+S] trajs = trajs[si:si+S] visibs = visibs[si:si+S] valids = valids[si:si+S] vis_ok0 = visibs[0] > 0 # N trajs = trajs[:,vis_ok0] visibs = visibs[:,vis_ok0] valids = valids[:,vis_ok0] # print('- best_count', best_count, 'best_ind', best_ind) if seq_len is not None: rgbs = rgbs[:seq_len] trajs = trajs[:seq_len] valids = valids[:seq_len] # req two timesteps valid (after seqlen trim) visval_ok = np.sum(visibs*valids, axis=0) > 1 trajs = trajs[:,visval_ok] valids = valids[:,visval_ok] visibs = visibs[:,visval_ok] return rgbs, trajs, visibs, valids def get_2d_sincos_pos_embed( embed_dim: int, grid_size: Union[int, Tuple[int, int]] ) -> torch.Tensor: """ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. It is a wrapper of get_2d_sincos_pos_embed_from_grid. Args: - embed_dim: The embedding dimension. - grid_size: The grid size. Returns: - pos_embed: The generated 2D positional embedding. """ if isinstance(grid_size, tuple): grid_size_h, grid_size_w = grid_size else: grid_size_h = grid_size_w = grid_size grid_h = torch.arange(grid_size_h, dtype=torch.float) grid_w = torch.arange(grid_size_w, dtype=torch.float) grid = torch.meshgrid(grid_w, grid_h, indexing="xy") grid = torch.stack(grid, dim=0) grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) def get_2d_sincos_pos_embed_from_grid( embed_dim: int, grid: torch.Tensor ) -> torch.Tensor: """ This function generates a 2D positional embedding from a given grid using sine and cosine functions. Args: - embed_dim: The embedding dimension. - grid: The grid to generate the embedding from. Returns: - emb: The generated 2D positional embedding. """ assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid( embed_dim: int, pos: torch.Tensor ) -> torch.Tensor: """ This function generates a 1D positional embedding from a given grid using sine and cosine functions. Args: - embed_dim: The embedding dimension. - pos: The position to generate the embedding from. Returns: - emb: The generated 1D positional embedding. """ assert embed_dim % 2 == 0 omega = torch.arange(embed_dim // 2, dtype=torch.double) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = torch.sin(out) # (M, D/2) emb_cos = torch.cos(out) # (M, D/2) emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) return emb[None].float() def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: """ This function generates a 2D positional embedding from given coordinates using sine and cosine functions. Args: - xy: The coordinates to generate the embedding from. - C: The size of the embedding. - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. Returns: - pe: The generated 2D positional embedding. """ B, N, D = xy.shape assert D == 2 x = xy[:, :, 0:1] y = xy[:, :, 1:2] div_term = ( torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) ).reshape(1, 1, int(C / 2)) pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) pe_x[:, :, 0::2] = torch.sin(x * div_term) pe_x[:, :, 1::2] = torch.cos(x * div_term) pe_y[:, :, 0::2] = torch.sin(y * div_term) pe_y[:, :, 1::2] = torch.cos(y * div_term) pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) if cat_coords: pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) return pe # from datasets.dataset import mask2bbox # from pips2 def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False): device = xy.device dtype = xy.dtype B, S, D = xy.shape assert(D==2) x = xy[:,:,0] y = xy[:,:,1] assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' omega = torch.arange(C // 4, device=device) / (C // 4 - 1) omega = 1. / (temperature ** omega) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) pe = pe.reshape(B,S,C).type(dtype) if cat_coords: pe = torch.cat([pe, xy], dim=2) # B,N,C+2 return pe # # prevent circular imports # def mask2bbox(mask): # if mask.ndim == 3: # mask = mask[..., 0] # ys, xs = np.where(mask > 0.4) # if ys.size == 0 or xs.size==0: # return np.array((0, 0, 0, 0), dtype=int) # lt = np.array([np.min(xs), np.min(ys)]) # rb = np.array([np.max(xs), np.max(ys)]) + 1 # return np.concatenate([lt, rb]) # def get_stark_2d_embedding(H, W, C=64, device='cuda:0', temperature=10000, normalize=True): # scale = 2*math.pi # mask = torch.ones((1,H,W), dtype=torch.float32, device=device) # y_embed = mask.cumsum(1, dtype=torch.float32) # cumulative sum along axis 1 (h axis) --> (b, h, w) # x_embed = mask.cumsum(2, dtype=torch.float32) # cumulative sum along axis 2 (w axis) --> (b, h, w) # if normalize: # eps = 1e-6 # y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale # 2pi * (y / sigma(y)) # x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale # 2pi * (x / sigma(x)) # dim_t = torch.arange(C, dtype=torch.float32, device=device) # (0,1,2,...,d/2) # dim_t = temperature ** (2 * (dim_t // 2) / C) # pos_x = x_embed[:, :, :, None] / dim_t # (b,h,w,d/2) # pos_y = y_embed[:, :, :, None] / dim_t # (b,h,w,d/2) # pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) # (b,h,w,d/2) # pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) # (b,h,w,d/2) # pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # (b,h,w,d) # return pos # def get_1d_embedding(x, C, cat_coords=False): # B, N, D = x.shape # assert(D==1) # div_term = (torch.arange(0, C, 2, device=x.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2)) # pe_x = torch.zeros(B, N, C, device=x.device, dtype=torch.float32) # pe_x[:, :, 0::2] = torch.sin(x * div_term) # pe_x[:, :, 1::2] = torch.cos(x * div_term) # if cat_coords: # pe_x = torch.cat([pe, x], dim=2) # B,N,C*2+2 # return pe_x # def posemb_sincos_2d(h, w, dim, temperature=10000, dtype=torch.float32, device='cuda:0'): # y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') # assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' # omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) # omega = 1. / (temperature ** omega) # y = y.flatten()[:, None] * omega[None, :] # x = x.flatten()[:, None] * omega[None, :] # pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) # B,C,H,W # return pe.type(dtype) def iou(bbox1, bbox2): # bbox1, bbox2: [x1, y1, x2, y2] x1, y1, x2, y2 = bbox1 x1_, y1_, x2_, y2_ = bbox2 inter_x1 = max(x1, x1_) inter_y1 = max(y1, y1_) inter_x2 = min(x2, x2_) inter_y2 = min(y2, y2_) inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) area1 = (x2 - x1) * (y2 - y1) area2 = (x2_ - x1_) * (y2_ - y1_) iou = inter_area / (area1 + area2 - inter_area) return iou # def get_2d_embedding(xy, C, cat_coords=False): # B, N, D = xy.shape # assert(D==2) # x = xy[:,:,0:1] # y = xy[:,:,1:2] # div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2)) # pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) # pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) # pe_x[:, :, 0::2] = torch.sin(x * div_term) # pe_x[:, :, 1::2] = torch.cos(x * div_term) # pe_y[:, :, 0::2] = torch.sin(y * div_term) # pe_y[:, :, 1::2] = torch.cos(y * div_term) # pe = torch.cat([pe_x, pe_y], dim=2) # B,N,C*2 # if cat_coords: # pe = torch.cat([pe, xy], dim=2) # B,N,C*2+2 # return pe # def get_3d_embedding(xyz, C, cat_coords=False): # B, N, D = xyz.shape # assert(D==3) # x = xyz[:,:,0:1] # y = xyz[:,:,1:2] # z = xyz[:,:,2:3] # div_term = (torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2)) # pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) # pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) # pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) # pe_x[:, :, 0::2] = torch.sin(x * div_term) # pe_x[:, :, 1::2] = torch.cos(x * div_term) # pe_y[:, :, 0::2] = torch.sin(y * div_term) # pe_y[:, :, 1::2] = torch.cos(y * div_term) # pe_z[:, :, 0::2] = torch.sin(z * div_term) # pe_z[:, :, 1::2] = torch.cos(z * div_term) # pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3 # if cat_coords: # pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3 # return pe class SimplePool(): def __init__(self, pool_size, version='pt', min_size=1): self.pool_size = pool_size self.version = version self.items = [] self.min_size = min_size if not (version=='pt' or version=='np'): print('version = %s; please choose pt or np') assert(False) # please choose pt or np def __len__(self): return len(self.items) def mean(self, min_size=None): if min_size is None: pool_size_thresh = self.min_size elif min_size=='half': pool_size_thresh = self.pool_size/2 else: pool_size_thresh = min_size if self.version=='np': if len(self.items) >= pool_size_thresh: return np.sum(self.items)/float(len(self.items)) else: return np.nan if self.version=='pt': if len(self.items) >= pool_size_thresh: return torch.sum(self.items)/float(len(self.items)) else: return torch.from_numpy(np.nan) def sample(self, with_replacement=True): idx = np.random.randint(len(self.items)) if with_replacement: return self.items[idx] else: return self.items.pop(idx) def fetch(self, num=None): if self.version=='pt': item_array = torch.stack(self.items) elif self.version=='np': item_array = np.stack(self.items) if num is not None: # there better be some items assert(len(self.items) >= num) # if there are not that many elements just return however many there are if len(self.items) < num: return item_array else: idxs = np.random.randint(len(self.items), size=num) return item_array[idxs] else: return item_array def is_full(self): full = len(self.items)==self.pool_size return full def empty(self): self.items = [] def have_min_size(self): return len(self.items) >= self.min_size def update(self, items): for item in items: if len(self.items) < self.pool_size: # the pool is not full, so let's add this in self.items.append(item) else: # the pool is full # pop from the front self.items.pop(0) # add to the back self.items.append(item) return self.items class SimpleHeap(): def __init__(self, pool_size, version='pt'): self.pool_size = pool_size self.version = version self.items = [] self.vals = [] if not (version=='pt' or version=='np'): print('version = %s; please choose pt or np') assert(False) # please choose pt or np def __len__(self): return len(self.items) def sample(self, random=True, with_replacement=True, semirandom=False): vals_arr = np.stack(self.vals) if random: ind = np.random.randint(len(self.items)) else: if semirandom and len(vals_arr)>1: # choose from the harder half inds = np.argsort(vals_arr) # ascending inds = inds[len(vals_arr)//2:] ind = np.random.choice(inds) else: # find the most valuable element ind = np.argmax(vals_arr) if with_replacement: return self.items[ind] else: item = self.items.pop(ind) val = self.vals.pop(ind) return item def fetch(self, num=None): if self.version=='pt': item_array = torch.stack(self.items) elif self.version=='np': item_array = np.stack(self.items) if num is not None: # there better be some items assert(len(self.items) >= num) # if there are not that many elements just return however many there are if len(self.items) < num: return item_array else: idxs = np.random.randint(len(self.items), size=num) return item_array[idxs] else: return item_array def is_full(self): full = len(self.items)==self.pool_size return full def empty(self): self.items = [] def update(self, vals, items): for val,item in zip(vals, items): if len(self.items) < self.pool_size: # the pool is not full, so let's add this in self.items.append(item) self.vals.append(val) else: # the pool is full # find our least-valuable element # and see if we should replace it vals_arr = np.stack(self.vals) ind = np.argmin(vals_arr) if vals_arr[ind] < val: # pop the min self.items.pop(ind) self.vals.pop(ind) # add to the back self.items.append(item) self.vals.append(val) return self.items def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False): """ Input: xyz: pointcloud data, [B, N, C], where C is probably 3 npoint: number of samples Return: inds: sampled pointcloud index, [B, npoint] """ device = xyz.device B, N, C = xyz.shape xyz = xyz.float() inds = torch.zeros(B, npoint, dtype=torch.long, device=device) distance = torch.ones((B, N), dtype=torch.float32, device=device) * 1e10 if deterministic: farthest = torch.randint(0, 1, (B,), dtype=torch.long, device=device) else: farthest = torch.randint(0, N, (B,), dtype=torch.long, device=device) batch_indices = torch.arange(B, dtype=torch.long, device=device) for i in range(npoint): if include_ends: if i==0: farthest = 0 elif i==1: farthest = N-1 inds[:, i] = farthest centroid = xyz[batch_indices, farthest, :].view(B, 1, C) dist = torch.sum((xyz - centroid) ** 2, -1) mask = dist < distance distance[mask] = dist[mask] farthest = torch.max(distance, -1)[1] if npoint > N: # if we need more samples, make them random distance += torch.randn_like(distance) return inds def farthest_point_sample_py(xyz, npoint, deterministic=False): N,C = xyz.shape inds = np.zeros(npoint, dtype=np.int32) distance = np.ones(N) * 1e10 if deterministic: farthest = 0 else: farthest = np.random.randint(0, N, dtype=np.int32) for i in range(npoint): inds[i] = farthest centroid = xyz[farthest, :].reshape(1,C) dist = np.sum((xyz - centroid) ** 2, -1) mask = dist < distance distance[mask] = dist[mask] farthest = np.argmax(distance, -1) if npoint > N: # if we need more samples, make them random distance += np.random.randn(*distance.shape) return inds def balanced_ce_loss(pred, gt, pos_weight=0.5, valid=None, dim=None, return_both=False, use_halfmask=False, H=64, W=64): # # pred and gt are the same shape # for (a,b) in zip(pred.size(), gt.size()): # if not a==b: # print('mismatch: pred, gt', pred.shape, gt.shape) # assert(a==b) # some shape mismatch! pred = pred.reshape(-1) gt = gt.reshape(-1) device = pred.device if valid is not None: valid = valid.reshape(-1) for (a,b) in zip(pred.size(), valid.size()): assert(a==b) # some shape mismatch! else: valid = torch.ones_like(gt) pos = (gt > 0.95).float() if use_halfmask: pos_wide = (gt >= 0.5).float() halfmask = (gt == 0.5).float() else: pos_wide = pos neg = (gt < 0.05).float() label = pos_wide*2.0 - 1.0 a = -label * pred b = F.relu(a) loss = b + torch.log(torch.exp(-b)+torch.exp(a-b)) if torch.sum(pos*valid)>0: pos_loss = loss[(pos*valid) > 0].mean() else: pos_loss = torch.tensor(0.0, requires_grad=True, device=device) if torch.sum(neg*valid)>0: neg_loss = loss[(neg*valid) > 0].mean() else: neg_loss = torch.tensor(0.0, requires_grad=True, device=device) balanced_loss = pos_weight*pos_loss + (1-pos_weight)*neg_loss return balanced_loss # pos_loss = utils.basic.reduce_masked_mean(loss, pos*valid, dim=dim) # neg_loss = utils.basic.reduce_masked_mean(loss, neg*valid, dim=dim) if use_halfmask: # here we will find the pixels which are already leaning positive, # and encourage them to be more positive B = loss.shape[0] loss_ = loss.reshape(B,-1) mask_ = halfmask.reshape(B,-1) * valid.reshape(B,-1) # to avoid the issue where spikes become spikier, # we will only apply this loss on batch els where we predicted zero positives pred_sig_ = torch.sigmoid(pred).reshape(B,-1) no_pred_ = torch.max(pred_sig_.round(), axis=1)[0] < 1 # B # and only on batch els where we have negatives available have_neg_ = torch.sum(neg, dim=1)>0 # B loss_ = loss_[no_pred_ & have_neg_] # N,H*W mask_ = mask_[no_pred_ & have_neg_] # N,H*W N = loss_.shape[0] if N > 0: # we want: # in the neg pixels, # set them to the max loss of the pos pixels, # so that they do not contribute to the min loss__ = loss_.reshape(-1) mask__ = mask_.reshape(-1) if torch.sum(mask__)>0: # print('loss_', loss_.shape, 'mask_', mask_.shape, 'loss__', loss__.shape, 'mask__', mask__.shape) mloss__ = loss__.detach() mloss__[mask__==0] = torch.max(loss__[mask__==1]) mloss_ = mloss__.reshape(N,H*W) # now, in each batch el, take a tiny region around the argmin, so we can boost this region minloss_mask_ = torch.zeros_like(mloss_).scatter(1,mloss_.argmin(1,True),value=1) minloss_mask_ = utils.improc.dilate2d(minloss_mask_.view(N,1,H,W), times=3).reshape(N,H*W) loss__ = loss_.reshape(-1) minloss_mask__ = minloss_mask_.reshape(-1) half_loss = loss__[minloss_mask__>0].mean() # print('N', N, 'half_loss', half_loss) pos_loss = pos_loss + half_loss # if False: # min_pos = 8 # # only apply the loss when we have some negatives available, # # otherwise it's a whole "ignore" frame, which may mean # # we are unsure if the target is even there # if torch.sum(mask__==0) > 0: # negatives available # # only apply the loss when the halfmask is larger area than # # min_pos (the number of pixels we want to boost), # # so that indexing will work # if torch.all(torch.sum(mask_==1, dim=1) >= min_pos): # topk indexing will work # # in the pixels we will not use, # # set them to the max of the pixels we may use, # # so that they do not contribute to the min # loss__[mask__==0] = torch.max(loss__[mask__==1]) # loss_ = loss__.reshape(B,-1) # half_loss = torch.mean(torch.topk(loss_, min_pos, dim=1, largest=False)[0], dim=1) # B # have_neg = (torch.sum(neg, dim=1)>0).float() # B # pos_loss = pos_loss + half_loss*have_neg # half_loss = [] # for b in range(B): # loss_b = loss_[b] # mask_b = mask_[b] # if torch.sum(mask_b): # inds = torch.nonzero(mask_b).reshape(-1) # half_loss.append(torch.min(loss_b[inds])) # if len(half_loss): # # # half_loss_ = half_loss.reshape(-1) # # half_loss = torch.min(half_loss, dim=1)[0] # B # # half_loss = torch.mean(torch.topk(half_loss, 4, dim=1, largest=False)[0], dim=1) # B # pos_loss = pos_loss + torch.stack(half_loss).mean() if return_both: return pos_loss, neg_loss balanced_loss = pos_weight*pos_loss + (1-pos_weight)*neg_loss return balanced_loss def dice_loss(pred, gt): # gt has ignores at 0.5 # pred and gt are the same shape for (a,b) in zip(pred.size(), gt.size()): assert(a==b) # some shape mismatch! prob = pred.sigmoid() # flatten everything except batch prob = prob.flatten(1) gt = gt.flatten(1) pos = (gt > 0.95).float() neg = (gt < 0.05).float() valid = (pos+neg).float().clamp(0,1) numerator = 2 * (prob * pos * valid).sum(1) denominator = (prob*valid).sum(1) + (pos*valid).sum(1) loss = 1 - (numerator + 1) / (denominator + 1) return loss def sigmoid_focal_loss(pred, gt, alpha=0.25, gamma=2):#, use_halfmask=False): # gt has ignores at 0.5 # pred and gt are the same shape for (a,b) in zip(pred.size(), gt.size()): assert(a==b) # some shape mismatch! # flatten everything except batch pred = pred.flatten(1) gt = gt.flatten(1) pos = (gt > 0.95).float() neg = (gt < 0.05).float() # if use_halfmask: # pos_wide = (gt >= 0.5).float() # halfmask = (gt == 0.5).float() # else: # pos_wide = pos valid = (pos+neg).float().clamp(0,1) prob = pred.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(pred, pos, reduction="none") p_t = prob * pos + (1 - prob) * (1 - pos) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * pos + (1 - alpha) * (1 - pos) loss = alpha_t * loss loss = (loss*valid).sum(1) / (1 + valid.sum(1)) return loss # def dice_loss(inputs, targets, normalizer=1): # inputs = inputs.sigmoid() # inputs = inputs.flatten(1) # numerator = 2 * (inputs * targets).sum(1) # denominator = inputs.sum(-1) + targets.sum(-1) # loss = 1 - (numerator + 1) / (denominator + 1) # return loss.sum() / normalizer # def sigmoid_focal_loss(inputs, targets, normalizer=1, alpha=0.25, gamma=2): # prob = inputs.sigmoid() # ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") # p_t = prob * targets + (1 - prob) * (1 - targets) # loss = ce_loss * ((1 - p_t) ** gamma) # if alpha >= 0: # alpha_t = alpha * targets + (1 - alpha) * (1 - targets) # loss = alpha_t * loss # return loss.mean(1).sum() / normalizer def data_replace_with_nearest(xys, valids): # replace invalid xys with nearby ones invalid_idx = np.where(valids==0)[0] valid_idx = np.where(valids==1)[0] for idx in invalid_idx: nearest = valid_idx[np.argmin(np.abs(valid_idx - idx))] xys[idx] = xys[nearest] return xys def data_get_traj_from_masks(masks): if masks.ndim==4: masks = masks[...,0] S, H, W = masks.shape masks = (masks > 0.1).astype(np.float32) fills = np.zeros((S)) xy_means = np.zeros((S,2)) xy_rands = np.zeros((S,2)) valids = np.zeros((S)) for si, mask in enumerate(masks): if np.sum(mask) > 0: ys, xs = np.where(mask) inds = np.random.permutation(len(xs)) xs, ys = xs[inds], ys[inds] x0, x1 = np.min(xs), np.max(xs)+1 y0, y1 = np.min(ys), np.max(ys)+1 # if (x1-x0)>0 and (y1-y0)>0: xy_means[si] = np.array([xs.mean(), ys.mean()]) xy_rands[si] = np.array([xs[0], ys[0]]) valids[si] = 1 crop = mask[y0:y1, x0:x1] fill = np.mean(crop) fills[si] = fill # print('fills', fills) return xy_means, xy_rands, valids, fills def data_zoom(zoom, xys, visibs, rgbs, valids=None, masks=None, masks2=None, masks3=None, masks4=None): S, H, W, C = rgbs.shape S,N,D = xys.shape _, H, W, C = rgbs.shape assert(C==3) crop_W = int(W//zoom) crop_H = int(H//zoom) if np.random.rand() < 0.25: # follow-crop # start with xy traj # smooth_xys = xys.copy() smooth_xys = xys[:,np.random.randint(N)].reshape(S,1,2) # make it inbounds smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) # smooth it out, to remove info about the traj, and simulate camera motion for _ in range(S*3): for ii in range(S): if ii==0: smooth_xys[ii] = (smooth_xys[ii] + smooth_xys[ii+1])/2.0 elif ii==S-1: smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii])/2.0 else: smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii] + smooth_xys[ii+1])/3.0 else: # static (no-hint) crop # zero-vel on random available coordinate if valids is not None: visval = visibs*valids # S,N visval = np.sum(visval, axis=1) # S else: visval = np.sum(visibs, axis=1) # S anchor_inds = np.nonzero(visval >= np.mean(visval))[0] ind = anchor_inds[np.random.randint(len(anchor_inds))] # print('ind', ind) smooth_xys = xys[ind:ind+1].repeat(S,axis=0) smooth_xys = smooth_xys.mean(axis=1, keepdims=True) # xmid = np.random.randint(crop_W//2, W-crop_W//2) # ymid = np.random.randint(crop_H//2, H-crop_H//2) # smooth_xys = np.stack([xmid, ymid], axis=-1).reshape(1,1,2).repeat(S, axis=0) # S,1,2 smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) if np.random.rand() < 0.5: # add a random alternate trajectory, to help push us off center alt_xys = np.random.randint(-crop_H//8, crop_H//8, (S,1,2)) for _ in range(4): # smooth out for ii in range(S): if ii==0: alt_xys[ii] = (alt_xys[ii] + alt_xys[ii+1])/2.0 elif ii==S-1: alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii])/2.0 else: alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii] + alt_xys[ii+1])/3.0 smooth_xys = smooth_xys + alt_xys smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) rgbs_crop = [] if masks is not None: masks_crop = [] if masks2 is not None: masks2_crop = [] if masks3 is not None: masks3_crop = [] if masks4 is not None: masks4_crop = [] offsets = [] for si in range(S): xy_mid = smooth_xys[si].squeeze(0).round().astype(np.int32) # 2 xmid, ymid = xy_mid[0], xy_mid[1] x0, x1 = np.clip(xmid-crop_W//2, 0, W), np.clip(xmid+crop_W//2, 0, W) y0, y1 = np.clip(ymid-crop_H//2, 0, H), np.clip(ymid+crop_H//2, 0, H) offset = np.array([x0, y0]).reshape(1,2) rgbs_crop.append(rgbs[si,y0:y1,x0:x1]) if masks is not None: masks_crop.append(masks[si,y0:y1,x0:x1]) if masks2 is not None: masks2_crop.append(masks2[si,y0:y1,x0:x1]) if masks3 is not None: masks3_crop.append(masks3[si,y0:y1,x0:x1]) if masks4 is not None: masks4_crop.append(masks4[si,y0:y1,x0:x1]) xys[si] -= offset offsets.append(offset) rgbs = np.stack(rgbs_crop, axis=0) if masks is not None: masks = np.stack(masks_crop, axis=0) if masks2 is not None: masks2 = np.stack(masks2_crop, axis=0) if masks3 is not None: masks3 = np.stack(masks3_crop, axis=0) if masks4 is not None: masks4 = np.stack(masks4_crop, axis=0) # update visibility annotations for si in range(S): oob_inds = np.logical_or( np.logical_or(xys[si,:,0] < 0, xys[si,:,0] > crop_W-1), np.logical_or(xys[si,:,1] < 0, xys[si,:,1] > crop_H-1)) visibs[si,oob_inds] = 0 # if masks4 is not None: # return xys, visibs, valids, rgbs, masks, masks2, masks3, masks4 # if masks3 is not None: # return xys, visibs, valids, rgbs, masks, masks2, masks3 # if masks2 is not None: # return xys, visibs, valids, rgbs, masks, masks2 # if masks is not None: # return xys, visibs, valids, rgbs, masks # else: # return xys, visibs, valids, rgbs if valids is not None: return xys, visibs, rgbs, valids else: return xys, visibs, rgbs def data_zoom_bbox(zoom, bboxes, visibs, rgbs):#, valids=None): S, H, W, C = rgbs.shape _, H, W, C = rgbs.shape assert(C==3) crop_W = int(W//zoom) crop_H = int(H//zoom) xys = bboxes[:,0:2]*0.5 + bboxes[:,2:4]*0.5 if np.random.rand() < 0.25: # follow-crop # start with xy traj smooth_xys = xys.copy() # make it inbounds smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) # smooth it out, to remove info about the traj, and simulate camera motion for _ in range(S*3): for ii in range(1,S-1): smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii] + smooth_xys[ii+1])/3.0 else: # static (no-hint) crop # zero-vel on random available coordinate anchor_inds = np.nonzero(visibs.reshape(-1)>0.5)[0] ind = anchor_inds[np.random.randint(len(anchor_inds))] smooth_xys = xys[ind:ind+1].repeat(S,axis=0) # xmid = np.random.randint(crop_W//2, W-crop_W//2) # ymid = np.random.randint(crop_H//2, H-crop_H//2) # smooth_xys = np.stack([xmid, ymid], axis=-1).reshape(1,1,2).repeat(S, axis=0) # S,1,2 smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) # print('xys', xys) # print('smooth_xys', smooth_xys) if np.random.rand() < 0.5: # add a random alternate trajectory, to help push us off center alt_xys = np.random.randint(-crop_H//8, crop_H//8, (S,2)) for _ in range(3): for ii in range(1,S-1): alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii] + alt_xys[ii+1])/3.0 smooth_xys = smooth_xys + alt_xys smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) rgbs_crop = [] offsets = [] for si in range(S): xy_mid = smooth_xys[si].round().astype(np.int32) xmid, ymid = xy_mid[0], xy_mid[1] x0, x1 = np.clip(xmid-crop_W//2, 0, W), np.clip(xmid+crop_W//2, 0, W) y0, y1 = np.clip(ymid-crop_H//2, 0, H), np.clip(ymid+crop_H//2, 0, H) offset = np.array([x0, y0]).reshape(2) rgbs_crop.append(rgbs[si,y0:y1,x0:x1]) xys[si] -= offset bboxes[si,0:2] -= offset bboxes[si,2:4] -= offset offsets.append(offset) rgbs = np.stack(rgbs_crop, axis=0) # update visibility annotations for si in range(S): # avoid 1px edge oob_inds = np.logical_or( np.logical_or(xys[si,0] < 1, xys[si,0] > W-2), np.logical_or(xys[si,1] < 1, xys[si,1] > H-2)) visibs[si,oob_inds] = 0 # clamp to image bounds xys0 = np.minimum(np.maximum(bboxes[:,0:2], np.zeros((2,), dtype=int)), np.array([W, H]) - 1) # S,2 xys1 = np.minimum(np.maximum(bboxes[:,2:4], np.zeros((2,), dtype=int)), np.array([W, H]) - 1) # S,2 bboxes = np.concatenate([xys0, xys1], axis=1) return bboxes, visibs, rgbs def data_pad_if_necessary(rgbs, masks, masks2=None): S,H,W,C = rgbs.shape mask_areas = (masks > 0).reshape(S,-1).sum(axis=1) mask_areas_norm = mask_areas / np.max(mask_areas) visibs = mask_areas_norm bboxes = np.stack([mask2bbox(mask) for mask in masks]) whs = bboxes[:,2:4] - bboxes[:,0:2] whs = whs[visibs > 0.5] # print('mean wh', np.mean(whs[:,0]), np.mean(whs[:,1])) if np.mean(whs[:,0]) >= W/2: # print('padding w') pad = ((0,0),(0,0),(W//4,W//4),(0,0)) rgbs = np.pad(rgbs, pad, mode="constant") masks = np.pad(masks, pad[:3], mode="constant") if masks2 is not None: masks2 = np.pad(masks2, pad[:3], mode="constant") # print('rgbs', rgbs.shape) # print('masks', masks.shape) if np.mean(whs[:,1]) >= H/2: # print('padding h') pad = ((0,0),(H//4,H//4),(0,0),(0,0)) rgbs = np.pad(rgbs, pad, mode="constant") masks = np.pad(masks, pad[:3], mode="constant") if masks2 is not None: masks2 = np.pad(masks2, pad[:3], mode="constant", constant_values=0.5) if masks2 is not None: return rgbs, masks, masks2 return rgbs, masks def data_pad_if_necessary_b(rgbs, bboxes, visibs): S,H,W,C = rgbs.shape whs = bboxes[:,2:4] - bboxes[:,0:2] whs = whs[visibs > 0.5] if np.mean(whs[:,0]) >= W/2: pad = ((0,0),(0,0),(W//4,W//4),(0,0)) rgbs = np.pad(rgbs, pad, mode="constant") bboxes[:,0] += W//4 bboxes[:,2] += W//4 if np.mean(whs[:,1]) >= H/2: pad = ((0,0),(H//4,H//4),(0,0),(0,0)) rgbs = np.pad(rgbs, pad, mode="constant") bboxes[:,1] += H//4 bboxes[:,3] += H//4 return rgbs, bboxes def posenc(x, min_deg, max_deg): """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1]. Instead of computing [sin(x), cos(x)], we use the trig identity cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]). Args: x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi]. min_deg: int, the minimum (inclusive) degree of the encoding. max_deg: int, the maximum (exclusive) degree of the encoding. legacy_posenc_order: bool, keep the same ordering as the original tf code. Returns: encoded: torch.Tensor, encoded variables. """ if min_deg == max_deg: return x scales = torch.tensor( [2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device ) xb = (x[..., None, :] * scales[:, None]).reshape(list(x.shape[:-1]) + [-1]) four_feat = torch.sin(torch.cat([xb, xb + 0.5 * torch.pi], dim=-1)) return torch.cat([x] + [four_feat], dim=-1)