saim1309's picture
Upload 6 files
ec7f44e verified
import os
import json
import glob
import xml.etree.ElementTree as ET
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torch
import torch.nn as nn
import torch.optim as optim
from shapely.geometry import Polygon
from pathlib import Path
# =====================
# Data Utils
# # =====================
import numpy as np
import json
def flat_corners_from_mockup(mockup_path):
"""
Returns 4 corners of print area from mockup.json
ordered TL, TR, BR, BL and normalized [0,1] w.r.t background.
"""
d = json.loads(Path(mockup_path).read_text())
bg_w = d["background"]["width"]
bg_h = d["background"]["height"]
area = d["printAreas"][0]
x, y = area["position"]["x"], area["position"]["y"]
w, h = area["width"], area["height"]
angle = area["rotation"]
cx, cy = x + w/2.0, y + h/2.0
# corners in px (TL,TR,BR,BL)
dx, dy = w/2.0, h/2.0
corners = np.array([[-dx, -dy], [dx, -dy], [dx, dy], [-dx, dy]], dtype=np.float32)
theta = np.deg2rad(angle)
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], dtype=np.float32)
rot = (corners @ R.T) + np.array([cx, cy], dtype=np.float32)
# normalize
norm = np.zeros_like(rot)
norm[:,0] = rot[:,0] / bg_w
norm[:,1] = rot[:,1] / bg_h
return rot.astype(np.float32), norm.astype(np.float32)
def parse_xml_points(xml_path):
"""
Parse the 4 corner points from the XML (FourPoint transform).
Returns normalized coordinates (TL, TR, BR, BL).
"""
tree = ET.parse(xml_path)
root = tree.getroot()
points = []
bg_w = int(root.find("background").get("width"))
bg_h = int(root.find("background").get("height"))
for transform in root.findall(".//transform"):
if transform.get("type") == "FourPoint":
for pt in ["TopLeft", "TopRight", "BottomRight", "BottomLeft"]:
node = transform.find(f".//point[@type='{pt}']")
if node is not None:
x = float(node.get("x")) / bg_w
y = float(node.get("y")) / bg_h
points.append([x, y])
break # only first transform
return np.array(points, dtype=np.float32) # (4,2)
class KP4Dataset(Dataset):
def __init__(self, root, img_size=512):
self.root = Path(root)
self.img_size = img_size
self.samples = []
# Transform pipeline (resize + tensor + normalize)
self.transform = T.Compose([
T.Resize((img_size, img_size)),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
# Walk recursively
for xml_file in self.root.rglob("*.xml"):
if "_visual" not in xml_file.stem:
continue
# Find matching perspective image
base = xml_file.stem
img_file = None
for ext in [".png", ".jpg", ".jpeg"]:
cand = xml_file.with_suffix(ext)
if cand.exists():
img_file = cand
break
if img_file is None:
continue
# Flat image (background)
flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".png")
if not flat_img.exists():
flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".jpg")
if not flat_img.exists():
continue
# Mockup.json
json_file = xml_file.parent / "mockup.json"
if not json_file.exists():
continue
self.samples.append((img_file, xml_file, flat_img, json_file))
if not self.samples:
raise RuntimeError(f"No valid samples found under {root}")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_file, xml_file, flat_img, json_file = self.samples[idx]
img = self.transform(Image.open(img_file).convert("RGB"))
flat = self.transform(Image.open(flat_img).convert("RGB"))
# flat points
_, flat_norm = flat_corners_from_mockup(json_file)
flat_pts = torch.tensor(flat_norm, dtype=torch.float32)
# perspective points
persp_norm = parse_xml_points(xml_file)
persp_pts = torch.tensor(persp_norm, dtype=torch.float32)
return {
"persp_img": img,
"flat_img": flat,
"flat_pts": flat_pts,
"persp_pts": persp_pts,
"xml": str(xml_file),
"json": str(json_file),
}
# =====================
# Model
# =====================
class SimpleTransformer(nn.Module):
def __init__(self, d_model=128, nhead=4, num_layers=2):
super().__init__()
self.fc_in = nn.Linear(8, d_model) # 4 corners * 2
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc_out = nn.Linear(d_model, 8) # predict 4 corners (x,y)*4
def forward(self, x):
x = self.fc_in(x).unsqueeze(1) # (B,1,8)->(B,1,d_model)
x = self.transformer(x)
x = self.fc_out(x).squeeze(1) # (B,d_model)->(B,8)
return x
# =====================
# Metrics
# =====================
def mse_loss(pred, gt):
return ((pred-gt)**2).mean()
def mean_corner_error(pred, gt, img_w, img_h):
pred_px = pred * torch.tensor([img_w,img_h], device=pred.device)
gt_px = gt * torch.tensor([img_w,img_h], device=gt.device)
err = torch.norm(pred_px-gt_px, dim=-1).mean().item()
return err
def iou_quad(pred, gt):
pred_poly = Polygon(pred.tolist())
gt_poly = Polygon(gt.tolist())
if not pred_poly.is_valid or not gt_poly.is_valid:
return 0.0
inter = pred_poly.intersection(gt_poly).area
union = pred_poly.union(gt_poly).area
return inter/union if union > 0 else 0.0
# =====================
# Training
# =====================
def train_model(
train_root,
test_root,
epochs=20,
batch_size=8,
lr=1e-3,
img_size=256,
save_dir="Transformer/checkpoints",
resume_path=None
):
device = "cuda" if torch.cuda.is_available() else "cpu"
train_ds = KP4Dataset(train_root, img_size=img_size)
val_ds = KP4Dataset(test_root, img_size=img_size)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)
model = SimpleTransformer().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
start_epoch = 0
os.makedirs(save_dir, exist_ok=True)
# Resume Training
if resume_path is not None and os.path.exists(resume_path):
print(f"Loading checkpoint from {resume_path}")
checkpoint = torch.load(resume_path, map_location=device)
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"])
start_epoch = checkpoint["epoch"]
print(f"Resumed from epoch {start_epoch}")
# ===================== Track Best Model =====================
best_iou = -1.0
best_model_path = os.path.join(save_dir, "best_model.pth")
for epoch in range(start_epoch, epochs):
# -------- Training --------
model.train()
total_loss = 0
for batch in train_loader:
flat_pts = batch["flat_pts"].to(device)
persp_pts = batch["persp_pts"].to(device)
flat_pts_in = flat_pts.view(flat_pts.size(0), -1)
target = persp_pts.view(persp_pts.size(0), -1)
pred = model(flat_pts_in)
loss = mse_loss(pred, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs} - Train Loss: {total_loss/len(train_loader):.6f}")
# -------- Validation --------
model.eval()
mse_all, ce_all, iou_all = [], [], []
with torch.no_grad():
for batch in val_loader:
flat_pts = batch["flat_pts"].to(device)
persp_pts = batch["persp_pts"].to(device)
flat_pts_in = flat_pts.view(1, -1)
target = persp_pts.view(1, -1)
pred = model(flat_pts_in)
mse_all.append(mse_loss(pred, target).item())
pred_quad = pred.view(4,2).cpu()
gt_quad = persp_pts.view(4,2).cpu()
w,h = batch["persp_img"].shape[2], batch["persp_img"].shape[1]
ce_all.append(mean_corner_error(pred_quad, gt_quad, w, h))
iou_all.append(iou_quad(pred_quad, gt_quad))
val_mse = np.mean(mse_all)
val_ce = np.mean(ce_all)
val_iou = np.mean(iou_all)
print(f" Val MSE: {val_mse:.6f}, CornerErr(px): {val_ce:.2f}, IoU: {val_iou:.3f}")
if (epoch + 1) % 100 == 0:
# -------- Save Epoch Checkpoint (like before) --------
checkpoint_path = os.path.join(save_dir, f"epoch_{epoch+1}.pth")
torch.save({
"epoch": epoch+1,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"val_iou": val_iou,
}, checkpoint_path)
print(f"Checkpoint saved: {checkpoint_path}")
# -------- Save Best Model --------
if val_iou > best_iou:
best_iou = val_iou
torch.save({
"epoch": epoch+1,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"best_iou": best_iou,
}, best_model_path)
print(f"Best model updated at epoch {epoch+1} (IoU={val_iou:.3f})")
# Save final model weights
final_path = os.path.join(save_dir, "final_model.pth")
torch.save(model.state_dict(), final_path)
print(f"Final model saved at {final_path}")
print(f"Best model saved at {best_model_path} with IoU={best_iou:.3f}")
return model
# =====================
# Main
# =====================
if __name__ == "__main__":
model = train_model(
train_root="Transformer/train",
test_root="Transformer/test",
epochs=3000,
batch_size=4,
lr=1e-3,
img_size=256,
resume_path=None
)