Spaces:
Running
on
Zero
Running
on
Zero
from lightning import LightningModule | |
import torch | |
from models.model_loader import create_model | |
from dataset.xray_loader import XrayData | |
import wandb | |
class XrayReg(LightningModule): | |
def __init__(self, config): | |
super().__init__() | |
self.save_hyperparameters(config) | |
model_config = config["model"] | |
dataset_config = config["dataset"] | |
self.model = create_model(model_config["name"]) | |
self.data = XrayData( | |
dataset_config["root_dir"], | |
dataset_config["label_csv"], | |
dataset_config["batch_size"], | |
val_split=dataset_config["val_split"], | |
apply_equalization=dataset_config["apply_equalization"], | |
) | |
self.data.setup() | |
self.test_results = [] | |
def forward(self, x): | |
return self.model(x) | |
def training_step(self, batch, batch_idx): | |
x, y, filenames = batch | |
y_hat = self(x).squeeze(1) | |
loss = torch.nn.functional.mse_loss(y_hat, y) | |
self.log("train_loss", torch.sqrt(loss), prog_bar=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x, y, filenames = batch | |
y_hat = self(x).squeeze(1) | |
loss = torch.nn.functional.mse_loss(y_hat, y) | |
self.log("val_loss", torch.sqrt(loss), prog_bar=True) | |
return loss | |
def test_step(self, batch, batch_idx): | |
x, y, filenames = batch | |
y_hat = self(x).squeeze(1) | |
for img, pred, file, gt in zip(x, y_hat, filenames, y): | |
self.test_results.append({ | |
"image": | |
wandb.Image(img.cpu().numpy().transpose( | |
1, 2, 0)), # Convert image for wandb logging | |
"prediction": | |
pred.item(), | |
"filename": | |
file, | |
"ground_truth": | |
gt.item(), | |
"delta": | |
abs(pred.item() - gt.item()), | |
}) | |
return None | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam( | |
self.parameters(), lr=self.hparams["training"]["learning_rate"]) | |
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, | |
mode='min', | |
factor=0.1, | |
patience=5, | |
verbose=True) | |
return { | |
'optimizer': optimizer, | |
'lr_scheduler': { | |
'scheduler': sch, | |
'monitor': 'val_loss', | |
"frequency": 1, | |
"interval": "epoch", | |
} | |
} | |
def train_dataloader(self): | |
return self.data.train_dataloader() | |
def val_dataloader(self): | |
return self.data.val_dataloader() | |
def test_dataloader(self): | |
return self.data.test_dataloader() | |
def save_test_results_to_wandb(self): | |
columns = ["image", "filename", "prediction", "ground_truth", "delta"] | |
wandb_table = wandb.Table(columns=columns) | |
for result in self.test_results: | |
wandb_table.add_data( | |
result["image"], | |
result["filename"], | |
result["prediction"], | |
result["ground_truth"], | |
result["delta"], | |
) | |
wandb.log({"test_results": wandb_table}) | |