|
import spaces |
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from accelerate import Accelerator |
|
import os |
|
import time |
|
import math |
|
import json |
|
from torchvision import transforms |
|
from safetensors.torch import load_file |
|
from networks import asylora_flux as lora_flux |
|
from library import flux_utils, strategy_flux |
|
import flux_minimal_inference_asylora as flux_train_utils |
|
import logging |
|
from huggingface_hub import login |
|
from huggingface_hub import hf_hub_download |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
accelerator = Accelerator(mixed_precision='bf16', device_placement=True) |
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
login(token=hf_token) |
|
domain_index = { |
|
'LEGO': 1, 'Cook': 2, 'Painting': 3, 'Icon': 4, 'Landscape illustration': 5, |
|
'Portrait': 6, 'Transformer': 7, 'Sand art': 8, 'Illustration': 9, 'Sketch': 10, |
|
'Clay toys': 11, 'Clay sculpture': 12, 'Zbrush Modeling': 13, 'Wood sculpture': 14, |
|
'Ink painting': 15, 'Pencil sketch': 16, 'Fabric toys': 17, 'Oil painting': 18, |
|
'Jade Carving': 19, 'Line draw': 20, 'Emoji': 21 |
|
} |
|
|
|
lora_paths = { |
|
"9 frame": "asymmetric_lora/asymmetric_lora_9f_general.safetensors", |
|
"4 frame": "asymmetric_lora/asymmetric_lora_4f_general.safetensors" |
|
} |
|
|
|
|
|
flux_repo_id="Kijai/flux-fp8" |
|
flux_file="flux1-dev-fp8.safetensors" |
|
lora_repo_id="showlab/makeanything" |
|
clip_repo_id = "comfyanonymous/flux_text_encoders" |
|
t5xxl_file = "t5xxl_fp16.safetensors" |
|
clip_l_file = "clip_l.safetensors" |
|
ae_repo_id = "black-forest-labs/FLUX.1-dev" |
|
ae_file = "ae.safetensors" |
|
|
|
model = None |
|
clip_l = None |
|
t5xxl = None |
|
ae = None |
|
lora_model = None |
|
|
|
|
|
def download_file(repo_id, file_name): |
|
return hf_hub_download(repo_id=repo_id, filename=file_name) |
|
|
|
|
|
def load_target_model(frame, domain): |
|
global model, clip_l, t5xxl, ae, lora_model |
|
|
|
BASE_FLUX_CHECKPOINT=download_file(flux_repo_id, flux_file) |
|
CLIP_L_PATH = download_file(clip_repo_id, clip_l_file) |
|
T5XXL_PATH = download_file(clip_repo_id, t5xxl_file) |
|
AE_PATH = download_file(ae_repo_id, ae_file) |
|
LORA_WEIGHTS_PATH = download_file(lora_repo_id, lora_paths[frame]) |
|
|
|
logger.info("Loading models...") |
|
_, model = flux_utils.load_flow_model( |
|
BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False |
|
) |
|
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False) |
|
clip_l.eval() |
|
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False) |
|
t5xxl.eval() |
|
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False) |
|
logger.info("Models loaded successfully.") |
|
|
|
multiplier = 1.0 |
|
weights_sd = load_file(LORA_WEIGHTS_PATH) |
|
lora_ups_num = 10 if frame=="9 frame" else 21 |
|
lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True, lora_ups_num=lora_ups_num) |
|
for sub_lora in lora_model.unet_loras: |
|
sub_lora.set_lora_up_cur(domain_index[domain]-1) |
|
|
|
lora_model.apply_to([clip_l, t5xxl], model) |
|
info = lora_model.load_state_dict(weights_sd, strict=True) |
|
logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}") |
|
lora_model.eval() |
|
|
|
logger.info("Models loaded successfully.") |
|
return "Models loaded successfully. Using Frame: {}, Damain: {}".format(frame, domain) |
|
|
|
|
|
@spaces.GPU(duration=180) |
|
def infer(prompt, frame, seed=0): |
|
global model, clip_l, t5xxl, ae, lora_model |
|
if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None: |
|
logger.error("Models not loaded. Please load the models first.") |
|
return None |
|
|
|
frame_num = int(frame[0:1]) |
|
|
|
logger.info(f"Started generating image with prompt: {prompt}") |
|
|
|
lora_model.to("cuda") |
|
|
|
model.eval() |
|
clip_l.eval() |
|
t5xxl.eval() |
|
ae.eval() |
|
|
|
logger.info(f"Using seed: {seed}") |
|
|
|
ae.to("cpu") |
|
clip_l.to(device) |
|
t5xxl.to(device) |
|
|
|
|
|
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512) |
|
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True) |
|
tokens_and_masks = tokenize_strategy.tokenize(prompt) |
|
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, True) |
|
|
|
logger.debug("Prompt encoded.") |
|
|
|
|
|
width = 1024 if frame_num == 4 else 1056 |
|
height = 1024 if frame_num == 4 else 1056 |
|
|
|
packed_latent_height, packed_latent_width = math.ceil(height / 16), math.ceil(width / 16) |
|
|
|
torch.manual_seed(seed) |
|
noise = torch.randn(1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, dtype=torch.float16) |
|
logger.debug("Noise prepared.") |
|
|
|
|
|
|
|
timesteps = flux_train_utils.get_schedule(20, noise.shape[1], shift=True) |
|
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(device) |
|
|
|
t5_attn_mask = t5_attn_mask.to(device) |
|
|
|
logger.debug("Image generation parameters set.") |
|
|
|
args = lambda: None |
|
args.frame_num = frame_num |
|
|
|
clip_l.to("cpu") |
|
t5xxl.to("cpu") |
|
|
|
torch.cuda.empty_cache() |
|
model.to(device) |
|
|
|
print(f"Model device: {model.device}") |
|
print(f"Noise device: {noise.device}") |
|
print(f"Image IDs device: {img_ids.device}") |
|
print(f"T5 output device: {t5_out.device}") |
|
print(f"Text IDs device: {txt_ids.device}") |
|
print(f"L pooled device: {l_pooled.device}") |
|
|
|
|
|
with accelerator.autocast(), torch.no_grad(): |
|
x = flux_train_utils.denoise( |
|
model, |
|
noise, |
|
img_ids, |
|
t5_out, |
|
txt_ids, |
|
l_pooled, |
|
timesteps, |
|
guidance=4.0, |
|
t5_attn_mask=t5_attn_mask, |
|
cfg_scale=1.0, |
|
) |
|
logger.debug("Denoising process completed.") |
|
|
|
|
|
x = x.float() |
|
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) |
|
model.to("cpu") |
|
ae.to(device) |
|
with accelerator.autocast(), torch.no_grad(): |
|
x = ae.decode(x) |
|
logger.debug("Latents decoded into image.") |
|
ae.to("cpu") |
|
|
|
|
|
x = x.clamp(-1, 1) |
|
x = x.permute(0, 2, 3, 1) |
|
generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) |
|
|
|
logger.info("Image generation completed.") |
|
torch.cuda.empty_cache() |
|
|
|
return generated_image |
|
|
|
def update_domains(floor): |
|
domains_dict = { |
|
"4 frame": [ |
|
"LEGO", "Cook", "Painting", "Icon", "Landscape illustration", |
|
"Portrait", "Transformer", "Sand art", "Illustration", "Sketch", |
|
"Clay toys", "Clay sculpture", "Zbrush Modeling", "Wood sculpture", "Ink painting", |
|
"Pencil sketch", "Fabric toys", "Oil painting", "Jade Carving", "Line draw", "Emoji" |
|
], |
|
"9 frame": [ |
|
"LEGO", "Cook", "Painting", "Icon", "Landscape illustration", |
|
"Portrait", "Transformer", "Sand art", "Illustration", "Sketch" |
|
] |
|
} |
|
return gr.Dropdown(choices=domains_dict[floor], label="Select Domains") |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Asymmertric LoRA Generation") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
frame_selector = gr.Radio(choices=["4 frame", "9 frame"], label="Select Model") |
|
with gr.Column(scale=2): |
|
domain_selector = gr.Dropdown(choices=["LEGO", "Cook", "Painting", "Icon", "Landscape illustration", |
|
"Portrait", "Transformer", "Sand art", "Illustration", "Sketch", |
|
"Clay toys", "Clay sculpture", "Zbrush Modeling", "Wood sculpture", "Ink painting", |
|
"Pencil sketch", "Fabric toys", "Oil painting", "Jade Carving", "Line draw", "Emoji"], label="Select Domains") |
|
|
|
|
|
load_button = gr.Button("Load Model") |
|
|
|
with gr.Column(scale=1): |
|
|
|
status_box = gr.Textbox(label="Status", placeholder="Model loading status", interactive=False, value="Model not loaded", lines=3) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=8) |
|
with gr.Row(): |
|
seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=42) |
|
run_button = gr.Button("Generate Image") |
|
|
|
with gr.Column(scale=1): |
|
|
|
result_image = gr.Image(label="Generated Image", interactive=False) |
|
|
|
frame_selector.change(update_domains, inputs=frame_selector, outputs=domain_selector) |
|
|
|
|
|
load_button.click(fn=load_target_model, inputs=[frame_selector, domain_selector], outputs=[status_box]) |
|
|
|
|
|
run_button.click(fn=infer, inputs=[prompt, frame_selector, seed], outputs=[result_image]) |
|
|
|
gr.Markdown("### Examples") |
|
examples = [ |
|
[ |
|
"9 frame", |
|
"LEGO", |
|
"sks1, 3*3 puzzle of 9 sub-images, step-by-step construction process of a LEGO model,<image-1> Lay down a gray plate as a road surface.<image-2> Position two red 2x4 bricks side by side to start forming a sports car’s chassis.<image-3> Attach black slope bricks at the front, shaping a sleek hood.<image-4> Insert transparent pieces at the front for headlights.<image-5> Clip on black wheel assemblies at each corner.<image-6> Add a windshield piece and a small black steering wheel inside.<image-7> Place smooth tiles on top to create a glossy roof.<image-8> Add side mirrors and a spoiler at the back.<image-9> Conclude by placing a minifigure driver behind the wheel, ready to race.", |
|
1855705978 |
|
], |
|
[ |
|
"9 frame", |
|
"Portrait", |
|
"sks6, 3*3 puzzle of 9 sub-images, step-by-step portrait painting process, woman with blonde curly hair", |
|
1062070717 |
|
], |
|
[ |
|
"9 frame", |
|
"Sand art", |
|
"sks8, 3*3 puzzle of 9 sub-images, step-by-step description of sand art creation, <image-1>: The outline of a classic pirate ship is drawn, capturing its sails and hull. <image-2>: Basic shapes of the ship’s structure and masts are added, defining its adventurous form. <image-3>: Details of the sails and rigging begin to appear, adding complexity. <image-4>: Shadows and highlights enhance the ship’s three-dimensional appearance. <image-5>: The ship’s deck and cannons are refined, giving it character. <image-6>: Additional elements like waves and seagulls are added for movement. <image-7>: A backdrop of a stormy sea with dark clouds is introduced, adding drama. <image-8>: Further details like lightning and crashing waves are sketched for intensity. <image-9>: Final touches include vibrant blues and grays, completing the thrilling pirate ship scene.", |
|
641262478 |
|
], |
|
] |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[frame_selector, domain_selector, prompt, seed], |
|
outputs=[result_image], |
|
cache_examples=False |
|
) |
|
|
|
|
|
demo.launch() |
|
|