from lightning import LightningModule import torch from models.model_loader import create_model from dataset.xray_loader import XrayData import wandb class XrayReg(LightningModule): def __init__(self, config): super().__init__() self.save_hyperparameters(config) model_config = config["model"] dataset_config = config["dataset"] self.model = create_model(model_config["name"]) self.data = XrayData( dataset_config["root_dir"], dataset_config["label_csv"], dataset_config["batch_size"], val_split=dataset_config["val_split"], apply_equalization=dataset_config["apply_equalization"], ) self.data.setup() self.test_results = [] def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): x, y, filenames = batch y_hat = self(x).squeeze(1) loss = torch.nn.functional.mse_loss(y_hat, y) self.log("train_loss", torch.sqrt(loss), prog_bar=True) return loss def validation_step(self, batch, batch_idx): x, y, filenames = batch y_hat = self(x).squeeze(1) loss = torch.nn.functional.mse_loss(y_hat, y) self.log("val_loss", torch.sqrt(loss), prog_bar=True) return loss def test_step(self, batch, batch_idx): x, y, filenames = batch y_hat = self(x).squeeze(1) for img, pred, file, gt in zip(x, y_hat, filenames, y): self.test_results.append({ "image": wandb.Image(img.cpu().numpy().transpose( 1, 2, 0)), # Convert image for wandb logging "prediction": pred.item(), "filename": file, "ground_truth": gt.item(), "delta": abs(pred.item() - gt.item()), }) return None def configure_optimizers(self): optimizer = torch.optim.Adam( self.parameters(), lr=self.hparams["training"]["learning_rate"]) sch = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True) return { 'optimizer': optimizer, 'lr_scheduler': { 'scheduler': sch, 'monitor': 'val_loss', "frequency": 1, "interval": "epoch", } } def train_dataloader(self): return self.data.train_dataloader() def val_dataloader(self): return self.data.val_dataloader() def test_dataloader(self): return self.data.test_dataloader() def save_test_results_to_wandb(self): columns = ["image", "filename", "prediction", "ground_truth", "delta"] wandb_table = wandb.Table(columns=columns) for result in self.test_results: wandb_table.add_data( result["image"], result["filename"], result["prediction"], result["ground_truth"], result["delta"], ) wandb.log({"test_results": wandb_table})