from pathlib import Path from typing import Iterator, List, Tuple from datasets import load_dataset class DrawBenchPrompts: def __init__(self): self.dataset = load_dataset("shunk031/DrawBench")["test"] def __iter__(self) -> Iterator[Tuple[str, Path]]: for i, row in enumerate(self.dataset): yield row["prompts"], Path(f"{i}.png") @property def name(self) -> str: return "draw_bench" @property def size(self) -> int: return len(self.dataset) @property def metrics(self) -> List[str]: return ["image_reward"]