Spaces:
Running
Running
nifleisch
commited on
Commit
·
2c50826
1
Parent(s):
61029d0
feat: add core logic for project
Browse files- api/__init__.py +45 -0
- api/baseline.py +53 -0
- api/fal.py +48 -0
- api/fireworks.py +53 -0
- api/flux.py +35 -0
- api/pruna.py +51 -0
- api/replicate.py +48 -0
- api/together.py +47 -0
- benchmark/__init__.py +39 -0
- benchmark/draw_bench.py +25 -0
- benchmark/genai_bench.py +39 -0
- benchmark/geneval.py +44 -0
- benchmark/hps.py +49 -0
- benchmark/metrics/__init__.py +42 -0
- benchmark/metrics/arniqa.py +26 -0
- benchmark/metrics/clip.py +23 -0
- benchmark/metrics/clip_iqa.py +28 -0
- benchmark/metrics/hps.py +77 -0
- benchmark/metrics/image_reward.py +26 -0
- benchmark/metrics/sharpness.py +24 -0
- benchmark/metrics/vqa.py +25 -0
- benchmark/parti.py +28 -0
- evaluate.py +106 -0
- evaluation_results/.gitkeep +0 -0
- images/.gitkeep +0 -0
- pyproject.toml +22 -0
- sample.py +110 -0
- uv.lock +0 -0
api/__init__.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Type
|
2 |
+
|
3 |
+
from api.baseline import BaselineAPI
|
4 |
+
from api.fireworks import FireworksAPI
|
5 |
+
from api.flux import FluxAPI
|
6 |
+
from api.pruna import PrunaAPI
|
7 |
+
from api.replicate import ReplicateAPI
|
8 |
+
from api.together import TogetherAPI
|
9 |
+
from api.fal import FalAPI
|
10 |
+
|
11 |
+
def create_api(api_type: str) -> FluxAPI:
|
12 |
+
"""
|
13 |
+
Factory function to create API instances.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
api_type (str): The type of API to create. Must be one of:
|
17 |
+
- "baseline"
|
18 |
+
- "fireworks"
|
19 |
+
- "pruna_speed_mode" (where speed_mode is the desired speed mode)
|
20 |
+
- "replicate"
|
21 |
+
- "together"
|
22 |
+
- "fal"
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
FluxAPI: An instance of the requested API implementation
|
26 |
+
|
27 |
+
Raises:
|
28 |
+
ValueError: If an invalid API type is provided
|
29 |
+
"""
|
30 |
+
if api_type.startswith("pruna_"):
|
31 |
+
speed_mode = api_type[6:] # Remove "pruna_" prefix
|
32 |
+
return PrunaAPI(speed_mode)
|
33 |
+
|
34 |
+
api_map: dict[str, Type[FluxAPI]] = {
|
35 |
+
"baseline": BaselineAPI,
|
36 |
+
"fireworks": FireworksAPI,
|
37 |
+
"replicate": ReplicateAPI,
|
38 |
+
"together": TogetherAPI,
|
39 |
+
"fal": FalAPI,
|
40 |
+
}
|
41 |
+
|
42 |
+
if api_type not in api_map:
|
43 |
+
raise ValueError(f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'")
|
44 |
+
|
45 |
+
return api_map[api_type]()
|
api/baseline.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
from api.flux import FluxAPI
|
9 |
+
|
10 |
+
|
11 |
+
class BaselineAPI(FluxAPI):
|
12 |
+
"""
|
13 |
+
As our baseline, we use the Replicate API with go_fast=False.
|
14 |
+
"""
|
15 |
+
def __init__(self):
|
16 |
+
load_dotenv()
|
17 |
+
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
18 |
+
if not self._api_key:
|
19 |
+
raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
|
20 |
+
|
21 |
+
@property
|
22 |
+
def name(self) -> str:
|
23 |
+
return "baseline"
|
24 |
+
|
25 |
+
def generate_image(self, prompt: str, save_path: Path) -> float:
|
26 |
+
import replicate
|
27 |
+
start_time = time.time()
|
28 |
+
result = replicate.run(
|
29 |
+
"black-forest-labs/flux-dev",
|
30 |
+
input={
|
31 |
+
"prompt": prompt,
|
32 |
+
"go_fast": False,
|
33 |
+
"guidance": 3.5,
|
34 |
+
"num_outputs": 1,
|
35 |
+
"aspect_ratio": "1:1",
|
36 |
+
"output_format": "png",
|
37 |
+
"num_inference_steps": 28,
|
38 |
+
"seed": 0,
|
39 |
+
},
|
40 |
+
)
|
41 |
+
end_time = time.time()
|
42 |
+
|
43 |
+
if result and len(result) > 0:
|
44 |
+
self._save_image_from_result(result[0], save_path)
|
45 |
+
else:
|
46 |
+
raise Exception("No result returned from Replicate API")
|
47 |
+
|
48 |
+
return end_time - start_time
|
49 |
+
|
50 |
+
def _save_image_from_result(self, result: Any, save_path: Path):
|
51 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
52 |
+
with open(save_path, "wb") as f:
|
53 |
+
f.write(result.read())
|
api/fal.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from io import BytesIO
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
import fal_client
|
7 |
+
import requests
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from flux import FluxAPI
|
11 |
+
|
12 |
+
|
13 |
+
class FalAPI(FluxAPI):
|
14 |
+
@property
|
15 |
+
def name(self) -> str:
|
16 |
+
return "fal"
|
17 |
+
|
18 |
+
def generate_image(self, prompt: str, save_path: Path) -> float:
|
19 |
+
start_time = time.time()
|
20 |
+
result = fal_client.subscribe(
|
21 |
+
"fal-ai/flux/dev",
|
22 |
+
arguments={
|
23 |
+
"seed": 0,
|
24 |
+
"prompt": prompt,
|
25 |
+
"image_size": "square_hd", # 1024x1024 image
|
26 |
+
"num_images": 1,
|
27 |
+
"guidance_scale": 3.5,
|
28 |
+
"num_inference_steps": 28,
|
29 |
+
"enable_safety_checker": True,
|
30 |
+
},
|
31 |
+
)
|
32 |
+
end_time = time.time()
|
33 |
+
|
34 |
+
url = result["images"][0]["url"]
|
35 |
+
self._save_image_from_url(url, save_path)
|
36 |
+
|
37 |
+
return end_time - start_time
|
38 |
+
|
39 |
+
def _save_image_from_url(self, url: str, save_path: Path):
|
40 |
+
response = requests.get(url)
|
41 |
+
image = Image.open(BytesIO(response.content))
|
42 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
43 |
+
image.save(save_path)
|
44 |
+
|
45 |
+
def _save_image_from_result(self, result: Any, save_path: Path):
|
46 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
47 |
+
with open(save_path, "wb") as f:
|
48 |
+
f.write(result.content)
|
api/fireworks.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
import requests
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
|
9 |
+
from api.flux import FluxAPI
|
10 |
+
|
11 |
+
|
12 |
+
class FireworksAPI(FluxAPI):
|
13 |
+
def __init__(self):
|
14 |
+
load_dotenv()
|
15 |
+
self._api_key = os.getenv("FIREWORKS_API_TOKEN")
|
16 |
+
if not self._api_key:
|
17 |
+
raise ValueError("FIREWORKS_API_TOKEN not found in environment variables")
|
18 |
+
self._url = "https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image"
|
19 |
+
|
20 |
+
@property
|
21 |
+
def name(self) -> str:
|
22 |
+
return "fireworks_fp8"
|
23 |
+
|
24 |
+
def generate_image(self, prompt: str, save_path: Path) -> float:
|
25 |
+
start_time = time.time()
|
26 |
+
|
27 |
+
headers = {
|
28 |
+
"Content-Type": "application/json",
|
29 |
+
"Accept": "image/jpeg",
|
30 |
+
"Authorization": f"Bearer {self._api_key}",
|
31 |
+
}
|
32 |
+
data = {
|
33 |
+
"prompt": prompt,
|
34 |
+
"aspect_ratio": "1:1",
|
35 |
+
"guidance_scale": 3.5,
|
36 |
+
"num_inference_steps": 28,
|
37 |
+
"seed": 0,
|
38 |
+
}
|
39 |
+
result = requests.post(self._url, headers=headers, json=data)
|
40 |
+
|
41 |
+
end_time = time.time()
|
42 |
+
|
43 |
+
if result.status_code == 200:
|
44 |
+
self._save_image_from_result(result, save_path)
|
45 |
+
else:
|
46 |
+
raise Exception(f"Error: {result.status_code} {result.text}")
|
47 |
+
|
48 |
+
return end_time - start_time
|
49 |
+
|
50 |
+
def _save_image_from_result(self, result: Any, save_path: Path):
|
51 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
52 |
+
with open(save_path, "wb") as f:
|
53 |
+
f.write(result.content)
|
api/flux.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
|
5 |
+
class FluxAPI(ABC):
|
6 |
+
"""
|
7 |
+
Abstract base class for Flux API implementations.
|
8 |
+
|
9 |
+
This class defines the common interface for all Flux API implementations.
|
10 |
+
"""
|
11 |
+
|
12 |
+
@property
|
13 |
+
@abstractmethod
|
14 |
+
def name(self) -> str:
|
15 |
+
"""
|
16 |
+
The name of the API implementation.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
str: The name of the specific API implementation
|
20 |
+
"""
|
21 |
+
pass
|
22 |
+
|
23 |
+
@abstractmethod
|
24 |
+
def generate_image(self, prompt: str, save_path: Path) -> float:
|
25 |
+
"""
|
26 |
+
Generate an image based on the prompt and save it to the specified path.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
prompt (str): The text prompt to generate the image from
|
30 |
+
save_path (Path): The path where the generated image should be saved
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
float: The time taken for the API call in seconds
|
34 |
+
"""
|
35 |
+
pass
|
api/pruna.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
import replicate
|
8 |
+
|
9 |
+
from api.flux import FluxAPI
|
10 |
+
|
11 |
+
|
12 |
+
class PrunaAPI(FluxAPI):
|
13 |
+
def __init__(self, speed_mode: str):
|
14 |
+
self._speed_mode = speed_mode
|
15 |
+
self._speed_mode_name = speed_mode.split(" ")[0].strip().lower().replace(" ", "_")
|
16 |
+
load_dotenv()
|
17 |
+
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
18 |
+
if not self._api_key:
|
19 |
+
raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
|
20 |
+
|
21 |
+
@property
|
22 |
+
def name(self) -> str:
|
23 |
+
return f"pruna_{self._speed_mode_name}"
|
24 |
+
|
25 |
+
def generate_image(self, prompt: str, save_path: Path) -> float:
|
26 |
+
start_time = time.time()
|
27 |
+
result = replicate.run(
|
28 |
+
"prunaai/flux.1-juiced:58977759ff2870cc010597ae75f4d87866d169b248e02b6e86c4e1bf8afe2410",
|
29 |
+
input={
|
30 |
+
"seed": 0,
|
31 |
+
"prompt": prompt,
|
32 |
+
"guidance": 3.5,
|
33 |
+
"num_outputs": 1,
|
34 |
+
"aspect_ratio": "1:1",
|
35 |
+
"output_format": "png",
|
36 |
+
"speed_mode": self._speed_mode,
|
37 |
+
"num_inference_steps": 28,
|
38 |
+
},
|
39 |
+
)
|
40 |
+
end_time = time.time()
|
41 |
+
|
42 |
+
if result:
|
43 |
+
self._save_image_from_result(result, save_path)
|
44 |
+
else:
|
45 |
+
raise Exception("No result returned from Replicate API")
|
46 |
+
return end_time - start_time
|
47 |
+
|
48 |
+
def _save_image_from_result(self, result: Any, save_path: Path):
|
49 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
50 |
+
with open(save_path, "wb") as f:
|
51 |
+
f.write(result.read())
|
api/replicate.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
import replicate
|
8 |
+
|
9 |
+
from api.flux import FluxAPI
|
10 |
+
|
11 |
+
|
12 |
+
class ReplicateAPI(FluxAPI):
|
13 |
+
def __init__(self):
|
14 |
+
load_dotenv()
|
15 |
+
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
16 |
+
if not self._api_key:
|
17 |
+
raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
|
18 |
+
|
19 |
+
@property
|
20 |
+
def name(self) -> str:
|
21 |
+
return "replicate_go_fast"
|
22 |
+
|
23 |
+
def generate_image(self, prompt: str, save_path: Path) -> float:
|
24 |
+
start_time = time.time()
|
25 |
+
result = replicate.run(
|
26 |
+
"black-forest-labs/flux-dev",
|
27 |
+
input={
|
28 |
+
"seed": 0,
|
29 |
+
"prompt": prompt,
|
30 |
+
"go_fast": True,
|
31 |
+
"guidance": 3.5,
|
32 |
+
"num_outputs": 1,
|
33 |
+
"aspect_ratio": "1:1",
|
34 |
+
"output_format": "png",
|
35 |
+
"num_inference_steps": 28,
|
36 |
+
},
|
37 |
+
)
|
38 |
+
end_time = time.time()
|
39 |
+
if result and len(result) > 0:
|
40 |
+
self._save_image_from_result(result[0], save_path)
|
41 |
+
else:
|
42 |
+
raise Exception("No result returned from Replicate API")
|
43 |
+
return end_time - start_time
|
44 |
+
|
45 |
+
def _save_image_from_result(self, result: Any, save_path: Path):
|
46 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
47 |
+
with open(save_path, "wb") as f:
|
48 |
+
f.write(result.read())
|
api/together.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any
|
6 |
+
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from PIL import Image
|
9 |
+
from together import Together
|
10 |
+
|
11 |
+
from api.flux import FluxAPI
|
12 |
+
|
13 |
+
|
14 |
+
class TogetherAPI(FluxAPI):
|
15 |
+
def __init__(self):
|
16 |
+
load_dotenv()
|
17 |
+
self._client = Together()
|
18 |
+
|
19 |
+
@property
|
20 |
+
def name(self) -> str:
|
21 |
+
return "together"
|
22 |
+
|
23 |
+
def generate_image(self, prompt: str, save_path: Path) -> float:
|
24 |
+
start_time = time.time()
|
25 |
+
result = self._client.images.generate(
|
26 |
+
prompt=prompt,
|
27 |
+
model="black-forest-labs/FLUX.1-dev",
|
28 |
+
width=1024,
|
29 |
+
height=1024,
|
30 |
+
steps=28,
|
31 |
+
n=1,
|
32 |
+
seed=0,
|
33 |
+
response_format="b64_json",
|
34 |
+
)
|
35 |
+
end_time = time.time()
|
36 |
+
if result and hasattr(result, 'data') and len(result.data) > 0:
|
37 |
+
self._save_image_from_result(result, save_path)
|
38 |
+
else:
|
39 |
+
raise Exception("No result returned from Together API")
|
40 |
+
return end_time - start_time
|
41 |
+
|
42 |
+
def _save_image_from_result(self, result: Any, save_path: Path):
|
43 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
44 |
+
b64_str = result.data[0].b64_json
|
45 |
+
image_data = base64.b64decode(b64_str)
|
46 |
+
image = Image.open(io.BytesIO(image_data))
|
47 |
+
image.save(save_path)
|
benchmark/__init__.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Type
|
2 |
+
|
3 |
+
from benchmark.draw_bench import DrawBenchPrompts
|
4 |
+
from benchmark.genai_bench import GenAIBenchPrompts
|
5 |
+
from benchmark.geneval import GenEvalPrompts
|
6 |
+
from benchmark.hps import HPSPrompts
|
7 |
+
from benchmark.parti import PartiPrompts
|
8 |
+
|
9 |
+
|
10 |
+
def create_benchmark(benchmark_type: str) -> Type[DrawBenchPrompts | GenAIBenchPrompts | GenEvalPrompts | HPSPrompts | PartiPrompts]:
|
11 |
+
"""
|
12 |
+
Factory function to create benchmark instances.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
benchmark_type (str): The type of benchmark to create. Must be one of:
|
16 |
+
- "draw_bench"
|
17 |
+
- "genai_bench"
|
18 |
+
- "geneval"
|
19 |
+
- "hps"
|
20 |
+
- "parti"
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
An instance of the requested benchmark implementation
|
24 |
+
|
25 |
+
Raises:
|
26 |
+
ValueError: If an invalid benchmark type is provided
|
27 |
+
"""
|
28 |
+
benchmark_map = {
|
29 |
+
"draw_bench": DrawBenchPrompts,
|
30 |
+
"genai_bench": GenAIBenchPrompts,
|
31 |
+
"geneval": GenEvalPrompts,
|
32 |
+
"hps": HPSPrompts,
|
33 |
+
"parti": PartiPrompts,
|
34 |
+
}
|
35 |
+
|
36 |
+
if benchmark_type not in benchmark_map:
|
37 |
+
raise ValueError(f"Invalid benchmark type: {benchmark_type}. Must be one of {list(benchmark_map.keys())}")
|
38 |
+
|
39 |
+
return benchmark_map[benchmark_type]()
|
benchmark/draw_bench.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Iterator, List, Tuple
|
3 |
+
|
4 |
+
from datasets import load_dataset
|
5 |
+
|
6 |
+
|
7 |
+
class DrawBenchPrompts:
|
8 |
+
def __init__(self):
|
9 |
+
self.dataset = load_dataset("shunk031/DrawBench")["test"]
|
10 |
+
|
11 |
+
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
12 |
+
for i, row in enumerate(self.dataset):
|
13 |
+
yield row["prompts"], Path(f"{i}.png")
|
14 |
+
|
15 |
+
@property
|
16 |
+
def name(self) -> str:
|
17 |
+
return "draw_bench"
|
18 |
+
|
19 |
+
@property
|
20 |
+
def size(self) -> int:
|
21 |
+
return len(self.dataset)
|
22 |
+
|
23 |
+
@property
|
24 |
+
def metrics(self) -> List[str]:
|
25 |
+
return ["image_reward"]
|
benchmark/genai_bench.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Iterator, List, Tuple
|
3 |
+
|
4 |
+
import requests
|
5 |
+
|
6 |
+
|
7 |
+
class GenAIBenchPrompts:
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
self._download_genai_bench_files()
|
11 |
+
prompts_path = Path('downloads/genai_bench/prompts.txt')
|
12 |
+
with open(prompts_path, 'r') as f:
|
13 |
+
self.prompts = [line.strip() for line in f if line.strip()]
|
14 |
+
|
15 |
+
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
16 |
+
for i, prompt in enumerate(self.prompts):
|
17 |
+
yield prompt, Path(f"{i}.png")
|
18 |
+
|
19 |
+
def _download_genai_bench_files(self) -> None:
|
20 |
+
folder_name = Path('downloads/genai_bench')
|
21 |
+
folder_name.mkdir(parents=True, exist_ok=True)
|
22 |
+
prompts_url = "https://huggingface.co/datasets/zhiqiulin/GenAI-Bench-527/raw/main/prompts.txt"
|
23 |
+
prompts_path = folder_name / "prompts.txt"
|
24 |
+
if not prompts_path.exists():
|
25 |
+
response = requests.get(prompts_url)
|
26 |
+
with open(prompts_path, 'w') as f:
|
27 |
+
f.write(response.text)
|
28 |
+
|
29 |
+
@property
|
30 |
+
def name(self) -> str:
|
31 |
+
return "genai_bench"
|
32 |
+
|
33 |
+
@property
|
34 |
+
def size(self) -> int:
|
35 |
+
return len(self.prompts)
|
36 |
+
|
37 |
+
@property
|
38 |
+
def metrics(self) -> List[str]:
|
39 |
+
return ["vqa"]
|
benchmark/geneval.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Any, Dict, Iterator, List, Tuple
|
4 |
+
|
5 |
+
import requests
|
6 |
+
|
7 |
+
|
8 |
+
class GenEvalPrompts:
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self._download_geneval_file()
|
12 |
+
metadata_path = Path('downloads/geneval/evaluation_metadata.jsonl')
|
13 |
+
self.entries: List[Dict[str, Any]] = []
|
14 |
+
with open(metadata_path, 'r') as f:
|
15 |
+
for line in f:
|
16 |
+
if line.strip():
|
17 |
+
self.entries.append(json.loads(line))
|
18 |
+
|
19 |
+
def __iter__(self) -> Iterator[Tuple[Dict[str, Any], Path]]:
|
20 |
+
for i, entry in enumerate(self.entries):
|
21 |
+
folder_name = f"{i:05d}"
|
22 |
+
yield entry, folder_name
|
23 |
+
|
24 |
+
def _download_geneval_file(self) -> None:
|
25 |
+
folder_name = Path('downloads/geneval')
|
26 |
+
folder_name.mkdir(parents=True, exist_ok=True)
|
27 |
+
metadata_url = "https://raw.githubusercontent.com/djghosh13/geneval/main/prompts/evaluation_metadata.jsonl"
|
28 |
+
metadata_path = folder_name / "evaluation_metadata.jsonl"
|
29 |
+
if not metadata_path.exists():
|
30 |
+
response = requests.get(metadata_url)
|
31 |
+
with open(metadata_path, 'w') as f:
|
32 |
+
f.write(response.text)
|
33 |
+
|
34 |
+
@property
|
35 |
+
def name(self) -> str:
|
36 |
+
return "geneval"
|
37 |
+
|
38 |
+
@property
|
39 |
+
def size(self) -> int:
|
40 |
+
return len(self.entries)
|
41 |
+
|
42 |
+
@property
|
43 |
+
def metrics(self) -> List[str]:
|
44 |
+
raise NotImplementedError("GenEval requires custom evaluation, see README.md")
|
benchmark/hps.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict, Iterator, List, Tuple
|
5 |
+
|
6 |
+
import huggingface_hub
|
7 |
+
|
8 |
+
|
9 |
+
class HPSPrompts:
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
self.hps_prompt_files = ['anime.json', 'concept-art.json', 'paintings.json', 'photo.json']
|
13 |
+
self._download_benchmark_prompts()
|
14 |
+
self.prompts: Dict[str, str] = {}
|
15 |
+
self._size = 0
|
16 |
+
for file in self.hps_prompt_files:
|
17 |
+
category = file.replace('.json', '')
|
18 |
+
with open(os.path.join('datasets/hps', file), 'r') as f:
|
19 |
+
prompts = json.load(f)
|
20 |
+
for i, prompt in enumerate(prompts):
|
21 |
+
if i == 100:
|
22 |
+
break
|
23 |
+
filename = f"{category}_{i:03d}.png"
|
24 |
+
self.prompts[filename] = prompt
|
25 |
+
self._size += 1
|
26 |
+
|
27 |
+
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
28 |
+
for filename, prompt in self.prompts.items():
|
29 |
+
yield prompt, filename
|
30 |
+
|
31 |
+
@property
|
32 |
+
def name(self) -> str:
|
33 |
+
return "hps"
|
34 |
+
|
35 |
+
@property
|
36 |
+
def size(self) -> int:
|
37 |
+
return self._size
|
38 |
+
|
39 |
+
def _download_benchmark_prompts(self) -> None:
|
40 |
+
folder_name = Path('downloads/hps')
|
41 |
+
folder_name.mkdir(parents=True, exist_ok=True)
|
42 |
+
for file in self.hps_prompt_files:
|
43 |
+
file_name = huggingface_hub.hf_hub_download("zhwang/HPDv2", file, subfolder="benchmark", repo_type="dataset")
|
44 |
+
if not os.path.exists(os.path.join(folder_name, file)):
|
45 |
+
os.symlink(file_name, os.path.join(folder_name, file))
|
46 |
+
|
47 |
+
@property
|
48 |
+
def metrics(self) -> List[str]:
|
49 |
+
return ["hps"]
|
benchmark/metrics/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Type
|
2 |
+
|
3 |
+
from benchmark.metrics.arniqa import ARNIQAMetric
|
4 |
+
from benchmark.metrics.clip import CLIPMetric
|
5 |
+
from benchmark.metrics.clip_iqa import CLIPIQAMetric
|
6 |
+
from benchmark.metrics.image_reward import ImageRewardMetric
|
7 |
+
from benchmark.metrics.sharpness import SharpnessMetric
|
8 |
+
from benchmark.metrics.vqa import VQAMetric
|
9 |
+
|
10 |
+
|
11 |
+
def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
|
12 |
+
"""
|
13 |
+
Factory function to create metric instances.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
metric_type (str): The type of metric to create. Must be one of:
|
17 |
+
- "arniqa"
|
18 |
+
- "clip"
|
19 |
+
- "clip_iqa"
|
20 |
+
- "image_reward"
|
21 |
+
- "sharpness"
|
22 |
+
- "vqa"
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
An instance of the requested metric implementation
|
26 |
+
|
27 |
+
Raises:
|
28 |
+
ValueError: If an invalid metric type is provided
|
29 |
+
"""
|
30 |
+
metric_map = {
|
31 |
+
"arniqa": ARNIQAMetric,
|
32 |
+
"clip": CLIPMetric,
|
33 |
+
"clip_iqa": CLIPIQAMetric,
|
34 |
+
"image_reward": ImageRewardMetric,
|
35 |
+
"sharpness": SharpnessMetric,
|
36 |
+
"vqa": VQAMetric,
|
37 |
+
}
|
38 |
+
|
39 |
+
if metric_type not in metric_map:
|
40 |
+
raise ValueError(f"Invalid metric type: {metric_type}. Must be one of {list(metric_map.keys())}")
|
41 |
+
|
42 |
+
return metric_map[metric_type]()
|
benchmark/metrics/arniqa.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torchmetrics.image.arniqa import ARNIQA
|
7 |
+
|
8 |
+
|
9 |
+
class ARNIQAMetric:
|
10 |
+
def __init__(self):
|
11 |
+
self.metric = ARNIQA(
|
12 |
+
regressor_dataset="koniq10k",
|
13 |
+
reduction="mean",
|
14 |
+
normalize=True,
|
15 |
+
autocast=False
|
16 |
+
)
|
17 |
+
|
18 |
+
@property
|
19 |
+
def name(self) -> str:
|
20 |
+
return "arniqa"
|
21 |
+
|
22 |
+
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
23 |
+
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
|
24 |
+
image_tensor = image_tensor.unsqueeze(0)
|
25 |
+
score = self.metric(image_tensor)
|
26 |
+
return {"arniqa": score.item()}
|
benchmark/metrics/clip.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torchmetrics.multimodal.clip_score import CLIPScore
|
7 |
+
|
8 |
+
|
9 |
+
class CLIPMetric:
|
10 |
+
def __init__(self, model_name_or_path: str = "openai/clip-vit-large-patch14"):
|
11 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
self.metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14")
|
13 |
+
self.metric.to(self.device)
|
14 |
+
|
15 |
+
@property
|
16 |
+
def name(self) -> str:
|
17 |
+
return "clip"
|
18 |
+
|
19 |
+
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
20 |
+
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
|
21 |
+
image_tensor = image_tensor.to(self.device)
|
22 |
+
score = self.metric(image_tensor, prompt)
|
23 |
+
return {"clip": score.item()}
|
benchmark/metrics/clip_iqa.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torchmetrics.multimodal import CLIPImageQualityAssessment
|
7 |
+
|
8 |
+
|
9 |
+
class CLIPIQAMetric:
|
10 |
+
def __init__(self):
|
11 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
self.metric = CLIPImageQualityAssessment(
|
13 |
+
model_name_or_path="clip_iqa",
|
14 |
+
data_range=255.0,
|
15 |
+
prompts=["quality"]
|
16 |
+
)
|
17 |
+
self.metric.to(self.device)
|
18 |
+
|
19 |
+
@property
|
20 |
+
def name(self) -> str:
|
21 |
+
return "clip_iqa"
|
22 |
+
|
23 |
+
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
24 |
+
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
|
25 |
+
image_tensor = image_tensor.unsqueeze(0)
|
26 |
+
image_tensor = image_tensor.to(self.device)
|
27 |
+
scores = self.metric(image_tensor)
|
28 |
+
return {"clip_iqa": scores.item()}
|
benchmark/metrics/hps.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
|
7 |
+
import huggingface_hub
|
8 |
+
from hpsv2.utils import root_path, hps_version_map
|
9 |
+
|
10 |
+
|
11 |
+
class HPSMetric:
|
12 |
+
def __init__(self):
|
13 |
+
self.hps_version = "v2.1"
|
14 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
+
self.model_dict = {}
|
16 |
+
self._initialize_model()
|
17 |
+
|
18 |
+
def _initialize_model(self):
|
19 |
+
if not self.model_dict:
|
20 |
+
model, preprocess_train, preprocess_val = create_model_and_transforms(
|
21 |
+
'ViT-H-14',
|
22 |
+
'laion2B-s32B-b79K',
|
23 |
+
precision='amp',
|
24 |
+
device=self.device,
|
25 |
+
jit=False,
|
26 |
+
force_quick_gelu=False,
|
27 |
+
force_custom_text=False,
|
28 |
+
force_patch_dropout=False,
|
29 |
+
force_image_size=None,
|
30 |
+
pretrained_image=False,
|
31 |
+
image_mean=None,
|
32 |
+
image_std=None,
|
33 |
+
light_augmentation=True,
|
34 |
+
aug_cfg={},
|
35 |
+
output_dict=True,
|
36 |
+
with_score_predictor=False,
|
37 |
+
with_region_predictor=False
|
38 |
+
)
|
39 |
+
self.model_dict['model'] = model
|
40 |
+
self.model_dict['preprocess_val'] = preprocess_val
|
41 |
+
|
42 |
+
# Load checkpoint
|
43 |
+
if not os.path.exists(root_path):
|
44 |
+
os.makedirs(root_path)
|
45 |
+
cp = huggingface_hub.hf_hub_download("xswu/HPSv2", hps_version_map[self.hps_version])
|
46 |
+
|
47 |
+
checkpoint = torch.load(cp, map_location=self.device)
|
48 |
+
model.load_state_dict(checkpoint['state_dict'])
|
49 |
+
self.tokenizer = get_tokenizer('ViT-H-14')
|
50 |
+
model = model.to(self.device)
|
51 |
+
model.eval()
|
52 |
+
|
53 |
+
@property
|
54 |
+
def name(self) -> str:
|
55 |
+
return "hps"
|
56 |
+
|
57 |
+
def compute_score(
|
58 |
+
self,
|
59 |
+
image: Image.Image,
|
60 |
+
prompt: str,
|
61 |
+
) -> Dict[str, float]:
|
62 |
+
model = self.model_dict['model']
|
63 |
+
preprocess_val = self.model_dict['preprocess_val']
|
64 |
+
|
65 |
+
with torch.no_grad():
|
66 |
+
# Process the image
|
67 |
+
image_tensor = preprocess_val(image).unsqueeze(0).to(device=self.device, non_blocking=True)
|
68 |
+
# Process the prompt
|
69 |
+
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
70 |
+
# Calculate the HPS
|
71 |
+
with torch.cuda.amp.autocast():
|
72 |
+
outputs = model(image_tensor, text)
|
73 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
74 |
+
logits_per_image = image_features @ text_features.T
|
75 |
+
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
|
76 |
+
|
77 |
+
return {"hps": float(hps_score[0])}
|
benchmark/metrics/image_reward.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
import ImageReward as RM
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
class ImageRewardMetric:
|
10 |
+
def __init__(self):
|
11 |
+
self.model = RM.load("ImageReward-v1.0")
|
12 |
+
|
13 |
+
@property
|
14 |
+
def name(self) -> str:
|
15 |
+
return "image_reward"
|
16 |
+
|
17 |
+
def compute_score(
|
18 |
+
self,
|
19 |
+
image: Image.Image,
|
20 |
+
prompt: str,
|
21 |
+
) -> Dict[str, float]:
|
22 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
23 |
+
image.save(tmp.name)
|
24 |
+
score = self.model.score(prompt, [tmp.name])
|
25 |
+
os.unlink(tmp.name)
|
26 |
+
return {"image_reward": score}
|
benchmark/metrics/sharpness.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
class SharpnessMetric:
|
9 |
+
def __init__(self):
|
10 |
+
self.kernel_size = 3
|
11 |
+
|
12 |
+
@property
|
13 |
+
def name(self) -> str:
|
14 |
+
return "sharpness"
|
15 |
+
|
16 |
+
def compute_score(
|
17 |
+
self,
|
18 |
+
image: Image.Image,
|
19 |
+
prompt: str,
|
20 |
+
) -> Dict[str, float]:
|
21 |
+
img = np.array(image.convert('L'))
|
22 |
+
laplacian = cv2.Laplacian(img, cv2.CV_64F, ksize=self.kernel_size)
|
23 |
+
sharpness = laplacian.var()
|
24 |
+
return {"sharpness": float(sharpness)}
|
benchmark/metrics/vqa.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
import t2v_metrics
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
class VQAMetric:
|
9 |
+
def __init__(self):
|
10 |
+
self.metric = t2v_metrics.VQAScore(model="clip-flant5-xxl")
|
11 |
+
|
12 |
+
@property
|
13 |
+
def name(self) -> str:
|
14 |
+
return "vqa_score"
|
15 |
+
|
16 |
+
def compute_score(
|
17 |
+
self,
|
18 |
+
image: Image.Image,
|
19 |
+
prompt: str,
|
20 |
+
) -> Dict[str, float]:
|
21 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
22 |
+
image.save(tmp.name)
|
23 |
+
score = self.metric(images=[tmp.name], texts=[prompt])
|
24 |
+
os.unlink(tmp.name)
|
25 |
+
return {"vqa_score": score[0][0].item()}
|
benchmark/parti.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Iterator, List, Tuple
|
3 |
+
|
4 |
+
from datasets import load_dataset
|
5 |
+
|
6 |
+
|
7 |
+
class PartiPrompts:
|
8 |
+
def __init__(self):
|
9 |
+
dataset = load_dataset("nateraw/parti-prompts")["train"]
|
10 |
+
shuffled_dataset = dataset.shuffle(seed=42)
|
11 |
+
selected_dataset = shuffled_dataset.select(range(800))
|
12 |
+
self.prompts = [row["Prompt"] for row in selected_dataset]
|
13 |
+
|
14 |
+
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
15 |
+
for i, prompt in enumerate(self.prompts):
|
16 |
+
yield prompt, Path(f"{i}.png")
|
17 |
+
|
18 |
+
@property
|
19 |
+
def name(self) -> str:
|
20 |
+
return "parti"
|
21 |
+
|
22 |
+
@property
|
23 |
+
def size(self) -> int:
|
24 |
+
return len(self.prompts)
|
25 |
+
|
26 |
+
@property
|
27 |
+
def metrics(self) -> List[str]:
|
28 |
+
return ["arniqa", "clip", "clip_iqa", "sharpness"]
|
evaluate.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict
|
5 |
+
|
6 |
+
from benchmark import create_benchmark
|
7 |
+
from benchmark.metrics import create_metric
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Path("images")) -> Dict:
|
12 |
+
"""
|
13 |
+
Evaluate a benchmark's images using its specific metrics.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
benchmark_type (str): Type of benchmark to evaluate
|
17 |
+
api_type (str): Type of API used to generate images
|
18 |
+
images_dir (Path): Base directory containing generated images
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
Dict containing evaluation results
|
22 |
+
"""
|
23 |
+
benchmark = create_benchmark(benchmark_type)
|
24 |
+
|
25 |
+
benchmark_dir = images_dir / api_type / benchmark_type
|
26 |
+
metadata_file = benchmark_dir / "metadata.jsonl"
|
27 |
+
|
28 |
+
if not metadata_file.exists():
|
29 |
+
raise FileNotFoundError(f"No metadata file found for {api_type}/{benchmark_type}. Please run sample.py first.")
|
30 |
+
|
31 |
+
metadata = []
|
32 |
+
with open(metadata_file, "r") as f:
|
33 |
+
for line in f:
|
34 |
+
metadata.append(json.loads(line))
|
35 |
+
|
36 |
+
metrics = {metric_type: create_metric(metric_type) for metric_type in benchmark.metrics}
|
37 |
+
|
38 |
+
results = {
|
39 |
+
"api": api_type,
|
40 |
+
"benchmark": benchmark_type,
|
41 |
+
"metrics": {metric: 0.0 for metric in benchmark.metrics},
|
42 |
+
"avg_inference_time": 0.0,
|
43 |
+
"total_images": len(metadata)
|
44 |
+
}
|
45 |
+
|
46 |
+
for entry in metadata:
|
47 |
+
image_path = benchmark_dir / entry["filepath"]
|
48 |
+
if not image_path.exists():
|
49 |
+
continue
|
50 |
+
|
51 |
+
image = Image.open(image_path)
|
52 |
+
|
53 |
+
for metric_type, metric in metrics.items():
|
54 |
+
try:
|
55 |
+
score = metric.compute_score(image, entry["prompt"])
|
56 |
+
results["metrics"][metric_type] += score[metric_type]
|
57 |
+
except Exception as e:
|
58 |
+
print(f"Error computing {metric_type} for {image_path}: {str(e)}")
|
59 |
+
|
60 |
+
results["avg_inference_time"] += entry["inference_time"]
|
61 |
+
|
62 |
+
for metric in results["metrics"]:
|
63 |
+
results["metrics"][metric] /= len(metadata)
|
64 |
+
results["avg_inference_time"] /= len(metadata)
|
65 |
+
|
66 |
+
return results
|
67 |
+
|
68 |
+
|
69 |
+
def main():
|
70 |
+
parser = argparse.ArgumentParser(description="Evaluate generated images using benchmark-specific metrics")
|
71 |
+
parser.add_argument("api_type", help="Type of API to evaluate")
|
72 |
+
parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to evaluate")
|
73 |
+
|
74 |
+
args = parser.parse_args()
|
75 |
+
|
76 |
+
results_dir = Path("evaluation_results")
|
77 |
+
results_dir.mkdir(exist_ok=True)
|
78 |
+
|
79 |
+
results_file = results_dir / f"{args.api_type}.jsonl"
|
80 |
+
existing_results = set()
|
81 |
+
|
82 |
+
if results_file.exists():
|
83 |
+
with open(results_file, "r") as f:
|
84 |
+
for line in f:
|
85 |
+
result = json.loads(line)
|
86 |
+
existing_results.add(result["benchmark"])
|
87 |
+
|
88 |
+
for benchmark_type in args.benchmarks:
|
89 |
+
if benchmark_type in existing_results:
|
90 |
+
print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
|
91 |
+
continue
|
92 |
+
|
93 |
+
try:
|
94 |
+
print(f"Evaluating {args.api_type}/{benchmark_type}")
|
95 |
+
results = evaluate_benchmark(benchmark_type, args.api_type)
|
96 |
+
|
97 |
+
# Append results to file
|
98 |
+
with open(results_file, "a") as f:
|
99 |
+
f.write(json.dumps(results) + "\n")
|
100 |
+
|
101 |
+
except Exception as e:
|
102 |
+
print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
main()
|
evaluation_results/.gitkeep
ADDED
File without changes
|
images/.gitkeep
ADDED
File without changes
|
pyproject.toml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "inferbench"
|
3 |
+
version = "0.1.0"
|
4 |
+
requires-python = ">=3.12"
|
5 |
+
dependencies = [
|
6 |
+
"datasets>=3.5.0",
|
7 |
+
"fal-client>=0.5.9",
|
8 |
+
"hpsv2>=1.2.0",
|
9 |
+
"huggingface-hub>=0.30.2",
|
10 |
+
"image-reward>=1.5",
|
11 |
+
"numpy>=2.2.5",
|
12 |
+
"opencv-python>=4.11.0.86",
|
13 |
+
"pillow>=11.2.1",
|
14 |
+
"python-dotenv>=1.1.0",
|
15 |
+
"replicate>=1.0.4",
|
16 |
+
"requests>=2.32.3",
|
17 |
+
"t2v-metrics>=1.2",
|
18 |
+
"together>=1.5.5",
|
19 |
+
"torch>=2.7.0",
|
20 |
+
"torchmetrics>=1.7.1",
|
21 |
+
"tqdm>=4.67.1",
|
22 |
+
]
|
sample.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from api import create_api
|
9 |
+
from benchmark import create_benchmark
|
10 |
+
|
11 |
+
|
12 |
+
def generate_images(api_type: str, benchmarks: List[str]):
|
13 |
+
images_dir = Path("images")
|
14 |
+
api = create_api(api_type)
|
15 |
+
|
16 |
+
api_dir = images_dir / api_type
|
17 |
+
api_dir.mkdir(parents=True, exist_ok=True)
|
18 |
+
|
19 |
+
for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
|
20 |
+
print(f"\nProcessing benchmark: {benchmark_type}")
|
21 |
+
|
22 |
+
benchmark = create_benchmark(benchmark_type)
|
23 |
+
|
24 |
+
if benchmark_type == "geneval":
|
25 |
+
benchmark_dir = api_dir / benchmark_type
|
26 |
+
benchmark_dir.mkdir(parents=True, exist_ok=True)
|
27 |
+
|
28 |
+
metadata_file = benchmark_dir / "metadata.jsonl"
|
29 |
+
existing_metadata = {}
|
30 |
+
if metadata_file.exists():
|
31 |
+
with open(metadata_file, "r") as f:
|
32 |
+
for line in f:
|
33 |
+
entry = json.loads(line)
|
34 |
+
existing_metadata[entry["filepath"]] = entry
|
35 |
+
|
36 |
+
for metadata, folder_name in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False):
|
37 |
+
sample_path = benchmark_dir / folder_name
|
38 |
+
samples_path = sample_path / "samples"
|
39 |
+
samples_path.mkdir(parents=True, exist_ok=True)
|
40 |
+
image_path = samples_path / "0000.png"
|
41 |
+
|
42 |
+
if image_path.exists():
|
43 |
+
continue
|
44 |
+
|
45 |
+
try:
|
46 |
+
inference_time = api.generate_image(metadata["prompt"], image_path)
|
47 |
+
|
48 |
+
metadata_entry = {
|
49 |
+
"filepath": str(image_path),
|
50 |
+
"prompt": metadata["prompt"],
|
51 |
+
"inference_time": inference_time
|
52 |
+
}
|
53 |
+
|
54 |
+
existing_metadata[str(image_path)] = metadata_entry
|
55 |
+
|
56 |
+
except Exception as e:
|
57 |
+
print(f"\nError generating image for prompt: {metadata['prompt']}")
|
58 |
+
print(f"Error: {str(e)}")
|
59 |
+
continue
|
60 |
+
else:
|
61 |
+
benchmark_dir = api_dir / benchmark_type
|
62 |
+
benchmark_dir.mkdir(parents=True, exist_ok=True)
|
63 |
+
|
64 |
+
metadata_file = benchmark_dir / "metadata.jsonl"
|
65 |
+
existing_metadata = {}
|
66 |
+
if metadata_file.exists():
|
67 |
+
with open(metadata_file, "r") as f:
|
68 |
+
for line in f:
|
69 |
+
entry = json.loads(line)
|
70 |
+
existing_metadata[entry["filepath"]] = entry
|
71 |
+
|
72 |
+
for prompt, image_path in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False):
|
73 |
+
full_image_path = benchmark_dir / image_path
|
74 |
+
|
75 |
+
if full_image_path.exists():
|
76 |
+
continue
|
77 |
+
|
78 |
+
try:
|
79 |
+
inference_time = api.generate_image(prompt, full_image_path)
|
80 |
+
|
81 |
+
metadata_entry = {
|
82 |
+
"filepath": str(image_path),
|
83 |
+
"prompt": prompt,
|
84 |
+
"inference_time": inference_time
|
85 |
+
}
|
86 |
+
|
87 |
+
existing_metadata[str(image_path)] = metadata_entry
|
88 |
+
|
89 |
+
except Exception as e:
|
90 |
+
print(f"\nError generating image for prompt: {prompt}")
|
91 |
+
print(f"Error: {str(e)}")
|
92 |
+
continue
|
93 |
+
|
94 |
+
with open(metadata_file, "w") as f:
|
95 |
+
for entry in existing_metadata.values():
|
96 |
+
f.write(json.dumps(entry) + "\n")
|
97 |
+
|
98 |
+
|
99 |
+
def main():
|
100 |
+
parser = argparse.ArgumentParser(description="Generate images for specified benchmarks using a given API")
|
101 |
+
parser.add_argument("api_type", help="Type of API to use for image generation")
|
102 |
+
parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")
|
103 |
+
|
104 |
+
args = parser.parse_args()
|
105 |
+
|
106 |
+
generate_images(args.api_type, args.benchmarks)
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
main()
|
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|