import gradio as gr import os import tempfile import cv2 from test import predict_one from plot import ( autocrop, get_json_corners, extract_points_from_xml, draw_feature_matching, stack_images_side_by_side ) # Hard-coded model checkpoint path MODEL_CKPT = "best_model.pth" # -------------------- # Pipeline # -------------------- def run_pipeline(flat_img, pers_img, mockup_json, xml_gt): # Temp dir for prediction + result tmpdir = tempfile.mkdtemp() xml_pred_path = os.path.join(tmpdir, "pred.xml") result_path = os.path.join(tmpdir, "result.png") # Run prediction predict_one(mockup_json, pers_img, MODEL_CKPT, out_path=xml_pred_path) # --- Visualization --- img_json = autocrop(cv2.cvtColor(cv2.imread(flat_img), cv2.COLOR_BGR2RGB)) img_xml = autocrop(cv2.cvtColor(cv2.imread(pers_img), cv2.COLOR_BGR2RGB)) json_pts = get_json_corners(mockup_json) gt_pts = extract_points_from_xml(xml_gt) pred_pts = extract_points_from_xml(xml_pred_path) color = (0, 255, 0) # Green for boxes color2 = (0, 0, 255) # Blue for lines match_json_gt = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), gt_pts, color,draw_boxes=True) match_json_pred = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), pred_pts, color2,draw_boxes=True) stacked = stack_images_side_by_side(match_json_gt, match_json_pred) # Add vertical center line h, w, _ = stacked.shape center_x = w // 2 cv2.line(stacked, (center_x, 0), (center_x, h), (255, 0, 0), 4) # blue line # Add text labels font = cv2.FONT_HERSHEY_SIMPLEX cv2.putText(stacked, "Ground Truth", (50, 50), font, 2, (0, 255, 0), 3, cv2.LINE_AA) cv2.putText(stacked, "Our Result", (center_x + 50, 50), font, 2, (0, 0, 255), 3, cv2.LINE_AA) # Save result cv2.imwrite(result_path, cv2.cvtColor(stacked, cv2.COLOR_RGB2BGR)) return result_path, xml_pred_path # -------------------- # Gradio UI # -------------------- with gr.Blocks() as demo: gr.Markdown("## Mesh Key Point Transformer Demo") with gr.Row(): flat_in = gr.Image(type="filepath", label="Flat Image", width=300, height=300) pers_in = gr.Image(type="filepath", label="Perspective Image", width=300, height=300) with gr.Row(): mockup_json_in = gr.File(type="filepath", label="Mockup JSON") xml_gt_in = gr.File(type="filepath", label="Ground Truth XML") run_btn = gr.Button("Run Prediction + Visualization") with gr.Row(): out_img = gr.Image(type="filepath", label="Comparison Output", width=800, height=600) out_xml = gr.File(type="filepath", label="Predicted XML") run_btn.click( fn=run_pipeline, inputs=[flat_in, pers_in, mockup_json_in, xml_gt_in], outputs=[out_img, out_xml] ) if __name__ == "__main__": demo.launch(share=True)