xray-reg / scripts /trainer.py
SuperSecureHuman's picture
Upload 59 files
465d7e4 verified
from lightning import LightningModule
import torch
from models.model_loader import create_model
from dataset.xray_loader import XrayData
import wandb
class XrayReg(LightningModule):
def __init__(self, config):
super().__init__()
self.save_hyperparameters(config)
model_config = config["model"]
dataset_config = config["dataset"]
self.model = create_model(model_config["name"])
self.data = XrayData(
dataset_config["root_dir"],
dataset_config["label_csv"],
dataset_config["batch_size"],
val_split=dataset_config["val_split"],
apply_equalization=dataset_config["apply_equalization"],
)
self.data.setup()
self.test_results = []
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y, filenames = batch
y_hat = self(x).squeeze(1)
loss = torch.nn.functional.mse_loss(y_hat, y)
self.log("train_loss", torch.sqrt(loss), prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y, filenames = batch
y_hat = self(x).squeeze(1)
loss = torch.nn.functional.mse_loss(y_hat, y)
self.log("val_loss", torch.sqrt(loss), prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
x, y, filenames = batch
y_hat = self(x).squeeze(1)
for img, pred, file, gt in zip(x, y_hat, filenames, y):
self.test_results.append({
"image":
wandb.Image(img.cpu().numpy().transpose(
1, 2, 0)), # Convert image for wandb logging
"prediction":
pred.item(),
"filename":
file,
"ground_truth":
gt.item(),
"delta":
abs(pred.item() - gt.item()),
})
return None
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(), lr=self.hparams["training"]["learning_rate"])
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
mode='min',
factor=0.1,
patience=5,
verbose=True)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': sch,
'monitor': 'val_loss',
"frequency": 1,
"interval": "epoch",
}
}
def train_dataloader(self):
return self.data.train_dataloader()
def val_dataloader(self):
return self.data.val_dataloader()
def test_dataloader(self):
return self.data.test_dataloader()
def save_test_results_to_wandb(self):
columns = ["image", "filename", "prediction", "ground_truth", "delta"]
wandb_table = wandb.Table(columns=columns)
for result in self.test_results:
wandb_table.add_data(
result["image"],
result["filename"],
result["prediction"],
result["ground_truth"],
result["delta"],
)
wandb.log({"test_results": wandb_table})