Spaces:
Sleeping
Sleeping
import torch | |
import cv2 | |
import torch.nn.functional as F | |
import numpy as np | |
def prep_frame_for_dino(img, scale_size=[192]): | |
""" | |
read a single frame & preprocess | |
""" | |
ori_h, ori_w, _ = img.shape | |
if len(scale_size) == 1: | |
if(ori_h > ori_w): | |
tw = scale_size[0] | |
th = (tw * ori_h) / ori_w | |
th = int((th // 64) * 64) | |
else: | |
th = scale_size[0] | |
tw = (th * ori_w) / ori_h | |
tw = int((tw // 64) * 64) | |
else: | |
th, tw = scale_size | |
img = cv2.resize(img, (tw, th)) | |
img = img.astype(np.float32) | |
img = img / 255.0 | |
img = img[:, :, ::-1] | |
img = np.transpose(img.copy(), (2, 0, 1)) | |
img = torch.from_numpy(img).float() | |
def color_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]): | |
for t, m, s in zip(x, mean, std): | |
t.sub_(m) | |
t.div_(s) | |
return x | |
img = color_normalize(img) | |
return img, ori_h, ori_w | |
def get_feats_from_dino(model, frame): | |
# batch version of the other func | |
B = frame.shape[0] | |
patch_size = model.patch_embed.patch_size | |
h, w = int(frame.shape[2] / patch_size), int(frame.shape[3] / patch_size) | |
out = model.get_intermediate_layers(frame.cuda(), n=1)[0] # B, 1+h*w, dim | |
dim = out.shape[-1] | |
out = out[:, 1:, :] # discard the [CLS] token | |
outmap = out.permute(0, 2, 1).reshape(B, dim, h, w) | |
return out, outmap, h, w | |
def restrict_neighborhood(h, w): | |
size_mask_neighborhood = 12 | |
# We restrict the set of source nodes considered to a spatial neighborhood of the query node (i.e. ``local attention'') | |
mask = torch.zeros(h, w, h, w) | |
for i in range(h): | |
for j in range(w): | |
for p in range(2 * size_mask_neighborhood + 1): | |
for q in range(2 * size_mask_neighborhood + 1): | |
if i - size_mask_neighborhood + p < 0 or i - size_mask_neighborhood + p >= h: | |
continue | |
if j - size_mask_neighborhood + q < 0 or j - size_mask_neighborhood + q >= w: | |
continue | |
mask[i, j, i - size_mask_neighborhood + p, j - size_mask_neighborhood + q] = 1 | |
mask = mask.reshape(h * w, h * w) | |
return mask.cuda(non_blocking=True) | |
def label_propagation(h, w, feat_tar, list_frame_feats, list_segs, mask_neighborhood=None): | |
ncontext = len(list_frame_feats) | |
feat_sources = torch.stack(list_frame_feats) # nmb_context x dim x h*w | |
feat_tar = F.normalize(feat_tar, dim=1, p=2) | |
feat_sources = F.normalize(feat_sources, dim=1, p=2) | |
# print('feat_tar', feat_tar.shape) | |
# print('feat_sources', feat_sources.shape) | |
feat_tar = feat_tar.unsqueeze(0).repeat(ncontext, 1, 1) | |
aff = torch.exp(torch.bmm(feat_tar, feat_sources) / 0.1) | |
size_mask_neighborhood = 12 | |
if size_mask_neighborhood > 0: | |
if mask_neighborhood is None: | |
mask_neighborhood = restrict_neighborhood(h, w) | |
mask_neighborhood = mask_neighborhood.unsqueeze(0).repeat(ncontext, 1, 1) | |
aff *= mask_neighborhood | |
aff = aff.transpose(2, 1).reshape(-1, h*w) # nmb_context*h*w (source: keys) x h*w (tar: queries) | |
topk = 5 | |
tk_val, _ = torch.topk(aff, dim=0, k=topk) | |
tk_val_min, _ = torch.min(tk_val, dim=0) | |
aff[aff < tk_val_min] = 0 | |
aff = aff / torch.sum(aff, keepdim=True, axis=0) | |
list_segs = [s.cuda() for s in list_segs] | |
segs = torch.cat(list_segs) | |
nmb_context, C, h, w = segs.shape | |
segs = segs.reshape(nmb_context, C, -1).transpose(2, 1).reshape(-1, C).T # C x nmb_context*h*w | |
seg_tar = torch.mm(segs, aff) | |
seg_tar = seg_tar.reshape(1, C, h, w) | |
return seg_tar, mask_neighborhood | |
def norm_mask(mask): | |
c, h, w = mask.size() | |
for cnt in range(c): | |
mask_cnt = mask[cnt,:,:] | |
if(mask_cnt.max() > 0): | |
mask_cnt = (mask_cnt - mask_cnt.min()) | |
mask_cnt = mask_cnt/mask_cnt.max() | |
mask[cnt,:,:] = mask_cnt | |
return mask | |
def get_dino_output(dino, rgbs, trajs_g, vis_g): | |
B, S, C, H, W = rgbs.shape | |
B1, S1, N, D = trajs_g.shape | |
assert(B1==B) | |
assert(S1==S) | |
assert(D==2) | |
assert(B==1) | |
xy0 = trajs_g[:,0] # B, N, 2 | |
# The queue stores the n preceeding frames | |
import queue | |
import copy | |
n_last_frames = 7 | |
que = queue.Queue(n_last_frames) | |
# run dino | |
prep_rgbs = [] | |
for s in range(S): | |
prep_rgb, ori_h, ori_w = prep_frame_for_dino(rgbs[0, s].permute(1,2,0).detach().cpu().numpy(), scale_size=[H]) | |
prep_rgbs.append(prep_rgb) | |
prep_rgbs = torch.stack(prep_rgbs, dim=0) # S, 3, H, W | |
with torch.no_grad(): | |
bs = 8 | |
idx = 0 | |
featmaps = [] | |
while idx < S: | |
end_id = min(S, idx+bs) | |
_, featmaps_cur, h, w = get_feats_from_dino(dino, prep_rgbs[idx:end_id]) # S, C, h, w | |
idx = end_id | |
featmaps.append(featmaps_cur) | |
featmaps = torch.cat(featmaps, dim=0) | |
C = featmaps.shape[1] | |
featmaps = featmaps.unsqueeze(0) # 1, S, C, h, w | |
# featmaps = F.normalize(featmaps, dim=2, p=2) | |
xy0 = trajs_g[:, 0, :] # B, N, 2 | |
patch_size = dino.patch_embed.patch_size | |
first_seg = torch.zeros((1, N, H//patch_size, W//patch_size)) | |
for n in range(N): | |
first_seg[0, n, (xy0[0, n, 1]/patch_size).long(), (xy0[0, n, 0]/patch_size).long()] = 1 | |
frame1_feat = featmaps[0, 0].reshape(C, h*w) # dim x h*w | |
mask_neighborhood = None | |
accs = [] | |
trajs_e = torch.zeros_like(trajs_g) | |
trajs_e[0,0] = trajs_g[0,0] | |
for cnt in range(1, S): | |
used_frame_feats = [frame1_feat] + [pair[0] for pair in list(que.queue)] | |
used_segs = [first_seg] + [pair[1] for pair in list(que.queue)] | |
feat_tar = featmaps[0, cnt].reshape(C, h*w) | |
frame_tar_avg, mask_neighborhood = label_propagation(h, w, feat_tar.T, used_frame_feats, used_segs, mask_neighborhood) | |
# pop out oldest frame if neccessary | |
if que.qsize() == n_last_frames: | |
que.get() | |
# push current results into queue | |
seg = copy.deepcopy(frame_tar_avg) | |
que.put([feat_tar, seg]) | |
# upsampling & argmax | |
frame_tar_avg = F.interpolate(frame_tar_avg, scale_factor=patch_size, mode='bilinear', align_corners=False, recompute_scale_factor=False)[0] | |
frame_tar_avg = norm_mask(frame_tar_avg) | |
_, frame_tar_seg = torch.max(frame_tar_avg, dim=0) | |
for n in range(N): | |
vis = vis_g[0,cnt,n] | |
if len(torch.nonzero(frame_tar_avg[n])) > 0: | |
# weighted average | |
nz = torch.nonzero(frame_tar_avg[n]) | |
coord_e = torch.sum(frame_tar_avg[n][nz[:,0], nz[:,1]].reshape(-1,1) * nz.float(), 0) / frame_tar_avg[n][nz[:,0], nz[:,1]].sum() # 2 | |
coord_e = coord_e[[1,0]] | |
else: | |
# stay where it was | |
coord_e = trajs_e[0,cnt-1,n] | |
trajs_e[0, cnt, n] = coord_e | |
return trajs_e | |