Spaces:
Sleeping
Sleeping
File size: 4,311 Bytes
bf40ad0 5226632 f53fb95 7cfa686 f53fb95 7cfa686 f53fb95 7cfa686 f53fb95 7cfa686 f53fb95 7cfa686 9a25740 7cfa686 bf40ad0 5226632 9a25740 f53fb95 5226632 7cfa686 9a25740 f53fb95 7cfa686 9a25740 f53fb95 7cfa686 f53fb95 7cfa686 5226632 7cfa686 5226632 7cfa686 9a25740 7cfa686 9a25740 7cfa686 9a25740 7cfa686 9a25740 7cfa686 bf40ad0 7cfa686 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
from gradio_image_annotation import image_annotator
from diffusers import EulerDiscreteScheduler
import torch
import os
import random
from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore
from migc.migc_utils import seed_everything, load_migc
from huggingface_hub import hf_hub_download
# 下载模型文件
migc_ckpt_path = hf_hub_download(repo_id="limuloo1999/MIGC", filename="MIGC_SD14.ckpt")
RV_path = hf_hub_download(repo_id="SG161222/Realistic_Vision_V6.0_B1_noVAE", filename="Realistic_Vision_V6.0_NV_B1.safetensors")
anime_path = hf_hub_download(repo_id="ckpt/cetus-mix", filename="cetusMix_v4.safetensors")
# -------- 风格切换器类 --------
class StyleSwitcher:
def __init__(self):
self.pipe = None
self.attn_store = AttentionStore()
self.styles = {
"realistic": RV_path,
"anime": anime_path
}
self.current_style = None
def load_model(self, style):
if style == self.current_style:
return self.pipe
if self.pipe:
del self.pipe
torch.cuda.empty_cache()
print(f"[Info] Switched from {self.current_style} to {style}.")
model_path = self.styles[style]
print(f"[Info] Loading {style} model...")
self.pipe = StableDiffusionMIGCPipeline.from_single_file(
model_path,
torch_dtype=torch.float32
)
self.pipe.safety_checker = None
self.pipe.attention_store = self.attn_store
load_migc(self.pipe.unet, self.attn_store, migc_ckpt_path, attn_processor=MIGCProcessor)
self.pipe = self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
self.current_style = style
return self.pipe
style_switcher = StyleSwitcher()
# ⬇️ 新增函数:返回随机 seed
def generate_random_seed():
return random.randint(0, 2**32 - 1)
# 生成函数
def get_boxes_json(annotations, seed_value, edit_mode, style_selection):
seed_everything(seed_value)
pipe = style_switcher.load_model(style_selection)
image = annotations["image"]
width = image.shape[1]
height = image.shape[0]
boxes = annotations["boxes"]
prompt_final = [[]]
bboxes = [[]]
for box in boxes:
box["xmin"] /= width
box["xmax"] /= width
box["ymin"] /= height
box["ymax"] /= height
prompt_final[0].append(box["label"])
bboxes[0].append([box["xmin"], box["ymin"], box["xmax"], box["ymax"]])
prompt = ", ".join(prompt_final[0])
prompt_final[0].insert(0, prompt)
negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry'
output_image = pipe(prompt_final, bboxes, num_inference_steps=30, guidance_scale=7.5,
MIGCsteps=15, aug_phase_with_and=False, negative_prompt=negative_prompt,
sa_preserve=True, use_sa_preserve=edit_mode).images[0]
return output_image
# 示例标注图
example_annotation = {
"image": os.path.join(os.path.dirname(__file__), "background.png"),
"boxes": [],
}
# ------------- Gradio UI -------------
with gr.Blocks() as demo:
with gr.Tab("DreamRenderer", id="DreamRenderer"):
with gr.Row():
with gr.Column(scale=1):
annotator = image_annotator(example_annotation, height=512, width=512)
with gr.Column(scale=1):
generated_image = gr.Image(label="Generated Image", height=512, width=512)
seed_input = gr.Number(label="Seed (Optional)", precision=0)
seed_random_btn = gr.Button("🎲 Random Seed")
edit_mode_toggle = gr.Checkbox(label="Edit Mode")
style_selector = gr.Radio(choices=["realistic", "anime"], label="风格选择", value="realistic")
button_get = gr.Button("生成图像")
button_get.click(
fn=get_boxes_json,
inputs=[annotator, seed_input, edit_mode_toggle, style_selector],
outputs=generated_image
)
seed_random_btn.click(fn=generate_random_seed, inputs=[], outputs=seed_input)
if __name__ == "__main__":
demo.launch(share=True)
|