import os import time from pathlib import Path from typing import Any import replicate from dotenv import load_dotenv from api.flux import FluxAPI class PrunaDevAPI(FluxAPI): def __init__(self): load_dotenv() self._api_key = os.getenv("REPLICATE_API_TOKEN") if not self._api_key: raise ValueError("REPLICATE_API_TOKEN not found in environment variables") @property def name(self) -> str: return "pruna_dev" def generate_image(self, prompt: str, save_path: Path) -> float: start_time = time.time() result = replicate.run( "prunaai/flux.1-dev:938a4eb31a87d65fb7b23fc300fb5b7ab88a36844bb26e54e1d1dec7acf4eefe", input={ "seed": 0, "prompt": prompt, "guidance": 3.5, "num_outputs": 1, "aspect_ratio": "1:1", "output_format": "png", "speed_mode": "Juiced 🔥 (default)", "num_inference_steps": 28, }, ) end_time = time.time() if result: self._save_image_from_result(result, save_path) else: raise Exception("No result returned from Replicate API") return end_time - start_time def _save_image_from_result(self, result: Any, save_path: Path): save_path.parent.mkdir(parents=True, exist_ok=True) with open(save_path, "wb") as f: f.write(result.read())