import os import time from pathlib import Path from typing import Any import replicate from dotenv import load_dotenv from api.flux import FluxAPI class PrunaAPI(FluxAPI): def __init__(self, speed_mode: str): self._speed_mode = speed_mode self._speed_mode_name = ( speed_mode.split(" ")[0].strip().lower().replace(" ", "_") ) load_dotenv() self._api_key = os.getenv("REPLICATE_API_TOKEN") if not self._api_key: raise ValueError("REPLICATE_API_TOKEN not found in environment variables") @property def name(self) -> str: return f"pruna_{self._speed_mode_name}" def generate_image(self, prompt: str, save_path: Path) -> float: start_time = time.time() result = replicate.run( "prunaai/flux.1-juiced:58977759ff2870cc010597ae75f4d87866d169b248e02b6e86c4e1bf8afe2410", input={ "seed": 0, "prompt": prompt, "guidance": 3.5, "num_outputs": 1, "aspect_ratio": "1:1", "output_format": "png", "speed_mode": self._speed_mode, "num_inference_steps": 28, }, ) end_time = time.time() if result: self._save_image_from_result(result, save_path) else: raise Exception("No result returned from Replicate API") return end_time - start_time def _save_image_from_result(self, result: Any, save_path: Path): save_path.parent.mkdir(parents=True, exist_ok=True) with open(save_path, "wb") as f: f.write(result.read())