import gradio as gr from gradio_image_annotation import image_annotator from diffusers import EulerDiscreteScheduler import torch import os import random from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore from migc.migc_utils import seed_everything, load_migc from huggingface_hub import hf_hub_download # 下载模型文件 migc_ckpt_path = hf_hub_download(repo_id="limuloo1999/MIGC", filename="MIGC_SD14.ckpt") RV_path = hf_hub_download(repo_id="SG161222/Realistic_Vision_V6.0_B1_noVAE", filename="Realistic_Vision_V6.0_NV_B1.safetensors") anime_path = hf_hub_download(repo_id="ckpt/cetus-mix", filename="cetusMix_v4.safetensors") # -------- 风格切换器类 -------- class StyleSwitcher: def __init__(self): self.pipe = None self.attn_store = AttentionStore() self.styles = { "realistic": RV_path, "anime": anime_path } self.current_style = None def load_model(self, style): if style == self.current_style: return self.pipe if self.pipe: del self.pipe torch.cuda.empty_cache() print(f"[Info] Switched from {self.current_style} to {style}.") model_path = self.styles[style] print(f"[Info] Loading {style} model...") self.pipe = StableDiffusionMIGCPipeline.from_single_file( model_path, torch_dtype=torch.float32 ) self.pipe.safety_checker = None self.pipe.attention_store = self.attn_store load_migc(self.pipe.unet, self.attn_store, migc_ckpt_path, attn_processor=MIGCProcessor) self.pipe = self.pipe.to("cuda" if torch.cuda.is_available() else "cpu") self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config) self.current_style = style return self.pipe style_switcher = StyleSwitcher() # ⬇️ 新增函数:返回随机 seed def generate_random_seed(): return random.randint(0, 2**32 - 1) # 生成函数 def get_boxes_json(annotations, seed_value, edit_mode, style_selection): seed_everything(seed_value) pipe = style_switcher.load_model(style_selection) image = annotations["image"] width = image.shape[1] height = image.shape[0] boxes = annotations["boxes"] prompt_final = [[]] bboxes = [[]] for box in boxes: box["xmin"] /= width box["xmax"] /= width box["ymin"] /= height box["ymax"] /= height prompt_final[0].append(box["label"]) bboxes[0].append([box["xmin"], box["ymin"], box["xmax"], box["ymax"]]) prompt = ", ".join(prompt_final[0]) prompt_final[0].insert(0, prompt) negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry' output_image = pipe(prompt_final, bboxes, num_inference_steps=30, guidance_scale=7.5, MIGCsteps=15, aug_phase_with_and=False, negative_prompt=negative_prompt, sa_preserve=True, use_sa_preserve=edit_mode).images[0] return output_image # 示例标注图 example_annotation = { "image": os.path.join(os.path.dirname(__file__), "background.png"), "boxes": [], } # ------------- Gradio UI ------------- with gr.Blocks() as demo: with gr.Tab("DreamRenderer", id="DreamRenderer"): with gr.Row(): with gr.Column(scale=1): annotator = image_annotator(example_annotation, height=512, width=512) with gr.Column(scale=1): generated_image = gr.Image(label="Generated Image", height=512, width=512) seed_input = gr.Number(label="Seed (Optional)", precision=0) seed_random_btn = gr.Button("🎲 Random Seed") edit_mode_toggle = gr.Checkbox(label="Edit Mode") style_selector = gr.Radio(choices=["realistic", "anime"], label="风格选择", value="realistic") button_get = gr.Button("生成图像") button_get.click( fn=get_boxes_json, inputs=[annotator, seed_input, edit_mode_toggle, style_selector], outputs=generated_image ) seed_random_btn.click(fn=generate_random_seed, inputs=[], outputs=seed_input) if __name__ == "__main__": demo.launch(share=True)