# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn import torch.nn.functional as F class AttentionCTCLoss(torch.nn.Module): def __init__(self, blank_logprob=-1): super(AttentionCTCLoss, self).__init__() self.log_softmax = torch.nn.LogSoftmax(dim=-1) self.blank_logprob = blank_logprob self.CTCLoss = nn.CTCLoss(zero_infinity=True) def forward(self, attn_logprob, in_lens, out_lens): key_lens = in_lens query_lens = out_lens max_key_len = attn_logprob.size(-1) # Reorder input to [query_len, batch_size, key_len] attn_logprob = attn_logprob.squeeze(1) attn_logprob = attn_logprob.permute(1, 0, 2) # Add blank label attn_logprob = F.pad( input=attn_logprob, pad=(1, 0, 0, 0, 0, 0), value=self.blank_logprob) # Convert to log probabilities # Note: Mask out probs beyond key_len key_inds = torch.arange( max_key_len+1, device=attn_logprob.device, dtype=torch.long) attn_logprob.masked_fill_( key_inds.view(1,1,-1) > key_lens.view(1,-1,1), # key_inds >= key_lens+1 -float("inf")) attn_logprob = self.log_softmax(attn_logprob) # Target sequences target_seqs = key_inds[1:].unsqueeze(0) target_seqs = target_seqs.repeat(key_lens.numel(), 1) # Evaluate CTC loss cost = self.CTCLoss( attn_logprob, target_seqs, input_lengths=query_lens, target_lengths=key_lens) return cost class AttentionBinarizationLoss(torch.nn.Module): def __init__(self): super(AttentionBinarizationLoss, self).__init__() def forward(self, hard_attention, soft_attention, eps=1e-12): log_sum = torch.log(torch.clamp(soft_attention[hard_attention == 1], min=eps)).sum() return -log_sum / hard_attention.sum()