| | import json |
| | import logging |
| | import os |
| | import re |
| | from abc import ABC, abstractmethod |
| | from pathlib import Path |
| | from typing import (Any, Callable, Dict, Iterable, List, Optional, Pattern, |
| | Set, Tuple, Union) |
| |
|
| | import numpy as np |
| | import SimpleITK |
| | from evalutils.exceptions import FileLoaderError |
| | from evalutils.io import FileLoader, ImageLoader, SimpleITKLoader |
| | from evalutils.validators import (UniqueImagesValidator, |
| | UniquePathIndicesValidator) |
| | from pandas import DataFrame |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | if Path(".env").exists(): |
| | from dotenv import dotenv_values |
| |
|
| | config = dotenv_values(".env") |
| |
|
| | TASK_TYPE = config["TASK_TYPE"] |
| | INPUT_FOLDER = config["INPUT_FOLDER"] |
| |
|
| | print("########## ENVIRONMENT VARIABLES ##########") |
| | print(f"TASK_TYPE: {TASK_TYPE}") |
| | print(f"INPUT_FOLDER: {INPUT_FOLDER}") |
| | else: |
| | TASK_TYPE = "mri" |
| | INPUT_FOLDER = "/input" |
| |
|
| | if INPUT_FOLDER == "/input": |
| | OUTPUT_FOLDER = "/output" |
| | else: |
| | OUTPUT_FOLDER = "./output" |
| |
|
| | DEFAULT_IMAGE_PATH = Path(f"{INPUT_FOLDER}/images/{TASK_TYPE}") |
| | DEFAULT_REGION_PATH = Path(f"{INPUT_FOLDER}/region.json") |
| | DEFAULT_MASK_PATH = Path(f"{INPUT_FOLDER}/images/body") |
| | DEFAULT_OUTPUT_PATH = Path(f"{OUTPUT_FOLDER}/images/synthetic-ct") |
| | DEFAULT_OUTPUT_FILE = Path(f"{OUTPUT_FOLDER}/results.json") |
| |
|
| |
|
| | class BaseSynthradAlgorithm(ABC): |
| | def __init__( |
| | self, |
| | input_path: Path = DEFAULT_IMAGE_PATH, |
| | mask_path: Path = DEFAULT_MASK_PATH, |
| | region_path: Path = DEFAULT_REGION_PATH, |
| | output_path: Path = DEFAULT_OUTPUT_PATH, |
| | output_file: Path = DEFAULT_OUTPUT_FILE, |
| | validators: Optional[Dict[str, callable]] = None, |
| | file_loader: FileLoader = SimpleITKLoader(), |
| | ): |
| | """ |
| | Parameters |
| | ---------- |
| | |
| | input_path |
| | The path in the container where the input images will be loaded from. |
| | from. Default: `/input/images/mri/` |
| | mask_path |
| | The path in the container where the input masks will be loaded from. |
| | Default: `/input/images/body/` |
| | output_path |
| | The path in the container where the output images will be written. |
| | Default: `/output/images/synthetic-ct/` |
| | |
| | output_file |
| | The path to the location where the results will be written. |
| | Default: `/output/results.json` |
| | file_loader |
| | The loaders that will be used to get all files. |
| | Default: `evalutils.io.SimpleITKLoader` for `image` and `mask` |
| | validators |
| | A dictionary containing the validators that will be used on the |
| | loaded data per file_loader key. Default: |
| | `evalutils.validators.UniqueImagesValidator` for `input_image` |
| | """ |
| |
|
| | self._index_keys = ["image", "mask"] |
| | self.input_path = input_path |
| | self.mask_path = mask_path |
| | self.region_path = region_path |
| | self.output_path = output_path |
| | self.output_file = output_file |
| | self._file_loader = file_loader |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | self.cases = {} |
| | self._case_results = [] |
| |
|
| | def load(self): |
| | self.images = self._load_cases( |
| | folder=self.input_path, file_loader=self._file_loader |
| | ) |
| |
|
| | self.masks = self._load_cases( |
| | folder=self.mask_path, file_loader=self._file_loader |
| | ) |
| |
|
| | with open(self.region_path, "r") as f: |
| | self.region = json.load(f) |
| |
|
| | def _load_cases( |
| | self, |
| | folder: Path, |
| | file_loader: ImageLoader, |
| | ) -> DataFrame: |
| | cases = [] |
| |
|
| | for fp in sorted(folder.glob("*")): |
| | try: |
| | new_cases = file_loader.load(fname=fp) |
| | except FileLoaderError: |
| | logger.warning(f"Could not load {fp.name} using {file_loader}.") |
| | else: |
| | cases.extend(new_cases) |
| |
|
| | if len(cases) == 0: |
| | raise FileLoaderError( |
| | f"Could not load any files in {folder} with " f"{file_loader}." |
| | ) |
| |
|
| | return cases |
| |
|
| | def validate(self): |
| | """TODO: Validates each dataframe for each fileloader separately""" |
| | pass |
| |
|
| | def _validate_data_frame(self, df: DataFrame): |
| | "TODO: Validate the dataframe for a specific fileloader" |
| | pass |
| |
|
| | def process_cases(self): |
| | self._case_results = [] |
| |
|
| | for idx, case in enumerate(zip(self.images, self.masks)): |
| | self._case_results.append(self.process_case(idx=idx, case=case)) |
| |
|
| | def process_case(self, idx: int, case: List[DataFrame]) -> Dict: |
| | images, images_file_paths = {}, {} |
| |
|
| | images["image"], images_file_paths["image"] = self._load_input_image(case[0]) |
| | images["mask"], images_file_paths["mask"] = self._load_input_image(case[1]) |
| |
|
| | images["region"] = self.region |
| |
|
| | |
| | out = self.predict(input_dict=images) |
| |
|
| | |
| | out_path = self.output_path / images_file_paths["image"].name |
| | if not self.output_path.exists(): |
| | self.output_path.mkdir(parents=True, exist_ok=True) |
| |
|
| | SimpleITK.WriteImage(out, str(out_path), True) |
| |
|
| | |
| | return { |
| | "outputs": [dict(type="metaio_image", filename=str(out_path))], |
| | "inputs": [ |
| | dict(type="metaio_image", filename=str(fn)) |
| | for fn in images_file_paths.values() |
| | ] + [dict(type="String", filename=str(self.region_path))], |
| | "error_messages": [], |
| | } |
| |
|
| | def _load_input_image(self, image) -> Tuple[SimpleITK.Image, Path]: |
| | input_image_file_path = image["path"] |
| | input_image_file_loader = self._file_loader |
| |
|
| | if not isinstance(input_image_file_loader, ImageLoader): |
| | raise RuntimeError("The used FileLoader was not of subclass ImageLoader") |
| |
|
| | |
| | input_image = input_image_file_loader.load_image(input_image_file_path) |
| |
|
| | |
| | if input_image_file_loader.hash_image(input_image) != image["hash"]: |
| | raise RuntimeError("Image hashes do not match") |
| | return input_image, input_image_file_path |
| |
|
| | @abstractmethod |
| | def predict(self, *, input_dict: Dict[str, SimpleITK.Image]) -> SimpleITK.Image: |
| | pass |
| |
|
| | def save(self): |
| | with open(str(self.output_file), "w") as f: |
| | json.dump(self._case_results, f) |
| |
|
| | def process(self): |
| | self.load() |
| | self.validate() |
| | self.process_cases() |
| | self.save() |
| |
|