import os import tempfile from typing import Dict import ImageReward as RM import torch from PIL import Image class ImageRewardMetric: def __init__(self): self.device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) self.model = RM.load("ImageReward-v1.0", device=str(self.device)) @property def name(self) -> str: return "image_reward" def compute_score( self, image: Image.Image, prompt: str, ) -> Dict[str, float]: with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: image.save(tmp.name) score = self.model.score(prompt, [tmp.name]) os.unlink(tmp.name) return {"image_reward": score}