|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
from .attention import SelfAttention |
|
|
|
|
|
|
|
class MultiAttention(nn.Module): |
|
def __init__(self, input_size=1024, output_size=1024, freq=10000, pos_enc=None, |
|
num_segments=None, heads=1, fusion=None): |
|
""" Class wrapping the MultiAttention part of PGL-SUM; its key modules and parameters. |
|
|
|
:param int input_size: The expected input feature size. |
|
:param int output_size: The hidden feature size of the attention mechanisms. |
|
:param int freq: The frequency of the sinusoidal positional encoding. |
|
:param None | str pos_enc: The selected positional encoding [absolute, relative]. |
|
:param None | int num_segments: The selected number of segments to split the videos. |
|
:param int heads: The selected number of global heads. |
|
:param None | str fusion: The selected type of feature fusion. |
|
""" |
|
super(MultiAttention, self).__init__() |
|
|
|
|
|
self.attention = SelfAttention(input_size=input_size, output_size=output_size, |
|
freq=freq, pos_enc=pos_enc, heads=heads) |
|
|
|
self.num_segments = num_segments |
|
if self.num_segments is not None: |
|
assert self.num_segments >= 2, "num_segments must be None or 2+" |
|
self.local_attention = nn.ModuleList() |
|
for _ in range(self.num_segments): |
|
|
|
self.local_attention.append(SelfAttention(input_size=input_size, output_size=output_size//num_segments, |
|
freq=freq, pos_enc=pos_enc, heads=4)) |
|
self.permitted_fusions = ["add", "mult", "avg", "max"] |
|
self.fusion = fusion |
|
if self.fusion is not None: |
|
self.fusion = self.fusion.lower() |
|
assert self.fusion in self.permitted_fusions, f"Fusion method must be: {*self.permitted_fusions,}" |
|
|
|
def forward(self, x): |
|
""" Compute the weighted frame features, based on the global and locals (multi-head) attention mechanisms. |
|
|
|
:param torch.Tensor x: Tensor with shape [T, input_size] containing the frame features. |
|
:return: A tuple of: |
|
weighted_value: Tensor with shape [T, input_size] containing the weighted frame features. |
|
attn_weights: Tensor with shape [T, T] containing the attention weights. |
|
""" |
|
weighted_value, attn_weights = self.attention(x) |
|
|
|
if self.num_segments is not None and self.fusion is not None: |
|
segment_size = math.ceil(x.shape[0] / self.num_segments) |
|
for segment in range(self.num_segments): |
|
left_pos = segment * segment_size |
|
right_pos = (segment + 1) * segment_size |
|
local_x = x[left_pos:right_pos] |
|
weighted_local_value, attn_local_weights = self.local_attention[segment](local_x) |
|
|
|
|
|
weighted_value[left_pos:right_pos] = F.normalize(weighted_value[left_pos:right_pos].clone(), p=2, dim=1) |
|
weighted_local_value = F.normalize(weighted_local_value, p=2, dim=1) |
|
if self.fusion == "add": |
|
weighted_value[left_pos:right_pos] += weighted_local_value |
|
elif self.fusion == "mult": |
|
weighted_value[left_pos:right_pos] *= weighted_local_value |
|
elif self.fusion == "avg": |
|
weighted_value[left_pos:right_pos] += weighted_local_value |
|
weighted_value[left_pos:right_pos] /= 2 |
|
elif self.fusion == "max": |
|
weighted_value[left_pos:right_pos] = torch.max(weighted_value[left_pos:right_pos].clone(), |
|
weighted_local_value) |
|
|
|
return weighted_value, attn_weights |
|
|
|
|
|
class PGL_SUM(nn.Module): |
|
def __init__(self, input_size=1024, output_size=1024, freq=10000, pos_enc=None, |
|
num_segments=None, heads=1, fusion=None): |
|
""" Class wrapping the PGL-SUM model; its key modules and parameters. |
|
|
|
:param int input_size: The expected input feature size. |
|
:param int output_size: The hidden feature size of the attention mechanisms. |
|
:param int freq: The frequency of the sinusoidal positional encoding. |
|
:param None | str pos_enc: The selected positional encoding [absolute, relative]. |
|
:param None | int num_segments: The selected number of segments to split the videos. |
|
:param int heads: The selected number of global heads. |
|
:param None | str fusion: The selected type of feature fusion. |
|
""" |
|
super(PGL_SUM, self).__init__() |
|
|
|
self.attention = MultiAttention(input_size=input_size, output_size=output_size, freq=freq, |
|
pos_enc=pos_enc, num_segments=num_segments, heads=heads, fusion=fusion) |
|
self.linear_1 = nn.Linear(in_features=input_size, out_features=input_size) |
|
self.linear_2 = nn.Linear(in_features=self.linear_1.out_features, out_features=1) |
|
|
|
self.drop = nn.Dropout(p=0.5) |
|
self.norm_y = nn.LayerNorm(normalized_shape=input_size, eps=1e-6) |
|
self.norm_linear = nn.LayerNorm(normalized_shape=self.linear_1.out_features, eps=1e-6) |
|
self.relu = nn.ReLU() |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, frame_features): |
|
""" Produce frames importance scores from the frame features, using the PGL-SUM model. |
|
|
|
:param torch.Tensor frame_features: Tensor of shape [T, input_size] containing the frame features produced by |
|
using the pool5 layer of GoogleNet. |
|
:return: A tuple of: |
|
y: Tensor with shape [1, T] containing the frames importance scores in [0, 1]. |
|
attn_weights: Tensor with shape [T, T] containing the attention weights. |
|
""" |
|
residual = frame_features |
|
weighted_value, attn_weights = self.attention(frame_features) |
|
y = weighted_value + residual |
|
y = self.drop(y) |
|
y = self.norm_y(y) |
|
|
|
|
|
y = self.linear_1(y) |
|
y = self.relu(y) |
|
y = self.drop(y) |
|
y = self.norm_linear(y) |
|
|
|
y = self.linear_2(y) |
|
y = self.sigmoid(y) |
|
y = y.view(1, -1) |
|
|
|
return y, attn_weights |
|
|
|
|
|
if __name__ == '__main__': |
|
pass |
|
"""Uncomment for a quick proof of concept |
|
model = PGL_SUM(input_size=256, output_size=256, num_segments=3, fusion="Add").cuda() |
|
_input = torch.randn(500, 256).cuda() # [seq_len, hidden_size] |
|
output, weights = model(_input) |
|
print(f"Output shape: {output.shape}\tattention shape: {weights.shape}") |
|
""" |
|
|