Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import math | |
import torch.nn.functional as F | |
import utils.basic | |
from typing import Tuple, Union | |
def standardize_test_data(rgbs, trajs, visibs, valids, S_cap=600, only_first=False, seq_len=None): | |
trajs = trajs.astype(np.float32) # S,N,2 | |
visibs = visibs.astype(np.float32) # S,N | |
valids = valids.astype(np.float32) # S,N | |
visval_ok = np.sum(valids*visibs, axis=0) > 1 | |
trajs = trajs[:,visval_ok] | |
visibs = visibs[:,visval_ok] | |
valids = valids[:,visval_ok] | |
# fill in missing data | |
N = trajs.shape[1] | |
for ni in range(N): | |
trajs[:,ni] = utils.misc.data_replace_with_nearest(trajs[:,ni], valids[:,ni]) | |
# use cap or seq_len | |
if seq_len is not None: | |
S = min(len(rgbs), seq_len) | |
else: | |
S = len(rgbs) | |
S = min(S, S_cap) | |
if only_first: | |
# we'll find the best frame to start on | |
best_count = 0 | |
best_ind = 0 | |
for si in range(0,len(rgbs)-64): | |
# try this slice | |
visibs_ = visibs[si:min(si+S,len(rgbs)+1)] # S,N | |
valids_ = valids[si:min(si+S,len(rgbs)+1)] # S,N | |
visval_ok0 = (visibs_[0]*valids_[0]) > 0 # N | |
visval_okA = np.sum(visibs_*valids_, axis=0) > 1 # N | |
all_ok = visval_ok0 & visval_okA | |
# print('- slicing %d to %d; sum(ok) %d' % (si, min(si+S,len(rgbs)+1), np.sum(all_ok))) | |
if np.sum(all_ok) > best_count: | |
best_count = np.sum(all_ok) | |
best_ind = si | |
si = best_ind | |
rgbs = rgbs[si:si+S] | |
trajs = trajs[si:si+S] | |
visibs = visibs[si:si+S] | |
valids = valids[si:si+S] | |
vis_ok0 = visibs[0] > 0 # N | |
trajs = trajs[:,vis_ok0] | |
visibs = visibs[:,vis_ok0] | |
valids = valids[:,vis_ok0] | |
# print('- best_count', best_count, 'best_ind', best_ind) | |
if seq_len is not None: | |
rgbs = rgbs[:seq_len] | |
trajs = trajs[:seq_len] | |
valids = valids[:seq_len] | |
# req two timesteps valid (after seqlen trim) | |
visval_ok = np.sum(visibs*valids, axis=0) > 1 | |
trajs = trajs[:,visval_ok] | |
valids = valids[:,visval_ok] | |
visibs = visibs[:,visval_ok] | |
return rgbs, trajs, visibs, valids | |
def get_2d_sincos_pos_embed( | |
embed_dim: int, grid_size: Union[int, Tuple[int, int]] | |
) -> torch.Tensor: | |
""" | |
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. | |
It is a wrapper of get_2d_sincos_pos_embed_from_grid. | |
Args: | |
- embed_dim: The embedding dimension. | |
- grid_size: The grid size. | |
Returns: | |
- pos_embed: The generated 2D positional embedding. | |
""" | |
if isinstance(grid_size, tuple): | |
grid_size_h, grid_size_w = grid_size | |
else: | |
grid_size_h = grid_size_w = grid_size | |
grid_h = torch.arange(grid_size_h, dtype=torch.float) | |
grid_w = torch.arange(grid_size_w, dtype=torch.float) | |
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") | |
grid = torch.stack(grid, dim=0) | |
grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) | |
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) | |
def get_2d_sincos_pos_embed_from_grid( | |
embed_dim: int, grid: torch.Tensor | |
) -> torch.Tensor: | |
""" | |
This function generates a 2D positional embedding from a given grid using sine and cosine functions. | |
Args: | |
- embed_dim: The embedding dimension. | |
- grid: The grid to generate the embedding from. | |
Returns: | |
- emb: The generated 2D positional embedding. | |
""" | |
assert embed_dim % 2 == 0 | |
# use half of dimensions to encode grid_h | |
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) | |
return emb | |
def get_1d_sincos_pos_embed_from_grid( | |
embed_dim: int, pos: torch.Tensor | |
) -> torch.Tensor: | |
""" | |
This function generates a 1D positional embedding from a given grid using sine and cosine functions. | |
Args: | |
- embed_dim: The embedding dimension. | |
- pos: The position to generate the embedding from. | |
Returns: | |
- emb: The generated 1D positional embedding. | |
""" | |
assert embed_dim % 2 == 0 | |
omega = torch.arange(embed_dim // 2, dtype=torch.double) | |
omega /= embed_dim / 2.0 | |
omega = 1.0 / 10000**omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
emb_sin = torch.sin(out) # (M, D/2) | |
emb_cos = torch.cos(out) # (M, D/2) | |
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) | |
return emb[None].float() | |
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: | |
""" | |
This function generates a 2D positional embedding from given coordinates using sine and cosine functions. | |
Args: | |
- xy: The coordinates to generate the embedding from. | |
- C: The size of the embedding. | |
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. | |
Returns: | |
- pe: The generated 2D positional embedding. | |
""" | |
B, N, D = xy.shape | |
assert D == 2 | |
x = xy[:, :, 0:1] | |
y = xy[:, :, 1:2] | |
div_term = ( | |
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) | |
).reshape(1, 1, int(C / 2)) | |
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | |
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | |
pe_x[:, :, 0::2] = torch.sin(x * div_term) | |
pe_x[:, :, 1::2] = torch.cos(x * div_term) | |
pe_y[:, :, 0::2] = torch.sin(y * div_term) | |
pe_y[:, :, 1::2] = torch.cos(y * div_term) | |
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) | |
if cat_coords: | |
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) | |
return pe | |
# from datasets.dataset import mask2bbox | |
# from pips2 | |
def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False): | |
device = xy.device | |
dtype = xy.dtype | |
B, S, D = xy.shape | |
assert(D==2) | |
x = xy[:,:,0] | |
y = xy[:,:,1] | |
assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' | |
omega = torch.arange(C // 4, device=device) / (C // 4 - 1) | |
omega = 1. / (temperature ** omega) | |
y = y.flatten()[:, None] * omega[None, :] | |
x = x.flatten()[:, None] * omega[None, :] | |
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) | |
pe = pe.reshape(B,S,C).type(dtype) | |
if cat_coords: | |
pe = torch.cat([pe, xy], dim=2) # B,N,C+2 | |
return pe | |
# # prevent circular imports | |
# def mask2bbox(mask): | |
# if mask.ndim == 3: | |
# mask = mask[..., 0] | |
# ys, xs = np.where(mask > 0.4) | |
# if ys.size == 0 or xs.size==0: | |
# return np.array((0, 0, 0, 0), dtype=int) | |
# lt = np.array([np.min(xs), np.min(ys)]) | |
# rb = np.array([np.max(xs), np.max(ys)]) + 1 | |
# return np.concatenate([lt, rb]) | |
# def get_stark_2d_embedding(H, W, C=64, device='cuda:0', temperature=10000, normalize=True): | |
# scale = 2*math.pi | |
# mask = torch.ones((1,H,W), dtype=torch.float32, device=device) | |
# y_embed = mask.cumsum(1, dtype=torch.float32) # cumulative sum along axis 1 (h axis) --> (b, h, w) | |
# x_embed = mask.cumsum(2, dtype=torch.float32) # cumulative sum along axis 2 (w axis) --> (b, h, w) | |
# if normalize: | |
# eps = 1e-6 | |
# y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale # 2pi * (y / sigma(y)) | |
# x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale # 2pi * (x / sigma(x)) | |
# dim_t = torch.arange(C, dtype=torch.float32, device=device) # (0,1,2,...,d/2) | |
# dim_t = temperature ** (2 * (dim_t // 2) / C) | |
# pos_x = x_embed[:, :, :, None] / dim_t # (b,h,w,d/2) | |
# pos_y = y_embed[:, :, :, None] / dim_t # (b,h,w,d/2) | |
# pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) # (b,h,w,d/2) | |
# pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) # (b,h,w,d/2) | |
# pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # (b,h,w,d) | |
# return pos | |
# def get_1d_embedding(x, C, cat_coords=False): | |
# B, N, D = x.shape | |
# assert(D==1) | |
# div_term = (torch.arange(0, C, 2, device=x.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2)) | |
# pe_x = torch.zeros(B, N, C, device=x.device, dtype=torch.float32) | |
# pe_x[:, :, 0::2] = torch.sin(x * div_term) | |
# pe_x[:, :, 1::2] = torch.cos(x * div_term) | |
# if cat_coords: | |
# pe_x = torch.cat([pe, x], dim=2) # B,N,C*2+2 | |
# return pe_x | |
# def posemb_sincos_2d(h, w, dim, temperature=10000, dtype=torch.float32, device='cuda:0'): | |
# y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') | |
# assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' | |
# omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) | |
# omega = 1. / (temperature ** omega) | |
# y = y.flatten()[:, None] * omega[None, :] | |
# x = x.flatten()[:, None] * omega[None, :] | |
# pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) # B,C,H,W | |
# return pe.type(dtype) | |
def iou(bbox1, bbox2): | |
# bbox1, bbox2: [x1, y1, x2, y2] | |
x1, y1, x2, y2 = bbox1 | |
x1_, y1_, x2_, y2_ = bbox2 | |
inter_x1 = max(x1, x1_) | |
inter_y1 = max(y1, y1_) | |
inter_x2 = min(x2, x2_) | |
inter_y2 = min(y2, y2_) | |
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) | |
area1 = (x2 - x1) * (y2 - y1) | |
area2 = (x2_ - x1_) * (y2_ - y1_) | |
iou = inter_area / (area1 + area2 - inter_area) | |
return iou | |
# def get_2d_embedding(xy, C, cat_coords=False): | |
# B, N, D = xy.shape | |
# assert(D==2) | |
# x = xy[:,:,0:1] | |
# y = xy[:,:,1:2] | |
# div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2)) | |
# pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | |
# pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | |
# pe_x[:, :, 0::2] = torch.sin(x * div_term) | |
# pe_x[:, :, 1::2] = torch.cos(x * div_term) | |
# pe_y[:, :, 0::2] = torch.sin(y * div_term) | |
# pe_y[:, :, 1::2] = torch.cos(y * div_term) | |
# pe = torch.cat([pe_x, pe_y], dim=2) # B,N,C*2 | |
# if cat_coords: | |
# pe = torch.cat([pe, xy], dim=2) # B,N,C*2+2 | |
# return pe | |
# def get_3d_embedding(xyz, C, cat_coords=False): | |
# B, N, D = xyz.shape | |
# assert(D==3) | |
# x = xyz[:,:,0:1] | |
# y = xyz[:,:,1:2] | |
# z = xyz[:,:,2:3] | |
# div_term = (torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (10000.0 / C)).reshape(1, 1, int(C/2)) | |
# pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | |
# pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | |
# pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | |
# pe_x[:, :, 0::2] = torch.sin(x * div_term) | |
# pe_x[:, :, 1::2] = torch.cos(x * div_term) | |
# pe_y[:, :, 0::2] = torch.sin(y * div_term) | |
# pe_y[:, :, 1::2] = torch.cos(y * div_term) | |
# pe_z[:, :, 0::2] = torch.sin(z * div_term) | |
# pe_z[:, :, 1::2] = torch.cos(z * div_term) | |
# pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3 | |
# if cat_coords: | |
# pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3 | |
# return pe | |
class SimplePool(): | |
def __init__(self, pool_size, version='pt', min_size=1): | |
self.pool_size = pool_size | |
self.version = version | |
self.items = [] | |
self.min_size = min_size | |
if not (version=='pt' or version=='np'): | |
print('version = %s; please choose pt or np') | |
assert(False) # please choose pt or np | |
def __len__(self): | |
return len(self.items) | |
def mean(self, min_size=None): | |
if min_size is None: | |
pool_size_thresh = self.min_size | |
elif min_size=='half': | |
pool_size_thresh = self.pool_size/2 | |
else: | |
pool_size_thresh = min_size | |
if self.version=='np': | |
if len(self.items) >= pool_size_thresh: | |
return np.sum(self.items)/float(len(self.items)) | |
else: | |
return np.nan | |
if self.version=='pt': | |
if len(self.items) >= pool_size_thresh: | |
return torch.sum(self.items)/float(len(self.items)) | |
else: | |
return torch.from_numpy(np.nan) | |
def sample(self, with_replacement=True): | |
idx = np.random.randint(len(self.items)) | |
if with_replacement: | |
return self.items[idx] | |
else: | |
return self.items.pop(idx) | |
def fetch(self, num=None): | |
if self.version=='pt': | |
item_array = torch.stack(self.items) | |
elif self.version=='np': | |
item_array = np.stack(self.items) | |
if num is not None: | |
# there better be some items | |
assert(len(self.items) >= num) | |
# if there are not that many elements just return however many there are | |
if len(self.items) < num: | |
return item_array | |
else: | |
idxs = np.random.randint(len(self.items), size=num) | |
return item_array[idxs] | |
else: | |
return item_array | |
def is_full(self): | |
full = len(self.items)==self.pool_size | |
return full | |
def empty(self): | |
self.items = [] | |
def have_min_size(self): | |
return len(self.items) >= self.min_size | |
def update(self, items): | |
for item in items: | |
if len(self.items) < self.pool_size: | |
# the pool is not full, so let's add this in | |
self.items.append(item) | |
else: | |
# the pool is full | |
# pop from the front | |
self.items.pop(0) | |
# add to the back | |
self.items.append(item) | |
return self.items | |
class SimpleHeap(): | |
def __init__(self, pool_size, version='pt'): | |
self.pool_size = pool_size | |
self.version = version | |
self.items = [] | |
self.vals = [] | |
if not (version=='pt' or version=='np'): | |
print('version = %s; please choose pt or np') | |
assert(False) # please choose pt or np | |
def __len__(self): | |
return len(self.items) | |
def sample(self, random=True, with_replacement=True, semirandom=False): | |
vals_arr = np.stack(self.vals) | |
if random: | |
ind = np.random.randint(len(self.items)) | |
else: | |
if semirandom and len(vals_arr)>1: | |
# choose from the harder half | |
inds = np.argsort(vals_arr) # ascending | |
inds = inds[len(vals_arr)//2:] | |
ind = np.random.choice(inds) | |
else: | |
# find the most valuable element | |
ind = np.argmax(vals_arr) | |
if with_replacement: | |
return self.items[ind] | |
else: | |
item = self.items.pop(ind) | |
val = self.vals.pop(ind) | |
return item | |
def fetch(self, num=None): | |
if self.version=='pt': | |
item_array = torch.stack(self.items) | |
elif self.version=='np': | |
item_array = np.stack(self.items) | |
if num is not None: | |
# there better be some items | |
assert(len(self.items) >= num) | |
# if there are not that many elements just return however many there are | |
if len(self.items) < num: | |
return item_array | |
else: | |
idxs = np.random.randint(len(self.items), size=num) | |
return item_array[idxs] | |
else: | |
return item_array | |
def is_full(self): | |
full = len(self.items)==self.pool_size | |
return full | |
def empty(self): | |
self.items = [] | |
def update(self, vals, items): | |
for val,item in zip(vals, items): | |
if len(self.items) < self.pool_size: | |
# the pool is not full, so let's add this in | |
self.items.append(item) | |
self.vals.append(val) | |
else: | |
# the pool is full | |
# find our least-valuable element | |
# and see if we should replace it | |
vals_arr = np.stack(self.vals) | |
ind = np.argmin(vals_arr) | |
if vals_arr[ind] < val: | |
# pop the min | |
self.items.pop(ind) | |
self.vals.pop(ind) | |
# add to the back | |
self.items.append(item) | |
self.vals.append(val) | |
return self.items | |
def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False): | |
""" | |
Input: | |
xyz: pointcloud data, [B, N, C], where C is probably 3 | |
npoint: number of samples | |
Return: | |
inds: sampled pointcloud index, [B, npoint] | |
""" | |
device = xyz.device | |
B, N, C = xyz.shape | |
xyz = xyz.float() | |
inds = torch.zeros(B, npoint, dtype=torch.long, device=device) | |
distance = torch.ones((B, N), dtype=torch.float32, device=device) * 1e10 | |
if deterministic: | |
farthest = torch.randint(0, 1, (B,), dtype=torch.long, device=device) | |
else: | |
farthest = torch.randint(0, N, (B,), dtype=torch.long, device=device) | |
batch_indices = torch.arange(B, dtype=torch.long, device=device) | |
for i in range(npoint): | |
if include_ends: | |
if i==0: | |
farthest = 0 | |
elif i==1: | |
farthest = N-1 | |
inds[:, i] = farthest | |
centroid = xyz[batch_indices, farthest, :].view(B, 1, C) | |
dist = torch.sum((xyz - centroid) ** 2, -1) | |
mask = dist < distance | |
distance[mask] = dist[mask] | |
farthest = torch.max(distance, -1)[1] | |
if npoint > N: | |
# if we need more samples, make them random | |
distance += torch.randn_like(distance) | |
return inds | |
def farthest_point_sample_py(xyz, npoint, deterministic=False): | |
N,C = xyz.shape | |
inds = np.zeros(npoint, dtype=np.int32) | |
distance = np.ones(N) * 1e10 | |
if deterministic: | |
farthest = 0 | |
else: | |
farthest = np.random.randint(0, N, dtype=np.int32) | |
for i in range(npoint): | |
inds[i] = farthest | |
centroid = xyz[farthest, :].reshape(1,C) | |
dist = np.sum((xyz - centroid) ** 2, -1) | |
mask = dist < distance | |
distance[mask] = dist[mask] | |
farthest = np.argmax(distance, -1) | |
if npoint > N: | |
# if we need more samples, make them random | |
distance += np.random.randn(*distance.shape) | |
return inds | |
def balanced_ce_loss(pred, gt, pos_weight=0.5, valid=None, dim=None, return_both=False, use_halfmask=False, H=64, W=64): | |
# # pred and gt are the same shape | |
# for (a,b) in zip(pred.size(), gt.size()): | |
# if not a==b: | |
# print('mismatch: pred, gt', pred.shape, gt.shape) | |
# assert(a==b) # some shape mismatch! | |
pred = pred.reshape(-1) | |
gt = gt.reshape(-1) | |
device = pred.device | |
if valid is not None: | |
valid = valid.reshape(-1) | |
for (a,b) in zip(pred.size(), valid.size()): | |
assert(a==b) # some shape mismatch! | |
else: | |
valid = torch.ones_like(gt) | |
pos = (gt > 0.95).float() | |
if use_halfmask: | |
pos_wide = (gt >= 0.5).float() | |
halfmask = (gt == 0.5).float() | |
else: | |
pos_wide = pos | |
neg = (gt < 0.05).float() | |
label = pos_wide*2.0 - 1.0 | |
a = -label * pred | |
b = F.relu(a) | |
loss = b + torch.log(torch.exp(-b)+torch.exp(a-b)) | |
if torch.sum(pos*valid)>0: | |
pos_loss = loss[(pos*valid) > 0].mean() | |
else: | |
pos_loss = torch.tensor(0.0, requires_grad=True, device=device) | |
if torch.sum(neg*valid)>0: | |
neg_loss = loss[(neg*valid) > 0].mean() | |
else: | |
neg_loss = torch.tensor(0.0, requires_grad=True, device=device) | |
balanced_loss = pos_weight*pos_loss + (1-pos_weight)*neg_loss | |
return balanced_loss | |
# pos_loss = utils.basic.reduce_masked_mean(loss, pos*valid, dim=dim) | |
# neg_loss = utils.basic.reduce_masked_mean(loss, neg*valid, dim=dim) | |
if use_halfmask: | |
# here we will find the pixels which are already leaning positive, | |
# and encourage them to be more positive | |
B = loss.shape[0] | |
loss_ = loss.reshape(B,-1) | |
mask_ = halfmask.reshape(B,-1) * valid.reshape(B,-1) | |
# to avoid the issue where spikes become spikier, | |
# we will only apply this loss on batch els where we predicted zero positives | |
pred_sig_ = torch.sigmoid(pred).reshape(B,-1) | |
no_pred_ = torch.max(pred_sig_.round(), axis=1)[0] < 1 # B | |
# and only on batch els where we have negatives available | |
have_neg_ = torch.sum(neg, dim=1)>0 # B | |
loss_ = loss_[no_pred_ & have_neg_] # N,H*W | |
mask_ = mask_[no_pred_ & have_neg_] # N,H*W | |
N = loss_.shape[0] | |
if N > 0: | |
# we want: | |
# in the neg pixels, | |
# set them to the max loss of the pos pixels, | |
# so that they do not contribute to the min | |
loss__ = loss_.reshape(-1) | |
mask__ = mask_.reshape(-1) | |
if torch.sum(mask__)>0: | |
# print('loss_', loss_.shape, 'mask_', mask_.shape, 'loss__', loss__.shape, 'mask__', mask__.shape) | |
mloss__ = loss__.detach() | |
mloss__[mask__==0] = torch.max(loss__[mask__==1]) | |
mloss_ = mloss__.reshape(N,H*W) | |
# now, in each batch el, take a tiny region around the argmin, so we can boost this region | |
minloss_mask_ = torch.zeros_like(mloss_).scatter(1,mloss_.argmin(1,True),value=1) | |
minloss_mask_ = utils.improc.dilate2d(minloss_mask_.view(N,1,H,W), times=3).reshape(N,H*W) | |
loss__ = loss_.reshape(-1) | |
minloss_mask__ = minloss_mask_.reshape(-1) | |
half_loss = loss__[minloss_mask__>0].mean() | |
# print('N', N, 'half_loss', half_loss) | |
pos_loss = pos_loss + half_loss | |
# if False: | |
# min_pos = 8 | |
# # only apply the loss when we have some negatives available, | |
# # otherwise it's a whole "ignore" frame, which may mean | |
# # we are unsure if the target is even there | |
# if torch.sum(mask__==0) > 0: # negatives available | |
# # only apply the loss when the halfmask is larger area than | |
# # min_pos (the number of pixels we want to boost), | |
# # so that indexing will work | |
# if torch.all(torch.sum(mask_==1, dim=1) >= min_pos): # topk indexing will work | |
# # in the pixels we will not use, | |
# # set them to the max of the pixels we may use, | |
# # so that they do not contribute to the min | |
# loss__[mask__==0] = torch.max(loss__[mask__==1]) | |
# loss_ = loss__.reshape(B,-1) | |
# half_loss = torch.mean(torch.topk(loss_, min_pos, dim=1, largest=False)[0], dim=1) # B | |
# have_neg = (torch.sum(neg, dim=1)>0).float() # B | |
# pos_loss = pos_loss + half_loss*have_neg | |
# half_loss = [] | |
# for b in range(B): | |
# loss_b = loss_[b] | |
# mask_b = mask_[b] | |
# if torch.sum(mask_b): | |
# inds = torch.nonzero(mask_b).reshape(-1) | |
# half_loss.append(torch.min(loss_b[inds])) | |
# if len(half_loss): | |
# # # half_loss_ = half_loss.reshape(-1) | |
# # half_loss = torch.min(half_loss, dim=1)[0] # B | |
# # half_loss = torch.mean(torch.topk(half_loss, 4, dim=1, largest=False)[0], dim=1) # B | |
# pos_loss = pos_loss + torch.stack(half_loss).mean() | |
if return_both: | |
return pos_loss, neg_loss | |
balanced_loss = pos_weight*pos_loss + (1-pos_weight)*neg_loss | |
return balanced_loss | |
def dice_loss(pred, gt): | |
# gt has ignores at 0.5 | |
# pred and gt are the same shape | |
for (a,b) in zip(pred.size(), gt.size()): | |
assert(a==b) # some shape mismatch! | |
prob = pred.sigmoid() | |
# flatten everything except batch | |
prob = prob.flatten(1) | |
gt = gt.flatten(1) | |
pos = (gt > 0.95).float() | |
neg = (gt < 0.05).float() | |
valid = (pos+neg).float().clamp(0,1) | |
numerator = 2 * (prob * pos * valid).sum(1) | |
denominator = (prob*valid).sum(1) + (pos*valid).sum(1) | |
loss = 1 - (numerator + 1) / (denominator + 1) | |
return loss | |
def sigmoid_focal_loss(pred, gt, alpha=0.25, gamma=2):#, use_halfmask=False): | |
# gt has ignores at 0.5 | |
# pred and gt are the same shape | |
for (a,b) in zip(pred.size(), gt.size()): | |
assert(a==b) # some shape mismatch! | |
# flatten everything except batch | |
pred = pred.flatten(1) | |
gt = gt.flatten(1) | |
pos = (gt > 0.95).float() | |
neg = (gt < 0.05).float() | |
# if use_halfmask: | |
# pos_wide = (gt >= 0.5).float() | |
# halfmask = (gt == 0.5).float() | |
# else: | |
# pos_wide = pos | |
valid = (pos+neg).float().clamp(0,1) | |
prob = pred.sigmoid() | |
ce_loss = F.binary_cross_entropy_with_logits(pred, pos, reduction="none") | |
p_t = prob * pos + (1 - prob) * (1 - pos) | |
loss = ce_loss * ((1 - p_t) ** gamma) | |
if alpha >= 0: | |
alpha_t = alpha * pos + (1 - alpha) * (1 - pos) | |
loss = alpha_t * loss | |
loss = (loss*valid).sum(1) / (1 + valid.sum(1)) | |
return loss | |
# def dice_loss(inputs, targets, normalizer=1): | |
# inputs = inputs.sigmoid() | |
# inputs = inputs.flatten(1) | |
# numerator = 2 * (inputs * targets).sum(1) | |
# denominator = inputs.sum(-1) + targets.sum(-1) | |
# loss = 1 - (numerator + 1) / (denominator + 1) | |
# return loss.sum() / normalizer | |
# def sigmoid_focal_loss(inputs, targets, normalizer=1, alpha=0.25, gamma=2): | |
# prob = inputs.sigmoid() | |
# ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | |
# p_t = prob * targets + (1 - prob) * (1 - targets) | |
# loss = ce_loss * ((1 - p_t) ** gamma) | |
# if alpha >= 0: | |
# alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
# loss = alpha_t * loss | |
# return loss.mean(1).sum() / normalizer | |
def data_replace_with_nearest(xys, valids): | |
# replace invalid xys with nearby ones | |
invalid_idx = np.where(valids==0)[0] | |
valid_idx = np.where(valids==1)[0] | |
for idx in invalid_idx: | |
nearest = valid_idx[np.argmin(np.abs(valid_idx - idx))] | |
xys[idx] = xys[nearest] | |
return xys | |
def data_get_traj_from_masks(masks): | |
if masks.ndim==4: | |
masks = masks[...,0] | |
S, H, W = masks.shape | |
masks = (masks > 0.1).astype(np.float32) | |
fills = np.zeros((S)) | |
xy_means = np.zeros((S,2)) | |
xy_rands = np.zeros((S,2)) | |
valids = np.zeros((S)) | |
for si, mask in enumerate(masks): | |
if np.sum(mask) > 0: | |
ys, xs = np.where(mask) | |
inds = np.random.permutation(len(xs)) | |
xs, ys = xs[inds], ys[inds] | |
x0, x1 = np.min(xs), np.max(xs)+1 | |
y0, y1 = np.min(ys), np.max(ys)+1 | |
# if (x1-x0)>0 and (y1-y0)>0: | |
xy_means[si] = np.array([xs.mean(), ys.mean()]) | |
xy_rands[si] = np.array([xs[0], ys[0]]) | |
valids[si] = 1 | |
crop = mask[y0:y1, x0:x1] | |
fill = np.mean(crop) | |
fills[si] = fill | |
# print('fills', fills) | |
return xy_means, xy_rands, valids, fills | |
def data_zoom(zoom, xys, visibs, rgbs, valids=None, masks=None, masks2=None, masks3=None, masks4=None): | |
S, H, W, C = rgbs.shape | |
S,N,D = xys.shape | |
_, H, W, C = rgbs.shape | |
assert(C==3) | |
crop_W = int(W//zoom) | |
crop_H = int(H//zoom) | |
if np.random.rand() < 0.25: # follow-crop | |
# start with xy traj | |
# smooth_xys = xys.copy() | |
smooth_xys = xys[:,np.random.randint(N)].reshape(S,1,2) | |
# make it inbounds | |
smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) | |
# smooth it out, to remove info about the traj, and simulate camera motion | |
for _ in range(S*3): | |
for ii in range(S): | |
if ii==0: | |
smooth_xys[ii] = (smooth_xys[ii] + smooth_xys[ii+1])/2.0 | |
elif ii==S-1: | |
smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii])/2.0 | |
else: | |
smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii] + smooth_xys[ii+1])/3.0 | |
else: # static (no-hint) crop | |
# zero-vel on random available coordinate | |
if valids is not None: | |
visval = visibs*valids # S,N | |
visval = np.sum(visval, axis=1) # S | |
else: | |
visval = np.sum(visibs, axis=1) # S | |
anchor_inds = np.nonzero(visval >= np.mean(visval))[0] | |
ind = anchor_inds[np.random.randint(len(anchor_inds))] | |
# print('ind', ind) | |
smooth_xys = xys[ind:ind+1].repeat(S,axis=0) | |
smooth_xys = smooth_xys.mean(axis=1, keepdims=True) | |
# xmid = np.random.randint(crop_W//2, W-crop_W//2) | |
# ymid = np.random.randint(crop_H//2, H-crop_H//2) | |
# smooth_xys = np.stack([xmid, ymid], axis=-1).reshape(1,1,2).repeat(S, axis=0) # S,1,2 | |
smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) | |
if np.random.rand() < 0.5: | |
# add a random alternate trajectory, to help push us off center | |
alt_xys = np.random.randint(-crop_H//8, crop_H//8, (S,1,2)) | |
for _ in range(4): # smooth out | |
for ii in range(S): | |
if ii==0: | |
alt_xys[ii] = (alt_xys[ii] + alt_xys[ii+1])/2.0 | |
elif ii==S-1: | |
alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii])/2.0 | |
else: | |
alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii] + alt_xys[ii+1])/3.0 | |
smooth_xys = smooth_xys + alt_xys | |
smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) | |
rgbs_crop = [] | |
if masks is not None: | |
masks_crop = [] | |
if masks2 is not None: | |
masks2_crop = [] | |
if masks3 is not None: | |
masks3_crop = [] | |
if masks4 is not None: | |
masks4_crop = [] | |
offsets = [] | |
for si in range(S): | |
xy_mid = smooth_xys[si].squeeze(0).round().astype(np.int32) # 2 | |
xmid, ymid = xy_mid[0], xy_mid[1] | |
x0, x1 = np.clip(xmid-crop_W//2, 0, W), np.clip(xmid+crop_W//2, 0, W) | |
y0, y1 = np.clip(ymid-crop_H//2, 0, H), np.clip(ymid+crop_H//2, 0, H) | |
offset = np.array([x0, y0]).reshape(1,2) | |
rgbs_crop.append(rgbs[si,y0:y1,x0:x1]) | |
if masks is not None: | |
masks_crop.append(masks[si,y0:y1,x0:x1]) | |
if masks2 is not None: | |
masks2_crop.append(masks2[si,y0:y1,x0:x1]) | |
if masks3 is not None: | |
masks3_crop.append(masks3[si,y0:y1,x0:x1]) | |
if masks4 is not None: | |
masks4_crop.append(masks4[si,y0:y1,x0:x1]) | |
xys[si] -= offset | |
offsets.append(offset) | |
rgbs = np.stack(rgbs_crop, axis=0) | |
if masks is not None: | |
masks = np.stack(masks_crop, axis=0) | |
if masks2 is not None: | |
masks2 = np.stack(masks2_crop, axis=0) | |
if masks3 is not None: | |
masks3 = np.stack(masks3_crop, axis=0) | |
if masks4 is not None: | |
masks4 = np.stack(masks4_crop, axis=0) | |
# update visibility annotations | |
for si in range(S): | |
oob_inds = np.logical_or( | |
np.logical_or(xys[si,:,0] < 0, xys[si,:,0] > crop_W-1), | |
np.logical_or(xys[si,:,1] < 0, xys[si,:,1] > crop_H-1)) | |
visibs[si,oob_inds] = 0 | |
# if masks4 is not None: | |
# return xys, visibs, valids, rgbs, masks, masks2, masks3, masks4 | |
# if masks3 is not None: | |
# return xys, visibs, valids, rgbs, masks, masks2, masks3 | |
# if masks2 is not None: | |
# return xys, visibs, valids, rgbs, masks, masks2 | |
# if masks is not None: | |
# return xys, visibs, valids, rgbs, masks | |
# else: | |
# return xys, visibs, valids, rgbs | |
if valids is not None: | |
return xys, visibs, rgbs, valids | |
else: | |
return xys, visibs, rgbs | |
def data_zoom_bbox(zoom, bboxes, visibs, rgbs):#, valids=None): | |
S, H, W, C = rgbs.shape | |
_, H, W, C = rgbs.shape | |
assert(C==3) | |
crop_W = int(W//zoom) | |
crop_H = int(H//zoom) | |
xys = bboxes[:,0:2]*0.5 + bboxes[:,2:4]*0.5 | |
if np.random.rand() < 0.25: # follow-crop | |
# start with xy traj | |
smooth_xys = xys.copy() | |
# make it inbounds | |
smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) | |
# smooth it out, to remove info about the traj, and simulate camera motion | |
for _ in range(S*3): | |
for ii in range(1,S-1): | |
smooth_xys[ii] = (smooth_xys[ii-1] + smooth_xys[ii] + smooth_xys[ii+1])/3.0 | |
else: # static (no-hint) crop | |
# zero-vel on random available coordinate | |
anchor_inds = np.nonzero(visibs.reshape(-1)>0.5)[0] | |
ind = anchor_inds[np.random.randint(len(anchor_inds))] | |
smooth_xys = xys[ind:ind+1].repeat(S,axis=0) | |
# xmid = np.random.randint(crop_W//2, W-crop_W//2) | |
# ymid = np.random.randint(crop_H//2, H-crop_H//2) | |
# smooth_xys = np.stack([xmid, ymid], axis=-1).reshape(1,1,2).repeat(S, axis=0) # S,1,2 | |
smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) | |
# print('xys', xys) | |
# print('smooth_xys', smooth_xys) | |
if np.random.rand() < 0.5: | |
# add a random alternate trajectory, to help push us off center | |
alt_xys = np.random.randint(-crop_H//8, crop_H//8, (S,2)) | |
for _ in range(3): | |
for ii in range(1,S-1): | |
alt_xys[ii] = (alt_xys[ii-1] + alt_xys[ii] + alt_xys[ii+1])/3.0 | |
smooth_xys = smooth_xys + alt_xys | |
smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2]) | |
rgbs_crop = [] | |
offsets = [] | |
for si in range(S): | |
xy_mid = smooth_xys[si].round().astype(np.int32) | |
xmid, ymid = xy_mid[0], xy_mid[1] | |
x0, x1 = np.clip(xmid-crop_W//2, 0, W), np.clip(xmid+crop_W//2, 0, W) | |
y0, y1 = np.clip(ymid-crop_H//2, 0, H), np.clip(ymid+crop_H//2, 0, H) | |
offset = np.array([x0, y0]).reshape(2) | |
rgbs_crop.append(rgbs[si,y0:y1,x0:x1]) | |
xys[si] -= offset | |
bboxes[si,0:2] -= offset | |
bboxes[si,2:4] -= offset | |
offsets.append(offset) | |
rgbs = np.stack(rgbs_crop, axis=0) | |
# update visibility annotations | |
for si in range(S): | |
# avoid 1px edge | |
oob_inds = np.logical_or( | |
np.logical_or(xys[si,0] < 1, xys[si,0] > W-2), | |
np.logical_or(xys[si,1] < 1, xys[si,1] > H-2)) | |
visibs[si,oob_inds] = 0 | |
# clamp to image bounds | |
xys0 = np.minimum(np.maximum(bboxes[:,0:2], np.zeros((2,), dtype=int)), np.array([W, H]) - 1) # S,2 | |
xys1 = np.minimum(np.maximum(bboxes[:,2:4], np.zeros((2,), dtype=int)), np.array([W, H]) - 1) # S,2 | |
bboxes = np.concatenate([xys0, xys1], axis=1) | |
return bboxes, visibs, rgbs | |
def data_pad_if_necessary(rgbs, masks, masks2=None): | |
S,H,W,C = rgbs.shape | |
mask_areas = (masks > 0).reshape(S,-1).sum(axis=1) | |
mask_areas_norm = mask_areas / np.max(mask_areas) | |
visibs = mask_areas_norm | |
bboxes = np.stack([mask2bbox(mask) for mask in masks]) | |
whs = bboxes[:,2:4] - bboxes[:,0:2] | |
whs = whs[visibs > 0.5] | |
# print('mean wh', np.mean(whs[:,0]), np.mean(whs[:,1])) | |
if np.mean(whs[:,0]) >= W/2: | |
# print('padding w') | |
pad = ((0,0),(0,0),(W//4,W//4),(0,0)) | |
rgbs = np.pad(rgbs, pad, mode="constant") | |
masks = np.pad(masks, pad[:3], mode="constant") | |
if masks2 is not None: | |
masks2 = np.pad(masks2, pad[:3], mode="constant") | |
# print('rgbs', rgbs.shape) | |
# print('masks', masks.shape) | |
if np.mean(whs[:,1]) >= H/2: | |
# print('padding h') | |
pad = ((0,0),(H//4,H//4),(0,0),(0,0)) | |
rgbs = np.pad(rgbs, pad, mode="constant") | |
masks = np.pad(masks, pad[:3], mode="constant") | |
if masks2 is not None: | |
masks2 = np.pad(masks2, pad[:3], mode="constant", constant_values=0.5) | |
if masks2 is not None: | |
return rgbs, masks, masks2 | |
return rgbs, masks | |
def data_pad_if_necessary_b(rgbs, bboxes, visibs): | |
S,H,W,C = rgbs.shape | |
whs = bboxes[:,2:4] - bboxes[:,0:2] | |
whs = whs[visibs > 0.5] | |
if np.mean(whs[:,0]) >= W/2: | |
pad = ((0,0),(0,0),(W//4,W//4),(0,0)) | |
rgbs = np.pad(rgbs, pad, mode="constant") | |
bboxes[:,0] += W//4 | |
bboxes[:,2] += W//4 | |
if np.mean(whs[:,1]) >= H/2: | |
pad = ((0,0),(H//4,H//4),(0,0),(0,0)) | |
rgbs = np.pad(rgbs, pad, mode="constant") | |
bboxes[:,1] += H//4 | |
bboxes[:,3] += H//4 | |
return rgbs, bboxes | |
def posenc(x, min_deg, max_deg): | |
"""Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1]. | |
Instead of computing [sin(x), cos(x)], we use the trig identity | |
cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]). | |
Args: | |
x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi]. | |
min_deg: int, the minimum (inclusive) degree of the encoding. | |
max_deg: int, the maximum (exclusive) degree of the encoding. | |
legacy_posenc_order: bool, keep the same ordering as the original tf code. | |
Returns: | |
encoded: torch.Tensor, encoded variables. | |
""" | |
if min_deg == max_deg: | |
return x | |
scales = torch.tensor( | |
[2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device | |
) | |
xb = (x[..., None, :] * scales[:, None]).reshape(list(x.shape[:-1]) + [-1]) | |
four_feat = torch.sin(torch.cat([xb, xb + 0.5 * torch.pi], dim=-1)) | |
return torch.cat([x] + [four_feat], dim=-1) | |