Spaces:
Configuration error
Configuration error
File size: 4,775 Bytes
0034848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from __future__ import absolute_import
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Sequence, Tuple
import numpy as np
from .serialization import Serializable
def get_shape(img: Any) -> Tuple[int, int]:
if isinstance(img, np.ndarray):
rows, cols = img.shape[:2]
return rows, cols
try:
import torch
if torch.is_tensor(img):
rows, cols = img.shape[-2:]
return rows, cols
except ImportError:
pass
raise RuntimeError(
f"Albumentations supports only numpy.ndarray and torch.Tensor data type for image. Got: {type(img)}"
)
def format_args(args_dict: Dict):
formatted_args = []
for k, v in args_dict.items():
if isinstance(v, str):
v = f"'{v}'"
formatted_args.append(f"{k}={v}")
return ", ".join(formatted_args)
class Params(Serializable, ABC):
def __init__(self, format: str, label_fields: Optional[Sequence[str]] = None):
self.format = format
self.label_fields = label_fields
def _to_dict(self) -> Dict[str, Any]:
return {"format": self.format, "label_fields": self.label_fields}
class DataProcessor(ABC):
def __init__(self, params: Params, additional_targets: Optional[Dict[str, str]] = None):
self.params = params
self.data_fields = [self.default_data_name]
if additional_targets is not None:
for k, v in additional_targets.items():
if v == self.default_data_name:
self.data_fields.append(k)
@property
@abstractmethod
def default_data_name(self) -> str:
raise NotImplementedError
def ensure_data_valid(self, data: Dict[str, Any]) -> None:
pass
def ensure_transforms_valid(self, transforms: Sequence[object]) -> None:
pass
def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
rows, cols = get_shape(data["image"])
for data_name in self.data_fields:
data[data_name] = self.filter(data[data_name], rows, cols)
data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="from")
data = self.remove_label_fields_from_data(data)
return data
def preprocess(self, data: Dict[str, Any]) -> None:
data = self.add_label_fields_to_data(data)
rows, cols = data["image"].shape[:2]
for data_name in self.data_fields:
data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="to")
def check_and_convert(self, data: Sequence, rows: int, cols: int, direction: str = "to") -> Sequence:
if self.params.format == "albumentations":
self.check(data, rows, cols)
return data
if direction == "to":
return self.convert_to_albumentations(data, rows, cols)
elif direction == "from":
return self.convert_from_albumentations(data, rows, cols)
else:
raise ValueError(f"Invalid direction. Must be `to` or `from`. Got `{direction}`")
@abstractmethod
def filter(self, data: Sequence, rows: int, cols: int) -> Sequence:
pass
@abstractmethod
def check(self, data: Sequence, rows: int, cols: int) -> None:
pass
@abstractmethod
def convert_to_albumentations(self, data: Sequence, rows: int, cols: int) -> Sequence:
pass
@abstractmethod
def convert_from_albumentations(self, data: Sequence, rows: int, cols: int) -> Sequence:
pass
def add_label_fields_to_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
if self.params.label_fields is None:
return data
for data_name in self.data_fields:
for field in self.params.label_fields:
assert len(data[data_name]) == len(data[field])
data_with_added_field = []
for d, field_value in zip(data[data_name], data[field]):
data_with_added_field.append(list(d) + [field_value])
data[data_name] = data_with_added_field
return data
def remove_label_fields_from_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
if self.params.label_fields is None:
return data
for data_name in self.data_fields:
label_fields_len = len(self.params.label_fields)
for idx, field in enumerate(self.params.label_fields):
field_values = []
for bbox in data[data_name]:
field_values.append(bbox[-label_fields_len + idx])
data[field] = field_values
if label_fields_len:
data[data_name] = [d[:-label_fields_len] for d in data[data_name]]
return data
|