COLE / src /metrics /metric_factory.py
Yurhu's picture
Initial snapshot upload
75ec748 verified
raw
history blame contribute delete
841 Bytes
import logging
from fastapi import HTTPException
from src.metrics.fquad_metric import FQuAD
from src.metrics.metrics_wrapper import (
PearsonCorrelation,
AccuracyWrapper,
Metric,
F1Score,
)
def metric_factory(metric_name: str) -> Metric:
"""
Factory method to create a Metric based on a metric name.
We support the "acc" (Accuracy) and "pearsonr" (Pearson correlation) metrics.
"""
match metric_name:
case "accuracy":
return AccuracyWrapper()
case "pearson":
return PearsonCorrelation()
case "f1":
return F1Score()
case "fquad":
return FQuAD()
case _:
error = f"Unknown metric {metric_name}."
logging.error(error)
raise HTTPException(200, error)