| | |
| |
|
| | import numpy as np |
| | import six |
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | def dropout(input_tensor, dropout_prob): |
| | """Perform dropout. |
| | |
| | Args: |
| | input_tensor: float Tensor. |
| | dropout_prob: Python float. The probability of dropping out a value (NOT of |
| | *keeping* a dimension as in `tf.nn.dropout`). |
| | |
| | Returns: |
| | A version of `input_tensor` with dropout applied. |
| | """ |
| | if dropout_prob is None or dropout_prob == 0.0: |
| | return input_tensor |
| |
|
| | output = nn.Dropout(input_tensor, rate=dropout_prob) |
| | return output |
| |
|
| |
|
| | def create_look_ahead_mask(seq_length, batch_size=0): |
| | """Create a look ahead mask given a certain seq length. |
| | |
| | Args: |
| | seq_length: int the length of the sequence. |
| | batch_size: if batch_size if provided, the mask will be repeaded. |
| | |
| | Returns: |
| | the mask ((batch_size), seq_length, seq_length) |
| | """ |
| | mask = 1 - troch.tril(torch.ones((seq_length, seq_length))) |
| | if batch_size > 0: |
| | mask = torch.repeat(torch.unsqueeze(mask, dim=0), batch_size, dim=0) |
| | return mask |
| |
|
| |
|
| | def create_attention_mask_from_input_mask(from_tensor, to_mask): |
| | """Create 3D attention mask from a 2D tensor mask. |
| | |
| | Args: |
| | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. |
| | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. |
| | |
| | Returns: |
| | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. |
| | """ |
| | from_shape = get_shape_list(from_tensor) |
| | batch_size = from_shape[0] |
| | from_seq_length = from_shape[1] |
| |
|
| | to_shape = get_shape_list(to_mask) |
| | to_seq_length = to_shape[1] |
| |
|
| | to_mask = torch.reshape(to_mask, (batch_size, 1, to_seq_length)).float() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | broadcast_ones = torch.ones( |
| | shape=[batch_size, from_seq_length, 1]).float() |
| |
|
| | |
| | mask = broadcast_ones * to_mask |
| |
|
| | return mask |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def gelu(x): |
| | """Gaussian Error Linear Unit. |
| | |
| | This is a smoother version of the RELU. |
| | Original paper: https://arxiv.org/abs/1606.08415 |
| | Args: |
| | x: float Tensor to perform activation. |
| | |
| | Returns: |
| | `x` with the GELU activation applied. |
| | """ |
| | cdf = 0.5 * (1.0 + torch.tanh( |
| | (np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))) |
| | return x * cdf |
| |
|
| |
|
| | def get_activation(activation_string): |
| | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. |
| | |
| | Args: |
| | activation_string: String name of the activation function. |
| | |
| | Returns: |
| | A Python function corresponding to the activation function. If |
| | `activation_string` is None, empty, or "linear", this will return None. |
| | If `activation_string` is not a string, it will return `activation_string`. |
| | |
| | Raises: |
| | ValueError: The `activation_string` does not correspond to a known |
| | activation. |
| | """ |
| |
|
| | |
| | |
| | if not isinstance(activation_string, six.string_types): |
| | return activation_string |
| |
|
| | if not activation_string: |
| | return None |
| |
|
| | act = activation_string.lower() |
| | if act == "linear": |
| | return None |
| | elif act == "relu": |
| | return nn.ReLU |
| | elif act == "gelu": |
| | return gelu |
| | elif act == "tanh": |
| | return torch.tanh |
| | else: |
| | raise ValueError("Unsupported activation: %s" % act) |
| |
|
| |
|
| | def get_shape_list(tensor): |
| | """Returns a list of the shape of tensor, preferring static dimensions. |
| | |
| | Args: |
| | tensor: A tf.Tensor object to find the shape of. |
| | |
| | Returns: |
| | A list of dimensions of the shape of tensor. All static dimensions will |
| | be returned as python integers, and dynamic dimensions will be returned |
| | as tf.Tensor scalars. |
| | """ |
| | |
| | shape = tensor.size() |
| |
|
| | non_static_indexes = [] |
| | for (index, dim) in enumerate(shape): |
| | if dim is None: |
| | non_static_indexes.append(index) |
| |
|
| | if not non_static_indexes: |
| | return shape |
| | else: |
| | print('something wrong with static shaping') |
| | assert False |
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def gather_indexes(sequence_tensor, positions): |
| | """Gathers the vectors at the specific positions over a minibatch.""" |
| | sequence_shape = get_shape_list(sequence_tensor) |
| | batch_size = sequence_shape[0] |
| | seq_length = sequence_shape[1] |
| | width = sequence_shape[2] |
| |
|
| | flat_offsets = torch.reshape( |
| | torch.range(0, batch_size).int() * seq_length, (-1, 1)) |
| | flat_positions = torch.reshape(positions + flat_offsets, (-1)) |
| | flat_sequence_tensor = torch.reshape(sequence_tensor, |
| | (batch_size * seq_length, width)) |
| | output_tensor = torch.gather(flat_sequence_tensor, flat_positions) |
| | output_tensor = torch.reshape(output_tensor, (batch_size, -1, width)) |
| | return output_tensor |
| |
|
| |
|
| | def split_heads(x, batch_size, seq_length, num_joints, num_attention_heads, |
| | model_depth): |
| | """Split the embedding vector for different heads for the spatial attention. |
| | |
| | Args: |
| | x: the embedding vector (batch_size, seq_len, num_joints, model_depth) or |
| | (batch_size, seq_len, model_depth) |
| | batch_size: the batch_size |
| | seq_length: the sequence length |
| | num_joints: the number of joints |
| | num_attention_heads: the number of attention heads |
| | model_depth: the model depth |
| | |
| | Returns: |
| | the split vector (batch_size, seq_len, num_heads, num_joints, depth) or |
| | (batch_size, num_heads, seq_len, depth) |
| | """ |
| | depth = model_depth // num_attention_heads |
| | if len(x.get_shape().as_list()) == 4: |
| | |
| | x = torch.reshape( |
| | x, (batch_size, seq_length, num_joints, num_attention_heads, depth)) |
| | return x.permute(0, 1, 3, 2, 4) |
| | elif len(x.get_shape().as_list()) == 3: |
| | |
| | x = torch.reshape(x, (batch_size, seq_length, num_attention_heads, depth)) |
| | return x.permute(0, 2, 1, 3) |
| | else: |
| | raise ValueError("Unsupported input tensor dimension.") |
| |
|
| |
|
| | def scaled_dot_product_attention(q, k, v, mask): |
| | """The scaled dot product attention mechanism. |
| | |
| | Attn(Q, K, V) = softmax((QK^T+mask)/sqrt(depth))V. |
| | |
| | Args: |
| | q: the query vectors matrix (..., attn_dim, d_model/num_heads) |
| | k: the key vector matrix (..., attn_dim, d_model/num_heads) |
| | v: the value vector matrix (..., attn_dim, d_model/num_heads) |
| | mask: a mask for attention |
| | |
| | Returns: |
| | the updated encoding and the attention weights matrix |
| | """ |
| | |
| | |
| | matmul_qk = q @ k.transpose() |
| |
|
| | |
| | dk = torch.shape(k)[-1].float() |
| | scaled_attention_logits = matmul_qk / torch.sqrt(dk) |
| |
|
| | |
| | if mask is not None: |
| | scaled_attention_logits += (mask * -1e9) |
| |
|
| | |
| | attention_weights = nn.softmax( |
| | scaled_attention_logits, dim=-1) |
| |
|
| | output = attention_weights @ v |
| |
|
| | return output, attention_weights |
| |
|