File size: 4,311 Bytes
bf40ad0
5226632
f53fb95
7cfa686
 
 
 
f53fb95
7cfa686
 
f53fb95
7cfa686
 
 
 
f53fb95
7cfa686
 
 
 
 
 
 
 
 
 
f53fb95
7cfa686
 
 
9a25740
7cfa686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf40ad0
5226632
 
 
 
9a25740
f53fb95
5226632
7cfa686
 
 
 
9a25740
f53fb95
7cfa686
9a25740
f53fb95
7cfa686
f53fb95
7cfa686
 
 
 
 
 
 
 
 
 
5226632
7cfa686
5226632
7cfa686
9a25740
 
7cfa686
9a25740
 
7cfa686
 
 
 
9a25740
7cfa686
 
 
 
 
 
9a25740
7cfa686
bf40ad0
 
7cfa686
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from gradio_image_annotation import image_annotator
from diffusers import EulerDiscreteScheduler
import torch
import os
import random

from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore
from migc.migc_utils import seed_everything, load_migc
from huggingface_hub import hf_hub_download

# 下载模型文件
migc_ckpt_path = hf_hub_download(repo_id="limuloo1999/MIGC", filename="MIGC_SD14.ckpt")
RV_path = hf_hub_download(repo_id="SG161222/Realistic_Vision_V6.0_B1_noVAE", filename="Realistic_Vision_V6.0_NV_B1.safetensors")
anime_path = hf_hub_download(repo_id="ckpt/cetus-mix", filename="cetusMix_v4.safetensors")

# -------- 风格切换器类 --------
class StyleSwitcher:
    def __init__(self):
        self.pipe = None
        self.attn_store = AttentionStore()
        self.styles = {
            "realistic": RV_path,
            "anime": anime_path
        }
        self.current_style = None

    def load_model(self, style):
        if style == self.current_style:
            return self.pipe

        if self.pipe:
            del self.pipe
            torch.cuda.empty_cache()
            print(f"[Info] Switched from {self.current_style} to {style}.")

        model_path = self.styles[style]
        print(f"[Info] Loading {style} model...")

        self.pipe = StableDiffusionMIGCPipeline.from_single_file(
            model_path,
            torch_dtype=torch.float32
        )
        self.pipe.safety_checker = None
        self.pipe.attention_store = self.attn_store
        load_migc(self.pipe.unet, self.attn_store, migc_ckpt_path, attn_processor=MIGCProcessor)
        self.pipe = self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
        self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)

        self.current_style = style
        return self.pipe

style_switcher = StyleSwitcher()

# ⬇️ 新增函数:返回随机 seed
def generate_random_seed():
    return random.randint(0, 2**32 - 1)

# 生成函数
def get_boxes_json(annotations, seed_value, edit_mode, style_selection):
    seed_everything(seed_value)

    pipe = style_switcher.load_model(style_selection)

    image = annotations["image"]
    width = image.shape[1]
    height = image.shape[0]
    boxes = annotations["boxes"]
    prompt_final = [[]]
    bboxes = [[]]
    for box in boxes:
        box["xmin"] /= width
        box["xmax"] /= width
        box["ymin"] /= height
        box["ymax"] /= height
        prompt_final[0].append(box["label"])
        bboxes[0].append([box["xmin"], box["ymin"], box["xmax"], box["ymax"]])

    prompt = ", ".join(prompt_final[0])
    prompt_final[0].insert(0, prompt)

    negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry'
    output_image = pipe(prompt_final, bboxes, num_inference_steps=30, guidance_scale=7.5,
                        MIGCsteps=15, aug_phase_with_and=False, negative_prompt=negative_prompt,
                        sa_preserve=True, use_sa_preserve=edit_mode).images[0]
    return output_image

# 示例标注图
example_annotation = {
    "image": os.path.join(os.path.dirname(__file__), "background.png"),
    "boxes": [],
}

# ------------- Gradio UI -------------
with gr.Blocks() as demo:
    with gr.Tab("DreamRenderer", id="DreamRenderer"):
        with gr.Row():
            with gr.Column(scale=1):
                annotator = image_annotator(example_annotation, height=512, width=512)
            with gr.Column(scale=1):
                generated_image = gr.Image(label="Generated Image", height=512, width=512)
                seed_input = gr.Number(label="Seed (Optional)", precision=0)
                seed_random_btn = gr.Button("🎲 Random Seed")
                edit_mode_toggle = gr.Checkbox(label="Edit Mode")
                style_selector = gr.Radio(choices=["realistic", "anime"], label="风格选择", value="realistic")

        button_get = gr.Button("生成图像")
        button_get.click(
            fn=get_boxes_json,
            inputs=[annotator, seed_input, edit_mode_toggle, style_selector],
            outputs=generated_image
        )

        seed_random_btn.click(fn=generate_random_seed, inputs=[], outputs=seed_input)

if __name__ == "__main__":
    demo.launch(share=True)