import json import os from pathlib import Path from typing import Dict, Iterator, List, Tuple import huggingface_hub class HPSPrompts: def __init__(self): super().__init__() self.hps_prompt_files = [ "anime.json", "concept-art.json", "paintings.json", "photo.json", ] self._download_benchmark_prompts() self.prompts: Dict[str, str] = {} self._size = 0 for file in self.hps_prompt_files: category = file.replace(".json", "") with open(os.path.join("downloads/hps", file), "r") as f: prompts = json.load(f) for i, prompt in enumerate(prompts): if i == 100: break filename = f"{category}_{i:03d}.png" self.prompts[filename] = prompt self._size += 1 def __iter__(self) -> Iterator[Tuple[str, Path]]: for filename, prompt in self.prompts.items(): yield prompt, Path(filename) @property def name(self) -> str: return "hps" @property def size(self) -> int: return self._size def _download_benchmark_prompts(self) -> None: folder_name = Path("downloads/hps") folder_name.mkdir(parents=True, exist_ok=True) for file in self.hps_prompt_files: file_name = huggingface_hub.hf_hub_download( "zhwang/HPDv2", file, subfolder="benchmark", repo_type="dataset" ) if not os.path.exists(os.path.join(folder_name, file)): os.symlink(file_name, os.path.join(folder_name, file)) @property def metrics(self) -> List[str]: return ["hps"]