File size: 3,050 Bytes
2c50826
 
 
199a7d9
2c50826
 
199a7d9
 
2c50826
 
 
 
 
199a7d9
 
 
 
 
 
 
2c50826
 
199a7d9
2c50826
 
 
199a7d9
 
 
2c50826
 
 
 
 
 
 
 
 
 
 
 
 
199a7d9
2c50826
199a7d9
 
 
2c50826
 
 
199a7d9
 
 
 
2c50826
199a7d9
 
2c50826
 
199a7d9
2c50826
 
 
199a7d9
2c50826
 
 
 
 
199a7d9
 
 
2c50826
 
199a7d9
 
 
 
 
2c50826
 
 
 
 
199a7d9
 
 
 
2c50826
 
199a7d9
2c50826
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
from typing import Dict

import huggingface_hub
import torch
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
from hpsv2.utils import hps_version_map, root_path
from PIL import Image


class HPSMetric:
    def __init__(self):
        self.hps_version = "v2.1"
        self.device = torch.device(
            "cuda"
            if torch.cuda.is_available()
            else "mps"
            if torch.backends.mps.is_available()
            else "cpu"
        )
        self.model_dict = {}
        self._initialize_model()

    def _initialize_model(self):
        if not self.model_dict:
            model, preprocess_train, preprocess_val = create_model_and_transforms(
                "ViT-H-14",
                "laion2B-s32B-b79K",
                precision="amp",
                device=self.device,
                jit=False,
                force_quick_gelu=False,
                force_custom_text=False,
                force_patch_dropout=False,
                force_image_size=None,
                pretrained_image=False,
                image_mean=None,
                image_std=None,
                light_augmentation=True,
                aug_cfg={},
                output_dict=True,
                with_score_predictor=False,
                with_region_predictor=False,
            )
            self.model_dict["model"] = model
            self.model_dict["preprocess_val"] = preprocess_val

            # Load checkpoint
            if not os.path.exists(root_path):
                os.makedirs(root_path)
            cp = huggingface_hub.hf_hub_download(
                "xswu/HPSv2", hps_version_map[self.hps_version]
            )

            checkpoint = torch.load(cp, map_location=self.device)
            model.load_state_dict(checkpoint["state_dict"])
            self.tokenizer = get_tokenizer("ViT-H-14")
            model = model.to(self.device)
            model.eval()

    @property
    def name(self) -> str:
        return "hps"

    def compute_score(
        self,
        image: Image.Image,
        prompt: str,
    ) -> Dict[str, float]:
        model = self.model_dict["model"]
        preprocess_val = self.model_dict["preprocess_val"]

        with torch.no_grad():
            # Process the image
            image_tensor = (
                preprocess_val(image)
                .unsqueeze(0)
                .to(device=self.device, non_blocking=True)
            )
            # Process the prompt
            text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
            # Calculate the HPS
            with torch.cuda.amp.autocast():
                outputs = model(image_tensor, text)
                image_features, text_features = (
                    outputs["image_features"],
                    outputs["text_features"],
                )
                logits_per_image = image_features @ text_features.T
                hps_score = torch.diagonal(logits_per_image).cpu().numpy()

        return {"hps": float(hps_score[0])}