Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| from typing import List, Optional, Tuple | |
| from .nn_module import fc_block, build_normalization | |
| class Attention(nn.Module): | |
| """ | |
| Overview: | |
| For each entry embedding, compute individual attention across all entries, add them up to get output attention. | |
| Interfaces: | |
| ``__init__``, ``split``, ``forward`` | |
| """ | |
| def __init__(self, input_dim: int, head_dim: int, output_dim: int, head_num: int, dropout: nn.Module) -> None: | |
| """ | |
| Overview: | |
| Initialize the Attention module with the provided dimensions and dropout layer. | |
| Arguments: | |
| - input_dim (:obj:`int`): The dimension of the input. | |
| - head_dim (:obj:`int`): The dimension of each head in the multi-head attention mechanism. | |
| - output_dim (:obj:`int`): The dimension of the output. | |
| - head_num (:obj:`int`): The number of heads in the multi-head attention mechanism. | |
| - dropout (:obj:`nn.Module`): The dropout layer used in the attention mechanism. | |
| """ | |
| super(Attention, self).__init__() | |
| self.head_num = head_num | |
| self.head_dim = head_dim | |
| self.dropout = dropout | |
| self.attention_pre = fc_block(input_dim, head_dim * head_num * 3) # query, key, value | |
| self.project = fc_block(head_dim * head_num, output_dim) | |
| def split(self, x: torch.Tensor, T: bool = False) -> List[torch.Tensor]: | |
| """ | |
| Overview: | |
| Split the input to get multi-head queries, keys, and values. | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): The tensor to be split, which could be a query, key, or value. | |
| - T (:obj:`bool`, optional): If True, transpose the output tensors. Defaults to False. | |
| Returns: | |
| - x (:obj:`List[torch.Tensor]`): A list of output tensors for each head. | |
| """ | |
| B, N = x.shape[:2] | |
| x = x.view(B, N, self.head_num, self.head_dim) | |
| x = x.permute(0, 2, 1, 3).contiguous() # B, head_num, N, head_dim | |
| if T: | |
| x = x.permute(0, 1, 3, 2).contiguous() | |
| return x | |
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Compute the attention from the input tensor. | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): The input tensor for the forward computation. | |
| - mask (:obj:`Optional[torch.Tensor]`, optional): Optional mask to exclude invalid entries. | |
| Defaults to None. | |
| Returns: | |
| - attention (:obj:`torch.Tensor`): The computed attention tensor. | |
| """ | |
| assert (len(x.shape) == 3) | |
| B, N = x.shape[:2] | |
| x = self.attention_pre(x) | |
| query, key, value = torch.chunk(x, 3, dim=2) | |
| query, key, value = self.split(query), self.split(key, T=True), self.split(value) | |
| score = torch.matmul(query, key) # B, head_num, N, N | |
| score /= math.sqrt(self.head_dim) | |
| if mask is not None: | |
| # inplace modification for reasonable softmax | |
| score.masked_fill_(~mask, value=-1e9) | |
| score = F.softmax(score, dim=-1) | |
| score = self.dropout(score) | |
| attention = torch.matmul(score, value) # B, head_num, N, head_dim | |
| attention = attention.permute(0, 2, 1, 3).contiguous() # B, N, head_num, head_dim | |
| attention = self.project(attention.view(B, N, -1)) # B, N, output_dim | |
| return attention | |
| class TransformerLayer(nn.Module): | |
| """ | |
| Overview: | |
| In transformer layer, first computes entries's attention and applies a feedforward layer. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__( | |
| self, input_dim: int, head_dim: int, hidden_dim: int, output_dim: int, head_num: int, mlp_num: int, | |
| dropout: nn.Module, activation: nn.Module | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialize the TransformerLayer with the provided dimensions, dropout layer, and activation function. | |
| Arguments: | |
| - input_dim (:obj:`int`): The dimension of the input. | |
| - head_dim (:obj:`int`): The dimension of each head in the multi-head attention mechanism. | |
| - hidden_dim (:obj:`int`): The dimension of the hidden layer in the MLP (Multi-Layer Perceptron). | |
| - output_dim (:obj:`int`): The dimension of the output. | |
| - head_num (:obj:`int`): The number of heads in the multi-head attention mechanism. | |
| - mlp_num (:obj:`int`): The number of layers in the MLP. | |
| - dropout (:obj:`nn.Module`): The dropout layer used in the attention mechanism. | |
| - activation (:obj:`nn.Module`): The activation function used in the MLP. | |
| """ | |
| super(TransformerLayer, self).__init__() | |
| self.attention = Attention(input_dim, head_dim, output_dim, head_num, dropout) | |
| self.layernorm1 = build_normalization('LN')(output_dim) | |
| self.dropout = dropout | |
| layers = [] | |
| dims = [output_dim] + [hidden_dim] * (mlp_num - 1) + [output_dim] | |
| for i in range(mlp_num): | |
| layers.append(fc_block(dims[i], dims[i + 1], activation=activation)) | |
| if i != mlp_num - 1: | |
| layers.append(self.dropout) | |
| layers.append(self.dropout) | |
| self.mlp = nn.Sequential(*layers) | |
| self.layernorm2 = build_normalization('LN')(output_dim) | |
| def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Overview: | |
| Compute the forward pass through the Transformer layer. | |
| Arguments: | |
| - inputs (:obj:`Tuple[torch.Tensor, torch.Tensor]`): A tuple containing the input tensor `x` and | |
| the mask tensor. | |
| Returns: | |
| - output (:obj:`Tuple[torch.Tensor, torch.Tensor]`): A tuple containing the predicted value tensor and | |
| the mask tensor. | |
| """ | |
| x, mask = inputs | |
| a = self.dropout(self.attention(x, mask)) | |
| x = self.layernorm1(x + a) | |
| m = self.dropout(self.mlp(x)) | |
| x = self.layernorm2(x + m) | |
| return x, mask | |
| class Transformer(nn.Module): | |
| """ | |
| Overview: | |
| Implementation of the Transformer model. | |
| .. note:: | |
| For more details, refer to "Attention is All You Need": http://arxiv.org/abs/1706.03762. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| head_dim: int = 128, | |
| hidden_dim: int = 1024, | |
| output_dim: int = 256, | |
| head_num: int = 2, | |
| mlp_num: int = 2, | |
| layer_num: int = 3, | |
| dropout_ratio: float = 0., | |
| activation: nn.Module = nn.ReLU(), | |
| ): | |
| """ | |
| Overview: | |
| Initialize the Transformer with the provided dimensions, dropout layer, activation function, | |
| and layer numbers. | |
| Arguments: | |
| - input_dim (:obj:`int`): The dimension of the input. | |
| - head_dim (:obj:`int`): The dimension of each head in the multi-head attention mechanism. | |
| - hidden_dim (:obj:`int`): The dimension of the hidden layer in the MLP (Multi-Layer Perceptron). | |
| - output_dim (:obj:`int`): The dimension of the output. | |
| - head_num (:obj:`int`): The number of heads in the multi-head attention mechanism. | |
| - mlp_num (:obj:`int`): The number of layers in the MLP. | |
| - layer_num (:obj:`int`): The number of Transformer layers. | |
| - dropout_ratio (:obj:`float`): The dropout ratio for the dropout layer. | |
| - activation (:obj:`nn.Module`): The activation function used in the MLP. | |
| """ | |
| super(Transformer, self).__init__() | |
| self.embedding = fc_block(input_dim, output_dim, activation=activation) | |
| self.act = activation | |
| layers = [] | |
| dims = [output_dim] + [output_dim] * layer_num | |
| self.dropout = nn.Dropout(dropout_ratio) | |
| for i in range(layer_num): | |
| layers.append( | |
| TransformerLayer(dims[i], head_dim, hidden_dim, dims[i + 1], head_num, mlp_num, self.dropout, self.act) | |
| ) | |
| self.main = nn.Sequential(*layers) | |
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Perform the forward pass through the Transformer. | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): The input tensor, with shape `(B, N, C)`, where `B` is batch size, \ | |
| `N` is the number of entries, and `C` is the feature dimension. | |
| - mask (:obj:`Optional[torch.Tensor]`, optional): The mask tensor (bool), used to mask out invalid \ | |
| entries in attention. It has shape `(B, N)`, where `B` is batch size and `N` is number of \ | |
| entries. Defaults to None. | |
| Returns: | |
| - x (:obj:`torch.Tensor`): The output tensor from the Transformer. | |
| """ | |
| if mask is not None: | |
| mask = mask.unsqueeze(dim=1).repeat(1, mask.shape[1], 1).unsqueeze(dim=1) | |
| x = self.embedding(x) | |
| x = self.dropout(x) | |
| x, mask = self.main((x, mask)) | |
| return x | |
| class ScaledDotProductAttention(nn.Module): | |
| """ | |
| Overview: | |
| Implementation of Scaled Dot Product Attention, a key component of Transformer models. | |
| This class performs the dot product of the query, key and value tensors, scales it with the square root of the | |
| dimension of the key vector (d_k) and applies dropout for regularization. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__(self, d_k: int, dropout: float = 0.0) -> None: | |
| """ | |
| Overview: | |
| Initialize the ScaledDotProductAttention module with the dimension of the key vector and the dropout rate. | |
| Arguments: | |
| - d_k (:obj:`int`): The dimension of the key vector. This will be used to scale the dot product of the \ | |
| query and key. | |
| - dropout (:obj:`float`, optional): The dropout rate to be applied after the softmax operation. \ | |
| Defaults to 0.0. | |
| """ | |
| super(ScaledDotProductAttention, self).__init__() | |
| self.d_k = d_k | |
| self.dropout = nn.Dropout(dropout) | |
| def forward( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Perform the Scaled Dot Product Attention operation on the query, key and value tensors. | |
| Arguments: | |
| - q (:obj:`torch.Tensor`): The query tensor. | |
| - k (:obj:`torch.Tensor`): The key tensor. | |
| - v (:obj:`torch.Tensor`): The value tensor. | |
| - mask (:obj:`Optional[torch.Tensor]`): An optional mask tensor to be applied on the attention scores. | |
| Defaults to None. | |
| Returns: | |
| - output (:obj:`torch.Tensor`): The output tensor after the attention operation. | |
| """ | |
| attn = torch.matmul(q / (self.d_k ** 0.5), k.transpose(2, 3)) | |
| if mask is not None: | |
| # inplace modification for reasonable softmax | |
| attn.masked_fill_(~mask, -1e9) | |
| attn = self.dropout(F.softmax(attn, dim=-1)) | |
| output = torch.matmul(attn, v) | |
| return output | |