Diffsplat / src /utils /metrics.py
paulpanwang's picture
Upload folder using huggingface_hub
476e0f0 verified
from typing import *
from PIL.Image import Image as PILImage
from torch import Tensor
import numpy as np
from skimage.metrics import structural_similarity as calculate_ssim
import torch
import torch.nn.functional as F
from transformers import (
CLIPImageProcessor, CLIPVisionModelWithProjection,
CLIPTokenizer, CLIPTextModelWithProjection,
)
import ImageReward as RM
from kiui.lpips import LPIPS
class TextConditionMetrics:
def __init__(self,
clip_name: str = "openai/clip-vit-base-patch32",
rm_name: str = "ImageReward-v1.0",
device_idx: int = 0,
):
self.image_processor = CLIPImageProcessor.from_pretrained(clip_name)
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_name).to(f"cuda:{device_idx}").eval()
self.tokenizer = CLIPTokenizer.from_pretrained(clip_name)
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_name).to(f"cuda:{device_idx}").eval()
self.rm_model = RM.load(rm_name)
self.device = f"cuda:{device_idx}"
@torch.no_grad()
def evaluate(self,
image: Union[PILImage, List[PILImage]],
text: Union[str, List[str]],
) -> Tuple[float, float, float]:
if isinstance(image, PILImage):
image = [image]
if isinstance(text, str):
text = [text]
assert len(image) == len(text)
image_inputs = self.image_processor(image, return_tensors="pt").pixel_values.to(self.device)
image_embeds = self.image_encoder(image_inputs).image_embeds.float() # (N, D)
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_inputs = self.tokenizer(
text,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.device)
text_embeds = self.text_encoder(text_input_ids).text_embeds.float() # (N, D)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
assert image_embeds.shape == text_embeds.shape
clip_scores = image_embeds @ text_embeds.T # (N, N)
# 1. CLIP similarity
clip_sim = clip_scores.diag().mean().item()
# 2. CLIP R-Precision
clip_rprec = (clip_scores.argmax(dim=1) == torch.arange(len(text)).to(self.device)).float().mean().item()
# 3. ImageReward
rm_scores = []
for img, txt in zip(image, text):
rm_scores.append(self.rm_model.score(txt, img))
rm_scores = torch.tensor(rm_scores, device=self.device)
rm_score = rm_scores.mean().item()
return clip_sim, clip_rprec, rm_score
class ImageConditionMetrics:
def __init__(self,
lpips_net: str = "vgg",
lpips_res: int = 256,
device_idx: int = 0,
):
self.lpips_loss = LPIPS(net=lpips_net).to(f"cuda:{device_idx}").eval()
self.lpips_res = lpips_res
self.device = f"cuda:{device_idx}"
@torch.no_grad()
def evaluate(self,
image: Union[Tensor, PILImage, List[PILImage]],
gt: Union[Tensor, PILImage, List[PILImage]],
chunk_size: Optional[int] = None,
input_tensor: bool = False,
) -> Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]:
if not input_tensor:
if isinstance(image, PILImage):
image = [image]
if isinstance(gt, PILImage):
gt = [gt]
assert len(image) == len(gt)
if chunk_size is None:
chunk_size = len(image)
def image_to_tensor(img: PILImage):
return torch.tensor(np.array(img).transpose(2, 0, 1) / 255., device=self.device).unsqueeze(0).float() # (1, 3, H, W)
image_pt = torch.cat([image_to_tensor(img) for img in image], dim=0)
gt_pt = torch.cat([image_to_tensor(img) for img in gt], dim=0)
else:
image_pt = image.to(device=self.device)
gt_pt = gt.to(device=self.device)
# 1. LPIPS
lpips = []
for i in range(0, len(image), chunk_size):
_lpips = self.lpips_loss(
F.interpolate(
image_pt[i:min(len(image), i+chunk_size)] * 2. - 1.,
(self.lpips_res, self.lpips_res), mode="bilinear", align_corners=False
),
F.interpolate(
gt_pt[i:min(len(image), i+chunk_size)] * 2. - 1.,
(self.lpips_res, self.lpips_res), mode="bilinear", align_corners=False
)
)
lpips.append(_lpips)
lpips = torch.cat(lpips)
lpips_mean, lpips_std = lpips.mean().item(), lpips.std().item()
# 2. PSNR
psnr = -10. * torch.log10((gt_pt - image_pt).pow(2).mean(dim=[1, 2, 3]))
psnr_mean, psnr_std = psnr.mean().item(), psnr.std().item()
# 3. SSIM
ssim = []
for i in range(len(image)):
_ssim = calculate_ssim(
(image_pt[i].cpu().float().numpy() * 255.).astype(np.uint8),
(gt_pt[i].cpu().float().numpy() * 255.).astype(np.uint8),
channel_axis=0,
)
ssim.append(_ssim)
ssim = np.array(ssim)
ssim_mean, ssim_std = ssim.mean(), ssim.std()
return (psnr_mean, psnr_std), (ssim_mean, ssim_std), (lpips_mean, lpips_std)