InferBench / benchmark /draw_bench.py
davidberenstein1957's picture
fix: correct dataset loading in DrawBenchPrompts to ensure proper retrieval of test data
5cac937
raw
history blame contribute delete
598 Bytes
from pathlib import Path
from typing import Iterator, List, Tuple
from datasets import load_dataset
class DrawBenchPrompts:
def __init__(self):
self.dataset = load_dataset("shunk031/DrawBench")["test"]
def __iter__(self) -> Iterator[Tuple[str, Path]]:
for i, row in enumerate(self.dataset):
yield row["prompts"], Path(f"{i}.png")
@property
def name(self) -> str:
return "draw_bench"
@property
def size(self) -> int:
return len(self.dataset)
@property
def metrics(self) -> List[str]:
return ["image_reward"]