File size: 5,452 Bytes
4bf9661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from typing import List, Optional
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModel
from safetensors.torch import load_file
import os
from typing import Union, List
from .config import MODEL_PATHS

class MLP(torch.nn.Module):
    def __init__(self, input_size: int, xcol: str = "emb", ycol: str = "avg_rating"):
        super().__init__()
        self.input_size = input_size
        self.xcol = xcol
        self.ycol = ycol
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(self.input_size, 1024),
            #torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(1024, 128),
            #torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(128, 64),
            #torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(64, 16),
            #torch.nn.ReLU(),
            torch.nn.Linear(16, 1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

    def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
        x = batch[self.xcol]
        y = batch[self.ycol].reshape(-1, 1)
        x_hat = self.layers(x)
        loss = torch.nn.functional.mse_loss(x_hat, y)
        return loss

    def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
        x = batch[self.xcol]
        y = batch[self.ycol].reshape(-1, 1)
        x_hat = self.layers(x)
        loss = torch.nn.functional.mse_loss(x_hat, y)
        return loss

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(self.parameters(), lr=1e-3)


class AestheticScore(torch.nn.Module):
    def __init__(self, device: torch.device, path: str = MODEL_PATHS):
        super().__init__()
        self.device = device
        self.aes_model_path = path.get("aesthetic_predictor")
        # Load the MLP model
        self.model = MLP(768)
        try:
            if self.aes_model_path.endswith(".safetensors"):
                state_dict = load_file(self.aes_model_path)
            else:
                state_dict = torch.load(self.aes_model_path)
            self.model.load_state_dict(state_dict)
        except Exception as e:
            raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")

        self.model.to(device)
        self.model.eval()

        # Load the CLIP model and processor
        clip_model_name = path.get('clip-large')
        self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
        self.processor = AutoProcessor.from_pretrained(clip_model_name)

    def _calculate_score(self, image: torch.Tensor) -> float:
        """Calculate the aesthetic score for a single image.

        Args:
            image (torch.Tensor): The processed image tensor.

        Returns:
            float: The aesthetic score.
        """
        with torch.no_grad():
            # Get image embeddings
            image_embs = self.model2.get_image_features(image)
            image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)

            # Compute score
            score = self.model(image_embs).cpu().flatten().item()

        return score

    @torch.no_grad()
    def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
        """Score the images based on their aesthetic quality.

        Args:
            images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).

        Returns:
            List[float]: List of scores for the images.
        """
        try:
            if isinstance(images, (str, Image.Image)):
                # Single image
                if isinstance(images, str):
                    pil_image = Image.open(images)
                else:
                    pil_image = images
                
                # Prepare image inputs
                image_inputs = self.processor(
                    images=pil_image,
                    padding=True,
                    truncation=True,
                    max_length=77,
                    return_tensors="pt",
                ).to(self.device)

                return [self._calculate_score(image_inputs["pixel_values"])]
            elif isinstance(images, list):
                # Multiple images
                scores = []
                for one_image in images:
                    if isinstance(one_image, str):
                        pil_image = Image.open(one_image)
                    elif isinstance(one_image, Image.Image):
                        pil_image = one_image
                    else:
                        raise TypeError("The type of parameter images is illegal.")
                    
                    # Prepare image inputs
                    image_inputs = self.processor(
                        images=pil_image,
                        padding=True,
                        truncation=True,
                        max_length=77,
                        return_tensors="pt",
                    ).to(self.device)

                    scores.append(self._calculate_score(image_inputs["pixel_values"]))
                return scores
            else:
                raise TypeError("The type of parameter images is illegal.")
        except Exception as e:
            raise RuntimeError(f"Error in scoring images: {e}")