File size: 2,897 Bytes
ec7f44e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21e587f
 
 
 
ec7f44e
 
21e587f
 
 
 
 
 
 
 
 
ec7f44e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import os
import tempfile
import cv2
from test import predict_one
from plot import (
    autocrop, get_json_corners, extract_points_from_xml,
    draw_feature_matching, stack_images_side_by_side
)

# Hard-coded model checkpoint path
MODEL_CKPT = "best_model.pth"


# --------------------
# Pipeline
# --------------------
def run_pipeline(flat_img, pers_img, mockup_json, xml_gt):
    # Temp dir for prediction + result
    tmpdir = tempfile.mkdtemp()
    xml_pred_path = os.path.join(tmpdir, "pred.xml")
    result_path   = os.path.join(tmpdir, "result.png")

    # Run prediction
    predict_one(mockup_json, pers_img, MODEL_CKPT, out_path=xml_pred_path)

    # --- Visualization ---
    img_json = autocrop(cv2.cvtColor(cv2.imread(flat_img), cv2.COLOR_BGR2RGB))
    img_xml  = autocrop(cv2.cvtColor(cv2.imread(pers_img), cv2.COLOR_BGR2RGB))

    json_pts = get_json_corners(mockup_json)
    gt_pts   = extract_points_from_xml(xml_gt)
    pred_pts = extract_points_from_xml(xml_pred_path)
    color = (0, 255, 0)  # Green for boxes
    color2 = (0, 0, 255) # Blue for lines
    match_json_gt   = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), gt_pts, color,draw_boxes=True)
    match_json_pred = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), pred_pts, color2,draw_boxes=True)

    stacked = stack_images_side_by_side(match_json_gt, match_json_pred)
    # Add vertical center line
    h, w, _ = stacked.shape
    center_x = w // 2
    cv2.line(stacked, (center_x, 0), (center_x, h), (255, 0, 0), 4)  # blue line

    # Add text labels
    font = cv2.FONT_HERSHEY_SIMPLEX
    cv2.putText(stacked, "Ground Truth", (50, 50), font, 2, (0, 255, 0), 3, cv2.LINE_AA)
    cv2.putText(stacked, "Our Result", (center_x + 50, 50), font, 2, (0, 0, 255), 3, cv2.LINE_AA)

    # Save result
    cv2.imwrite(result_path, cv2.cvtColor(stacked, cv2.COLOR_RGB2BGR))

    return result_path, xml_pred_path


# --------------------
# Gradio UI
# --------------------
with gr.Blocks() as demo:
    gr.Markdown("## Mesh Key Point Transformer Demo")

    with gr.Row():
        flat_in = gr.Image(type="filepath", label="Flat Image", width=300, height=300)
        pers_in = gr.Image(type="filepath", label="Perspective Image", width=300, height=300)

    with gr.Row():
        mockup_json_in = gr.File(type="filepath", label="Mockup JSON")
        xml_gt_in      = gr.File(type="filepath", label="Ground Truth XML")

    run_btn = gr.Button("Run Prediction + Visualization")

    with gr.Row():
        out_img = gr.Image(type="filepath", label="Comparison Output", width=800, height=600)
        out_xml = gr.File(type="filepath", label="Predicted XML")

    run_btn.click(
        fn=run_pipeline,
        inputs=[flat_in, pers_in, mockup_json_in, xml_gt_in],
        outputs=[out_img, out_xml]
    )


if __name__ == "__main__":
    demo.launch(share=True)