Spaces:
Running
Running
import os | |
from typing import Dict | |
import huggingface_hub | |
import torch | |
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer | |
from hpsv2.utils import hps_version_map, root_path | |
from PIL import Image | |
class HPSMetric: | |
def __init__(self): | |
self.hps_version = "v2.1" | |
self.device = torch.device( | |
"cuda" | |
if torch.cuda.is_available() | |
else "mps" | |
if torch.backends.mps.is_available() | |
else "cpu" | |
) | |
self.model_dict = {} | |
self._initialize_model() | |
def _initialize_model(self): | |
if not self.model_dict: | |
model, preprocess_train, preprocess_val = create_model_and_transforms( | |
"ViT-H-14", | |
"laion2B-s32B-b79K", | |
precision="amp", | |
device=self.device, | |
jit=False, | |
force_quick_gelu=False, | |
force_custom_text=False, | |
force_patch_dropout=False, | |
force_image_size=None, | |
pretrained_image=False, | |
image_mean=None, | |
image_std=None, | |
light_augmentation=True, | |
aug_cfg={}, | |
output_dict=True, | |
with_score_predictor=False, | |
with_region_predictor=False, | |
) | |
self.model_dict["model"] = model | |
self.model_dict["preprocess_val"] = preprocess_val | |
# Load checkpoint | |
if not os.path.exists(root_path): | |
os.makedirs(root_path) | |
cp = huggingface_hub.hf_hub_download( | |
"xswu/HPSv2", hps_version_map[self.hps_version] | |
) | |
checkpoint = torch.load(cp, map_location=self.device) | |
model.load_state_dict(checkpoint["state_dict"]) | |
self.tokenizer = get_tokenizer("ViT-H-14") | |
model = model.to(self.device) | |
model.eval() | |
def name(self) -> str: | |
return "hps" | |
def compute_score( | |
self, | |
image: Image.Image, | |
prompt: str, | |
) -> Dict[str, float]: | |
model = self.model_dict["model"] | |
preprocess_val = self.model_dict["preprocess_val"] | |
with torch.no_grad(): | |
# Process the image | |
image_tensor = ( | |
preprocess_val(image) | |
.unsqueeze(0) | |
.to(device=self.device, non_blocking=True) | |
) | |
# Process the prompt | |
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True) | |
# Calculate the HPS | |
with torch.cuda.amp.autocast(): | |
outputs = model(image_tensor, text) | |
image_features, text_features = ( | |
outputs["image_features"], | |
outputs["text_features"], | |
) | |
logits_per_image = image_features @ text_features.T | |
hps_score = torch.diagonal(logits_per_image).cpu().numpy() | |
return {"hps": float(hps_score[0])} | |