import math import torch import torch.nn as nn import torch.nn.functional as F class DepthwiseSeparableConv(nn.Module): """ Depth-wise separable convolution uses less parameters to generate output by convolution. :Examples: >>> m = DepthwiseSeparableConv(300, 200, 5, dim=1) >>> input_tensor = torch.randn(32, 300, 20) >>> output = m(input_tensor) """ def __init__(self, in_ch, out_ch, k, dim=1, relu=True): """ :param in_ch: input hidden dimension size :param out_ch: output hidden dimension size :param k: kernel size :param dim: default 1. 1D conv or 2D conv """ super(DepthwiseSeparableConv, self).__init__() self.relu = relu if dim == 1: self.depthwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=in_ch, kernel_size=k, groups=in_ch, padding=k//2) self.pointwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=out_ch, kernel_size=1, padding=0) elif dim == 2: self.depthwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=k, groups=in_ch, padding=k//2) self.pointwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1, padding=0) else: raise Exception("Incorrect dimension!") def forward(self, x): """ :Input: (N, L_in, D) :Output: (N, L_out, D) """ x = x.transpose(1, 2) if self.relu: out = F.relu(self.pointwise_conv(self.depthwise_conv(x)), inplace=True) else: out = self.pointwise_conv(self.depthwise_conv(x)) return out.transpose(1, 2) # (N, L, D) class ConvEncoder(nn.Module): def __init__(self, kernel_size=7, n_filters=128, dropout=0.1): super(ConvEncoder, self).__init__() self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(n_filters) self.conv = DepthwiseSeparableConv(in_ch=n_filters, out_ch=n_filters, k=kernel_size, relu=True) def forward(self, x, mask): """ :param x: (N, L, D) :param mask: (N, L), is not used. :return: (N, L, D) """ return self.layer_norm(self.dropout(self.conv(x)) + x) # (N, L, D) class TrainablePositionalEncoding(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """ def __init__(self, max_position_embeddings, hidden_size, dropout=0.1): super(TrainablePositionalEncoding, self).__init__() self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) self.LayerNorm = nn.LayerNorm(hidden_size) self.dropout = nn.Dropout(dropout) def forward(self, input_feat): """ Args: input_feat: (N, L, D) """ bsz, seq_length = input_feat.shape[:2] position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device) position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L) position_embeddings = self.position_embeddings(position_ids) embeddings = self.LayerNorm(input_feat + position_embeddings) embeddings = self.dropout(embeddings) return embeddings class PositionEncoding(nn.Module): """ Add positional information to input tensor. :Examples: >>> model = PositionEncoding(n_filters=6, max_len=10) >>> test_input1 = torch.zeros(3, 10, 6) >>> output1 = model(test_input1) >>> output1.size() >>> test_input2 = torch.zeros(5, 3, 9, 6) >>> output2 = model(test_input2) >>> output2.size() """ def __init__(self, n_filters=128, max_len=500, pe_type="cosine"): """ :param n_filters: same with input hidden size :param max_len: maximum sequence length :param pe_type: cosine or linear or None """ super(PositionEncoding, self).__init__() self.pe_type = pe_type if pe_type != "none": position = torch.arange(0, max_len).float().unsqueeze(1) if pe_type == "cosine": # Compute the positional encodings once in log space. pe = torch.zeros(max_len, n_filters) # (L, D) div_term = torch.exp(torch.arange(0, n_filters, 2).float() * - (math.log(10000.0) / n_filters)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) elif pe_type == "linear": pe = position / max_len else: raise ValueError self.register_buffer("pe", pe) # buffer is a tensor, not a variable, (L, D) def forward(self, x): """ :Input: (*, L, D) :Output: (*, L, D) the same size as input """ if self.pe_type != "none": pe = self.pe.data[:x.size(-2), :] # (#x.size(-2), n_filters) extra_dim = len(x.size()) - 2 for _ in range(extra_dim): pe = pe.unsqueeze(0) x = x + pe return x class LinearLayer(nn.Module): """linear layer configurable with layer normalization, dropout, ReLU.""" def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): super(LinearLayer, self).__init__() self.relu = relu self.layer_norm = layer_norm if layer_norm: self.LayerNorm = nn.LayerNorm(in_hsz) layers = [ nn.Dropout(dropout), nn.Linear(in_hsz, out_hsz) ] self.net = nn.Sequential(*layers) def forward(self, x): """(N, L, D)""" if self.layer_norm: x = self.LayerNorm(x) x = self.net(x) if self.relu: x = F.relu(x, inplace=True) return x # (N, L, D) bert_config = dict( hidden_size=768, intermediate_size=768, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, num_attention_heads=4, ) class BertLayer(nn.Module): def __init__(self, config, use_self_attention=True): super(BertLayer, self).__init__() self.use_self_attention = use_self_attention if use_self_attention: self.attention = BertAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward(self, hidden_states, attention_mask): """ Args: hidden_states: (N, L, D) attention_mask: (N, L) with 1 indicate valid, 0 indicates invalid Returns: """ if self.use_self_attention: attention_output = self.attention(hidden_states, attention_mask) else: attention_output = hidden_states intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output class BertAttention(nn.Module): def __init__(self, config): super(BertAttention, self).__init__() self.self = BertSelfAttention(config) self.output = BertSelfOutput(config) def forward(self, input_tensor, attention_mask): """ Args: input_tensor: (N, L, D) attention_mask: (N, L) Returns: """ self_output = self.self(input_tensor, input_tensor, input_tensor, attention_mask) attention_output = self.output(self_output, input_tensor) return attention_output class BertIntermediate(nn.Module): def __init__(self, config): super(BertIntermediate, self).__init__() self.dense = nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size), nn.ReLU(True)) def forward(self, hidden_states): return self.dense(hidden_states) class BertOutput(nn.Module): def __init__(self, config): super(BertOutput, self).__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertSelfAttention(nn.Module): def __init__(self, config): super(BertSelfAttention, self).__init__() if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # (N, L, nh, dh) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) # (N, nh, L, dh) def forward(self, query_states, key_states, value_states, attention_mask): """ Args: query_states: (N, Lq, D) key_states: (N, L, D) value_states: (N, L, D) attention_mask: (N, Lq, L) Returns: """ # only need to mask the dimension where the softmax (last dim) is applied, as another dim (second last) # will be ignored in future computation anyway attention_mask = (1 - attention_mask.unsqueeze(1)) * -10000. # (N, 1, Lq, L) mixed_query_layer = self.query(query_states) mixed_key_layer = self.key(key_states) mixed_value_layer = self.value(value_states) query_layer = self.transpose_for_scores(mixed_query_layer) # (N, nh, Lq, dh) key_layer = self.transpose_for_scores(mixed_key_layer) # (N, nh, L, dh) value_layer = self.transpose_for_scores(mixed_value_layer) # (N, nh, L, dh) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # (N, nh, Lq, L) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer class BertSelfOutput(nn.Module): def __init__(self, config): super(BertSelfOutput, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states