Nils Fleischmann
feat: add aws canva + examples in readme + my current environment + disable HPS for now
5291ba9
from typing import Type
from benchmark.metrics.arniqa import ARNIQAMetric
from benchmark.metrics.clip import CLIPMetric
from benchmark.metrics.clip_iqa import CLIPIQAMetric
from benchmark.metrics.image_reward import ImageRewardMetric
from benchmark.metrics.sharpness import SharpnessMetric
from benchmark.metrics.vqa import VQAMetric
#from benchmark.metrics.hps import HPSMetric
def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
"""
Factory function to create metric instances.
Args:
metric_type (str): The type of metric to create. Must be one of:
- "arniqa"
- "clip"
- "clip_iqa"
- "image_reward"
- "sharpness"
- "vqa"
- "hps"
Returns:
An instance of the requested metric implementation
Raises:
ValueError: If an invalid metric type is provided
"""
metric_map = {
"arniqa": ARNIQAMetric,
"clip": CLIPMetric,
"clip_iqa": CLIPIQAMetric,
"image_reward": ImageRewardMetric,
"sharpness": SharpnessMetric,
"vqa": VQAMetric,
#"hps": HPSMetric,
}
if metric_type not in metric_map:
raise ValueError(f"Invalid metric type: {metric_type}. Must be one of {list(metric_map.keys())}")
return metric_map[metric_type]()