from pathlib import Path from typing import Iterator, List, Tuple from datasets import load_dataset class PartiPrompts: def __init__(self): dataset = load_dataset("nateraw/parti-prompts")["train"] shuffled_dataset = dataset.shuffle(seed=42) selected_dataset = shuffled_dataset.select(range(800)) self.prompts = [row["Prompt"] for row in selected_dataset] def __iter__(self) -> Iterator[Tuple[str, Path]]: for i, prompt in enumerate(self.prompts): yield prompt, Path(f"{i}.png") @property def name(self) -> str: return "parti" @property def size(self) -> int: return len(self.prompts) @property def metrics(self) -> List[str]: return ["arniqa", "clip", "clip_iqa", "sharpness"]