nifleisch commited on
Commit
2c50826
·
1 Parent(s): 61029d0

feat: add core logic for project

Browse files
api/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Type
2
+
3
+ from api.baseline import BaselineAPI
4
+ from api.fireworks import FireworksAPI
5
+ from api.flux import FluxAPI
6
+ from api.pruna import PrunaAPI
7
+ from api.replicate import ReplicateAPI
8
+ from api.together import TogetherAPI
9
+ from api.fal import FalAPI
10
+
11
+ def create_api(api_type: str) -> FluxAPI:
12
+ """
13
+ Factory function to create API instances.
14
+
15
+ Args:
16
+ api_type (str): The type of API to create. Must be one of:
17
+ - "baseline"
18
+ - "fireworks"
19
+ - "pruna_speed_mode" (where speed_mode is the desired speed mode)
20
+ - "replicate"
21
+ - "together"
22
+ - "fal"
23
+
24
+ Returns:
25
+ FluxAPI: An instance of the requested API implementation
26
+
27
+ Raises:
28
+ ValueError: If an invalid API type is provided
29
+ """
30
+ if api_type.startswith("pruna_"):
31
+ speed_mode = api_type[6:] # Remove "pruna_" prefix
32
+ return PrunaAPI(speed_mode)
33
+
34
+ api_map: dict[str, Type[FluxAPI]] = {
35
+ "baseline": BaselineAPI,
36
+ "fireworks": FireworksAPI,
37
+ "replicate": ReplicateAPI,
38
+ "together": TogetherAPI,
39
+ "fal": FalAPI,
40
+ }
41
+
42
+ if api_type not in api_map:
43
+ raise ValueError(f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'")
44
+
45
+ return api_map[api_type]()
api/baseline.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ from dotenv import load_dotenv
7
+
8
+ from api.flux import FluxAPI
9
+
10
+
11
+ class BaselineAPI(FluxAPI):
12
+ """
13
+ As our baseline, we use the Replicate API with go_fast=False.
14
+ """
15
+ def __init__(self):
16
+ load_dotenv()
17
+ self._api_key = os.getenv("REPLICATE_API_TOKEN")
18
+ if not self._api_key:
19
+ raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
20
+
21
+ @property
22
+ def name(self) -> str:
23
+ return "baseline"
24
+
25
+ def generate_image(self, prompt: str, save_path: Path) -> float:
26
+ import replicate
27
+ start_time = time.time()
28
+ result = replicate.run(
29
+ "black-forest-labs/flux-dev",
30
+ input={
31
+ "prompt": prompt,
32
+ "go_fast": False,
33
+ "guidance": 3.5,
34
+ "num_outputs": 1,
35
+ "aspect_ratio": "1:1",
36
+ "output_format": "png",
37
+ "num_inference_steps": 28,
38
+ "seed": 0,
39
+ },
40
+ )
41
+ end_time = time.time()
42
+
43
+ if result and len(result) > 0:
44
+ self._save_image_from_result(result[0], save_path)
45
+ else:
46
+ raise Exception("No result returned from Replicate API")
47
+
48
+ return end_time - start_time
49
+
50
+ def _save_image_from_result(self, result: Any, save_path: Path):
51
+ save_path.parent.mkdir(parents=True, exist_ok=True)
52
+ with open(save_path, "wb") as f:
53
+ f.write(result.read())
api/fal.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from io import BytesIO
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import fal_client
7
+ import requests
8
+ from PIL import Image
9
+
10
+ from flux import FluxAPI
11
+
12
+
13
+ class FalAPI(FluxAPI):
14
+ @property
15
+ def name(self) -> str:
16
+ return "fal"
17
+
18
+ def generate_image(self, prompt: str, save_path: Path) -> float:
19
+ start_time = time.time()
20
+ result = fal_client.subscribe(
21
+ "fal-ai/flux/dev",
22
+ arguments={
23
+ "seed": 0,
24
+ "prompt": prompt,
25
+ "image_size": "square_hd", # 1024x1024 image
26
+ "num_images": 1,
27
+ "guidance_scale": 3.5,
28
+ "num_inference_steps": 28,
29
+ "enable_safety_checker": True,
30
+ },
31
+ )
32
+ end_time = time.time()
33
+
34
+ url = result["images"][0]["url"]
35
+ self._save_image_from_url(url, save_path)
36
+
37
+ return end_time - start_time
38
+
39
+ def _save_image_from_url(self, url: str, save_path: Path):
40
+ response = requests.get(url)
41
+ image = Image.open(BytesIO(response.content))
42
+ save_path.parent.mkdir(parents=True, exist_ok=True)
43
+ image.save(save_path)
44
+
45
+ def _save_image_from_result(self, result: Any, save_path: Path):
46
+ save_path.parent.mkdir(parents=True, exist_ok=True)
47
+ with open(save_path, "wb") as f:
48
+ f.write(result.content)
api/fireworks.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import requests
7
+ from dotenv import load_dotenv
8
+
9
+ from api.flux import FluxAPI
10
+
11
+
12
+ class FireworksAPI(FluxAPI):
13
+ def __init__(self):
14
+ load_dotenv()
15
+ self._api_key = os.getenv("FIREWORKS_API_TOKEN")
16
+ if not self._api_key:
17
+ raise ValueError("FIREWORKS_API_TOKEN not found in environment variables")
18
+ self._url = "https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image"
19
+
20
+ @property
21
+ def name(self) -> str:
22
+ return "fireworks_fp8"
23
+
24
+ def generate_image(self, prompt: str, save_path: Path) -> float:
25
+ start_time = time.time()
26
+
27
+ headers = {
28
+ "Content-Type": "application/json",
29
+ "Accept": "image/jpeg",
30
+ "Authorization": f"Bearer {self._api_key}",
31
+ }
32
+ data = {
33
+ "prompt": prompt,
34
+ "aspect_ratio": "1:1",
35
+ "guidance_scale": 3.5,
36
+ "num_inference_steps": 28,
37
+ "seed": 0,
38
+ }
39
+ result = requests.post(self._url, headers=headers, json=data)
40
+
41
+ end_time = time.time()
42
+
43
+ if result.status_code == 200:
44
+ self._save_image_from_result(result, save_path)
45
+ else:
46
+ raise Exception(f"Error: {result.status_code} {result.text}")
47
+
48
+ return end_time - start_time
49
+
50
+ def _save_image_from_result(self, result: Any, save_path: Path):
51
+ save_path.parent.mkdir(parents=True, exist_ok=True)
52
+ with open(save_path, "wb") as f:
53
+ f.write(result.content)
api/flux.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from pathlib import Path
3
+
4
+
5
+ class FluxAPI(ABC):
6
+ """
7
+ Abstract base class for Flux API implementations.
8
+
9
+ This class defines the common interface for all Flux API implementations.
10
+ """
11
+
12
+ @property
13
+ @abstractmethod
14
+ def name(self) -> str:
15
+ """
16
+ The name of the API implementation.
17
+
18
+ Returns:
19
+ str: The name of the specific API implementation
20
+ """
21
+ pass
22
+
23
+ @abstractmethod
24
+ def generate_image(self, prompt: str, save_path: Path) -> float:
25
+ """
26
+ Generate an image based on the prompt and save it to the specified path.
27
+
28
+ Args:
29
+ prompt (str): The text prompt to generate the image from
30
+ save_path (Path): The path where the generated image should be saved
31
+
32
+ Returns:
33
+ float: The time taken for the API call in seconds
34
+ """
35
+ pass
api/pruna.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ from dotenv import load_dotenv
7
+ import replicate
8
+
9
+ from api.flux import FluxAPI
10
+
11
+
12
+ class PrunaAPI(FluxAPI):
13
+ def __init__(self, speed_mode: str):
14
+ self._speed_mode = speed_mode
15
+ self._speed_mode_name = speed_mode.split(" ")[0].strip().lower().replace(" ", "_")
16
+ load_dotenv()
17
+ self._api_key = os.getenv("REPLICATE_API_TOKEN")
18
+ if not self._api_key:
19
+ raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
20
+
21
+ @property
22
+ def name(self) -> str:
23
+ return f"pruna_{self._speed_mode_name}"
24
+
25
+ def generate_image(self, prompt: str, save_path: Path) -> float:
26
+ start_time = time.time()
27
+ result = replicate.run(
28
+ "prunaai/flux.1-juiced:58977759ff2870cc010597ae75f4d87866d169b248e02b6e86c4e1bf8afe2410",
29
+ input={
30
+ "seed": 0,
31
+ "prompt": prompt,
32
+ "guidance": 3.5,
33
+ "num_outputs": 1,
34
+ "aspect_ratio": "1:1",
35
+ "output_format": "png",
36
+ "speed_mode": self._speed_mode,
37
+ "num_inference_steps": 28,
38
+ },
39
+ )
40
+ end_time = time.time()
41
+
42
+ if result:
43
+ self._save_image_from_result(result, save_path)
44
+ else:
45
+ raise Exception("No result returned from Replicate API")
46
+ return end_time - start_time
47
+
48
+ def _save_image_from_result(self, result: Any, save_path: Path):
49
+ save_path.parent.mkdir(parents=True, exist_ok=True)
50
+ with open(save_path, "wb") as f:
51
+ f.write(result.read())
api/replicate.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ from dotenv import load_dotenv
7
+ import replicate
8
+
9
+ from api.flux import FluxAPI
10
+
11
+
12
+ class ReplicateAPI(FluxAPI):
13
+ def __init__(self):
14
+ load_dotenv()
15
+ self._api_key = os.getenv("REPLICATE_API_TOKEN")
16
+ if not self._api_key:
17
+ raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
18
+
19
+ @property
20
+ def name(self) -> str:
21
+ return "replicate_go_fast"
22
+
23
+ def generate_image(self, prompt: str, save_path: Path) -> float:
24
+ start_time = time.time()
25
+ result = replicate.run(
26
+ "black-forest-labs/flux-dev",
27
+ input={
28
+ "seed": 0,
29
+ "prompt": prompt,
30
+ "go_fast": True,
31
+ "guidance": 3.5,
32
+ "num_outputs": 1,
33
+ "aspect_ratio": "1:1",
34
+ "output_format": "png",
35
+ "num_inference_steps": 28,
36
+ },
37
+ )
38
+ end_time = time.time()
39
+ if result and len(result) > 0:
40
+ self._save_image_from_result(result[0], save_path)
41
+ else:
42
+ raise Exception("No result returned from Replicate API")
43
+ return end_time - start_time
44
+
45
+ def _save_image_from_result(self, result: Any, save_path: Path):
46
+ save_path.parent.mkdir(parents=True, exist_ok=True)
47
+ with open(save_path, "wb") as f:
48
+ f.write(result.read())
api/together.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import time
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ from dotenv import load_dotenv
8
+ from PIL import Image
9
+ from together import Together
10
+
11
+ from api.flux import FluxAPI
12
+
13
+
14
+ class TogetherAPI(FluxAPI):
15
+ def __init__(self):
16
+ load_dotenv()
17
+ self._client = Together()
18
+
19
+ @property
20
+ def name(self) -> str:
21
+ return "together"
22
+
23
+ def generate_image(self, prompt: str, save_path: Path) -> float:
24
+ start_time = time.time()
25
+ result = self._client.images.generate(
26
+ prompt=prompt,
27
+ model="black-forest-labs/FLUX.1-dev",
28
+ width=1024,
29
+ height=1024,
30
+ steps=28,
31
+ n=1,
32
+ seed=0,
33
+ response_format="b64_json",
34
+ )
35
+ end_time = time.time()
36
+ if result and hasattr(result, 'data') and len(result.data) > 0:
37
+ self._save_image_from_result(result, save_path)
38
+ else:
39
+ raise Exception("No result returned from Together API")
40
+ return end_time - start_time
41
+
42
+ def _save_image_from_result(self, result: Any, save_path: Path):
43
+ save_path.parent.mkdir(parents=True, exist_ok=True)
44
+ b64_str = result.data[0].b64_json
45
+ image_data = base64.b64decode(b64_str)
46
+ image = Image.open(io.BytesIO(image_data))
47
+ image.save(save_path)
benchmark/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type
2
+
3
+ from benchmark.draw_bench import DrawBenchPrompts
4
+ from benchmark.genai_bench import GenAIBenchPrompts
5
+ from benchmark.geneval import GenEvalPrompts
6
+ from benchmark.hps import HPSPrompts
7
+ from benchmark.parti import PartiPrompts
8
+
9
+
10
+ def create_benchmark(benchmark_type: str) -> Type[DrawBenchPrompts | GenAIBenchPrompts | GenEvalPrompts | HPSPrompts | PartiPrompts]:
11
+ """
12
+ Factory function to create benchmark instances.
13
+
14
+ Args:
15
+ benchmark_type (str): The type of benchmark to create. Must be one of:
16
+ - "draw_bench"
17
+ - "genai_bench"
18
+ - "geneval"
19
+ - "hps"
20
+ - "parti"
21
+
22
+ Returns:
23
+ An instance of the requested benchmark implementation
24
+
25
+ Raises:
26
+ ValueError: If an invalid benchmark type is provided
27
+ """
28
+ benchmark_map = {
29
+ "draw_bench": DrawBenchPrompts,
30
+ "genai_bench": GenAIBenchPrompts,
31
+ "geneval": GenEvalPrompts,
32
+ "hps": HPSPrompts,
33
+ "parti": PartiPrompts,
34
+ }
35
+
36
+ if benchmark_type not in benchmark_map:
37
+ raise ValueError(f"Invalid benchmark type: {benchmark_type}. Must be one of {list(benchmark_map.keys())}")
38
+
39
+ return benchmark_map[benchmark_type]()
benchmark/draw_bench.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Iterator, List, Tuple
3
+
4
+ from datasets import load_dataset
5
+
6
+
7
+ class DrawBenchPrompts:
8
+ def __init__(self):
9
+ self.dataset = load_dataset("shunk031/DrawBench")["test"]
10
+
11
+ def __iter__(self) -> Iterator[Tuple[str, Path]]:
12
+ for i, row in enumerate(self.dataset):
13
+ yield row["prompts"], Path(f"{i}.png")
14
+
15
+ @property
16
+ def name(self) -> str:
17
+ return "draw_bench"
18
+
19
+ @property
20
+ def size(self) -> int:
21
+ return len(self.dataset)
22
+
23
+ @property
24
+ def metrics(self) -> List[str]:
25
+ return ["image_reward"]
benchmark/genai_bench.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Iterator, List, Tuple
3
+
4
+ import requests
5
+
6
+
7
+ class GenAIBenchPrompts:
8
+ def __init__(self):
9
+ super().__init__()
10
+ self._download_genai_bench_files()
11
+ prompts_path = Path('downloads/genai_bench/prompts.txt')
12
+ with open(prompts_path, 'r') as f:
13
+ self.prompts = [line.strip() for line in f if line.strip()]
14
+
15
+ def __iter__(self) -> Iterator[Tuple[str, Path]]:
16
+ for i, prompt in enumerate(self.prompts):
17
+ yield prompt, Path(f"{i}.png")
18
+
19
+ def _download_genai_bench_files(self) -> None:
20
+ folder_name = Path('downloads/genai_bench')
21
+ folder_name.mkdir(parents=True, exist_ok=True)
22
+ prompts_url = "https://huggingface.co/datasets/zhiqiulin/GenAI-Bench-527/raw/main/prompts.txt"
23
+ prompts_path = folder_name / "prompts.txt"
24
+ if not prompts_path.exists():
25
+ response = requests.get(prompts_url)
26
+ with open(prompts_path, 'w') as f:
27
+ f.write(response.text)
28
+
29
+ @property
30
+ def name(self) -> str:
31
+ return "genai_bench"
32
+
33
+ @property
34
+ def size(self) -> int:
35
+ return len(self.prompts)
36
+
37
+ @property
38
+ def metrics(self) -> List[str]:
39
+ return ["vqa"]
benchmark/geneval.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any, Dict, Iterator, List, Tuple
4
+
5
+ import requests
6
+
7
+
8
+ class GenEvalPrompts:
9
+ def __init__(self):
10
+ super().__init__()
11
+ self._download_geneval_file()
12
+ metadata_path = Path('downloads/geneval/evaluation_metadata.jsonl')
13
+ self.entries: List[Dict[str, Any]] = []
14
+ with open(metadata_path, 'r') as f:
15
+ for line in f:
16
+ if line.strip():
17
+ self.entries.append(json.loads(line))
18
+
19
+ def __iter__(self) -> Iterator[Tuple[Dict[str, Any], Path]]:
20
+ for i, entry in enumerate(self.entries):
21
+ folder_name = f"{i:05d}"
22
+ yield entry, folder_name
23
+
24
+ def _download_geneval_file(self) -> None:
25
+ folder_name = Path('downloads/geneval')
26
+ folder_name.mkdir(parents=True, exist_ok=True)
27
+ metadata_url = "https://raw.githubusercontent.com/djghosh13/geneval/main/prompts/evaluation_metadata.jsonl"
28
+ metadata_path = folder_name / "evaluation_metadata.jsonl"
29
+ if not metadata_path.exists():
30
+ response = requests.get(metadata_url)
31
+ with open(metadata_path, 'w') as f:
32
+ f.write(response.text)
33
+
34
+ @property
35
+ def name(self) -> str:
36
+ return "geneval"
37
+
38
+ @property
39
+ def size(self) -> int:
40
+ return len(self.entries)
41
+
42
+ @property
43
+ def metrics(self) -> List[str]:
44
+ raise NotImplementedError("GenEval requires custom evaluation, see README.md")
benchmark/hps.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Dict, Iterator, List, Tuple
5
+
6
+ import huggingface_hub
7
+
8
+
9
+ class HPSPrompts:
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.hps_prompt_files = ['anime.json', 'concept-art.json', 'paintings.json', 'photo.json']
13
+ self._download_benchmark_prompts()
14
+ self.prompts: Dict[str, str] = {}
15
+ self._size = 0
16
+ for file in self.hps_prompt_files:
17
+ category = file.replace('.json', '')
18
+ with open(os.path.join('datasets/hps', file), 'r') as f:
19
+ prompts = json.load(f)
20
+ for i, prompt in enumerate(prompts):
21
+ if i == 100:
22
+ break
23
+ filename = f"{category}_{i:03d}.png"
24
+ self.prompts[filename] = prompt
25
+ self._size += 1
26
+
27
+ def __iter__(self) -> Iterator[Tuple[str, Path]]:
28
+ for filename, prompt in self.prompts.items():
29
+ yield prompt, filename
30
+
31
+ @property
32
+ def name(self) -> str:
33
+ return "hps"
34
+
35
+ @property
36
+ def size(self) -> int:
37
+ return self._size
38
+
39
+ def _download_benchmark_prompts(self) -> None:
40
+ folder_name = Path('downloads/hps')
41
+ folder_name.mkdir(parents=True, exist_ok=True)
42
+ for file in self.hps_prompt_files:
43
+ file_name = huggingface_hub.hf_hub_download("zhwang/HPDv2", file, subfolder="benchmark", repo_type="dataset")
44
+ if not os.path.exists(os.path.join(folder_name, file)):
45
+ os.symlink(file_name, os.path.join(folder_name, file))
46
+
47
+ @property
48
+ def metrics(self) -> List[str]:
49
+ return ["hps"]
benchmark/metrics/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type
2
+
3
+ from benchmark.metrics.arniqa import ARNIQAMetric
4
+ from benchmark.metrics.clip import CLIPMetric
5
+ from benchmark.metrics.clip_iqa import CLIPIQAMetric
6
+ from benchmark.metrics.image_reward import ImageRewardMetric
7
+ from benchmark.metrics.sharpness import SharpnessMetric
8
+ from benchmark.metrics.vqa import VQAMetric
9
+
10
+
11
+ def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
12
+ """
13
+ Factory function to create metric instances.
14
+
15
+ Args:
16
+ metric_type (str): The type of metric to create. Must be one of:
17
+ - "arniqa"
18
+ - "clip"
19
+ - "clip_iqa"
20
+ - "image_reward"
21
+ - "sharpness"
22
+ - "vqa"
23
+
24
+ Returns:
25
+ An instance of the requested metric implementation
26
+
27
+ Raises:
28
+ ValueError: If an invalid metric type is provided
29
+ """
30
+ metric_map = {
31
+ "arniqa": ARNIQAMetric,
32
+ "clip": CLIPMetric,
33
+ "clip_iqa": CLIPIQAMetric,
34
+ "image_reward": ImageRewardMetric,
35
+ "sharpness": SharpnessMetric,
36
+ "vqa": VQAMetric,
37
+ }
38
+
39
+ if metric_type not in metric_map:
40
+ raise ValueError(f"Invalid metric type: {metric_type}. Must be one of {list(metric_map.keys())}")
41
+
42
+ return metric_map[metric_type]()
benchmark/metrics/arniqa.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from torchmetrics.image.arniqa import ARNIQA
7
+
8
+
9
+ class ARNIQAMetric:
10
+ def __init__(self):
11
+ self.metric = ARNIQA(
12
+ regressor_dataset="koniq10k",
13
+ reduction="mean",
14
+ normalize=True,
15
+ autocast=False
16
+ )
17
+
18
+ @property
19
+ def name(self) -> str:
20
+ return "arniqa"
21
+
22
+ def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
23
+ image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
24
+ image_tensor = image_tensor.unsqueeze(0)
25
+ score = self.metric(image_tensor)
26
+ return {"arniqa": score.item()}
benchmark/metrics/clip.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from torchmetrics.multimodal.clip_score import CLIPScore
7
+
8
+
9
+ class CLIPMetric:
10
+ def __init__(self, model_name_or_path: str = "openai/clip-vit-large-patch14"):
11
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ self.metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14")
13
+ self.metric.to(self.device)
14
+
15
+ @property
16
+ def name(self) -> str:
17
+ return "clip"
18
+
19
+ def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
20
+ image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
21
+ image_tensor = image_tensor.to(self.device)
22
+ score = self.metric(image_tensor, prompt)
23
+ return {"clip": score.item()}
benchmark/metrics/clip_iqa.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from torchmetrics.multimodal import CLIPImageQualityAssessment
7
+
8
+
9
+ class CLIPIQAMetric:
10
+ def __init__(self):
11
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ self.metric = CLIPImageQualityAssessment(
13
+ model_name_or_path="clip_iqa",
14
+ data_range=255.0,
15
+ prompts=["quality"]
16
+ )
17
+ self.metric.to(self.device)
18
+
19
+ @property
20
+ def name(self) -> str:
21
+ return "clip_iqa"
22
+
23
+ def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
24
+ image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
25
+ image_tensor = image_tensor.unsqueeze(0)
26
+ image_tensor = image_tensor.to(self.device)
27
+ scores = self.metric(image_tensor)
28
+ return {"clip_iqa": scores.item()}
benchmark/metrics/hps.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
7
+ import huggingface_hub
8
+ from hpsv2.utils import root_path, hps_version_map
9
+
10
+
11
+ class HPSMetric:
12
+ def __init__(self):
13
+ self.hps_version = "v2.1"
14
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ self.model_dict = {}
16
+ self._initialize_model()
17
+
18
+ def _initialize_model(self):
19
+ if not self.model_dict:
20
+ model, preprocess_train, preprocess_val = create_model_and_transforms(
21
+ 'ViT-H-14',
22
+ 'laion2B-s32B-b79K',
23
+ precision='amp',
24
+ device=self.device,
25
+ jit=False,
26
+ force_quick_gelu=False,
27
+ force_custom_text=False,
28
+ force_patch_dropout=False,
29
+ force_image_size=None,
30
+ pretrained_image=False,
31
+ image_mean=None,
32
+ image_std=None,
33
+ light_augmentation=True,
34
+ aug_cfg={},
35
+ output_dict=True,
36
+ with_score_predictor=False,
37
+ with_region_predictor=False
38
+ )
39
+ self.model_dict['model'] = model
40
+ self.model_dict['preprocess_val'] = preprocess_val
41
+
42
+ # Load checkpoint
43
+ if not os.path.exists(root_path):
44
+ os.makedirs(root_path)
45
+ cp = huggingface_hub.hf_hub_download("xswu/HPSv2", hps_version_map[self.hps_version])
46
+
47
+ checkpoint = torch.load(cp, map_location=self.device)
48
+ model.load_state_dict(checkpoint['state_dict'])
49
+ self.tokenizer = get_tokenizer('ViT-H-14')
50
+ model = model.to(self.device)
51
+ model.eval()
52
+
53
+ @property
54
+ def name(self) -> str:
55
+ return "hps"
56
+
57
+ def compute_score(
58
+ self,
59
+ image: Image.Image,
60
+ prompt: str,
61
+ ) -> Dict[str, float]:
62
+ model = self.model_dict['model']
63
+ preprocess_val = self.model_dict['preprocess_val']
64
+
65
+ with torch.no_grad():
66
+ # Process the image
67
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device=self.device, non_blocking=True)
68
+ # Process the prompt
69
+ text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
70
+ # Calculate the HPS
71
+ with torch.cuda.amp.autocast():
72
+ outputs = model(image_tensor, text)
73
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
74
+ logits_per_image = image_features @ text_features.T
75
+ hps_score = torch.diagonal(logits_per_image).cpu().numpy()
76
+
77
+ return {"hps": float(hps_score[0])}
benchmark/metrics/image_reward.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from typing import Dict
4
+
5
+ import ImageReward as RM
6
+ from PIL import Image
7
+
8
+
9
+ class ImageRewardMetric:
10
+ def __init__(self):
11
+ self.model = RM.load("ImageReward-v1.0")
12
+
13
+ @property
14
+ def name(self) -> str:
15
+ return "image_reward"
16
+
17
+ def compute_score(
18
+ self,
19
+ image: Image.Image,
20
+ prompt: str,
21
+ ) -> Dict[str, float]:
22
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
23
+ image.save(tmp.name)
24
+ score = self.model.score(prompt, [tmp.name])
25
+ os.unlink(tmp.name)
26
+ return {"image_reward": score}
benchmark/metrics/sharpness.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+
8
+ class SharpnessMetric:
9
+ def __init__(self):
10
+ self.kernel_size = 3
11
+
12
+ @property
13
+ def name(self) -> str:
14
+ return "sharpness"
15
+
16
+ def compute_score(
17
+ self,
18
+ image: Image.Image,
19
+ prompt: str,
20
+ ) -> Dict[str, float]:
21
+ img = np.array(image.convert('L'))
22
+ laplacian = cv2.Laplacian(img, cv2.CV_64F, ksize=self.kernel_size)
23
+ sharpness = laplacian.var()
24
+ return {"sharpness": float(sharpness)}
benchmark/metrics/vqa.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from typing import Dict
4
+
5
+ import t2v_metrics
6
+ from PIL import Image
7
+
8
+ class VQAMetric:
9
+ def __init__(self):
10
+ self.metric = t2v_metrics.VQAScore(model="clip-flant5-xxl")
11
+
12
+ @property
13
+ def name(self) -> str:
14
+ return "vqa_score"
15
+
16
+ def compute_score(
17
+ self,
18
+ image: Image.Image,
19
+ prompt: str,
20
+ ) -> Dict[str, float]:
21
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
22
+ image.save(tmp.name)
23
+ score = self.metric(images=[tmp.name], texts=[prompt])
24
+ os.unlink(tmp.name)
25
+ return {"vqa_score": score[0][0].item()}
benchmark/parti.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Iterator, List, Tuple
3
+
4
+ from datasets import load_dataset
5
+
6
+
7
+ class PartiPrompts:
8
+ def __init__(self):
9
+ dataset = load_dataset("nateraw/parti-prompts")["train"]
10
+ shuffled_dataset = dataset.shuffle(seed=42)
11
+ selected_dataset = shuffled_dataset.select(range(800))
12
+ self.prompts = [row["Prompt"] for row in selected_dataset]
13
+
14
+ def __iter__(self) -> Iterator[Tuple[str, Path]]:
15
+ for i, prompt in enumerate(self.prompts):
16
+ yield prompt, Path(f"{i}.png")
17
+
18
+ @property
19
+ def name(self) -> str:
20
+ return "parti"
21
+
22
+ @property
23
+ def size(self) -> int:
24
+ return len(self.prompts)
25
+
26
+ @property
27
+ def metrics(self) -> List[str]:
28
+ return ["arniqa", "clip", "clip_iqa", "sharpness"]
evaluate.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ from typing import Dict
5
+
6
+ from benchmark import create_benchmark
7
+ from benchmark.metrics import create_metric
8
+ from PIL import Image
9
+
10
+
11
+ def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Path("images")) -> Dict:
12
+ """
13
+ Evaluate a benchmark's images using its specific metrics.
14
+
15
+ Args:
16
+ benchmark_type (str): Type of benchmark to evaluate
17
+ api_type (str): Type of API used to generate images
18
+ images_dir (Path): Base directory containing generated images
19
+
20
+ Returns:
21
+ Dict containing evaluation results
22
+ """
23
+ benchmark = create_benchmark(benchmark_type)
24
+
25
+ benchmark_dir = images_dir / api_type / benchmark_type
26
+ metadata_file = benchmark_dir / "metadata.jsonl"
27
+
28
+ if not metadata_file.exists():
29
+ raise FileNotFoundError(f"No metadata file found for {api_type}/{benchmark_type}. Please run sample.py first.")
30
+
31
+ metadata = []
32
+ with open(metadata_file, "r") as f:
33
+ for line in f:
34
+ metadata.append(json.loads(line))
35
+
36
+ metrics = {metric_type: create_metric(metric_type) for metric_type in benchmark.metrics}
37
+
38
+ results = {
39
+ "api": api_type,
40
+ "benchmark": benchmark_type,
41
+ "metrics": {metric: 0.0 for metric in benchmark.metrics},
42
+ "avg_inference_time": 0.0,
43
+ "total_images": len(metadata)
44
+ }
45
+
46
+ for entry in metadata:
47
+ image_path = benchmark_dir / entry["filepath"]
48
+ if not image_path.exists():
49
+ continue
50
+
51
+ image = Image.open(image_path)
52
+
53
+ for metric_type, metric in metrics.items():
54
+ try:
55
+ score = metric.compute_score(image, entry["prompt"])
56
+ results["metrics"][metric_type] += score[metric_type]
57
+ except Exception as e:
58
+ print(f"Error computing {metric_type} for {image_path}: {str(e)}")
59
+
60
+ results["avg_inference_time"] += entry["inference_time"]
61
+
62
+ for metric in results["metrics"]:
63
+ results["metrics"][metric] /= len(metadata)
64
+ results["avg_inference_time"] /= len(metadata)
65
+
66
+ return results
67
+
68
+
69
+ def main():
70
+ parser = argparse.ArgumentParser(description="Evaluate generated images using benchmark-specific metrics")
71
+ parser.add_argument("api_type", help="Type of API to evaluate")
72
+ parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to evaluate")
73
+
74
+ args = parser.parse_args()
75
+
76
+ results_dir = Path("evaluation_results")
77
+ results_dir.mkdir(exist_ok=True)
78
+
79
+ results_file = results_dir / f"{args.api_type}.jsonl"
80
+ existing_results = set()
81
+
82
+ if results_file.exists():
83
+ with open(results_file, "r") as f:
84
+ for line in f:
85
+ result = json.loads(line)
86
+ existing_results.add(result["benchmark"])
87
+
88
+ for benchmark_type in args.benchmarks:
89
+ if benchmark_type in existing_results:
90
+ print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
91
+ continue
92
+
93
+ try:
94
+ print(f"Evaluating {args.api_type}/{benchmark_type}")
95
+ results = evaluate_benchmark(benchmark_type, args.api_type)
96
+
97
+ # Append results to file
98
+ with open(results_file, "a") as f:
99
+ f.write(json.dumps(results) + "\n")
100
+
101
+ except Exception as e:
102
+ print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
103
+
104
+
105
+ if __name__ == "__main__":
106
+ main()
evaluation_results/.gitkeep ADDED
File without changes
images/.gitkeep ADDED
File without changes
pyproject.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "inferbench"
3
+ version = "0.1.0"
4
+ requires-python = ">=3.12"
5
+ dependencies = [
6
+ "datasets>=3.5.0",
7
+ "fal-client>=0.5.9",
8
+ "hpsv2>=1.2.0",
9
+ "huggingface-hub>=0.30.2",
10
+ "image-reward>=1.5",
11
+ "numpy>=2.2.5",
12
+ "opencv-python>=4.11.0.86",
13
+ "pillow>=11.2.1",
14
+ "python-dotenv>=1.1.0",
15
+ "replicate>=1.0.4",
16
+ "requests>=2.32.3",
17
+ "t2v-metrics>=1.2",
18
+ "together>=1.5.5",
19
+ "torch>=2.7.0",
20
+ "torchmetrics>=1.7.1",
21
+ "tqdm>=4.67.1",
22
+ ]
sample.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ from typing import List
5
+
6
+ from tqdm import tqdm
7
+
8
+ from api import create_api
9
+ from benchmark import create_benchmark
10
+
11
+
12
+ def generate_images(api_type: str, benchmarks: List[str]):
13
+ images_dir = Path("images")
14
+ api = create_api(api_type)
15
+
16
+ api_dir = images_dir / api_type
17
+ api_dir.mkdir(parents=True, exist_ok=True)
18
+
19
+ for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
20
+ print(f"\nProcessing benchmark: {benchmark_type}")
21
+
22
+ benchmark = create_benchmark(benchmark_type)
23
+
24
+ if benchmark_type == "geneval":
25
+ benchmark_dir = api_dir / benchmark_type
26
+ benchmark_dir.mkdir(parents=True, exist_ok=True)
27
+
28
+ metadata_file = benchmark_dir / "metadata.jsonl"
29
+ existing_metadata = {}
30
+ if metadata_file.exists():
31
+ with open(metadata_file, "r") as f:
32
+ for line in f:
33
+ entry = json.loads(line)
34
+ existing_metadata[entry["filepath"]] = entry
35
+
36
+ for metadata, folder_name in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False):
37
+ sample_path = benchmark_dir / folder_name
38
+ samples_path = sample_path / "samples"
39
+ samples_path.mkdir(parents=True, exist_ok=True)
40
+ image_path = samples_path / "0000.png"
41
+
42
+ if image_path.exists():
43
+ continue
44
+
45
+ try:
46
+ inference_time = api.generate_image(metadata["prompt"], image_path)
47
+
48
+ metadata_entry = {
49
+ "filepath": str(image_path),
50
+ "prompt": metadata["prompt"],
51
+ "inference_time": inference_time
52
+ }
53
+
54
+ existing_metadata[str(image_path)] = metadata_entry
55
+
56
+ except Exception as e:
57
+ print(f"\nError generating image for prompt: {metadata['prompt']}")
58
+ print(f"Error: {str(e)}")
59
+ continue
60
+ else:
61
+ benchmark_dir = api_dir / benchmark_type
62
+ benchmark_dir.mkdir(parents=True, exist_ok=True)
63
+
64
+ metadata_file = benchmark_dir / "metadata.jsonl"
65
+ existing_metadata = {}
66
+ if metadata_file.exists():
67
+ with open(metadata_file, "r") as f:
68
+ for line in f:
69
+ entry = json.loads(line)
70
+ existing_metadata[entry["filepath"]] = entry
71
+
72
+ for prompt, image_path in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False):
73
+ full_image_path = benchmark_dir / image_path
74
+
75
+ if full_image_path.exists():
76
+ continue
77
+
78
+ try:
79
+ inference_time = api.generate_image(prompt, full_image_path)
80
+
81
+ metadata_entry = {
82
+ "filepath": str(image_path),
83
+ "prompt": prompt,
84
+ "inference_time": inference_time
85
+ }
86
+
87
+ existing_metadata[str(image_path)] = metadata_entry
88
+
89
+ except Exception as e:
90
+ print(f"\nError generating image for prompt: {prompt}")
91
+ print(f"Error: {str(e)}")
92
+ continue
93
+
94
+ with open(metadata_file, "w") as f:
95
+ for entry in existing_metadata.values():
96
+ f.write(json.dumps(entry) + "\n")
97
+
98
+
99
+ def main():
100
+ parser = argparse.ArgumentParser(description="Generate images for specified benchmarks using a given API")
101
+ parser.add_argument("api_type", help="Type of API to use for image generation")
102
+ parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")
103
+
104
+ args = parser.parse_args()
105
+
106
+ generate_images(args.api_type, args.benchmarks)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff