File size: 4,256 Bytes
ec7f44e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch
import xml.etree.ElementTree as ET
from PIL import Image
import numpy as np
from pathlib import Path
import json
from train import SimpleTransformer, flat_corners_from_mockup
# --------------------
# Utility: order 4 points (same as old)
# --------------------
def order_points_clockwise(pts):
pts = np.array(pts, dtype="float32")
y_sorted = pts[np.argsort(pts[:, 1]), :]
top_two = y_sorted[:2, :]
bottom_two = y_sorted[2:, :]
if top_two[0][0] < top_two[1][0]:
tl, tr = top_two
else:
tr, tl = top_two
if bottom_two[0][0] < bottom_two[1][0]:
bl, br = bottom_two
else:
br, bl = bottom_two
return np.array([tl, tr, br, bl], dtype="float32")
# --------------------
# Utility: save XML prediction
# --------------------
def save_prediction_xml(pred_pts, out_path, img_w, img_h):
ordered = order_points_clockwise(pred_pts)
TL, TR, BR, BL = ordered
root = ET.Element("visualization", version="1.0")
ET.SubElement(root, "effects", surfacecolor="", iswood="0")
ET.SubElement(root, "background",
width=str(img_w), height=str(img_h),
color1="#C4CDE4", color2="", color3="")
transforms_node = ET.SubElement(root, "transforms")
transform = ET.SubElement(transforms_node, "transform",
type="FourPoint", offsetX="0", offsetY="0", offsetZ="0.0",
rotationX="0.0", rotationY="0.0", rotationZ="0.0",
name="Region", posCode="REGION", posName="Region",
posDef="0", techCode="EMBF03", techName="Embroidery Fixed",
techDef="0", areaWidth="100", areaHeight="100",
maxColors="12", defaultLogoSize="100", sizeX="100", sizeY="100")
pts = {"TopLeft": TL, "TopRight": TR, "BottomRight": BR, "BottomLeft": BL}
for ptype, (x, y) in pts.items():
ET.SubElement(transform, "point",
type=ptype, x=str(float(x)), y=str(float(y)),
z="0.0", warp="0", warpShift="0")
overlays = ET.SubElement(root, "overlays")
overlay = ET.SubElement(overlays, "overlay")
for (x, y) in ordered:
ET.SubElement(overlay, "point", type="Next", x=str(float(x)), y=str(float(y)), z="0.0")
ET.SubElement(root, "ruler",
startX=str(TL[0]), startY=str(TL[1]),
stopX=str(BR[0]), stopY=str(BR[1]), value="100")
tree = ET.ElementTree(root)
tree.write(out_path, encoding="utf-8", xml_declaration=True)
# --------------------
# Predict one sample
# --------------------
def predict_one(mockup_json, pers_img_path, model_ckpt, out_path="prediction.xml"):
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load perspective image
pers_img = Image.open(pers_img_path).convert("RGB")
orig_w, orig_h = pers_img.size
# Load flat points from mockup.json
_, flat_norm = flat_corners_from_mockup(mockup_json)
flat_in = torch.tensor(flat_norm.flatten(), dtype=torch.float32).unsqueeze(0).to(device) # (1,8)
# Load model
model = SimpleTransformer().to(device)
state = torch.load(model_ckpt, map_location=device, weights_only=False)
if "model_state" in state: # resume checkpoint format
model.load_state_dict(state["model_state"])
else: # final model
model.load_state_dict(state)
model.eval()
# Predict
with torch.no_grad():
pred = model(flat_in) # (1,8)
pred = pred.view(4, 2).cpu().numpy()
# Convert normalized coords to pixel coords
pred_px = pred.copy()
pred_px[:, 0] *= orig_w
pred_px[:, 1] *= orig_h
# Save prediction
save_prediction_xml(pred_px, out_path, orig_w, orig_h)
print(f"Saved prediction -> {out_path}")
# --------------------
# Example usage
# --------------------
if __name__ == "__main__":
mockup_json = "Transformer/test/100847_TD/front/LAS02/mockup.json"
pers_img = "Transformer/test/100847_TD/front/LAS02/4BC13E58-1D8A-4E5D-8A40-C1F4B1248893_visual.jpg"
model_ckpt = "Transformer/transformer_model.pth"
predict_one(mockup_json, pers_img, model_ckpt, out_path="Transformer/Prediction/pred3.xml") |