InferBench / benchmark /metrics /image_reward.py
davidberenstein1957's picture
chore: update .gitignore, environment.yml, and README; remove nils_installs.txt
199a7d9
raw
history blame contribute delete
863 Bytes
import os
import tempfile
from typing import Dict
import ImageReward as RM
import torch
from PIL import Image
class ImageRewardMetric:
def __init__(self):
self.device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self.model = RM.load("ImageReward-v1.0", device=str(self.device))
@property
def name(self) -> str:
return "image_reward"
def compute_score(
self,
image: Image.Image,
prompt: str,
) -> Dict[str, float]:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
image.save(tmp.name)
score = self.model.score(prompt, [tmp.name])
os.unlink(tmp.name)
return {"image_reward": score}