Spaces:
Configuration error
Configuration error
File size: 2,204 Bytes
0a3dbb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import abc
from copy import deepcopy
import cv2
import numpy as np
from sklearn.decomposition import PCA
from typing_extensions import Protocol
class TransformerInterface(Protocol):
@abc.abstractmethod
def inverse_transform(self, X: np.ndarray) -> np.ndarray:
...
@abc.abstractmethod
def fit(self, X: np.ndarray, y=None):
...
@abc.abstractmethod
def transform(self, X: np.ndarray, y=None) -> np.ndarray:
...
class DomainAdapter:
def __init__(self,
transformer: TransformerInterface,
ref_img: np.ndarray,
color_conversions=(None, None),
):
self.color_in, self.color_out = color_conversions
self.source_transformer = deepcopy(transformer)
self.target_transformer = transformer
self.target_transformer.fit(self.flatten(ref_img))
def to_colorspace(self, img):
if self.color_in is None:
return img
return cv2.cvtColor(img, self.color_in)
def from_colorspace(self, img):
if self.color_out is None:
return img
return cv2.cvtColor(img.astype('uint8'), self.color_out)
def flatten(self, img):
img = self.to_colorspace(img)
img = img.astype('float32') / 255.
return img.reshape(-1, 3)
def reconstruct(self, pixels, h, w):
pixels = (np.clip(pixels, 0, 1) * 255).astype('uint8')
return self.from_colorspace(pixels.reshape(h, w, 3))
@staticmethod
def _pca_sign(x):
return np.sign(np.trace(x.components_))
def __call__(self, image: np.ndarray):
h, w, _ = image.shape
pixels = self.flatten(image)
self.source_transformer.fit(pixels)
if self.target_transformer.__class__ in (PCA,):
# dirty hack to make sure colors are not inverted
if self._pca_sign(self.target_transformer) != self._pca_sign(self.source_transformer):
self.target_transformer.components_ *= -1
representation = self.source_transformer.transform(pixels)
result = self.target_transformer.inverse_transform(representation)
return self.reconstruct(result, h, w)
|