Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from PIL import Image | |
import os | |
import tempfile | |
import sys | |
import time | |
from inferencer import Inferencer | |
from accelerate.utils import set_seed | |
from huggingface_hub import snapshot_download | |
model_dir = snapshot_download(repo_id="Skywork/Skywork-UniPic-1.5B") | |
model_path = os.path.join(model_dir,"pytorch_model.bin") | |
ckpt_name = "UniPic" | |
inferencer = Inferencer( | |
config_file="qwen2_5_1_5b_kl16_mar_h.py", | |
model_path=model_path, | |
image_size=1024, | |
#cfg_prompt="Generate an image.", | |
) | |
TEMP_DIR = tempfile.mkdtemp() | |
print(f"Temporary directory created at: {TEMP_DIR}") | |
def save_temp_image(pil_img): | |
# 只支持512——>1024的编辑 | |
# img_resized = pil_img.resize((512, 512)) | |
path = os.path.join(TEMP_DIR, f"temp_{int(time.time())}.png") | |
pil_img.save(path, format="PNG") | |
return path | |
def handle_image_upload(file, history): | |
if file is None: | |
return None, history | |
file_path = file.name if hasattr(file, "name") else file | |
pil_img = Image.open(file_path) | |
saved_path = save_temp_image(pil_img) | |
return saved_path, history + [((saved_path,), None)] | |
def clear_all(): | |
for file in os.listdir(TEMP_DIR): | |
path = os.path.join(TEMP_DIR, file) | |
try: | |
if os.path.isfile(path): | |
os.remove(path) | |
except Exception as e: | |
print(f"Failed to delete temp file: {path}, error: {e}") | |
return [], None, "Understand Image" | |
def extract_assistant_reply(full_text): | |
if "assistant" in full_text: | |
parts = full_text.strip().split("assistant") | |
return parts[-1].lstrip(":").strip() | |
return full_text.replace("<|im_end|>", "").strip() | |
def on_submit(history, user_msg, img_path, mode, grid_size=1): | |
# 把 history 中的 tuples 全部换成 lists | |
updated_history = [list(item) for item in history] | |
user_msg = user_msg.strip() | |
updated_history.append([user_msg, None]) | |
# set_seed(42) | |
try: | |
if mode == "Understand Image": | |
if img_path is None: | |
updated_history.append([None, "⚠️ Please upload or generate an image first."]) | |
return updated_history, "", img_path | |
raw = ( | |
inferencer.query_image(Image.open(img_path), user_msg) | |
if img_path else inferencer.query_text(user_msg) | |
) | |
reply = extract_assistant_reply(raw) | |
updated_history.append([None, reply]) | |
return updated_history, "", img_path | |
elif mode == "Generate Image": | |
if not user_msg: | |
updated_history.append([None, "⚠️ Please enter a prompt."]) | |
return updated_history, "", img_path | |
imgs = inferencer.gen_image( | |
raw_prompt=user_msg, | |
images_to_generate=grid_size**2, | |
cfg=3.0, | |
num_iter=48, | |
cfg_schedule="constant", | |
temperature=1.0, | |
) | |
paths = [save_temp_image(img) for img in imgs] | |
# 多图必须是列表格式 | |
updated_history.append([None, paths]) | |
return updated_history, "", paths[-1] | |
elif mode == "Edit Image": | |
if img_path is None: | |
updated_history.append([None, "⚠️ Please upload or generate an image first."]) | |
return updated_history, "", img_path | |
if not user_msg: | |
updated_history.append([None, "⚠️ Please enter an edit instruction."]) | |
return updated_history, "", img_path | |
img = Image.open(img_path) | |
imgs = inferencer.edit_image( | |
source_image=img, | |
prompt=user_msg, | |
cfg=3.0, | |
cfg_prompt="repeat this image.", | |
cfg_schedule="constant", | |
temperature=0.85, | |
grid_size=grid_size, | |
num_iter=48, | |
) | |
paths = [save_temp_image(img) for img in imgs] | |
updated_history.append([None, paths]) | |
return updated_history, "", paths[-1] | |
except Exception as e: | |
updated_history.append([None, f"⚠️ Failed to process: {e}"]) | |
return updated_history, "", img_path | |
CSS = """ | |
/* 整体布局:上下两块 */ | |
.gradio-container { | |
display: flex !important; | |
flex-direction: column; | |
height: 100vh; | |
margin: 0; | |
padding: 0; | |
} | |
.gr-tabs { /* ✅ 新增:确保 tab 能继承高度 */ | |
flex: 1 1 auto; | |
display: flex; | |
flex-direction: column; | |
} | |
/* 聊天 tab */ | |
#tab_item_4, #tab_item_5 { | |
display: flex; | |
flex-direction: column; | |
flex: 1 1 auto; | |
overflow: hidden; /* 防止出现双滚动条 */ | |
padding: 8px; | |
} | |
/* Chatbot 撑满 */ | |
#chatbot1, #chatbot2{ | |
flex-grow: 1 !important; | |
max-height: 66vh !important; /* 限制聊天框最大高度为屏幕的2/3 */ | |
overflow-y: auto !important; /* 当内容溢出时显示滚动条 */ | |
border: 1px solid #ddd; | |
border-radius: 8px; | |
padding: 12px; | |
margin-bottom: 8px; | |
} | |
/* 图片消息放大 */ | |
#chatbot1 img, #chatbot2 img { | |
max-width: 80vw !important; | |
height: auto !important; | |
border-radius: 4px; | |
} | |
/* 底部输入区:固定高度 */ | |
.input-row { | |
flex: 0 0 auto; | |
display: flex; | |
align-items: center; | |
padding: 8px; | |
border-top: 1px solid #eee; | |
background: #fafafa; | |
} | |
/* 文本框和按钮排布 */ | |
.input-row .textbox-col { flex: 5; } | |
.input-row .upload-col, .input-row .clear-col { flex: 1; margin-left: 8px; } | |
/* 文本框样式 */ | |
.gr-text-input { | |
width: 100% !important; | |
border-radius: 18px !important; | |
padding: 8px 16px !important; | |
border: 1px solid #ddd !important; | |
font-size: 16px !important; | |
} | |
/* 按钮和上传组件样式 */ | |
.gr-button, .gr-upload { | |
width: 100% !important; | |
border-radius: 18px !important; | |
padding: 8px 16px !important; | |
font-size: 16px !important; | |
} | |
""" | |
with gr.Blocks(css=CSS) as demo: | |
img_state = gr.State(value=None) | |
mode_state = gr.State(value="Understand Image") | |
with gr.Tabs(): | |
with gr.Tab("Skywork UniPic Chatbot", elem_id="tab_item_4"): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot1", | |
show_label=False, | |
avatar_images=( | |
"user.png", | |
"ai.png", | |
), | |
) | |
with gr.Row(): | |
mode_selector = gr.Radio( | |
choices=["Generate Image","Edit Image","Understand Image"], | |
value="Generate Image", | |
label="Mode", | |
interactive=True, | |
) | |
with gr.Row(elem_classes="input-row"): | |
with gr.Column(elem_classes="textbox-col"): | |
user_input = gr.Textbox( | |
placeholder="Type your message here...", | |
show_label=False, | |
lines=1, | |
) | |
with gr.Column(elem_classes="upload-col"): | |
image_input = gr.UploadButton( | |
"📷 Upload Image", | |
file_types=["image"], | |
file_count="single", | |
type="filepath", | |
) | |
with gr.Column(elem_classes="clear-col"): | |
clear_btn = gr.Button("🧹 Clear History") | |
user_input.submit( | |
on_submit, | |
[chatbot, user_input, img_state, mode_selector], | |
[chatbot, user_input, img_state], | |
) | |
image_input.upload( | |
handle_image_upload, [image_input, chatbot], [img_state, chatbot] | |
) | |
clear_btn.click(clear_all, outputs=[chatbot, img_state, mode_selector]) | |
# if __name__ == "__main__": | |
# demo.launch(server_name="0.0.0.0", share=True, debug=True, server_port=7689) | |
demo.launch() |