|
import gradio as gr |
|
import os |
|
import tempfile |
|
import cv2 |
|
from test import predict_one |
|
from plot import ( |
|
autocrop, get_json_corners, extract_points_from_xml, |
|
draw_feature_matching, stack_images_side_by_side |
|
) |
|
|
|
|
|
MODEL_CKPT = "best_model.pth" |
|
|
|
|
|
|
|
|
|
|
|
def run_pipeline(flat_img, pers_img, mockup_json, xml_gt): |
|
|
|
tmpdir = tempfile.mkdtemp() |
|
xml_pred_path = os.path.join(tmpdir, "pred.xml") |
|
result_path = os.path.join(tmpdir, "result.png") |
|
|
|
|
|
predict_one(mockup_json, pers_img, MODEL_CKPT, out_path=xml_pred_path) |
|
|
|
|
|
img_json = autocrop(cv2.cvtColor(cv2.imread(flat_img), cv2.COLOR_BGR2RGB)) |
|
img_xml = autocrop(cv2.cvtColor(cv2.imread(pers_img), cv2.COLOR_BGR2RGB)) |
|
|
|
json_pts = get_json_corners(mockup_json) |
|
gt_pts = extract_points_from_xml(xml_gt) |
|
pred_pts = extract_points_from_xml(xml_pred_path) |
|
color = (0, 255, 0) |
|
color2 = (0, 0, 255) |
|
match_json_gt = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), gt_pts, color,draw_boxes=True) |
|
match_json_pred = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), pred_pts, color2,draw_boxes=True) |
|
|
|
stacked = stack_images_side_by_side(match_json_gt, match_json_pred) |
|
|
|
h, w, _ = stacked.shape |
|
center_x = w // 2 |
|
cv2.line(stacked, (center_x, 0), (center_x, h), (255, 0, 0), 4) |
|
|
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
cv2.putText(stacked, "Ground Truth", (50, 50), font, 2, (0, 255, 0), 3, cv2.LINE_AA) |
|
cv2.putText(stacked, "Our Result", (center_x + 50, 50), font, 2, (0, 0, 255), 3, cv2.LINE_AA) |
|
|
|
|
|
cv2.imwrite(result_path, cv2.cvtColor(stacked, cv2.COLOR_RGB2BGR)) |
|
|
|
return result_path, xml_pred_path |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Mesh Key Point Transformer Demo") |
|
|
|
with gr.Row(): |
|
flat_in = gr.Image(type="filepath", label="Flat Image", width=300, height=300) |
|
pers_in = gr.Image(type="filepath", label="Perspective Image", width=300, height=300) |
|
|
|
with gr.Row(): |
|
mockup_json_in = gr.File(type="filepath", label="Mockup JSON") |
|
xml_gt_in = gr.File(type="filepath", label="Ground Truth XML") |
|
|
|
run_btn = gr.Button("Run Prediction + Visualization") |
|
|
|
with gr.Row(): |
|
out_img = gr.Image(type="filepath", label="Comparison Output", width=800, height=600) |
|
out_xml = gr.File(type="filepath", label="Predicted XML") |
|
|
|
run_btn.click( |
|
fn=run_pipeline, |
|
inputs=[flat_in, pers_in, mockup_json_in, xml_gt_in], |
|
outputs=[out_img, out_xml] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|