Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
from PIL.Image import Image as PILImage | |
from torch import Tensor | |
import numpy as np | |
from skimage.metrics import structural_similarity as calculate_ssim | |
import torch | |
import torch.nn.functional as F | |
from transformers import ( | |
CLIPImageProcessor, CLIPVisionModelWithProjection, | |
CLIPTokenizer, CLIPTextModelWithProjection, | |
) | |
import ImageReward as RM | |
from kiui.lpips import LPIPS | |
class TextConditionMetrics: | |
def __init__(self, | |
clip_name: str = "openai/clip-vit-base-patch32", | |
rm_name: str = "ImageReward-v1.0", | |
device_idx: int = 0, | |
): | |
self.image_processor = CLIPImageProcessor.from_pretrained(clip_name) | |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_name).to(f"cuda:{device_idx}").eval() | |
self.tokenizer = CLIPTokenizer.from_pretrained(clip_name) | |
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_name).to(f"cuda:{device_idx}").eval() | |
self.rm_model = RM.load(rm_name) | |
self.device = f"cuda:{device_idx}" | |
def evaluate(self, | |
image: Union[PILImage, List[PILImage]], | |
text: Union[str, List[str]], | |
) -> Tuple[float, float, float]: | |
if isinstance(image, PILImage): | |
image = [image] | |
if isinstance(text, str): | |
text = [text] | |
assert len(image) == len(text) | |
image_inputs = self.image_processor(image, return_tensors="pt").pixel_values.to(self.device) | |
image_embeds = self.image_encoder(image_inputs).image_embeds.float() # (N, D) | |
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) | |
text_inputs = self.tokenizer( | |
text, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids.to(self.device) | |
text_embeds = self.text_encoder(text_input_ids).text_embeds.float() # (N, D) | |
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) | |
assert image_embeds.shape == text_embeds.shape | |
clip_scores = image_embeds @ text_embeds.T # (N, N) | |
# 1. CLIP similarity | |
clip_sim = clip_scores.diag().mean().item() | |
# 2. CLIP R-Precision | |
clip_rprec = (clip_scores.argmax(dim=1) == torch.arange(len(text)).to(self.device)).float().mean().item() | |
# 3. ImageReward | |
rm_scores = [] | |
for img, txt in zip(image, text): | |
rm_scores.append(self.rm_model.score(txt, img)) | |
rm_scores = torch.tensor(rm_scores, device=self.device) | |
rm_score = rm_scores.mean().item() | |
return clip_sim, clip_rprec, rm_score | |
class ImageConditionMetrics: | |
def __init__(self, | |
lpips_net: str = "vgg", | |
lpips_res: int = 256, | |
device_idx: int = 0, | |
): | |
self.lpips_loss = LPIPS(net=lpips_net).to(f"cuda:{device_idx}").eval() | |
self.lpips_res = lpips_res | |
self.device = f"cuda:{device_idx}" | |
def evaluate(self, | |
image: Union[Tensor, PILImage, List[PILImage]], | |
gt: Union[Tensor, PILImage, List[PILImage]], | |
chunk_size: Optional[int] = None, | |
input_tensor: bool = False, | |
) -> Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]: | |
if not input_tensor: | |
if isinstance(image, PILImage): | |
image = [image] | |
if isinstance(gt, PILImage): | |
gt = [gt] | |
assert len(image) == len(gt) | |
if chunk_size is None: | |
chunk_size = len(image) | |
def image_to_tensor(img: PILImage): | |
return torch.tensor(np.array(img).transpose(2, 0, 1) / 255., device=self.device).unsqueeze(0).float() # (1, 3, H, W) | |
image_pt = torch.cat([image_to_tensor(img) for img in image], dim=0) | |
gt_pt = torch.cat([image_to_tensor(img) for img in gt], dim=0) | |
else: | |
image_pt = image.to(device=self.device) | |
gt_pt = gt.to(device=self.device) | |
# 1. LPIPS | |
lpips = [] | |
for i in range(0, len(image), chunk_size): | |
_lpips = self.lpips_loss( | |
F.interpolate( | |
image_pt[i:min(len(image), i+chunk_size)] * 2. - 1., | |
(self.lpips_res, self.lpips_res), mode="bilinear", align_corners=False | |
), | |
F.interpolate( | |
gt_pt[i:min(len(image), i+chunk_size)] * 2. - 1., | |
(self.lpips_res, self.lpips_res), mode="bilinear", align_corners=False | |
) | |
) | |
lpips.append(_lpips) | |
lpips = torch.cat(lpips) | |
lpips_mean, lpips_std = lpips.mean().item(), lpips.std().item() | |
# 2. PSNR | |
psnr = -10. * torch.log10((gt_pt - image_pt).pow(2).mean(dim=[1, 2, 3])) | |
psnr_mean, psnr_std = psnr.mean().item(), psnr.std().item() | |
# 3. SSIM | |
ssim = [] | |
for i in range(len(image)): | |
_ssim = calculate_ssim( | |
(image_pt[i].cpu().float().numpy() * 255.).astype(np.uint8), | |
(gt_pt[i].cpu().float().numpy() * 255.).astype(np.uint8), | |
channel_axis=0, | |
) | |
ssim.append(_ssim) | |
ssim = np.array(ssim) | |
ssim_mean, ssim_std = ssim.mean(), ssim.std() | |
return (psnr_mean, psnr_std), (ssim_mean, ssim_std), (lpips_mean, lpips_std) | |