import os, glob import numpy as np from PIL import Image import torch from torch.utils.data import Dataset, DataLoader import torch.nn as nn from models.unet import UNet import torch.optim as optim class SliceDataset(Dataset): def __init__(self, folder): self.paths = sorted(glob.glob(os.path.join(folder, "*.png"))) def __len__(self): return len(self.paths) def __getitem__(self, idx): p = self.paths[idx] img = np.array(Image.open(p).convert("L"), dtype=np.float32)/255.0 mask = (img > img.mean() + 0.25).astype(np.float32) img = img[np.newaxis,...] mask = mask[np.newaxis,...] return torch.tensor(img), torch.tensor(mask) def train(folder, epochs=3, out="models/unet_best.pt"): ds = SliceDataset(folder) dl = DataLoader(ds, batch_size=4, shuffle=True) model = UNet(in_channels=1, out_channels=1) opt = optim.Adam(model.parameters(), lr=1e-3) loss_fn = nn.BCELoss() for epoch in range(epochs): total=0 model.train() for x,y in dl: outp = model(x) loss = loss_fn(outp, y) opt.zero_grad(); loss.backward(); opt.step() total += loss.item() print(f"Epoch {epoch+1}, loss {total/len(dl):.4f}") os.makedirs(os.path.dirname(out), exist_ok=True) torch.save(model.state_dict(), out) print("Saved model to", out) if __name__ == '__main__': import argparse p = argparse.ArgumentParser() p.add_argument('--data', default='examples/synthetic_phantom') p.add_argument('--epochs', type=int, default=3) p.add_argument('--out', default='models/unet_best.pt') args = p.parse_args() train(args.data, epochs=args.epochs, out=args.out)