UniPic / app.py
yichenchenchen's picture
Update app.py
da489db verified
import gradio as gr
from PIL import Image
import os
import tempfile
import sys
import time
from inferencer import Inferencer
from accelerate.utils import set_seed
from huggingface_hub import snapshot_download
model_dir = snapshot_download(repo_id="Skywork/Skywork-UniPic-1.5B")
model_path = os.path.join(model_dir,"pytorch_model.bin")
ckpt_name = "UniPic"
inferencer = Inferencer(
config_file="qwen2_5_1_5b_kl16_mar_h.py",
model_path=model_path,
image_size=1024,
#cfg_prompt="Generate an image.",
)
TEMP_DIR = tempfile.mkdtemp()
print(f"Temporary directory created at: {TEMP_DIR}")
def save_temp_image(pil_img):
# 只支持512——>1024的编辑
# img_resized = pil_img.resize((512, 512))
path = os.path.join(TEMP_DIR, f"temp_{int(time.time())}.png")
pil_img.save(path, format="PNG")
return path
def handle_image_upload(file, history):
if file is None:
return None, history
file_path = file.name if hasattr(file, "name") else file
pil_img = Image.open(file_path)
saved_path = save_temp_image(pil_img)
return saved_path, history + [((saved_path,), None)]
def clear_all():
for file in os.listdir(TEMP_DIR):
path = os.path.join(TEMP_DIR, file)
try:
if os.path.isfile(path):
os.remove(path)
except Exception as e:
print(f"Failed to delete temp file: {path}, error: {e}")
return [], None, "Understand Image"
def extract_assistant_reply(full_text):
if "assistant" in full_text:
parts = full_text.strip().split("assistant")
return parts[-1].lstrip(":").strip()
return full_text.replace("<|im_end|>", "").strip()
def on_submit(history, user_msg, img_path, mode, grid_size=1):
# 把 history 中的 tuples 全部换成 lists
updated_history = [list(item) for item in history]
user_msg = user_msg.strip()
updated_history.append([user_msg, None])
# set_seed(42)
try:
if mode == "Understand Image":
if img_path is None:
updated_history.append([None, "⚠️ Please upload or generate an image first."])
return updated_history, "", img_path
raw = (
inferencer.query_image(Image.open(img_path), user_msg)
if img_path else inferencer.query_text(user_msg)
)
reply = extract_assistant_reply(raw)
updated_history.append([None, reply])
return updated_history, "", img_path
elif mode == "Generate Image":
if not user_msg:
updated_history.append([None, "⚠️ Please enter a prompt."])
return updated_history, "", img_path
imgs = inferencer.gen_image(
raw_prompt=user_msg,
images_to_generate=grid_size**2,
cfg=3.0,
num_iter=48,
cfg_schedule="constant",
temperature=1.0,
)
paths = [save_temp_image(img) for img in imgs]
# 多图必须是列表格式
updated_history.append([None, paths])
return updated_history, "", paths[-1]
elif mode == "Edit Image":
if img_path is None:
updated_history.append([None, "⚠️ Please upload or generate an image first."])
return updated_history, "", img_path
if not user_msg:
updated_history.append([None, "⚠️ Please enter an edit instruction."])
return updated_history, "", img_path
img = Image.open(img_path)
imgs = inferencer.edit_image(
source_image=img,
prompt=user_msg,
cfg=3.0,
cfg_prompt="repeat this image.",
cfg_schedule="constant",
temperature=0.85,
grid_size=grid_size,
num_iter=48,
)
paths = [save_temp_image(img) for img in imgs]
updated_history.append([None, paths])
return updated_history, "", paths[-1]
except Exception as e:
updated_history.append([None, f"⚠️ Failed to process: {e}"])
return updated_history, "", img_path
CSS = """
/* 整体布局:上下两块 */
.gradio-container {
display: flex !important;
flex-direction: column;
height: 100vh;
margin: 0;
padding: 0;
}
.gr-tabs { /* ✅ 新增:确保 tab 能继承高度 */
flex: 1 1 auto;
display: flex;
flex-direction: column;
}
/* 聊天 tab */
#tab_item_4, #tab_item_5 {
display: flex;
flex-direction: column;
flex: 1 1 auto;
overflow: hidden; /* 防止出现双滚动条 */
padding: 8px;
}
/* Chatbot 撑满 */
#chatbot1, #chatbot2{
flex-grow: 1 !important;
max-height: 66vh !important; /* 限制聊天框最大高度为屏幕的2/3 */
overflow-y: auto !important; /* 当内容溢出时显示滚动条 */
border: 1px solid #ddd;
border-radius: 8px;
padding: 12px;
margin-bottom: 8px;
}
/* 图片消息放大 */
#chatbot1 img, #chatbot2 img {
max-width: 80vw !important;
height: auto !important;
border-radius: 4px;
}
/* 底部输入区:固定高度 */
.input-row {
flex: 0 0 auto;
display: flex;
align-items: center;
padding: 8px;
border-top: 1px solid #eee;
background: #fafafa;
}
/* 文本框和按钮排布 */
.input-row .textbox-col { flex: 5; }
.input-row .upload-col, .input-row .clear-col { flex: 1; margin-left: 8px; }
/* 文本框样式 */
.gr-text-input {
width: 100% !important;
border-radius: 18px !important;
padding: 8px 16px !important;
border: 1px solid #ddd !important;
font-size: 16px !important;
}
/* 按钮和上传组件样式 */
.gr-button, .gr-upload {
width: 100% !important;
border-radius: 18px !important;
padding: 8px 16px !important;
font-size: 16px !important;
}
"""
with gr.Blocks(css=CSS) as demo:
img_state = gr.State(value=None)
mode_state = gr.State(value="Understand Image")
with gr.Tabs():
with gr.Tab("Skywork UniPic Chatbot", elem_id="tab_item_4"):
chatbot = gr.Chatbot(
elem_id="chatbot1",
show_label=False,
avatar_images=(
"user.png",
"ai.png",
),
)
with gr.Row():
mode_selector = gr.Radio(
choices=["Generate Image","Edit Image","Understand Image"],
value="Generate Image",
label="Mode",
interactive=True,
)
with gr.Row(elem_classes="input-row"):
with gr.Column(elem_classes="textbox-col"):
user_input = gr.Textbox(
placeholder="Type your message here...",
show_label=False,
lines=1,
)
with gr.Column(elem_classes="upload-col"):
image_input = gr.UploadButton(
"📷 Upload Image",
file_types=["image"],
file_count="single",
type="filepath",
)
with gr.Column(elem_classes="clear-col"):
clear_btn = gr.Button("🧹 Clear History")
user_input.submit(
on_submit,
[chatbot, user_input, img_state, mode_selector],
[chatbot, user_input, img_state],
)
image_input.upload(
handle_image_upload, [image_input, chatbot], [img_state, chatbot]
)
clear_btn.click(clear_all, outputs=[chatbot, img_state, mode_selector])
# if __name__ == "__main__":
# demo.launch(server_name="0.0.0.0", share=True, debug=True, server_port=7689)
demo.launch()