|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
from .bleu_scorer import BleuScorer |
|
|
|
|
|
class Bleu(nn.Module): |
|
def __init__(self, n=4, **kwargs): |
|
|
|
super().__init__() |
|
self._n = n |
|
|
|
def forward(self, gts, res): |
|
return self.compute_score(gts, res) |
|
|
|
def compute_score(self, gts, res): |
|
res = {i: [v] for i, v in enumerate(res)} |
|
gts = {i: [v] for i, v in enumerate(gts)} |
|
bleu_scorer = BleuScorer(n=self._n) |
|
|
|
for id in sorted(gts.keys()): |
|
hypo = res[id] |
|
ref = gts[id] |
|
|
|
|
|
assert (type(hypo) is list) |
|
assert (len(hypo) == 1) |
|
assert (type(ref) is list) |
|
assert (len(ref) >= 1) |
|
|
|
bleu_scorer += (hypo[0], ref) |
|
|
|
|
|
score, scores = bleu_scorer.compute_score(option='closest', verbose=0) |
|
|
|
|
|
|
|
return score[self._n-1], scores[self._n-1] |
|
|
|
def method(self): |
|
return "Bleu" |
|
|