saim1309's picture
Update app.py
21e587f verified
import gradio as gr
import os
import tempfile
import cv2
from test import predict_one
from plot import (
autocrop, get_json_corners, extract_points_from_xml,
draw_feature_matching, stack_images_side_by_side
)
# Hard-coded model checkpoint path
MODEL_CKPT = "best_model.pth"
# --------------------
# Pipeline
# --------------------
def run_pipeline(flat_img, pers_img, mockup_json, xml_gt):
# Temp dir for prediction + result
tmpdir = tempfile.mkdtemp()
xml_pred_path = os.path.join(tmpdir, "pred.xml")
result_path = os.path.join(tmpdir, "result.png")
# Run prediction
predict_one(mockup_json, pers_img, MODEL_CKPT, out_path=xml_pred_path)
# --- Visualization ---
img_json = autocrop(cv2.cvtColor(cv2.imread(flat_img), cv2.COLOR_BGR2RGB))
img_xml = autocrop(cv2.cvtColor(cv2.imread(pers_img), cv2.COLOR_BGR2RGB))
json_pts = get_json_corners(mockup_json)
gt_pts = extract_points_from_xml(xml_gt)
pred_pts = extract_points_from_xml(xml_pred_path)
color = (0, 255, 0) # Green for boxes
color2 = (0, 0, 255) # Blue for lines
match_json_gt = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), gt_pts, color,draw_boxes=True)
match_json_pred = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), pred_pts, color2,draw_boxes=True)
stacked = stack_images_side_by_side(match_json_gt, match_json_pred)
# Add vertical center line
h, w, _ = stacked.shape
center_x = w // 2
cv2.line(stacked, (center_x, 0), (center_x, h), (255, 0, 0), 4) # blue line
# Add text labels
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(stacked, "Ground Truth", (50, 50), font, 2, (0, 255, 0), 3, cv2.LINE_AA)
cv2.putText(stacked, "Our Result", (center_x + 50, 50), font, 2, (0, 0, 255), 3, cv2.LINE_AA)
# Save result
cv2.imwrite(result_path, cv2.cvtColor(stacked, cv2.COLOR_RGB2BGR))
return result_path, xml_pred_path
# --------------------
# Gradio UI
# --------------------
with gr.Blocks() as demo:
gr.Markdown("## Mesh Key Point Transformer Demo")
with gr.Row():
flat_in = gr.Image(type="filepath", label="Flat Image", width=300, height=300)
pers_in = gr.Image(type="filepath", label="Perspective Image", width=300, height=300)
with gr.Row():
mockup_json_in = gr.File(type="filepath", label="Mockup JSON")
xml_gt_in = gr.File(type="filepath", label="Ground Truth XML")
run_btn = gr.Button("Run Prediction + Visualization")
with gr.Row():
out_img = gr.Image(type="filepath", label="Comparison Output", width=800, height=600)
out_xml = gr.File(type="filepath", label="Predicted XML")
run_btn.click(
fn=run_pipeline,
inputs=[flat_in, pers_in, mockup_json_in, xml_gt_in],
outputs=[out_img, out_xml]
)
if __name__ == "__main__":
demo.launch(share=True)