File size: 4,176 Bytes
6ffba01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import torch
import json
from CondRefAR.pipeline import CondRefARPipeline
from transformers import AutoTokenizer, T5EncoderModel

# 简化:直接用 transformers 的 flan-t5-xl 提取文本嵌入
def build_t5(device, dtype):
    tok = AutoTokenizer.from_pretrained("google/flan-t5-xl")
    enc = T5EncoderModel.from_pretrained("google/flan-t5-xl", torch_dtype=dtype)
    enc = enc.to(device)
    enc.eval()
    return tok, enc

def text_to_emb(prompt, tok, enc, device, dtype):
    inputs = tok([prompt], return_tensors="pt", padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, max_length=120)
    with torch.no_grad():
        out = enc(input_ids=inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device))
        emb = out['last_hidden_state'].detach()  # [B, T, D]
    return emb.to(dtype)

def build_pipeline():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
    with open("configs/gpt_config.json","r") as f:
        gpt_cfg = json.load(f)
    with open("configs/vq_config.json","r") as f:
        vq_cfg = json.load(f)
    pipe = CondRefARPipeline.from_pretrained(".", gpt_cfg, vq_cfg, device=device, torch_dtype=dtype)
    tok, enc = build_t5(device, dtype)
    return pipe, tok, enc

pipe, tok, enc = build_pipeline()

def infer(prompt, control_image, cfg_scale, temperature, top_k, top_p):
    emb = text_to_emb(prompt, tok, enc, pipe.device, pipe.dtype)
    imgs = pipe(emb, control_image['composite'][:, :, :3], cfg_scale=cfg_scale, temperature=temperature, top_k=top_k, top_p=top_p)
    return imgs[0]


EXAMPLES = [
    [
        "Aerial view of a large industrial area with multiple buildings and roads. There are several roads and highways visible in the image, and there are several parking lots scattered throughout the area.",
        "assets/examples/example1.jpg",
        4.0, 1.0, 2000, 1.0,
    ],
    [
        "Aaerial view of a forested area with a river running through it. On the right side of the image, there is a small town or village with a red-roofed building. ",
        "assets/examples/example2.jpg",
        5.0, 0.95, 2500, 0.95,
    ],
]


with gr.Blocks(title="CondRef-AR", theme=gr.themes.Soft()) as demo:
    gr.Markdown("## CondRef-AR: Controllable Aerial Image Generation")

    with gr.Row(equal_height=True):
        # 左侧:输入区
        with gr.Column(scale=3):
            prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Describe the city...")
            editor = gr.ImageEditor(
                type="numpy", crop_size="1:1", canvas_size=(512, 512),
                label="Image"
            )
            with gr.Row():
                btn_gen = gr.Button("Generate", variant="primary")
                btn_clear = gr.Button("Clear")

        # 右侧:参数 + 输出 + 示例
        with gr.Column(scale=2):
            with gr.Accordion("Advanced settings", open=False):
                cfg_scale = gr.Slider(1, 8, value=4, step=0.5, label="CFG scale")
                temperature = gr.Slider(0.5, 1.5, value=1.0, step=0.05, label="Temperature")
                top_k = gr.Slider(50, 4000, value=2000, step=50, label="top_k")
                top_p = gr.Slider(0.5, 1.0, value=1.0, step=0.01, label="top_p")

            output = gr.Image(type="pil", label="Result", height=512)

            # 可点击示例:点击后自动填充并运行
            gr.Examples(
                examples=EXAMPLES,
                inputs=[prompt, editor, cfg_scale, temperature, top_k, top_p],
                outputs=output,
                fn=infer,
                cache_examples=False,
                examples_per_page=2,
                label="Examples"
            )

    # 按钮事件
    btn_gen.click(
        infer,
        inputs=[prompt, editor, cfg_scale, temperature, top_k, top_p],
        outputs=output
    )
    btn_clear.click(lambda: (None, None), outputs=[editor, output])

if __name__ == "__main__":
    demo.launch()