JasonSmithSO's picture
Upload 777 files
0034848 verified
from __future__ import absolute_import
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Sequence, Tuple
import numpy as np
from .serialization import Serializable
def get_shape(img: Any) -> Tuple[int, int]:
if isinstance(img, np.ndarray):
rows, cols = img.shape[:2]
return rows, cols
try:
import torch
if torch.is_tensor(img):
rows, cols = img.shape[-2:]
return rows, cols
except ImportError:
pass
raise RuntimeError(
f"Albumentations supports only numpy.ndarray and torch.Tensor data type for image. Got: {type(img)}"
)
def format_args(args_dict: Dict):
formatted_args = []
for k, v in args_dict.items():
if isinstance(v, str):
v = f"'{v}'"
formatted_args.append(f"{k}={v}")
return ", ".join(formatted_args)
class Params(Serializable, ABC):
def __init__(self, format: str, label_fields: Optional[Sequence[str]] = None):
self.format = format
self.label_fields = label_fields
def _to_dict(self) -> Dict[str, Any]:
return {"format": self.format, "label_fields": self.label_fields}
class DataProcessor(ABC):
def __init__(self, params: Params, additional_targets: Optional[Dict[str, str]] = None):
self.params = params
self.data_fields = [self.default_data_name]
if additional_targets is not None:
for k, v in additional_targets.items():
if v == self.default_data_name:
self.data_fields.append(k)
@property
@abstractmethod
def default_data_name(self) -> str:
raise NotImplementedError
def ensure_data_valid(self, data: Dict[str, Any]) -> None:
pass
def ensure_transforms_valid(self, transforms: Sequence[object]) -> None:
pass
def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
rows, cols = get_shape(data["image"])
for data_name in self.data_fields:
data[data_name] = self.filter(data[data_name], rows, cols)
data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="from")
data = self.remove_label_fields_from_data(data)
return data
def preprocess(self, data: Dict[str, Any]) -> None:
data = self.add_label_fields_to_data(data)
rows, cols = data["image"].shape[:2]
for data_name in self.data_fields:
data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="to")
def check_and_convert(self, data: Sequence, rows: int, cols: int, direction: str = "to") -> Sequence:
if self.params.format == "albumentations":
self.check(data, rows, cols)
return data
if direction == "to":
return self.convert_to_albumentations(data, rows, cols)
elif direction == "from":
return self.convert_from_albumentations(data, rows, cols)
else:
raise ValueError(f"Invalid direction. Must be `to` or `from`. Got `{direction}`")
@abstractmethod
def filter(self, data: Sequence, rows: int, cols: int) -> Sequence:
pass
@abstractmethod
def check(self, data: Sequence, rows: int, cols: int) -> None:
pass
@abstractmethod
def convert_to_albumentations(self, data: Sequence, rows: int, cols: int) -> Sequence:
pass
@abstractmethod
def convert_from_albumentations(self, data: Sequence, rows: int, cols: int) -> Sequence:
pass
def add_label_fields_to_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
if self.params.label_fields is None:
return data
for data_name in self.data_fields:
for field in self.params.label_fields:
assert len(data[data_name]) == len(data[field])
data_with_added_field = []
for d, field_value in zip(data[data_name], data[field]):
data_with_added_field.append(list(d) + [field_value])
data[data_name] = data_with_added_field
return data
def remove_label_fields_from_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
if self.params.label_fields is None:
return data
for data_name in self.data_fields:
label_fields_len = len(self.params.label_fields)
for idx, field in enumerate(self.params.label_fields):
field_values = []
for bbox in data[data_name]:
field_values.append(bbox[-label_fields_len + idx])
data[field] = field_values
if label_fields_len:
data[data_name] = [d[:-label_fields_len] for d in data[data_name]]
return data