""" This files runs and saves the outputs for all example prompts. """ import os import hashlib import pickle, json from dataclasses import asdict from src.smc.inference import ( infer_pretrained, infer_smc_grad, infer_ft, PretrainedInferenceConfig, SMCGradInferenceConfig, FTInferenceConfig, InferenceOutput, ) examples = [ "A photo of a yellow bird and a black motorcycle", "A green stop sign in a red field", "A pink bicycle leaning against a fence near a river", "A cat in the style of Van Gogh’s Starry Night", "A stylish dog wearing sunglasses", "A photo of a blue clock and a white cup", "A dog on the moon", ] EXAMPLES_DIR = "examples" def short_hash(s): return hashlib.md5(s.encode()).hexdigest()[:8] def dataclass_to_json(obj, pretty=False): """Convert a dataclass instance to a JSON string.""" if not hasattr(obj, "__dataclass_fields__"): raise TypeError("Object must be a dataclass instance") # Convert to dict and sort keys to ensure stable serialization data = asdict(obj) if pretty: return json.dumps(data, indent=4, sort_keys=True) else: return json.dumps(data, separators=(",", ":"), sort_keys=True) def hash_dataclass(obj, algo="blake2s", digest_size=8): """Compute a deterministic hash for a dataclass instance.""" s = dataclass_to_json(obj) h = hashlib.new(algo) h.update(s.encode()) return h.hexdigest()[:digest_size * 2] # 2 hex chars per byte def does_out_exist(out_dir): return os.path.exists(os.path.join(out_dir, "out.pickle")) def save_out(out_dir, out: InferenceOutput): pickle.dump(out, open(os.path.join(out_dir, "out.pickle"), "wb")) for i, img in enumerate(out.images): img.save(os.path.join(out_dir, f"{i}.png")) def get_out_if_exists(method, config): out_dir = os.path.join(EXAMPLES_DIR, short_hash(config.prompt), method, hash_dataclass(config)) if does_out_exist(out_dir): return pickle.load(open(os.path.join(out_dir, "out.pickle"), "rb")) else: return None def main(): for prompt in examples: prompt_hash = short_hash(prompt) prompt_dir = os.path.join(EXAMPLES_DIR, prompt_hash) os.makedirs(prompt_dir, exist_ok=True) print(f"Running prompt: {prompt}") # Save prompt in file with open(os.path.join(prompt_dir, "prompt.txt"), "w") as f: f.write(prompt) config = PretrainedInferenceConfig(prompt=prompt) out_dir = os.path.join(prompt_dir, "pretrained", hash_dataclass(config)) if not does_out_exist(out_dir): os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, "config.json"), "w") as f: f.write(dataclass_to_json(config, pretty=True)) out = infer_pretrained(config, device="cuda") save_out(out_dir, out) config = SMCGradInferenceConfig(prompt=prompt) out_dir = os.path.join(prompt_dir, "smc_grad", hash_dataclass(config)) if not does_out_exist(out_dir): os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, "config.json"), "w") as f: f.write(dataclass_to_json(config, pretty=True)) out = infer_smc_grad(config, device="cuda") save_out(out_dir, out) config = FTInferenceConfig(prompt=prompt) out_dir = os.path.join(prompt_dir, "ft", hash_dataclass(config)) if not does_out_exist(out_dir): os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, "config.json"), "w") as f: f.write(dataclass_to_json(config)) out = infer_ft(config, device="cuda") save_out(out_dir, out) if __name__ == "__main__": main()