|
import os |
|
import json |
|
import glob |
|
import xml.etree.ElementTree as ET |
|
import numpy as np |
|
from PIL import Image |
|
from torch.utils.data import Dataset, DataLoader |
|
import torchvision.transforms as T |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from shapely.geometry import Polygon |
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import json |
|
|
|
def flat_corners_from_mockup(mockup_path): |
|
""" |
|
Returns 4 corners of print area from mockup.json |
|
ordered TL, TR, BR, BL and normalized [0,1] w.r.t background. |
|
""" |
|
d = json.loads(Path(mockup_path).read_text()) |
|
bg_w = d["background"]["width"] |
|
bg_h = d["background"]["height"] |
|
area = d["printAreas"][0] |
|
x, y = area["position"]["x"], area["position"]["y"] |
|
w, h = area["width"], area["height"] |
|
angle = area["rotation"] |
|
cx, cy = x + w/2.0, y + h/2.0 |
|
|
|
|
|
dx, dy = w/2.0, h/2.0 |
|
corners = np.array([[-dx, -dy], [dx, -dy], [dx, dy], [-dx, dy]], dtype=np.float32) |
|
theta = np.deg2rad(angle) |
|
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], dtype=np.float32) |
|
rot = (corners @ R.T) + np.array([cx, cy], dtype=np.float32) |
|
|
|
|
|
norm = np.zeros_like(rot) |
|
norm[:,0] = rot[:,0] / bg_w |
|
norm[:,1] = rot[:,1] / bg_h |
|
return rot.astype(np.float32), norm.astype(np.float32) |
|
|
|
def parse_xml_points(xml_path): |
|
""" |
|
Parse the 4 corner points from the XML (FourPoint transform). |
|
Returns normalized coordinates (TL, TR, BR, BL). |
|
""" |
|
tree = ET.parse(xml_path) |
|
root = tree.getroot() |
|
|
|
points = [] |
|
bg_w = int(root.find("background").get("width")) |
|
bg_h = int(root.find("background").get("height")) |
|
|
|
for transform in root.findall(".//transform"): |
|
if transform.get("type") == "FourPoint": |
|
for pt in ["TopLeft", "TopRight", "BottomRight", "BottomLeft"]: |
|
node = transform.find(f".//point[@type='{pt}']") |
|
if node is not None: |
|
x = float(node.get("x")) / bg_w |
|
y = float(node.get("y")) / bg_h |
|
points.append([x, y]) |
|
break |
|
|
|
return np.array(points, dtype=np.float32) |
|
|
|
class KP4Dataset(Dataset): |
|
def __init__(self, root, img_size=512): |
|
self.root = Path(root) |
|
self.img_size = img_size |
|
self.samples = [] |
|
|
|
|
|
self.transform = T.Compose([ |
|
T.Resize((img_size, img_size)), |
|
T.ToTensor(), |
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
|
]) |
|
|
|
|
|
for xml_file in self.root.rglob("*.xml"): |
|
if "_visual" not in xml_file.stem: |
|
continue |
|
|
|
|
|
base = xml_file.stem |
|
img_file = None |
|
for ext in [".png", ".jpg", ".jpeg"]: |
|
cand = xml_file.with_suffix(ext) |
|
if cand.exists(): |
|
img_file = cand |
|
break |
|
if img_file is None: |
|
continue |
|
|
|
|
|
flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".png") |
|
if not flat_img.exists(): |
|
flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".jpg") |
|
if not flat_img.exists(): |
|
continue |
|
|
|
|
|
json_file = xml_file.parent / "mockup.json" |
|
if not json_file.exists(): |
|
continue |
|
|
|
self.samples.append((img_file, xml_file, flat_img, json_file)) |
|
|
|
if not self.samples: |
|
raise RuntimeError(f"No valid samples found under {root}") |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, idx): |
|
img_file, xml_file, flat_img, json_file = self.samples[idx] |
|
|
|
img = self.transform(Image.open(img_file).convert("RGB")) |
|
flat = self.transform(Image.open(flat_img).convert("RGB")) |
|
|
|
|
|
_, flat_norm = flat_corners_from_mockup(json_file) |
|
flat_pts = torch.tensor(flat_norm, dtype=torch.float32) |
|
|
|
|
|
persp_norm = parse_xml_points(xml_file) |
|
persp_pts = torch.tensor(persp_norm, dtype=torch.float32) |
|
|
|
return { |
|
"persp_img": img, |
|
"flat_img": flat, |
|
"flat_pts": flat_pts, |
|
"persp_pts": persp_pts, |
|
"xml": str(xml_file), |
|
"json": str(json_file), |
|
} |
|
|
|
|
|
|
|
|
|
class SimpleTransformer(nn.Module): |
|
def __init__(self, d_model=128, nhead=4, num_layers=2): |
|
super().__init__() |
|
self.fc_in = nn.Linear(8, d_model) |
|
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True) |
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
|
self.fc_out = nn.Linear(d_model, 8) |
|
|
|
def forward(self, x): |
|
x = self.fc_in(x).unsqueeze(1) |
|
x = self.transformer(x) |
|
x = self.fc_out(x).squeeze(1) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
def mse_loss(pred, gt): |
|
return ((pred-gt)**2).mean() |
|
|
|
def mean_corner_error(pred, gt, img_w, img_h): |
|
pred_px = pred * torch.tensor([img_w,img_h], device=pred.device) |
|
gt_px = gt * torch.tensor([img_w,img_h], device=gt.device) |
|
err = torch.norm(pred_px-gt_px, dim=-1).mean().item() |
|
return err |
|
|
|
def iou_quad(pred, gt): |
|
pred_poly = Polygon(pred.tolist()) |
|
gt_poly = Polygon(gt.tolist()) |
|
if not pred_poly.is_valid or not gt_poly.is_valid: |
|
return 0.0 |
|
inter = pred_poly.intersection(gt_poly).area |
|
union = pred_poly.union(gt_poly).area |
|
return inter/union if union > 0 else 0.0 |
|
|
|
|
|
|
|
|
|
|
|
def train_model( |
|
train_root, |
|
test_root, |
|
epochs=20, |
|
batch_size=8, |
|
lr=1e-3, |
|
img_size=256, |
|
save_dir="Transformer/checkpoints", |
|
resume_path=None |
|
): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
train_ds = KP4Dataset(train_root, img_size=img_size) |
|
val_ds = KP4Dataset(test_root, img_size=img_size) |
|
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) |
|
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False) |
|
|
|
model = SimpleTransformer().to(device) |
|
optimizer = optim.Adam(model.parameters(), lr=lr) |
|
start_epoch = 0 |
|
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
if resume_path is not None and os.path.exists(resume_path): |
|
print(f"Loading checkpoint from {resume_path}") |
|
checkpoint = torch.load(resume_path, map_location=device) |
|
model.load_state_dict(checkpoint["model_state"]) |
|
optimizer.load_state_dict(checkpoint["optimizer_state"]) |
|
start_epoch = checkpoint["epoch"] |
|
print(f"Resumed from epoch {start_epoch}") |
|
|
|
|
|
best_iou = -1.0 |
|
best_model_path = os.path.join(save_dir, "best_model.pth") |
|
|
|
for epoch in range(start_epoch, epochs): |
|
|
|
model.train() |
|
total_loss = 0 |
|
for batch in train_loader: |
|
flat_pts = batch["flat_pts"].to(device) |
|
persp_pts = batch["persp_pts"].to(device) |
|
|
|
flat_pts_in = flat_pts.view(flat_pts.size(0), -1) |
|
target = persp_pts.view(persp_pts.size(0), -1) |
|
|
|
pred = model(flat_pts_in) |
|
loss = mse_loss(pred, target) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
total_loss += loss.item() |
|
|
|
print(f"Epoch {epoch+1}/{epochs} - Train Loss: {total_loss/len(train_loader):.6f}") |
|
|
|
|
|
model.eval() |
|
mse_all, ce_all, iou_all = [], [], [] |
|
with torch.no_grad(): |
|
for batch in val_loader: |
|
flat_pts = batch["flat_pts"].to(device) |
|
persp_pts = batch["persp_pts"].to(device) |
|
|
|
flat_pts_in = flat_pts.view(1, -1) |
|
target = persp_pts.view(1, -1) |
|
|
|
pred = model(flat_pts_in) |
|
mse_all.append(mse_loss(pred, target).item()) |
|
|
|
pred_quad = pred.view(4,2).cpu() |
|
gt_quad = persp_pts.view(4,2).cpu() |
|
|
|
w,h = batch["persp_img"].shape[2], batch["persp_img"].shape[1] |
|
ce_all.append(mean_corner_error(pred_quad, gt_quad, w, h)) |
|
iou_all.append(iou_quad(pred_quad, gt_quad)) |
|
|
|
val_mse = np.mean(mse_all) |
|
val_ce = np.mean(ce_all) |
|
val_iou = np.mean(iou_all) |
|
|
|
print(f" Val MSE: {val_mse:.6f}, CornerErr(px): {val_ce:.2f}, IoU: {val_iou:.3f}") |
|
if (epoch + 1) % 100 == 0: |
|
|
|
checkpoint_path = os.path.join(save_dir, f"epoch_{epoch+1}.pth") |
|
torch.save({ |
|
"epoch": epoch+1, |
|
"model_state": model.state_dict(), |
|
"optimizer_state": optimizer.state_dict(), |
|
"val_iou": val_iou, |
|
}, checkpoint_path) |
|
print(f"Checkpoint saved: {checkpoint_path}") |
|
|
|
|
|
if val_iou > best_iou: |
|
best_iou = val_iou |
|
torch.save({ |
|
"epoch": epoch+1, |
|
"model_state": model.state_dict(), |
|
"optimizer_state": optimizer.state_dict(), |
|
"best_iou": best_iou, |
|
}, best_model_path) |
|
print(f"Best model updated at epoch {epoch+1} (IoU={val_iou:.3f})") |
|
|
|
|
|
final_path = os.path.join(save_dir, "final_model.pth") |
|
torch.save(model.state_dict(), final_path) |
|
print(f"Final model saved at {final_path}") |
|
print(f"Best model saved at {best_model_path} with IoU={best_iou:.3f}") |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
model = train_model( |
|
train_root="Transformer/train", |
|
test_root="Transformer/test", |
|
epochs=3000, |
|
batch_size=4, |
|
lr=1e-3, |
|
img_size=256, |
|
resume_path=None |
|
) |
|
|