File size: 8,729 Bytes
4bf9661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import os
import torch
from PIL import Image
from typing import List, Union
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from .BLIP.blip_pretrain import BLIP_Pretrain
from torchvision.transforms import InterpolationMode
from safetensors.torch import load_file
from .config import MODEL_PATHS
BICUBIC = InterpolationMode.BICUBIC

def _convert_image_to_rgb(image):
    return image.convert("RGB")

def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

class MLP(torch.nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.input_size = input_size
        
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(self.input_size, 1024),
            #nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(1024, 128),
            #nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(128, 64),
            #nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(64, 16),
            #nn.ReLU(),
            torch.nn.Linear(16, 1)
        )
        
        # initial MLP param
        for name, param in self.layers.named_parameters():
            if 'weight' in name:
                torch.nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
            if 'bias' in name:
                torch.nn.init.constant_(param, val=0)
        
    def forward(self, input):
        return self.layers(input)

class ImageReward(torch.nn.Module):
    def __init__(self, med_config, device='cpu', bert_model_path=""):
        super().__init__()
        self.device = device
        
        self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path)
        self.preprocess = _transform(224)
        self.mlp = MLP(768)
        
        self.mean = 0.16717362830052426
        self.std = 1.0333394966054072

    def score_grad(self, prompt_ids, prompt_attention_mask, image):
        """Calculate the score with gradient for a single image and prompt.

        Args:
            prompt_ids (torch.Tensor): Tokenized prompt IDs.
            prompt_attention_mask (torch.Tensor): Attention mask for the prompt.
            image (torch.Tensor): The processed image tensor.

        Returns:
            torch.Tensor: The reward score.
        """
        image_embeds = self.blip.visual_encoder(image)
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
        text_output = self.blip.text_encoder(
            prompt_ids,
            attention_mask=prompt_attention_mask,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )
        txt_features = text_output.last_hidden_state[:, 0, :]
        rewards = self.mlp(txt_features)
        rewards = (rewards - self.mean) / self.std
        return rewards

    def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
        """Score the images based on the prompt.

        Args:
            prompt (str): The prompt text.
            images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).

        Returns:
            List[float]: List of scores for the images.
        """
        if isinstance(images, (str, Image.Image)):
            # Single image
            if isinstance(images, str):
                pil_image = Image.open(images)
            else:
                pil_image = images
            image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
            return [self._calculate_score(prompt, image).item()]
        elif isinstance(images, list):
            # Multiple images
            scores = []
            for one_image in images:
                if isinstance(one_image, str):
                    pil_image = Image.open(one_image)
                elif isinstance(one_image, Image.Image):
                    pil_image = one_image
                else:
                    raise TypeError("The type of parameter images is illegal.")
                image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
                scores.append(self._calculate_score(prompt, image).item())
            return scores
        else:
            raise TypeError("The type of parameter images is illegal.")

    def _calculate_score(self, prompt: str, image: torch.Tensor) -> torch.Tensor:
        """Calculate the score for a single image and prompt.

        Args:
            prompt (str): The prompt text.
            image (torch.Tensor): The processed image tensor.

        Returns:
            torch.Tensor: The reward score.
        """
        text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
        image_embeds = self.blip.visual_encoder(image)
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
        text_output = self.blip.text_encoder(
            text_input.input_ids,
            attention_mask=text_input.attention_mask,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )
        txt_features = text_output.last_hidden_state[:, 0, :].float()
        rewards = self.mlp(txt_features)
        rewards = (rewards - self.mean) / self.std
        return rewards

    def inference_rank(self, prompt: str, generations_list: List[Union[str, Image.Image]]) -> tuple:
        """Rank the images based on the prompt.

        Args:
            prompt (str): The prompt text.
            generations_list (List[Union[str, Image.Image]]): List of image paths or PIL images.

        Returns:
            tuple: (indices, rewards) where indices are the ranks and rewards are the scores.
        """
        text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
        txt_set = []
        for generation in generations_list:
            if isinstance(generation, str):
                pil_image = Image.open(generation)
            elif isinstance(generation, Image.Image):
                pil_image = generation
            else:
                raise TypeError("The type of parameter generations_list is illegal.")
            image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
            image_embeds = self.blip.visual_encoder(image)
            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
            text_output = self.blip.text_encoder(
                text_input.input_ids,
                attention_mask=text_input.attention_mask,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=True,
            )
            txt_set.append(text_output.last_hidden_state[:, 0, :])
        txt_features = torch.cat(txt_set, 0).float()
        rewards = self.mlp(txt_features)
        rewards = (rewards - self.mean) / self.std
        rewards = torch.squeeze(rewards)
        _, rank = torch.sort(rewards, dim=0, descending=True)
        _, indices = torch.sort(rank, dim=0)
        indices = indices + 1
        return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()


class ImageRewardScore(torch.nn.Module):
    def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
        super().__init__()
        self.device = device if isinstance(device, torch.device) else torch.device(device)
        model_path = path.get("imagereward")
        med_config = path.get("med_config")
        state_dict = load_file(model_path)
        self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device)
        self.model.load_state_dict(state_dict, strict=False)
        self.model.eval()

    @torch.no_grad()
    def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
        """Score the images based on the prompt.

        Args:
            images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
            prompt (str): The prompt text.

        Returns:
            List[float]: List of scores for the images.
        """
        return self.model.score(images, prompt)