File size: 1,432 Bytes
2c50826
 
 
 
 
 
 
 
5291ba9
2c50826
 
 
 
 
 
 
 
 
 
 
 
 
4f41410
2c50826
 
 
 
 
 
 
 
 
 
 
 
 
5291ba9
2c50826
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from typing import Type

from benchmark.metrics.arniqa import ARNIQAMetric
from benchmark.metrics.clip import CLIPMetric
from benchmark.metrics.clip_iqa import CLIPIQAMetric
from benchmark.metrics.image_reward import ImageRewardMetric
from benchmark.metrics.sharpness import SharpnessMetric
from benchmark.metrics.vqa import VQAMetric
#from benchmark.metrics.hps import HPSMetric

def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
    """
    Factory function to create metric instances.
    
    Args:
        metric_type (str): The type of metric to create. Must be one of:
            - "arniqa"
            - "clip"
            - "clip_iqa"
            - "image_reward"
            - "sharpness"
            - "vqa"
            - "hps"
    Returns:
        An instance of the requested metric implementation
        
    Raises:
        ValueError: If an invalid metric type is provided
    """
    metric_map = {
        "arniqa": ARNIQAMetric,
        "clip": CLIPMetric,
        "clip_iqa": CLIPIQAMetric,
        "image_reward": ImageRewardMetric,
        "sharpness": SharpnessMetric,
        "vqa": VQAMetric,
        #"hps": HPSMetric,
    }
    
    if metric_type not in metric_map:
        raise ValueError(f"Invalid metric type: {metric_type}. Must be one of {list(metric_map.keys())}")
    
    return metric_map[metric_type]()