from pathlib import Path from typing import Dict import t2v_metrics import torch class VQAMetric: def __init__(self): self.device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) self.metric = t2v_metrics.VQAScore( model="clip-flant5-xxl", device=str(self.device) ) @property def name(self) -> str: return "vqa_score" def compute_score( self, image_path: Path, prompt: str, ) -> Dict[str, float]: score = self.metric(images=[str(image_path)], texts=[prompt]) return {"vqa": score[0][0].item()}