Spaces:
Running
Running
File size: 1,432 Bytes
2c50826 5291ba9 2c50826 4f41410 2c50826 5291ba9 2c50826 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from typing import Type
from benchmark.metrics.arniqa import ARNIQAMetric
from benchmark.metrics.clip import CLIPMetric
from benchmark.metrics.clip_iqa import CLIPIQAMetric
from benchmark.metrics.image_reward import ImageRewardMetric
from benchmark.metrics.sharpness import SharpnessMetric
from benchmark.metrics.vqa import VQAMetric
#from benchmark.metrics.hps import HPSMetric
def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
"""
Factory function to create metric instances.
Args:
metric_type (str): The type of metric to create. Must be one of:
- "arniqa"
- "clip"
- "clip_iqa"
- "image_reward"
- "sharpness"
- "vqa"
- "hps"
Returns:
An instance of the requested metric implementation
Raises:
ValueError: If an invalid metric type is provided
"""
metric_map = {
"arniqa": ARNIQAMetric,
"clip": CLIPMetric,
"clip_iqa": CLIPIQAMetric,
"image_reward": ImageRewardMetric,
"sharpness": SharpnessMetric,
"vqa": VQAMetric,
#"hps": HPSMetric,
}
if metric_type not in metric_map:
raise ValueError(f"Invalid metric type: {metric_type}. Must be one of {list(metric_map.keys())}")
return metric_map[metric_type]()
|