File size: 3,429 Bytes
465d7e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from lightning import LightningModule
import torch
from models.model_loader import create_model
from dataset.xray_loader import XrayData

import wandb


class XrayReg(LightningModule):

    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters(config)
        model_config = config["model"]
        dataset_config = config["dataset"]
        self.model = create_model(model_config["name"])

        self.data = XrayData(
            dataset_config["root_dir"],
            dataset_config["label_csv"],
            dataset_config["batch_size"],
            val_split=dataset_config["val_split"],
            apply_equalization=dataset_config["apply_equalization"],
        )
        self.data.setup()
        self.test_results = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y, filenames = batch
        y_hat = self(x).squeeze(1)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log("train_loss", torch.sqrt(loss), prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y, filenames = batch
        y_hat = self(x).squeeze(1)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log("val_loss", torch.sqrt(loss), prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y, filenames = batch
        y_hat = self(x).squeeze(1)
        for img, pred, file, gt in zip(x, y_hat, filenames, y):
            self.test_results.append({
                "image":
                wandb.Image(img.cpu().numpy().transpose(
                    1, 2, 0)),  # Convert image for wandb logging
                "prediction":
                pred.item(),
                "filename":
                file,
                "ground_truth":
                gt.item(),
                "delta":
                abs(pred.item() - gt.item()),
            })
        return None

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.hparams["training"]["learning_rate"])

        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         mode='min',
                                                         factor=0.1,
                                                         patience=5,
                                                         verbose=True)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': sch,
                'monitor': 'val_loss',
                "frequency": 1,
                "interval": "epoch",
            }
        }

    def train_dataloader(self):
        return self.data.train_dataloader()

    def val_dataloader(self):
        return self.data.val_dataloader()

    def test_dataloader(self):
        return self.data.test_dataloader()

    def save_test_results_to_wandb(self):
        columns = ["image", "filename", "prediction", "ground_truth", "delta"]
        wandb_table = wandb.Table(columns=columns)

        for result in self.test_results:
            wandb_table.add_data(
                result["image"],
                result["filename"],
                result["prediction"],
                result["ground_truth"],
                result["delta"],
            )
        wandb.log({"test_results": wandb_table})