aharley's picture
added basics
6d95ea1
import torch
import cv2
import torch.nn.functional as F
import numpy as np
def prep_frame_for_dino(img, scale_size=[192]):
"""
read a single frame & preprocess
"""
ori_h, ori_w, _ = img.shape
if len(scale_size) == 1:
if(ori_h > ori_w):
tw = scale_size[0]
th = (tw * ori_h) / ori_w
th = int((th // 64) * 64)
else:
th = scale_size[0]
tw = (th * ori_w) / ori_h
tw = int((tw // 64) * 64)
else:
th, tw = scale_size
img = cv2.resize(img, (tw, th))
img = img.astype(np.float32)
img = img / 255.0
img = img[:, :, ::-1]
img = np.transpose(img.copy(), (2, 0, 1))
img = torch.from_numpy(img).float()
def color_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]):
for t, m, s in zip(x, mean, std):
t.sub_(m)
t.div_(s)
return x
img = color_normalize(img)
return img, ori_h, ori_w
def get_feats_from_dino(model, frame):
# batch version of the other func
B = frame.shape[0]
patch_size = model.patch_embed.patch_size
h, w = int(frame.shape[2] / patch_size), int(frame.shape[3] / patch_size)
out = model.get_intermediate_layers(frame.cuda(), n=1)[0] # B, 1+h*w, dim
dim = out.shape[-1]
out = out[:, 1:, :] # discard the [CLS] token
outmap = out.permute(0, 2, 1).reshape(B, dim, h, w)
return out, outmap, h, w
def restrict_neighborhood(h, w):
size_mask_neighborhood = 12
# We restrict the set of source nodes considered to a spatial neighborhood of the query node (i.e. ``local attention'')
mask = torch.zeros(h, w, h, w)
for i in range(h):
for j in range(w):
for p in range(2 * size_mask_neighborhood + 1):
for q in range(2 * size_mask_neighborhood + 1):
if i - size_mask_neighborhood + p < 0 or i - size_mask_neighborhood + p >= h:
continue
if j - size_mask_neighborhood + q < 0 or j - size_mask_neighborhood + q >= w:
continue
mask[i, j, i - size_mask_neighborhood + p, j - size_mask_neighborhood + q] = 1
mask = mask.reshape(h * w, h * w)
return mask.cuda(non_blocking=True)
def label_propagation(h, w, feat_tar, list_frame_feats, list_segs, mask_neighborhood=None):
ncontext = len(list_frame_feats)
feat_sources = torch.stack(list_frame_feats) # nmb_context x dim x h*w
feat_tar = F.normalize(feat_tar, dim=1, p=2)
feat_sources = F.normalize(feat_sources, dim=1, p=2)
# print('feat_tar', feat_tar.shape)
# print('feat_sources', feat_sources.shape)
feat_tar = feat_tar.unsqueeze(0).repeat(ncontext, 1, 1)
aff = torch.exp(torch.bmm(feat_tar, feat_sources) / 0.1)
size_mask_neighborhood = 12
if size_mask_neighborhood > 0:
if mask_neighborhood is None:
mask_neighborhood = restrict_neighborhood(h, w)
mask_neighborhood = mask_neighborhood.unsqueeze(0).repeat(ncontext, 1, 1)
aff *= mask_neighborhood
aff = aff.transpose(2, 1).reshape(-1, h*w) # nmb_context*h*w (source: keys) x h*w (tar: queries)
topk = 5
tk_val, _ = torch.topk(aff, dim=0, k=topk)
tk_val_min, _ = torch.min(tk_val, dim=0)
aff[aff < tk_val_min] = 0
aff = aff / torch.sum(aff, keepdim=True, axis=0)
list_segs = [s.cuda() for s in list_segs]
segs = torch.cat(list_segs)
nmb_context, C, h, w = segs.shape
segs = segs.reshape(nmb_context, C, -1).transpose(2, 1).reshape(-1, C).T # C x nmb_context*h*w
seg_tar = torch.mm(segs, aff)
seg_tar = seg_tar.reshape(1, C, h, w)
return seg_tar, mask_neighborhood
def norm_mask(mask):
c, h, w = mask.size()
for cnt in range(c):
mask_cnt = mask[cnt,:,:]
if(mask_cnt.max() > 0):
mask_cnt = (mask_cnt - mask_cnt.min())
mask_cnt = mask_cnt/mask_cnt.max()
mask[cnt,:,:] = mask_cnt
return mask
def get_dino_output(dino, rgbs, trajs_g, vis_g):
B, S, C, H, W = rgbs.shape
B1, S1, N, D = trajs_g.shape
assert(B1==B)
assert(S1==S)
assert(D==2)
assert(B==1)
xy0 = trajs_g[:,0] # B, N, 2
# The queue stores the n preceeding frames
import queue
import copy
n_last_frames = 7
que = queue.Queue(n_last_frames)
# run dino
prep_rgbs = []
for s in range(S):
prep_rgb, ori_h, ori_w = prep_frame_for_dino(rgbs[0, s].permute(1,2,0).detach().cpu().numpy(), scale_size=[H])
prep_rgbs.append(prep_rgb)
prep_rgbs = torch.stack(prep_rgbs, dim=0) # S, 3, H, W
with torch.no_grad():
bs = 8
idx = 0
featmaps = []
while idx < S:
end_id = min(S, idx+bs)
_, featmaps_cur, h, w = get_feats_from_dino(dino, prep_rgbs[idx:end_id]) # S, C, h, w
idx = end_id
featmaps.append(featmaps_cur)
featmaps = torch.cat(featmaps, dim=0)
C = featmaps.shape[1]
featmaps = featmaps.unsqueeze(0) # 1, S, C, h, w
# featmaps = F.normalize(featmaps, dim=2, p=2)
xy0 = trajs_g[:, 0, :] # B, N, 2
patch_size = dino.patch_embed.patch_size
first_seg = torch.zeros((1, N, H//patch_size, W//patch_size))
for n in range(N):
first_seg[0, n, (xy0[0, n, 1]/patch_size).long(), (xy0[0, n, 0]/patch_size).long()] = 1
frame1_feat = featmaps[0, 0].reshape(C, h*w) # dim x h*w
mask_neighborhood = None
accs = []
trajs_e = torch.zeros_like(trajs_g)
trajs_e[0,0] = trajs_g[0,0]
for cnt in range(1, S):
used_frame_feats = [frame1_feat] + [pair[0] for pair in list(que.queue)]
used_segs = [first_seg] + [pair[1] for pair in list(que.queue)]
feat_tar = featmaps[0, cnt].reshape(C, h*w)
frame_tar_avg, mask_neighborhood = label_propagation(h, w, feat_tar.T, used_frame_feats, used_segs, mask_neighborhood)
# pop out oldest frame if neccessary
if que.qsize() == n_last_frames:
que.get()
# push current results into queue
seg = copy.deepcopy(frame_tar_avg)
que.put([feat_tar, seg])
# upsampling & argmax
frame_tar_avg = F.interpolate(frame_tar_avg, scale_factor=patch_size, mode='bilinear', align_corners=False, recompute_scale_factor=False)[0]
frame_tar_avg = norm_mask(frame_tar_avg)
_, frame_tar_seg = torch.max(frame_tar_avg, dim=0)
for n in range(N):
vis = vis_g[0,cnt,n]
if len(torch.nonzero(frame_tar_avg[n])) > 0:
# weighted average
nz = torch.nonzero(frame_tar_avg[n])
coord_e = torch.sum(frame_tar_avg[n][nz[:,0], nz[:,1]].reshape(-1,1) * nz.float(), 0) / frame_tar_avg[n][nz[:,0], nz[:,1]].sum() # 2
coord_e = coord_e[[1,0]]
else:
# stay where it was
coord_e = trajs_e[0,cnt-1,n]
trajs_e[0, cnt, n] = coord_e
return trajs_e