import os import json import glob import xml.etree.ElementTree as ET import numpy as np from PIL import Image from torch.utils.data import Dataset, DataLoader import torchvision.transforms as T import torch import torch.nn as nn import torch.optim as optim from shapely.geometry import Polygon from pathlib import Path # ===================== # Data Utils # # ===================== import numpy as np import json def flat_corners_from_mockup(mockup_path): """ Returns 4 corners of print area from mockup.json ordered TL, TR, BR, BL and normalized [0,1] w.r.t background. """ d = json.loads(Path(mockup_path).read_text()) bg_w = d["background"]["width"] bg_h = d["background"]["height"] area = d["printAreas"][0] x, y = area["position"]["x"], area["position"]["y"] w, h = area["width"], area["height"] angle = area["rotation"] cx, cy = x + w/2.0, y + h/2.0 # corners in px (TL,TR,BR,BL) dx, dy = w/2.0, h/2.0 corners = np.array([[-dx, -dy], [dx, -dy], [dx, dy], [-dx, dy]], dtype=np.float32) theta = np.deg2rad(angle) R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], dtype=np.float32) rot = (corners @ R.T) + np.array([cx, cy], dtype=np.float32) # normalize norm = np.zeros_like(rot) norm[:,0] = rot[:,0] / bg_w norm[:,1] = rot[:,1] / bg_h return rot.astype(np.float32), norm.astype(np.float32) def parse_xml_points(xml_path): """ Parse the 4 corner points from the XML (FourPoint transform). Returns normalized coordinates (TL, TR, BR, BL). """ tree = ET.parse(xml_path) root = tree.getroot() points = [] bg_w = int(root.find("background").get("width")) bg_h = int(root.find("background").get("height")) for transform in root.findall(".//transform"): if transform.get("type") == "FourPoint": for pt in ["TopLeft", "TopRight", "BottomRight", "BottomLeft"]: node = transform.find(f".//point[@type='{pt}']") if node is not None: x = float(node.get("x")) / bg_w y = float(node.get("y")) / bg_h points.append([x, y]) break # only first transform return np.array(points, dtype=np.float32) # (4,2) class KP4Dataset(Dataset): def __init__(self, root, img_size=512): self.root = Path(root) self.img_size = img_size self.samples = [] # Transform pipeline (resize + tensor + normalize) self.transform = T.Compose([ T.Resize((img_size, img_size)), T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) # Walk recursively for xml_file in self.root.rglob("*.xml"): if "_visual" not in xml_file.stem: continue # Find matching perspective image base = xml_file.stem img_file = None for ext in [".png", ".jpg", ".jpeg"]: cand = xml_file.with_suffix(ext) if cand.exists(): img_file = cand break if img_file is None: continue # Flat image (background) flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".png") if not flat_img.exists(): flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".jpg") if not flat_img.exists(): continue # Mockup.json json_file = xml_file.parent / "mockup.json" if not json_file.exists(): continue self.samples.append((img_file, xml_file, flat_img, json_file)) if not self.samples: raise RuntimeError(f"No valid samples found under {root}") def __len__(self): return len(self.samples) def __getitem__(self, idx): img_file, xml_file, flat_img, json_file = self.samples[idx] img = self.transform(Image.open(img_file).convert("RGB")) flat = self.transform(Image.open(flat_img).convert("RGB")) # flat points _, flat_norm = flat_corners_from_mockup(json_file) flat_pts = torch.tensor(flat_norm, dtype=torch.float32) # perspective points persp_norm = parse_xml_points(xml_file) persp_pts = torch.tensor(persp_norm, dtype=torch.float32) return { "persp_img": img, "flat_img": flat, "flat_pts": flat_pts, "persp_pts": persp_pts, "xml": str(xml_file), "json": str(json_file), } # ===================== # Model # ===================== class SimpleTransformer(nn.Module): def __init__(self, d_model=128, nhead=4, num_layers=2): super().__init__() self.fc_in = nn.Linear(8, d_model) # 4 corners * 2 encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.fc_out = nn.Linear(d_model, 8) # predict 4 corners (x,y)*4 def forward(self, x): x = self.fc_in(x).unsqueeze(1) # (B,1,8)->(B,1,d_model) x = self.transformer(x) x = self.fc_out(x).squeeze(1) # (B,d_model)->(B,8) return x # ===================== # Metrics # ===================== def mse_loss(pred, gt): return ((pred-gt)**2).mean() def mean_corner_error(pred, gt, img_w, img_h): pred_px = pred * torch.tensor([img_w,img_h], device=pred.device) gt_px = gt * torch.tensor([img_w,img_h], device=gt.device) err = torch.norm(pred_px-gt_px, dim=-1).mean().item() return err def iou_quad(pred, gt): pred_poly = Polygon(pred.tolist()) gt_poly = Polygon(gt.tolist()) if not pred_poly.is_valid or not gt_poly.is_valid: return 0.0 inter = pred_poly.intersection(gt_poly).area union = pred_poly.union(gt_poly).area return inter/union if union > 0 else 0.0 # ===================== # Training # ===================== def train_model( train_root, test_root, epochs=20, batch_size=8, lr=1e-3, img_size=256, save_dir="Transformer/checkpoints", resume_path=None ): device = "cuda" if torch.cuda.is_available() else "cpu" train_ds = KP4Dataset(train_root, img_size=img_size) val_ds = KP4Dataset(test_root, img_size=img_size) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False) model = SimpleTransformer().to(device) optimizer = optim.Adam(model.parameters(), lr=lr) start_epoch = 0 os.makedirs(save_dir, exist_ok=True) # Resume Training if resume_path is not None and os.path.exists(resume_path): print(f"Loading checkpoint from {resume_path}") checkpoint = torch.load(resume_path, map_location=device) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) start_epoch = checkpoint["epoch"] print(f"Resumed from epoch {start_epoch}") # ===================== Track Best Model ===================== best_iou = -1.0 best_model_path = os.path.join(save_dir, "best_model.pth") for epoch in range(start_epoch, epochs): # -------- Training -------- model.train() total_loss = 0 for batch in train_loader: flat_pts = batch["flat_pts"].to(device) persp_pts = batch["persp_pts"].to(device) flat_pts_in = flat_pts.view(flat_pts.size(0), -1) target = persp_pts.view(persp_pts.size(0), -1) pred = model(flat_pts_in) loss = mse_loss(pred, target) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}/{epochs} - Train Loss: {total_loss/len(train_loader):.6f}") # -------- Validation -------- model.eval() mse_all, ce_all, iou_all = [], [], [] with torch.no_grad(): for batch in val_loader: flat_pts = batch["flat_pts"].to(device) persp_pts = batch["persp_pts"].to(device) flat_pts_in = flat_pts.view(1, -1) target = persp_pts.view(1, -1) pred = model(flat_pts_in) mse_all.append(mse_loss(pred, target).item()) pred_quad = pred.view(4,2).cpu() gt_quad = persp_pts.view(4,2).cpu() w,h = batch["persp_img"].shape[2], batch["persp_img"].shape[1] ce_all.append(mean_corner_error(pred_quad, gt_quad, w, h)) iou_all.append(iou_quad(pred_quad, gt_quad)) val_mse = np.mean(mse_all) val_ce = np.mean(ce_all) val_iou = np.mean(iou_all) print(f" Val MSE: {val_mse:.6f}, CornerErr(px): {val_ce:.2f}, IoU: {val_iou:.3f}") if (epoch + 1) % 100 == 0: # -------- Save Epoch Checkpoint (like before) -------- checkpoint_path = os.path.join(save_dir, f"epoch_{epoch+1}.pth") torch.save({ "epoch": epoch+1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "val_iou": val_iou, }, checkpoint_path) print(f"Checkpoint saved: {checkpoint_path}") # -------- Save Best Model -------- if val_iou > best_iou: best_iou = val_iou torch.save({ "epoch": epoch+1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "best_iou": best_iou, }, best_model_path) print(f"Best model updated at epoch {epoch+1} (IoU={val_iou:.3f})") # Save final model weights final_path = os.path.join(save_dir, "final_model.pth") torch.save(model.state_dict(), final_path) print(f"Final model saved at {final_path}") print(f"Best model saved at {best_model_path} with IoU={best_iou:.3f}") return model # ===================== # Main # ===================== if __name__ == "__main__": model = train_model( train_root="Transformer/train", test_root="Transformer/test", epochs=3000, batch_size=4, lr=1e-3, img_size=256, resume_path=None )