import torch import os import cv2 import numpy as np from extractor import visualise_resnet, visualise_resnet_layer, visualise_vit_layer def get_deep_feature(network_name, video_name, frame, frame_number, model, device, layer_name): if network_name == 'resnet50': if layer_name == 'layerstack': all_layers = ['resnet50.conv1', 'resnet50.layer1[0]', 'resnet50.layer1[1]', 'resnet50.layer1[2]', 'resnet50.layer2[0]', 'resnet50.layer2[1]', 'resnet50.layer2[2]', 'resnet50.layer2[3]', 'resnet50.layer3[0]', 'resnet50.layer3[1]', 'resnet50.layer3[2]', 'resnet50.layer3[3]', 'resnet50.layer4[0]', 'resnet50.layer4[1]', 'resnet50.layer4[2]'] resnet50 = model activations_dict, _, total_flops, total_params = visualise_resnet.process_video_frame(video_name, frame, frame_number, all_layers, resnet50, device) elif layer_name == 'pool': visual_layer = 'resnet50.avgpool' # before avg_pool resnet50 = model activations_dict, _, total_flops, total_params = visualise_resnet_layer.process_video_frame(video_name, frame, frame_number, visual_layer, resnet50, device) elif network_name == 'vit': patch_size = 16 activations_dict, _, total_flops, total_params = visualise_vit_layer.process_video_frame(video_name, frame, frame_number, model, patch_size, device) return activations_dict, total_flops, total_params def process_video_feature(video_feature, network_name, layer_name): # initialize an empty list to store processed frames averaged_frames = [] # iterate through each frame in the video_feature for frame in video_feature: frame_features = [] if network_name == 'vit': # global mean and std global_mean = torch.mean(frame, dim=0) global_max = torch.max(frame, dim=0)[0] global_std = torch.std(frame, dim=0) # concatenate all pooling combined_features = torch.hstack([global_mean, global_max, global_std]) frame_features.append(combined_features) elif network_name == 'resnet50': if layer_name == 'layerstack': # iterate through each layer in the current framex for layer_array in frame.values(): # calculate the mean along the specified axes (1 and 2) for each layer layer_mean = torch.mean(layer_array, dim=(1, 2)) # append the calculated mean to the list for the current frame frame_features.append(layer_mean) elif layer_name == 'pool': frame = torch.squeeze(torch.tensor(frame)) # global mean and std global_mean = torch.mean(frame, dim=0) global_max = torch.max(frame, dim=0)[0] global_std = torch.std(frame, dim=0) # concatenate all pooling combined_features = torch.hstack([frame, global_mean, global_max, global_std]) frame_features.append(combined_features) # concatenate the layer means horizontally to form the processed frame processed_frame = torch.hstack(frame_features) averaged_frames.append(processed_frame) averaged_frames = torch.stack(averaged_frames) return averaged_frames def flow_to_rgb(flow): mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) mag = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) # convert angle to hue hue = ang * 180 / np.pi / 2 # create HSV hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8) hsv[..., 0] = hue hsv[..., 1] = 255 hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) # convert HSV to RGB rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) return rgb def get_patch_diff(residual_frame, patch_size): h, w = residual_frame.shape[2:] # Assuming (1, C, H, W) shape h_adj = (h // patch_size) * patch_size w_adj = (w // patch_size) * patch_size residual_frame_adj = residual_frame[:, :, :h_adj, :w_adj] # calculate absolute patch difference diff = torch.zeros((h_adj // patch_size, w_adj // patch_size), device=residual_frame.device) for i in range(0, h_adj, patch_size): for j in range(0, w_adj, patch_size): patch = residual_frame_adj[:, :, i:i + patch_size, j:j + patch_size] # absolute sum diff[i // patch_size, j // patch_size] = torch.sum(torch.abs(patch)) return diff def extract_important_patches(residual_frame, diff, patch_size=16, target_size=224, top_n=196): # find top n patches indices patch_idx = torch.argsort(-diff.view(-1)) top_patches = [(idx // diff.shape[1], idx % diff.shape[1]) for idx in patch_idx[:top_n]] sorted_idx = sorted(top_patches, key=lambda x: (x[0], x[1])) imp_patches_img = torch.zeros((residual_frame.shape[1], target_size, target_size), dtype=residual_frame.dtype, device=residual_frame.device) patches_per_row = target_size // patch_size # 14 # order the patch in the original location relation positions = [] for idx, (y, x) in enumerate(sorted_idx): patch = residual_frame[:, :, y * patch_size:(y + 1) * patch_size, x * patch_size:(x + 1) * patch_size] # new patch location row_idx = idx // patches_per_row col_idx = idx % patches_per_row start_y = row_idx * patch_size start_x = col_idx * patch_size imp_patches_img[:, start_y:start_y + patch_size, start_x:start_x + patch_size] = patch positions.append((y.item(), x.item())) return imp_patches_img, positions def get_frame_patches(frame, positions, patch_size, target_size): imp_patches_img = torch.zeros((frame.shape[1], target_size, target_size), dtype=frame.dtype, device=frame.device) patches_per_row = target_size // patch_size for idx, (y, x) in enumerate(positions): start_y = y * patch_size start_x = x * patch_size end_y = start_y + patch_size end_x = start_x + patch_size patch = frame[:, :, start_y:end_y, start_x:end_x] row_idx = idx // patches_per_row col_idx = idx % patches_per_row target_start_y = row_idx * patch_size target_start_x = col_idx * patch_size imp_patches_img[:, target_start_y:target_start_y + patch_size, target_start_x:target_start_x + patch_size] = patch.squeeze(0) return imp_patches_img def process_patches(original_path, frag_name, residual, patch_size, target_size, top_n): diff = get_patch_diff(residual, patch_size) imp_patches, positions = extract_important_patches(residual, diff, patch_size, target_size, top_n) if frag_name == 'frame_diff': frag_path = original_path.replace('.png', '_residual_imp.png') elif frag_name == 'optical_flow': frag_path = original_path.replace('.png', '_residual_of_imp.png') # cv2.imwrite(frag_path, imp_patches) return frag_path, imp_patches, positions def merge_fragments(diff_fragment, flow_fragment): alpha = 0.5 merged_fragment = diff_fragment * alpha + flow_fragment * (1 - alpha) return merged_fragment def concatenate_features(frame_feature, residual_feature): return torch.cat((frame_feature, residual_feature), dim=-1)