Spaces:
Running
Running
File size: 863 Bytes
2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import os
import tempfile
from typing import Dict
import ImageReward as RM
import torch
from PIL import Image
class ImageRewardMetric:
def __init__(self):
self.device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self.model = RM.load("ImageReward-v1.0", device=str(self.device))
@property
def name(self) -> str:
return "image_reward"
def compute_score(
self,
image: Image.Image,
prompt: str,
) -> Dict[str, float]:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
image.save(tmp.name)
score = self.model.score(prompt, [tmp.name])
os.unlink(tmp.name)
return {"image_reward": score}
|