File size: 1,211 Bytes
bad8293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch.nn as nn
from rouge_score import rouge_scorer
from six.moves import zip_longest
import numpy as np


class Rouge(nn.Module):
    def __init__(self, rouges, **kwargs):
        super().__init__()
        rouges = [r.replace('rougel', 'rougeL') for r in rouges]
        self.scorer = rouge_scorer.RougeScorer(rouges, use_stemmer=True)
        self.rouges = rouges

    def forward(self, refs, hyps):
        scores = []
        for target_rec, prediction_rec in zip_longest(refs, hyps):
            if target_rec is None or prediction_rec is None:
                raise ValueError("Must have equal number of lines across target and "
                                 "prediction.")
            scores.append(self.scorer.score(target_rec, prediction_rec))
        f1_rouge = [s[self.rouges[0]].fmeasure for s in scores]
        return np.mean(f1_rouge), f1_rouge


class Rouge1(Rouge):
    def __init__(self, **kwargs):
        super(Rouge1, self).__init__(rouges=['rouge1'])


class Rouge2(Rouge):
    def __init__(self, **kwargs):
        super(Rouge2, self).__init__(rouges=['rouge2'])


class RougeL(Rouge):
    def __init__(self, **kwargs):
        super(RougeL, self).__init__(rouges=['rougeL'])