import os from typing import Dict import huggingface_hub import torch from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer from hpsv2.utils import hps_version_map, root_path from PIL import Image class HPSMetric: def __init__(self): self.hps_version = "v2.1" self.device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) self.model_dict = {} self._initialize_model() def _initialize_model(self): if not self.model_dict: model, preprocess_train, preprocess_val = create_model_and_transforms( "ViT-H-14", "laion2B-s32B-b79K", precision="amp", device=self.device, jit=False, force_quick_gelu=False, force_custom_text=False, force_patch_dropout=False, force_image_size=None, pretrained_image=False, image_mean=None, image_std=None, light_augmentation=True, aug_cfg={}, output_dict=True, with_score_predictor=False, with_region_predictor=False, ) self.model_dict["model"] = model self.model_dict["preprocess_val"] = preprocess_val # Load checkpoint if not os.path.exists(root_path): os.makedirs(root_path) cp = huggingface_hub.hf_hub_download( "xswu/HPSv2", hps_version_map[self.hps_version] ) checkpoint = torch.load(cp, map_location=self.device) model.load_state_dict(checkpoint["state_dict"]) self.tokenizer = get_tokenizer("ViT-H-14") model = model.to(self.device) model.eval() @property def name(self) -> str: return "hps" def compute_score( self, image: Image.Image, prompt: str, ) -> Dict[str, float]: model = self.model_dict["model"] preprocess_val = self.model_dict["preprocess_val"] with torch.no_grad(): # Process the image image_tensor = ( preprocess_val(image) .unsqueeze(0) .to(device=self.device, non_blocking=True) ) # Process the prompt text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True) # Calculate the HPS with torch.cuda.amp.autocast(): outputs = model(image_tensor, text) image_features, text_features = ( outputs["image_features"], outputs["text_features"], ) logits_per_image = image_features @ text_features.T hps_score = torch.diagonal(logits_per_image).cpu().numpy() return {"hps": float(hps_score[0])}