import gradio as gr from PIL import Image import os import tempfile import sys import time from inferencer import Inferencer from accelerate.utils import set_seed from huggingface_hub import snapshot_download model_dir = snapshot_download(repo_id="Skywork/Skywork-UniPic-1.5B") model_path = os.path.join(model_dir,"pytorch_model.bin") ckpt_name = "UniPic" inferencer = Inferencer( config_file="qwen2_5_1_5b_kl16_mar_h.py", model_path=model_path, image_size=1024, #cfg_prompt="Generate an image.", ) TEMP_DIR = tempfile.mkdtemp() print(f"Temporary directory created at: {TEMP_DIR}") def save_temp_image(pil_img): # 只支持512——>1024的编辑 # img_resized = pil_img.resize((512, 512)) path = os.path.join(TEMP_DIR, f"temp_{int(time.time())}.png") pil_img.save(path, format="PNG") return path def handle_image_upload(file, history): if file is None: return None, history file_path = file.name if hasattr(file, "name") else file pil_img = Image.open(file_path) saved_path = save_temp_image(pil_img) return saved_path, history + [((saved_path,), None)] def clear_all(): for file in os.listdir(TEMP_DIR): path = os.path.join(TEMP_DIR, file) try: if os.path.isfile(path): os.remove(path) except Exception as e: print(f"Failed to delete temp file: {path}, error: {e}") return [], None, "Understand Image" def extract_assistant_reply(full_text): if "assistant" in full_text: parts = full_text.strip().split("assistant") return parts[-1].lstrip(":").strip() return full_text.replace("<|im_end|>", "").strip() def on_submit(history, user_msg, img_path, mode, grid_size=1): # 把 history 中的 tuples 全部换成 lists updated_history = [list(item) for item in history] user_msg = user_msg.strip() updated_history.append([user_msg, None]) # set_seed(42) try: if mode == "Understand Image": if img_path is None: updated_history.append([None, "⚠️ Please upload or generate an image first."]) return updated_history, "", img_path raw = ( inferencer.query_image(Image.open(img_path), user_msg) if img_path else inferencer.query_text(user_msg) ) reply = extract_assistant_reply(raw) updated_history.append([None, reply]) return updated_history, "", img_path elif mode == "Generate Image": if not user_msg: updated_history.append([None, "⚠️ Please enter a prompt."]) return updated_history, "", img_path imgs = inferencer.gen_image( raw_prompt=user_msg, images_to_generate=grid_size**2, cfg=3.0, num_iter=48, cfg_schedule="constant", temperature=1.0, ) paths = [save_temp_image(img) for img in imgs] # 多图必须是列表格式 updated_history.append([None, paths]) return updated_history, "", paths[-1] elif mode == "Edit Image": if img_path is None: updated_history.append([None, "⚠️ Please upload or generate an image first."]) return updated_history, "", img_path if not user_msg: updated_history.append([None, "⚠️ Please enter an edit instruction."]) return updated_history, "", img_path img = Image.open(img_path) imgs = inferencer.edit_image( source_image=img, prompt=user_msg, cfg=3.0, cfg_prompt="repeat this image.", cfg_schedule="constant", temperature=0.85, grid_size=grid_size, num_iter=48, ) paths = [save_temp_image(img) for img in imgs] updated_history.append([None, paths]) return updated_history, "", paths[-1] except Exception as e: updated_history.append([None, f"⚠️ Failed to process: {e}"]) return updated_history, "", img_path CSS = """ /* 整体布局:上下两块 */ .gradio-container { display: flex !important; flex-direction: column; height: 100vh; margin: 0; padding: 0; } .gr-tabs { /* ✅ 新增:确保 tab 能继承高度 */ flex: 1 1 auto; display: flex; flex-direction: column; } /* 聊天 tab */ #tab_item_4, #tab_item_5 { display: flex; flex-direction: column; flex: 1 1 auto; overflow: hidden; /* 防止出现双滚动条 */ padding: 8px; } /* Chatbot 撑满 */ #chatbot1, #chatbot2{ flex-grow: 1 !important; max-height: 66vh !important; /* 限制聊天框最大高度为屏幕的2/3 */ overflow-y: auto !important; /* 当内容溢出时显示滚动条 */ border: 1px solid #ddd; border-radius: 8px; padding: 12px; margin-bottom: 8px; } /* 图片消息放大 */ #chatbot1 img, #chatbot2 img { max-width: 80vw !important; height: auto !important; border-radius: 4px; } /* 底部输入区:固定高度 */ .input-row { flex: 0 0 auto; display: flex; align-items: center; padding: 8px; border-top: 1px solid #eee; background: #fafafa; } /* 文本框和按钮排布 */ .input-row .textbox-col { flex: 5; } .input-row .upload-col, .input-row .clear-col { flex: 1; margin-left: 8px; } /* 文本框样式 */ .gr-text-input { width: 100% !important; border-radius: 18px !important; padding: 8px 16px !important; border: 1px solid #ddd !important; font-size: 16px !important; } /* 按钮和上传组件样式 */ .gr-button, .gr-upload { width: 100% !important; border-radius: 18px !important; padding: 8px 16px !important; font-size: 16px !important; } """ with gr.Blocks(css=CSS) as demo: img_state = gr.State(value=None) mode_state = gr.State(value="Understand Image") with gr.Tabs(): with gr.Tab("Skywork UniPic Chatbot", elem_id="tab_item_4"): chatbot = gr.Chatbot( elem_id="chatbot1", show_label=False, avatar_images=( "user.png", "ai.png", ), ) with gr.Row(): mode_selector = gr.Radio( choices=["Generate Image","Edit Image","Understand Image"], value="Generate Image", label="Mode", interactive=True, ) with gr.Row(elem_classes="input-row"): with gr.Column(elem_classes="textbox-col"): user_input = gr.Textbox( placeholder="Type your message here...", show_label=False, lines=1, ) with gr.Column(elem_classes="upload-col"): image_input = gr.UploadButton( "📷 Upload Image", file_types=["image"], file_count="single", type="filepath", ) with gr.Column(elem_classes="clear-col"): clear_btn = gr.Button("🧹 Clear History") user_input.submit( on_submit, [chatbot, user_input, img_state, mode_selector], [chatbot, user_input, img_state], ) image_input.upload( handle_image_upload, [image_input, chatbot], [img_state, chatbot] ) clear_btn.click(clear_all, outputs=[chatbot, img_state, mode_selector]) # if __name__ == "__main__": # demo.launch(server_name="0.0.0.0", share=True, debug=True, server_port=7689) demo.launch()