Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,503 Bytes
476e0f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from typing import *
from PIL.Image import Image as PILImage
from torch import Tensor
import numpy as np
from skimage.metrics import structural_similarity as calculate_ssim
import torch
import torch.nn.functional as F
from transformers import (
CLIPImageProcessor, CLIPVisionModelWithProjection,
CLIPTokenizer, CLIPTextModelWithProjection,
)
import ImageReward as RM
from kiui.lpips import LPIPS
class TextConditionMetrics:
def __init__(self,
clip_name: str = "openai/clip-vit-base-patch32",
rm_name: str = "ImageReward-v1.0",
device_idx: int = 0,
):
self.image_processor = CLIPImageProcessor.from_pretrained(clip_name)
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_name).to(f"cuda:{device_idx}").eval()
self.tokenizer = CLIPTokenizer.from_pretrained(clip_name)
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_name).to(f"cuda:{device_idx}").eval()
self.rm_model = RM.load(rm_name)
self.device = f"cuda:{device_idx}"
@torch.no_grad()
def evaluate(self,
image: Union[PILImage, List[PILImage]],
text: Union[str, List[str]],
) -> Tuple[float, float, float]:
if isinstance(image, PILImage):
image = [image]
if isinstance(text, str):
text = [text]
assert len(image) == len(text)
image_inputs = self.image_processor(image, return_tensors="pt").pixel_values.to(self.device)
image_embeds = self.image_encoder(image_inputs).image_embeds.float() # (N, D)
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_inputs = self.tokenizer(
text,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.device)
text_embeds = self.text_encoder(text_input_ids).text_embeds.float() # (N, D)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
assert image_embeds.shape == text_embeds.shape
clip_scores = image_embeds @ text_embeds.T # (N, N)
# 1. CLIP similarity
clip_sim = clip_scores.diag().mean().item()
# 2. CLIP R-Precision
clip_rprec = (clip_scores.argmax(dim=1) == torch.arange(len(text)).to(self.device)).float().mean().item()
# 3. ImageReward
rm_scores = []
for img, txt in zip(image, text):
rm_scores.append(self.rm_model.score(txt, img))
rm_scores = torch.tensor(rm_scores, device=self.device)
rm_score = rm_scores.mean().item()
return clip_sim, clip_rprec, rm_score
class ImageConditionMetrics:
def __init__(self,
lpips_net: str = "vgg",
lpips_res: int = 256,
device_idx: int = 0,
):
self.lpips_loss = LPIPS(net=lpips_net).to(f"cuda:{device_idx}").eval()
self.lpips_res = lpips_res
self.device = f"cuda:{device_idx}"
@torch.no_grad()
def evaluate(self,
image: Union[Tensor, PILImage, List[PILImage]],
gt: Union[Tensor, PILImage, List[PILImage]],
chunk_size: Optional[int] = None,
input_tensor: bool = False,
) -> Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]:
if not input_tensor:
if isinstance(image, PILImage):
image = [image]
if isinstance(gt, PILImage):
gt = [gt]
assert len(image) == len(gt)
if chunk_size is None:
chunk_size = len(image)
def image_to_tensor(img: PILImage):
return torch.tensor(np.array(img).transpose(2, 0, 1) / 255., device=self.device).unsqueeze(0).float() # (1, 3, H, W)
image_pt = torch.cat([image_to_tensor(img) for img in image], dim=0)
gt_pt = torch.cat([image_to_tensor(img) for img in gt], dim=0)
else:
image_pt = image.to(device=self.device)
gt_pt = gt.to(device=self.device)
# 1. LPIPS
lpips = []
for i in range(0, len(image), chunk_size):
_lpips = self.lpips_loss(
F.interpolate(
image_pt[i:min(len(image), i+chunk_size)] * 2. - 1.,
(self.lpips_res, self.lpips_res), mode="bilinear", align_corners=False
),
F.interpolate(
gt_pt[i:min(len(image), i+chunk_size)] * 2. - 1.,
(self.lpips_res, self.lpips_res), mode="bilinear", align_corners=False
)
)
lpips.append(_lpips)
lpips = torch.cat(lpips)
lpips_mean, lpips_std = lpips.mean().item(), lpips.std().item()
# 2. PSNR
psnr = -10. * torch.log10((gt_pt - image_pt).pow(2).mean(dim=[1, 2, 3]))
psnr_mean, psnr_std = psnr.mean().item(), psnr.std().item()
# 3. SSIM
ssim = []
for i in range(len(image)):
_ssim = calculate_ssim(
(image_pt[i].cpu().float().numpy() * 255.).astype(np.uint8),
(gt_pt[i].cpu().float().numpy() * 255.).astype(np.uint8),
channel_axis=0,
)
ssim.append(_ssim)
ssim = np.array(ssim)
ssim_mean, ssim_std = ssim.mean(), ssim.std()
return (psnr_mean, psnr_std), (ssim_mean, ssim_std), (lpips_mean, lpips_std)
|