Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import json | |
import copy | |
from torch.utils.data import Dataset | |
import os | |
from PIL import Image | |
def geneval_save_fn(image, metadata, root_path): | |
path = os.path.join(root_path, metadata['filename']) | |
if not os.path.exists(path): | |
os.makedirs(path, exist_ok=True) | |
# save image | |
image_path = os.path.join(path, "samples", f"{metadata['seed']}.png") | |
if not os.path.exists(os.path.dirname(image_path)): | |
os.makedirs(os.path.dirname(image_path), exist_ok=True) | |
Image.fromarray(image).save(image_path) | |
# metadata_path | |
metadata_path = os.path.join(path, "metadata.jsonl") | |
with open(metadata_path, "w") as fp: | |
json.dump(metadata, fp) | |
class GenEvalDataset(Dataset): | |
def __init__(self, meta_json_path, num_samples_per_instance, latent_shape): | |
self.latent_shape = latent_shape | |
self.meta_json_path = meta_json_path | |
with open(meta_json_path) as fp: | |
self.metadatas = [json.loads(line) for line in fp] | |
self.num_instances = len(self.metadatas) | |
self.num_samples_per_instance = num_samples_per_instance | |
self.num_samples = self.num_instances * self.num_samples_per_instance | |
def __len__(self): | |
return self.num_samples | |
def __getitem__(self, idx): | |
instance_idx = idx // self.num_samples_per_instance | |
sample_idx = idx % self.num_samples_per_instance | |
metadata = copy.deepcopy(self.metadatas[instance_idx]) | |
generator = torch.Generator().manual_seed(sample_idx) | |
condition = metadata["prompt"] | |
latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32) | |
filename = f"{idx}" | |
metadata["seed"] = sample_idx | |
metadata["filename"] = filename | |
metadata["save_fn"] = geneval_save_fn | |
return latent, condition, metadata |