salimshakeel
upload files
d2542a3
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .attention import SelfAttention
class MultiAttention(nn.Module):
def __init__(self, input_size=1024, output_size=1024, freq=10000, pos_enc=None,
num_segments=None, heads=1, fusion=None):
""" Class wrapping the MultiAttention part of PGL-SUM; its key modules and parameters.
:param int input_size: The expected input feature size.
:param int output_size: The hidden feature size of the attention mechanisms.
:param int freq: The frequency of the sinusoidal positional encoding.
:param None | str pos_enc: The selected positional encoding [absolute, relative].
:param None | int num_segments: The selected number of segments to split the videos.
:param int heads: The selected number of global heads.
:param None | str fusion: The selected type of feature fusion.
"""
super(MultiAttention, self).__init__()
# Global Attention, considering differences among all frames
self.attention = SelfAttention(input_size=input_size, output_size=output_size,
freq=freq, pos_enc=pos_enc, heads=heads)
self.num_segments = num_segments
if self.num_segments is not None:
assert self.num_segments >= 2, "num_segments must be None or 2+"
self.local_attention = nn.ModuleList()
for _ in range(self.num_segments):
# Local Attention, considering differences among the same segment with reduce hidden size
self.local_attention.append(SelfAttention(input_size=input_size, output_size=output_size//num_segments,
freq=freq, pos_enc=pos_enc, heads=4))
self.permitted_fusions = ["add", "mult", "avg", "max"]
self.fusion = fusion
if self.fusion is not None:
self.fusion = self.fusion.lower()
assert self.fusion in self.permitted_fusions, f"Fusion method must be: {*self.permitted_fusions,}"
def forward(self, x):
""" Compute the weighted frame features, based on the global and locals (multi-head) attention mechanisms.
:param torch.Tensor x: Tensor with shape [T, input_size] containing the frame features.
:return: A tuple of:
weighted_value: Tensor with shape [T, input_size] containing the weighted frame features.
attn_weights: Tensor with shape [T, T] containing the attention weights.
"""
weighted_value, attn_weights = self.attention(x) # global attention
if self.num_segments is not None and self.fusion is not None:
segment_size = math.ceil(x.shape[0] / self.num_segments)
for segment in range(self.num_segments):
left_pos = segment * segment_size
right_pos = (segment + 1) * segment_size
local_x = x[left_pos:right_pos]
weighted_local_value, attn_local_weights = self.local_attention[segment](local_x) # local attentions
# Normalize the features vectors
weighted_value[left_pos:right_pos] = F.normalize(weighted_value[left_pos:right_pos].clone(), p=2, dim=1)
weighted_local_value = F.normalize(weighted_local_value, p=2, dim=1)
if self.fusion == "add":
weighted_value[left_pos:right_pos] += weighted_local_value
elif self.fusion == "mult":
weighted_value[left_pos:right_pos] *= weighted_local_value
elif self.fusion == "avg":
weighted_value[left_pos:right_pos] += weighted_local_value
weighted_value[left_pos:right_pos] /= 2
elif self.fusion == "max":
weighted_value[left_pos:right_pos] = torch.max(weighted_value[left_pos:right_pos].clone(),
weighted_local_value)
return weighted_value, attn_weights
class PGL_SUM(nn.Module):
def __init__(self, input_size=1024, output_size=1024, freq=10000, pos_enc=None,
num_segments=None, heads=1, fusion=None):
""" Class wrapping the PGL-SUM model; its key modules and parameters.
:param int input_size: The expected input feature size.
:param int output_size: The hidden feature size of the attention mechanisms.
:param int freq: The frequency of the sinusoidal positional encoding.
:param None | str pos_enc: The selected positional encoding [absolute, relative].
:param None | int num_segments: The selected number of segments to split the videos.
:param int heads: The selected number of global heads.
:param None | str fusion: The selected type of feature fusion.
"""
super(PGL_SUM, self).__init__()
self.attention = MultiAttention(input_size=input_size, output_size=output_size, freq=freq,
pos_enc=pos_enc, num_segments=num_segments, heads=heads, fusion=fusion)
self.linear_1 = nn.Linear(in_features=input_size, out_features=input_size)
self.linear_2 = nn.Linear(in_features=self.linear_1.out_features, out_features=1)
self.drop = nn.Dropout(p=0.5)
self.norm_y = nn.LayerNorm(normalized_shape=input_size, eps=1e-6)
self.norm_linear = nn.LayerNorm(normalized_shape=self.linear_1.out_features, eps=1e-6)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, frame_features):
""" Produce frames importance scores from the frame features, using the PGL-SUM model.
:param torch.Tensor frame_features: Tensor of shape [T, input_size] containing the frame features produced by
using the pool5 layer of GoogleNet.
:return: A tuple of:
y: Tensor with shape [1, T] containing the frames importance scores in [0, 1].
attn_weights: Tensor with shape [T, T] containing the attention weights.
"""
residual = frame_features
weighted_value, attn_weights = self.attention(frame_features)
y = weighted_value + residual
y = self.drop(y)
y = self.norm_y(y)
# 2-layer NN (Regressor Network)
y = self.linear_1(y)
y = self.relu(y)
y = self.drop(y)
y = self.norm_linear(y)
y = self.linear_2(y)
y = self.sigmoid(y)
y = y.view(1, -1)
return y, attn_weights
if __name__ == '__main__':
pass
"""Uncomment for a quick proof of concept
model = PGL_SUM(input_size=256, output_size=256, num_segments=3, fusion="Add").cuda()
_input = torch.randn(500, 256).cuda() # [seq_len, hidden_size]
output, weights = model(_input)
print(f"Output shape: {output.shape}\tattention shape: {weights.shape}")
"""