Spaces:
Running
Running
File size: 3,050 Bytes
2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 199a7d9 2c50826 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os
from typing import Dict
import huggingface_hub
import torch
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
from hpsv2.utils import hps_version_map, root_path
from PIL import Image
class HPSMetric:
def __init__(self):
self.hps_version = "v2.1"
self.device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self.model_dict = {}
self._initialize_model()
def _initialize_model(self):
if not self.model_dict:
model, preprocess_train, preprocess_val = create_model_and_transforms(
"ViT-H-14",
"laion2B-s32B-b79K",
precision="amp",
device=self.device,
jit=False,
force_quick_gelu=False,
force_custom_text=False,
force_patch_dropout=False,
force_image_size=None,
pretrained_image=False,
image_mean=None,
image_std=None,
light_augmentation=True,
aug_cfg={},
output_dict=True,
with_score_predictor=False,
with_region_predictor=False,
)
self.model_dict["model"] = model
self.model_dict["preprocess_val"] = preprocess_val
# Load checkpoint
if not os.path.exists(root_path):
os.makedirs(root_path)
cp = huggingface_hub.hf_hub_download(
"xswu/HPSv2", hps_version_map[self.hps_version]
)
checkpoint = torch.load(cp, map_location=self.device)
model.load_state_dict(checkpoint["state_dict"])
self.tokenizer = get_tokenizer("ViT-H-14")
model = model.to(self.device)
model.eval()
@property
def name(self) -> str:
return "hps"
def compute_score(
self,
image: Image.Image,
prompt: str,
) -> Dict[str, float]:
model = self.model_dict["model"]
preprocess_val = self.model_dict["preprocess_val"]
with torch.no_grad():
# Process the image
image_tensor = (
preprocess_val(image)
.unsqueeze(0)
.to(device=self.device, non_blocking=True)
)
# Process the prompt
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
# Calculate the HPS
with torch.cuda.amp.autocast():
outputs = model(image_tensor, text)
image_features, text_features = (
outputs["image_features"],
outputs["text_features"],
)
logits_per_image = image_features @ text_features.T
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
return {"hps": float(hps_score[0])}
|