File size: 5,503 Bytes
476e0f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from typing import *
from PIL.Image import Image as PILImage
from torch import Tensor

import numpy as np
from skimage.metrics import structural_similarity as calculate_ssim
import torch
import torch.nn.functional as F
from transformers import (
    CLIPImageProcessor, CLIPVisionModelWithProjection,
    CLIPTokenizer, CLIPTextModelWithProjection,
)
import ImageReward as RM
from kiui.lpips import LPIPS


class TextConditionMetrics:
    def __init__(self,
        clip_name: str = "openai/clip-vit-base-patch32",
        rm_name: str = "ImageReward-v1.0",
        device_idx: int = 0,
    ):
        self.image_processor = CLIPImageProcessor.from_pretrained(clip_name)
        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_name).to(f"cuda:{device_idx}").eval()

        self.tokenizer = CLIPTokenizer.from_pretrained(clip_name)
        self.text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_name).to(f"cuda:{device_idx}").eval()

        self.rm_model = RM.load(rm_name)

        self.device = f"cuda:{device_idx}"

    @torch.no_grad()
    def evaluate(self,
        image: Union[PILImage, List[PILImage]],
        text: Union[str, List[str]],
    ) -> Tuple[float, float, float]:
        if isinstance(image, PILImage):
            image = [image]
        if isinstance(text, str):
            text = [text]

        assert len(image) == len(text)

        image_inputs = self.image_processor(image, return_tensors="pt").pixel_values.to(self.device)
        image_embeds = self.image_encoder(image_inputs).image_embeds.float()  # (N, D)
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        text_inputs = self.tokenizer(
            text,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids.to(self.device)
        text_embeds = self.text_encoder(text_input_ids).text_embeds.float()  # (N, D)
        text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

        assert image_embeds.shape == text_embeds.shape

        clip_scores = image_embeds @ text_embeds.T  # (N, N)

        # 1. CLIP similarity
        clip_sim = clip_scores.diag().mean().item()

        # 2. CLIP R-Precision
        clip_rprec = (clip_scores.argmax(dim=1) == torch.arange(len(text)).to(self.device)).float().mean().item()

        # 3. ImageReward
        rm_scores = []
        for img, txt in zip(image, text):
            rm_scores.append(self.rm_model.score(txt, img))
        rm_scores = torch.tensor(rm_scores, device=self.device)
        rm_score = rm_scores.mean().item()

        return clip_sim, clip_rprec, rm_score


class ImageConditionMetrics:
    def __init__(self,
        lpips_net: str = "vgg",
        lpips_res: int = 256,
        device_idx: int = 0,
    ):
        self.lpips_loss = LPIPS(net=lpips_net).to(f"cuda:{device_idx}").eval()

        self.lpips_res = lpips_res
        self.device = f"cuda:{device_idx}"

    @torch.no_grad()
    def evaluate(self,
        image: Union[Tensor, PILImage, List[PILImage]],
        gt: Union[Tensor, PILImage, List[PILImage]],
        chunk_size: Optional[int] = None,
        input_tensor: bool = False,
    ) -> Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]:
        if not input_tensor:
            if isinstance(image, PILImage):
                image = [image]
            if isinstance(gt, PILImage):
                gt = [gt]

            assert len(image) == len(gt)

            if chunk_size is None:
                chunk_size = len(image)

            def image_to_tensor(img: PILImage):
                return torch.tensor(np.array(img).transpose(2, 0, 1) / 255., device=self.device).unsqueeze(0).float()  # (1, 3, H, W)
            image_pt = torch.cat([image_to_tensor(img) for img in image], dim=0)
            gt_pt = torch.cat([image_to_tensor(img) for img in gt], dim=0)
        else:
            image_pt = image.to(device=self.device)
            gt_pt = gt.to(device=self.device)

        # 1. LPIPS
        lpips = []
        for i in range(0, len(image), chunk_size):
            _lpips = self.lpips_loss(
                F.interpolate(
                    image_pt[i:min(len(image), i+chunk_size)] * 2. - 1.,
                    (self.lpips_res, self.lpips_res), mode="bilinear", align_corners=False
                ),
                F.interpolate(
                    gt_pt[i:min(len(image), i+chunk_size)] * 2. - 1.,
                    (self.lpips_res, self.lpips_res), mode="bilinear", align_corners=False
                )
            )
            lpips.append(_lpips)
        lpips = torch.cat(lpips)
        lpips_mean, lpips_std = lpips.mean().item(), lpips.std().item()

        # 2. PSNR
        psnr = -10. * torch.log10((gt_pt - image_pt).pow(2).mean(dim=[1, 2, 3]))
        psnr_mean, psnr_std = psnr.mean().item(), psnr.std().item()

        # 3. SSIM
        ssim = []
        for i in range(len(image)):
            _ssim = calculate_ssim(
                (image_pt[i].cpu().float().numpy() * 255.).astype(np.uint8),
                (gt_pt[i].cpu().float().numpy() * 255.).astype(np.uint8),
                channel_axis=0,
            )
            ssim.append(_ssim)
        ssim = np.array(ssim)
        ssim_mean, ssim_std = ssim.mean(), ssim.std()

        return (psnr_mean, psnr_std), (ssim_mean, ssim_std), (lpips_mean, lpips_std)