from typing import Type from api.aws import AWSBedrockAPI from api.baseline import BaselineAPI from api.fal import FalAPI from api.fireworks import FireworksAPI from api.flux import FluxAPI from api.pruna import PrunaAPI from api.pruna_dev import PrunaDevAPI from api.replicate import ReplicateAPI from api.together import TogetherAPI __all__ = [ "create_api", "FluxAPI", "BaselineAPI", "FireworksAPI", "PrunaAPI", "ReplicateAPI", "TogetherAPI", "FalAPI", "PrunaDevAPI", ] def create_api(api_type: str) -> FluxAPI: """ Factory function to create API instances. Args: api_type (str): The type of API to create. Must be one of: - "baseline" - "fireworks" - "pruna_speed_mode" (where speed_mode is the desired speed mode) - "replicate" - "together" - "fal" - "aws" Returns: FluxAPI: An instance of the requested API implementation Raises: ValueError: If an invalid API type is provided """ if api_type == "pruna_dev": return PrunaDevAPI() if api_type.startswith("pruna_"): speed_mode = api_type[6:] # Remove "pruna_" prefix return PrunaAPI(speed_mode) api_map: dict[str, Type[FluxAPI]] = { "baseline": BaselineAPI, "fireworks": FireworksAPI, "replicate": ReplicateAPI, "together": TogetherAPI, "fal": FalAPI, "aws": AWSBedrockAPI, } if api_type not in api_map: raise ValueError( f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'" ) return api_map[api_type]()