|
|
import gradio as gr
|
|
|
import torch
|
|
|
import json
|
|
|
from CondRefAR.pipeline import CondRefARPipeline
|
|
|
from transformers import AutoTokenizer, T5EncoderModel
|
|
|
|
|
|
|
|
|
def build_t5(device, dtype):
|
|
|
tok = AutoTokenizer.from_pretrained("google/flan-t5-xl")
|
|
|
enc = T5EncoderModel.from_pretrained("google/flan-t5-xl", torch_dtype=dtype)
|
|
|
enc = enc.to(device)
|
|
|
enc.eval()
|
|
|
return tok, enc
|
|
|
|
|
|
def text_to_emb(prompt, tok, enc, device, dtype):
|
|
|
inputs = tok([prompt], return_tensors="pt", padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, max_length=120)
|
|
|
with torch.no_grad():
|
|
|
out = enc(input_ids=inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device))
|
|
|
emb = out['last_hidden_state'].detach()
|
|
|
return emb.to(dtype)
|
|
|
|
|
|
def build_pipeline():
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
|
|
with open("configs/gpt_config.json","r") as f:
|
|
|
gpt_cfg = json.load(f)
|
|
|
with open("configs/vq_config.json","r") as f:
|
|
|
vq_cfg = json.load(f)
|
|
|
pipe = CondRefARPipeline.from_pretrained(".", gpt_cfg, vq_cfg, device=device, torch_dtype=dtype)
|
|
|
tok, enc = build_t5(device, dtype)
|
|
|
return pipe, tok, enc
|
|
|
|
|
|
pipe, tok, enc = build_pipeline()
|
|
|
|
|
|
def infer(prompt, control_image, cfg_scale, temperature, top_k, top_p):
|
|
|
emb = text_to_emb(prompt, tok, enc, pipe.device, pipe.dtype)
|
|
|
imgs = pipe(emb, control_image['composite'][:, :, :3], cfg_scale=cfg_scale, temperature=temperature, top_k=top_k, top_p=top_p)
|
|
|
return imgs[0]
|
|
|
|
|
|
|
|
|
EXAMPLES = [
|
|
|
[
|
|
|
"Aerial view of a large industrial area with multiple buildings and roads. There are several roads and highways visible in the image, and there are several parking lots scattered throughout the area.",
|
|
|
"assets/examples/example1.jpg",
|
|
|
4.0, 1.0, 2000, 1.0,
|
|
|
],
|
|
|
[
|
|
|
"Aaerial view of a forested area with a river running through it. On the right side of the image, there is a small town or village with a red-roofed building. ",
|
|
|
"assets/examples/example2.jpg",
|
|
|
5.0, 0.95, 2500, 0.95,
|
|
|
],
|
|
|
]
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="CondRef-AR", theme=gr.themes.Soft()) as demo:
|
|
|
gr.Markdown("## CondRef-AR: Controllable Aerial Image Generation")
|
|
|
|
|
|
with gr.Row(equal_height=True):
|
|
|
|
|
|
with gr.Column(scale=3):
|
|
|
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Describe the city...")
|
|
|
editor = gr.ImageEditor(
|
|
|
type="numpy", crop_size="1:1", canvas_size=(512, 512),
|
|
|
label="Image"
|
|
|
)
|
|
|
with gr.Row():
|
|
|
btn_gen = gr.Button("Generate", variant="primary")
|
|
|
btn_clear = gr.Button("Clear")
|
|
|
|
|
|
|
|
|
with gr.Column(scale=2):
|
|
|
with gr.Accordion("Advanced settings", open=False):
|
|
|
cfg_scale = gr.Slider(1, 8, value=4, step=0.5, label="CFG scale")
|
|
|
temperature = gr.Slider(0.5, 1.5, value=1.0, step=0.05, label="Temperature")
|
|
|
top_k = gr.Slider(50, 4000, value=2000, step=50, label="top_k")
|
|
|
top_p = gr.Slider(0.5, 1.0, value=1.0, step=0.01, label="top_p")
|
|
|
|
|
|
output = gr.Image(type="pil", label="Result", height=512)
|
|
|
|
|
|
|
|
|
gr.Examples(
|
|
|
examples=EXAMPLES,
|
|
|
inputs=[prompt, editor, cfg_scale, temperature, top_k, top_p],
|
|
|
outputs=output,
|
|
|
fn=infer,
|
|
|
cache_examples=False,
|
|
|
examples_per_page=2,
|
|
|
label="Examples"
|
|
|
)
|
|
|
|
|
|
|
|
|
btn_gen.click(
|
|
|
infer,
|
|
|
inputs=[prompt, editor, cfg_scale, temperature, top_k, top_p],
|
|
|
outputs=output
|
|
|
)
|
|
|
btn_clear.click(lambda: (None, None), outputs=[editor, output])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
demo.launch() |