CondRef-AR / app.py
PuTorch's picture
upload CondRef-AR model
6ffba01 verified
import gradio as gr
import torch
import json
from CondRefAR.pipeline import CondRefARPipeline
from transformers import AutoTokenizer, T5EncoderModel
# 简化:直接用 transformers 的 flan-t5-xl 提取文本嵌入
def build_t5(device, dtype):
tok = AutoTokenizer.from_pretrained("google/flan-t5-xl")
enc = T5EncoderModel.from_pretrained("google/flan-t5-xl", torch_dtype=dtype)
enc = enc.to(device)
enc.eval()
return tok, enc
def text_to_emb(prompt, tok, enc, device, dtype):
inputs = tok([prompt], return_tensors="pt", padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, max_length=120)
with torch.no_grad():
out = enc(input_ids=inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device))
emb = out['last_hidden_state'].detach() # [B, T, D]
return emb.to(dtype)
def build_pipeline():
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
with open("configs/gpt_config.json","r") as f:
gpt_cfg = json.load(f)
with open("configs/vq_config.json","r") as f:
vq_cfg = json.load(f)
pipe = CondRefARPipeline.from_pretrained(".", gpt_cfg, vq_cfg, device=device, torch_dtype=dtype)
tok, enc = build_t5(device, dtype)
return pipe, tok, enc
pipe, tok, enc = build_pipeline()
def infer(prompt, control_image, cfg_scale, temperature, top_k, top_p):
emb = text_to_emb(prompt, tok, enc, pipe.device, pipe.dtype)
imgs = pipe(emb, control_image['composite'][:, :, :3], cfg_scale=cfg_scale, temperature=temperature, top_k=top_k, top_p=top_p)
return imgs[0]
EXAMPLES = [
[
"Aerial view of a large industrial area with multiple buildings and roads. There are several roads and highways visible in the image, and there are several parking lots scattered throughout the area.",
"assets/examples/example1.jpg",
4.0, 1.0, 2000, 1.0,
],
[
"Aaerial view of a forested area with a river running through it. On the right side of the image, there is a small town or village with a red-roofed building. ",
"assets/examples/example2.jpg",
5.0, 0.95, 2500, 0.95,
],
]
with gr.Blocks(title="CondRef-AR", theme=gr.themes.Soft()) as demo:
gr.Markdown("## CondRef-AR: Controllable Aerial Image Generation")
with gr.Row(equal_height=True):
# 左侧:输入区
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Describe the city...")
editor = gr.ImageEditor(
type="numpy", crop_size="1:1", canvas_size=(512, 512),
label="Image"
)
with gr.Row():
btn_gen = gr.Button("Generate", variant="primary")
btn_clear = gr.Button("Clear")
# 右侧:参数 + 输出 + 示例
with gr.Column(scale=2):
with gr.Accordion("Advanced settings", open=False):
cfg_scale = gr.Slider(1, 8, value=4, step=0.5, label="CFG scale")
temperature = gr.Slider(0.5, 1.5, value=1.0, step=0.05, label="Temperature")
top_k = gr.Slider(50, 4000, value=2000, step=50, label="top_k")
top_p = gr.Slider(0.5, 1.0, value=1.0, step=0.01, label="top_p")
output = gr.Image(type="pil", label="Result", height=512)
# 可点击示例:点击后自动填充并运行
gr.Examples(
examples=EXAMPLES,
inputs=[prompt, editor, cfg_scale, temperature, top_k, top_p],
outputs=output,
fn=infer,
cache_examples=False,
examples_per_page=2,
label="Examples"
)
# 按钮事件
btn_gen.click(
infer,
inputs=[prompt, editor, cfg_scale, temperature, top_k, top_p],
outputs=output
)
btn_clear.click(lambda: (None, None), outputs=[editor, output])
if __name__ == "__main__":
demo.launch()