davidberenstein1957's picture
refactor: improve code formatting and organization across multiple API and benchmark files
34046e2
raw
history blame contribute delete
792 Bytes
from pathlib import Path
from typing import Iterator, List, Tuple
from datasets import load_dataset
class PartiPrompts:
def __init__(self):
dataset = load_dataset("nateraw/parti-prompts")["train"]
shuffled_dataset = dataset.shuffle(seed=42)
selected_dataset = shuffled_dataset.select(range(800))
self.prompts = [row["Prompt"] for row in selected_dataset]
def __iter__(self) -> Iterator[Tuple[str, Path]]:
for i, prompt in enumerate(self.prompts):
yield prompt, Path(f"{i}.png")
@property
def name(self) -> str:
return "parti"
@property
def size(self) -> int:
return len(self.prompts)
@property
def metrics(self) -> List[str]:
return ["arniqa", "clip", "clip_iqa", "sharpness"]