import spaces import os import random from PIL import Image import torch import gradio as gr import dotenv from adapter import load_ip_adapter_model, get_file_path from example import EXAMPLES dotenv.load_dotenv(".env.local") ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID") ADAPTER_MODEL_PATH = os.environ.get("ADAPTER_MODEL_PATH") ADAPTER_CONFIG_PATH = os.environ.get("ADAPTER_CONFIG_PATH") assert ADAPTER_REPO_ID is not None assert ADAPTER_MODEL_PATH is not None assert ADAPTER_CONFIG_PATH is not None BASE_MODEL_REPO_ID = os.environ.get( "BASE_MODEL_REPO_ID", "p1atdev/animagine-xl-4.0-bnb-nf4" ) BASE_MODEL_PATH = os.environ.get( "BASE_MODEL_PATH", "animagine-xl-4.0-opt.bnb_nf4.safetensors" ) INITIAL_BATCH_SIZE = int(os.environ.get("INITIAL_BATCH_SIZE", 1)) adapter_model_path = get_file_path(ADAPTER_REPO_ID, ADAPTER_MODEL_PATH) adapter_config_path = get_file_path(ADAPTER_REPO_ID, ADAPTER_CONFIG_PATH) base_model_path = get_file_path(BASE_MODEL_REPO_ID, BASE_MODEL_PATH) model = load_ip_adapter_model( model_path=base_model_path, config_path=adapter_config_path, adapter_path=adapter_model_path, ) model.to("cuda:0") @spaces.GPU def on_generate( prompt: str, negative_prompt: str, image: Image.Image | None, width: int, height: int, steps: int, cfg_scale: float, seed: int, randomize_seed: bool = True, num_images: int = 4, progress=gr.Progress(track_tqdm=True), ): if image is not None: image = image.convert("RGB") if randomize_seed: seed = random.randint(0, 2147483647) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): images = model.generate( prompt=[prompt] * num_images, # batch size 4 negative_prompt=negative_prompt, reference_image=image, num_inference_steps=steps, cfg_scale=cfg_scale, width=width, height=height, seed=seed, do_offloading=False, device="cuda:0", max_token_length=225, execution_dtype=torch.bfloat16, ) torch.cuda.empty_cache() return images, seed def main(): with gr.Blocks() as demo: with gr.Row(): with gr.Column(): prompt = gr.TextArea( label="Prompt", value="masterpiece, best quality", placeholder="masterpiece, best quality", interactive=True, ) input_image = gr.Image( label="Reference Image", type="pil", height=600, ) with gr.Accordion("Negative Prompt", open=False): negative_prompt = gr.TextArea( label="Negative Prompt", show_label=False, value="lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry", interactive=True, ) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=2048, step=128, value=896, interactive=True, ) height = gr.Slider( label="Height", minimum=256, maximum=2048, step=128, value=1152, interactive=True, ) with gr.Accordion("Advanced options", open=False): num_images = gr.Slider( label="Number of images to generate", minimum=1, maximum=8, step=1, value=INITIAL_BATCH_SIZE, interactive=True, ) with gr.Row(): seed = gr.Slider( label="Seed", minimum=0, maximum=2147483647, step=1, value=0, ) randomize_seed = gr.Checkbox( label="Randomize seed", value=True, interactive=True, scale=1, ) steps = gr.Slider( label="Inference steps", minimum=10, maximum=50, step=1, value=25, interactive=True, ) cfg_scale = gr.Slider( label="CFG scale", minimum=3.0, maximum=8.0, step=0.5, value=5.0, interactive=True, ) with gr.Column(): generate_button = gr.Button( "Generate", variant="primary", ) output_image = gr.Gallery( label="Generated images", type="pil", rows=2, height="768px", preview=True, show_label=True, ) comment = gr.Markdown( label="Comment", visible=False, ) gr.Examples( examples=EXAMPLES, inputs=[input_image, prompt, width, height, comment], cache_examples=False, ) gr.on( triggers=[generate_button.click], fn=on_generate, inputs=[ prompt, negative_prompt, input_image, width, height, steps, cfg_scale, seed, randomize_seed, num_images, ], outputs=[output_image, seed], ) demo.launch() if __name__ == "__main__": main()