File size: 841 Bytes
75ec748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import logging

from fastapi import HTTPException

from src.metrics.fquad_metric import FQuAD
from src.metrics.metrics_wrapper import (
    PearsonCorrelation,
    AccuracyWrapper,
    Metric,
    F1Score,
)


def metric_factory(metric_name: str) -> Metric:
    """

    Factory method to create a Metric based on a metric name.

    We support the "acc" (Accuracy) and "pearsonr" (Pearson correlation) metrics.

    """
    match metric_name:
        case "accuracy":
            return AccuracyWrapper()
        case "pearson":
            return PearsonCorrelation()
        case "f1":
            return F1Score()
        case "fquad":
            return FQuAD()
        case _:
            error = f"Unknown metric {metric_name}."
            logging.error(error)
            raise HTTPException(200, error)