Spaces:
Running
Running
Nils Fleischmann
feat: add aws canva + examples in readme + my current environment + disable HPS for now
5291ba9
from typing import Type | |
from benchmark.metrics.arniqa import ARNIQAMetric | |
from benchmark.metrics.clip import CLIPMetric | |
from benchmark.metrics.clip_iqa import CLIPIQAMetric | |
from benchmark.metrics.image_reward import ImageRewardMetric | |
from benchmark.metrics.sharpness import SharpnessMetric | |
from benchmark.metrics.vqa import VQAMetric | |
#from benchmark.metrics.hps import HPSMetric | |
def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]: | |
""" | |
Factory function to create metric instances. | |
Args: | |
metric_type (str): The type of metric to create. Must be one of: | |
- "arniqa" | |
- "clip" | |
- "clip_iqa" | |
- "image_reward" | |
- "sharpness" | |
- "vqa" | |
- "hps" | |
Returns: | |
An instance of the requested metric implementation | |
Raises: | |
ValueError: If an invalid metric type is provided | |
""" | |
metric_map = { | |
"arniqa": ARNIQAMetric, | |
"clip": CLIPMetric, | |
"clip_iqa": CLIPIQAMetric, | |
"image_reward": ImageRewardMetric, | |
"sharpness": SharpnessMetric, | |
"vqa": VQAMetric, | |
#"hps": HPSMetric, | |
} | |
if metric_type not in metric_map: | |
raise ValueError(f"Invalid metric type: {metric_type}. Must be one of {list(metric_map.keys())}") | |
return metric_map[metric_type]() | |