|  | import math | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DepthwiseSeparableConv(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | Depth-wise separable convolution uses less parameters to generate output by convolution. | 
					
						
						|  | :Examples: | 
					
						
						|  | >>> m = DepthwiseSeparableConv(300, 200, 5, dim=1) | 
					
						
						|  | >>> input_tensor = torch.randn(32, 300, 20) | 
					
						
						|  | >>> output = m(input_tensor) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, in_ch, out_ch, k, dim=1, relu=True): | 
					
						
						|  | """ | 
					
						
						|  | :param in_ch: input hidden dimension size | 
					
						
						|  | :param out_ch: output hidden dimension size | 
					
						
						|  | :param k: kernel size | 
					
						
						|  | :param dim: default 1. 1D conv or 2D conv | 
					
						
						|  | """ | 
					
						
						|  | super(DepthwiseSeparableConv, self).__init__() | 
					
						
						|  | self.relu = relu | 
					
						
						|  | if dim == 1: | 
					
						
						|  | self.depthwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=in_ch, | 
					
						
						|  | kernel_size=k, groups=in_ch, padding=k//2) | 
					
						
						|  | self.pointwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=out_ch, | 
					
						
						|  | kernel_size=1, padding=0) | 
					
						
						|  | elif dim == 2: | 
					
						
						|  | self.depthwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=in_ch, | 
					
						
						|  | kernel_size=k, groups=in_ch, padding=k//2) | 
					
						
						|  | self.pointwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, | 
					
						
						|  | kernel_size=1, padding=0) | 
					
						
						|  | else: | 
					
						
						|  | raise Exception("Incorrect dimension!") | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ | 
					
						
						|  | :Input: (N, L_in, D) | 
					
						
						|  | :Output: (N, L_out, D) | 
					
						
						|  | """ | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  | if self.relu: | 
					
						
						|  | out = F.relu(self.pointwise_conv(self.depthwise_conv(x)), inplace=True) | 
					
						
						|  | else: | 
					
						
						|  | out = self.pointwise_conv(self.depthwise_conv(x)) | 
					
						
						|  | return out.transpose(1, 2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ConvEncoder(nn.Module): | 
					
						
						|  | def __init__(self, kernel_size=7, n_filters=128, dropout=0.1): | 
					
						
						|  | super(ConvEncoder, self).__init__() | 
					
						
						|  | self.dropout = nn.Dropout(dropout) | 
					
						
						|  | self.layer_norm = nn.LayerNorm(n_filters) | 
					
						
						|  | self.conv = DepthwiseSeparableConv(in_ch=n_filters, out_ch=n_filters, k=kernel_size, relu=True) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, mask): | 
					
						
						|  | """ | 
					
						
						|  | :param x: (N, L, D) | 
					
						
						|  | :param mask: (N, L), is not used. | 
					
						
						|  | :return: (N, L, D) | 
					
						
						|  | """ | 
					
						
						|  | return self.layer_norm(self.dropout(self.conv(x)) + x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TrainablePositionalEncoding(nn.Module): | 
					
						
						|  | """Construct the embeddings from word, position and token_type embeddings. | 
					
						
						|  | """ | 
					
						
						|  | def __init__(self, max_position_embeddings, hidden_size, dropout=0.1): | 
					
						
						|  | super(TrainablePositionalEncoding, self).__init__() | 
					
						
						|  | self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) | 
					
						
						|  | self.LayerNorm = nn.LayerNorm(hidden_size) | 
					
						
						|  | self.dropout = nn.Dropout(dropout) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input_feat): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | input_feat: (N, L, D) | 
					
						
						|  | """ | 
					
						
						|  | bsz, seq_length = input_feat.shape[:2] | 
					
						
						|  | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device) | 
					
						
						|  | position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) | 
					
						
						|  |  | 
					
						
						|  | position_embeddings = self.position_embeddings(position_ids) | 
					
						
						|  |  | 
					
						
						|  | embeddings = self.LayerNorm(input_feat + position_embeddings) | 
					
						
						|  | embeddings = self.dropout(embeddings) | 
					
						
						|  | return embeddings | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PositionEncoding(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | Add positional information to input tensor. | 
					
						
						|  | :Examples: | 
					
						
						|  | >>> model = PositionEncoding(n_filters=6, max_len=10) | 
					
						
						|  | >>> test_input1 = torch.zeros(3, 10, 6) | 
					
						
						|  | >>> output1 = model(test_input1) | 
					
						
						|  | >>> output1.size() | 
					
						
						|  | >>> test_input2 = torch.zeros(5, 3, 9, 6) | 
					
						
						|  | >>> output2 = model(test_input2) | 
					
						
						|  | >>> output2.size() | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, n_filters=128, max_len=500, pe_type="cosine"): | 
					
						
						|  | """ | 
					
						
						|  | :param n_filters: same with input hidden size | 
					
						
						|  | :param max_len: maximum sequence length | 
					
						
						|  | :param pe_type: cosine or linear or None | 
					
						
						|  | """ | 
					
						
						|  | super(PositionEncoding, self).__init__() | 
					
						
						|  | self.pe_type = pe_type | 
					
						
						|  | if pe_type != "none": | 
					
						
						|  | position = torch.arange(0, max_len).float().unsqueeze(1) | 
					
						
						|  | if pe_type == "cosine": | 
					
						
						|  |  | 
					
						
						|  | pe = torch.zeros(max_len, n_filters) | 
					
						
						|  | div_term = torch.exp(torch.arange(0, n_filters, 2).float() * - (math.log(10000.0) / n_filters)) | 
					
						
						|  | pe[:, 0::2] = torch.sin(position * div_term) | 
					
						
						|  | pe[:, 1::2] = torch.cos(position * div_term) | 
					
						
						|  | elif pe_type == "linear": | 
					
						
						|  | pe = position / max_len | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError | 
					
						
						|  | self.register_buffer("pe", pe) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ | 
					
						
						|  | :Input: (*, L, D) | 
					
						
						|  | :Output: (*, L, D) the same size as input | 
					
						
						|  | """ | 
					
						
						|  | if self.pe_type != "none": | 
					
						
						|  | pe = self.pe.data[:x.size(-2), :] | 
					
						
						|  | extra_dim = len(x.size()) - 2 | 
					
						
						|  | for _ in range(extra_dim): | 
					
						
						|  | pe = pe.unsqueeze(0) | 
					
						
						|  | x = x + pe | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LinearLayer(nn.Module): | 
					
						
						|  | """linear layer configurable with layer normalization, dropout, ReLU.""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): | 
					
						
						|  | super(LinearLayer, self).__init__() | 
					
						
						|  | self.relu = relu | 
					
						
						|  | self.layer_norm = layer_norm | 
					
						
						|  | if layer_norm: | 
					
						
						|  | self.LayerNorm = nn.LayerNorm(in_hsz) | 
					
						
						|  | layers = [ | 
					
						
						|  | nn.Dropout(dropout), | 
					
						
						|  | nn.Linear(in_hsz, out_hsz) | 
					
						
						|  | ] | 
					
						
						|  | self.net = nn.Sequential(*layers) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """(N, L, D)""" | 
					
						
						|  | if self.layer_norm: | 
					
						
						|  | x = self.LayerNorm(x) | 
					
						
						|  | x = self.net(x) | 
					
						
						|  | if self.relu: | 
					
						
						|  | x = F.relu(x, inplace=True) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | bert_config = dict( | 
					
						
						|  | hidden_size=768, | 
					
						
						|  | intermediate_size=768, | 
					
						
						|  | hidden_dropout_prob=0.1, | 
					
						
						|  | attention_probs_dropout_prob=0.1, | 
					
						
						|  | num_attention_heads=4, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BertLayer(nn.Module): | 
					
						
						|  | def __init__(self, config, use_self_attention=True): | 
					
						
						|  | super(BertLayer, self).__init__() | 
					
						
						|  | self.use_self_attention = use_self_attention | 
					
						
						|  | if use_self_attention: | 
					
						
						|  | self.attention = BertAttention(config) | 
					
						
						|  | self.intermediate = BertIntermediate(config) | 
					
						
						|  | self.output = BertOutput(config) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, hidden_states, attention_mask): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | hidden_states:  (N, L, D) | 
					
						
						|  | attention_mask:  (N, L) with 1 indicate valid, 0 indicates invalid | 
					
						
						|  | Returns: | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | if self.use_self_attention: | 
					
						
						|  | attention_output = self.attention(hidden_states, attention_mask) | 
					
						
						|  | else: | 
					
						
						|  | attention_output = hidden_states | 
					
						
						|  | intermediate_output = self.intermediate(attention_output) | 
					
						
						|  | layer_output = self.output(intermediate_output, attention_output) | 
					
						
						|  | return layer_output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BertAttention(nn.Module): | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super(BertAttention, self).__init__() | 
					
						
						|  | self.self = BertSelfAttention(config) | 
					
						
						|  | self.output = BertSelfOutput(config) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input_tensor, attention_mask): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | input_tensor: (N, L, D) | 
					
						
						|  | attention_mask: (N, L) | 
					
						
						|  | Returns: | 
					
						
						|  | """ | 
					
						
						|  | self_output = self.self(input_tensor, input_tensor, input_tensor, attention_mask) | 
					
						
						|  | attention_output = self.output(self_output, input_tensor) | 
					
						
						|  | return attention_output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BertIntermediate(nn.Module): | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super(BertIntermediate, self).__init__() | 
					
						
						|  | self.dense = nn.Sequential( | 
					
						
						|  | nn.Linear(config.hidden_size, config.intermediate_size), | 
					
						
						|  | nn.ReLU(True)) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, hidden_states): | 
					
						
						|  | return self.dense(hidden_states) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BertOutput(nn.Module): | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super(BertOutput, self).__init__() | 
					
						
						|  | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) | 
					
						
						|  | self.LayerNorm = nn.LayerNorm(config.hidden_size) | 
					
						
						|  | self.dropout = nn.Dropout(config.hidden_dropout_prob) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, hidden_states, input_tensor): | 
					
						
						|  | hidden_states = self.dense(hidden_states) | 
					
						
						|  | hidden_states = self.dropout(hidden_states) | 
					
						
						|  | hidden_states = self.LayerNorm(hidden_states + input_tensor) | 
					
						
						|  | return hidden_states | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BertSelfAttention(nn.Module): | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super(BertSelfAttention, self).__init__() | 
					
						
						|  | if config.hidden_size % config.num_attention_heads != 0: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "The hidden size (%d) is not a multiple of the number of attention " | 
					
						
						|  | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) | 
					
						
						|  | self.num_attention_heads = config.num_attention_heads | 
					
						
						|  | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) | 
					
						
						|  | self.all_head_size = self.num_attention_heads * self.attention_head_size | 
					
						
						|  |  | 
					
						
						|  | self.query = nn.Linear(config.hidden_size, self.all_head_size) | 
					
						
						|  | self.key = nn.Linear(config.hidden_size, self.all_head_size) | 
					
						
						|  | self.value = nn.Linear(config.hidden_size, self.all_head_size) | 
					
						
						|  |  | 
					
						
						|  | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) | 
					
						
						|  |  | 
					
						
						|  | def transpose_for_scores(self, x): | 
					
						
						|  | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | 
					
						
						|  | x = x.view(*new_x_shape) | 
					
						
						|  | return x.permute(0, 2, 1, 3) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, query_states, key_states, value_states, attention_mask): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | query_states: (N, Lq, D) | 
					
						
						|  | key_states: (N, L, D) | 
					
						
						|  | value_states: (N, L, D) | 
					
						
						|  | attention_mask: (N, Lq, L) | 
					
						
						|  | Returns: | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | attention_mask = (1 - attention_mask.unsqueeze(1)) * -10000. | 
					
						
						|  | mixed_query_layer = self.query(query_states) | 
					
						
						|  | mixed_key_layer = self.key(key_states) | 
					
						
						|  | mixed_value_layer = self.value(value_states) | 
					
						
						|  |  | 
					
						
						|  | query_layer = self.transpose_for_scores(mixed_query_layer) | 
					
						
						|  | key_layer = self.transpose_for_scores(mixed_key_layer) | 
					
						
						|  | value_layer = self.transpose_for_scores(mixed_value_layer) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | 
					
						
						|  | attention_scores = attention_scores / math.sqrt(self.attention_head_size) | 
					
						
						|  |  | 
					
						
						|  | attention_scores = attention_scores + attention_mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | attention_probs = nn.Softmax(dim=-1)(attention_scores) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | attention_probs = self.dropout(attention_probs) | 
					
						
						|  |  | 
					
						
						|  | context_layer = torch.matmul(attention_probs, value_layer) | 
					
						
						|  | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | 
					
						
						|  | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | 
					
						
						|  | context_layer = context_layer.view(*new_context_layer_shape) | 
					
						
						|  | return context_layer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BertSelfOutput(nn.Module): | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super(BertSelfOutput, self).__init__() | 
					
						
						|  | self.dense = nn.Linear(config.hidden_size, config.hidden_size) | 
					
						
						|  | self.LayerNorm = nn.LayerNorm(config.hidden_size) | 
					
						
						|  | self.dropout = nn.Dropout(config.hidden_dropout_prob) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, hidden_states, input_tensor): | 
					
						
						|  | hidden_states = self.dense(hidden_states) | 
					
						
						|  | hidden_states = self.dropout(hidden_states) | 
					
						
						|  | hidden_states = self.LayerNorm(hidden_states + input_tensor) | 
					
						
						|  | return hidden_states | 
					
						
						|  |  |