File size: 4,256 Bytes
ec7f44e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import xml.etree.ElementTree as ET
from PIL import Image
import numpy as np
from pathlib import Path
import json

from train import SimpleTransformer, flat_corners_from_mockup

# --------------------
# Utility: order 4 points (same as old)
# --------------------
def order_points_clockwise(pts):
    pts = np.array(pts, dtype="float32")
    y_sorted = pts[np.argsort(pts[:, 1]), :]

    top_two = y_sorted[:2, :]
    bottom_two = y_sorted[2:, :]

    if top_two[0][0] < top_two[1][0]:
        tl, tr = top_two
    else:
        tr, tl = top_two

    if bottom_two[0][0] < bottom_two[1][0]:
        bl, br = bottom_two
    else:
        br, bl = bottom_two

    return np.array([tl, tr, br, bl], dtype="float32")

# --------------------
# Utility: save XML prediction
# --------------------
def save_prediction_xml(pred_pts, out_path, img_w, img_h):
    ordered = order_points_clockwise(pred_pts)
    TL, TR, BR, BL = ordered

    root = ET.Element("visualization", version="1.0")
    ET.SubElement(root, "effects", surfacecolor="", iswood="0")
    ET.SubElement(root, "background",
                  width=str(img_w), height=str(img_h),
                  color1="#C4CDE4", color2="", color3="")

    transforms_node = ET.SubElement(root, "transforms")
    transform = ET.SubElement(transforms_node, "transform",
                              type="FourPoint", offsetX="0", offsetY="0", offsetZ="0.0",
                              rotationX="0.0", rotationY="0.0", rotationZ="0.0",
                              name="Region", posCode="REGION", posName="Region",
                              posDef="0", techCode="EMBF03", techName="Embroidery Fixed",
                              techDef="0", areaWidth="100", areaHeight="100",
                              maxColors="12", defaultLogoSize="100", sizeX="100", sizeY="100")

    pts = {"TopLeft": TL, "TopRight": TR, "BottomRight": BR, "BottomLeft": BL}
    for ptype, (x, y) in pts.items():
        ET.SubElement(transform, "point",
                      type=ptype, x=str(float(x)), y=str(float(y)),
                      z="0.0", warp="0", warpShift="0")

    overlays = ET.SubElement(root, "overlays")
    overlay = ET.SubElement(overlays, "overlay")
    for (x, y) in ordered:
        ET.SubElement(overlay, "point", type="Next", x=str(float(x)), y=str(float(y)), z="0.0")

    ET.SubElement(root, "ruler",
                  startX=str(TL[0]), startY=str(TL[1]),
                  stopX=str(BR[0]), stopY=str(BR[1]), value="100")

    tree = ET.ElementTree(root)
    tree.write(out_path, encoding="utf-8", xml_declaration=True)


# --------------------
# Predict one sample
# --------------------
def predict_one(mockup_json, pers_img_path, model_ckpt, out_path="prediction.xml"):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load perspective image
    pers_img = Image.open(pers_img_path).convert("RGB")
    orig_w, orig_h = pers_img.size

    # Load flat points from mockup.json
    _, flat_norm = flat_corners_from_mockup(mockup_json)
    flat_in = torch.tensor(flat_norm.flatten(), dtype=torch.float32).unsqueeze(0).to(device)  # (1,8)

    # Load model
    model = SimpleTransformer().to(device)
    state = torch.load(model_ckpt, map_location=device, weights_only=False)
    if "model_state" in state:  # resume checkpoint format
        model.load_state_dict(state["model_state"])
    else:  # final model
        model.load_state_dict(state)
    model.eval()

    # Predict
    with torch.no_grad():
        pred = model(flat_in)  # (1,8)
        pred = pred.view(4, 2).cpu().numpy()

    # Convert normalized coords to pixel coords
    pred_px = pred.copy()
    pred_px[:, 0] *= orig_w
    pred_px[:, 1] *= orig_h

    # Save prediction
    save_prediction_xml(pred_px, out_path, orig_w, orig_h)
    print(f"Saved prediction -> {out_path}")


# --------------------
# Example usage
# --------------------
if __name__ == "__main__":
    mockup_json = "Transformer/test/100847_TD/front/LAS02/mockup.json"
    pers_img = "Transformer/test/100847_TD/front/LAS02/4BC13E58-1D8A-4E5D-8A40-C1F4B1248893_visual.jpg"
    model_ckpt = "Transformer/transformer_model.pth"
    predict_one(mockup_json, pers_img, model_ckpt, out_path="Transformer/Prediction/pred3.xml")