from src.utils.config_loader import constants from huggingface_hub import snapshot_download from zipfile import ZipFile import numpy as np import os, shutil import matplotlib.pyplot as plt import cv2 import math def download_hf_dataset(repo_id, allow_patterns=None): """Used to download dataset from any public hugging face dataset""" snapshot_download( repo_id=repo_id, repo_type="dataset", local_dir=constants.RAW_DATASET_DIR, allow_patterns=allow_patterns, ) def download_personal_hf_dataset(name): """Used to download dataset from a specific hugging face dataset""" download_hf_dataset( repo_id="Anuj-Panthri/Image-Colorization-Datasets", allow_patterns=f"{name}/*" ) def unzip_file(file_path, destination_dir): """unzips file to destination_dir""" if os.path.exists(destination_dir): shutil.rmtree(destination_dir) os.makedirs(destination_dir) with ZipFile(file_path, "r") as zip: zip.extractall(destination_dir) def is_bw(img: np.ndarray): """checks if RGB image is black and white""" rg, gb, rb = ( img[:, :, 0] - img[:, :, 1], img[:, :, 1] - img[:, :, 2], img[:, :, 0] - img[:, :, 2], ) rg, gb, rb = np.abs(rg).sum(), np.abs(gb).sum(), np.abs(rb).sum() avg = np.mean([rg, gb, rb]) return avg < 10 def print_title(msg: str, max_chars=105): n = (max_chars - len(msg)) // 2 print("=" * n, msg.upper(), "=" * n, sep="") def scale_L(L): return L / 100 def rescale_L(L): return L * 100 def scale_AB(AB): return AB / 128 def rescale_AB(AB): return AB * 128 def show_images_from_paths( image_paths: list[str], image_size=64, cols=4, row_size=5, col_size=5, show_BW=False, title=None, save=False, label="", ): n = len(image_paths) rows = math.ceil(n / cols) fig = plt.figure(figsize=(col_size * cols, row_size * rows)) if title: plt.title(title) plt.axis("off") for i in range(n): fig.add_subplot(rows, cols, i + 1) img = cv2.imread(image_paths[i])[:, :, ::-1] img = cv2.resize(img, [image_size, image_size]) if show_BW: BW = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) BW = np.tile(BW, (1, 1, 3)) img = np.concatenate([BW, img], axis=1) plt.imshow(img.astype("uint8")) if save: os.makedirs(constants.ARTIFACT_DATASET_VISUALIZATION_DIR, exist_ok=True) plt.savefig( os.path.join( constants.ARTIFACT_DATASET_VISUALIZATION_DIR, f"{label}_image.png" ) ) plt.show() def see_batch( L_batch, AB_batch, show_L=False, cols=4, row_size=5, col_size=5, title=None, save=False, label="", ): n = L_batch.shape[0] rows = math.ceil(n / cols) fig = plt.figure(figsize=(col_size * cols, row_size * rows)) if title: plt.title(title) plt.axis("off") for i in range(n): fig.add_subplot(rows, cols, i + 1) L, AB = L_batch[i], AB_batch[i] L, AB = rescale_L(L), rescale_AB(AB) # print(L.shape,AB.shape) img = np.concatenate([L, AB], axis=-1) img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB) * 255 # print(img.min(),img.max()) if show_L: L = np.tile(L, (1, 1, 3)) / 100 * 255 img = np.concatenate([L, img], axis=1) plt.imshow(img.astype("uint8")) if save: os.makedirs(constants.ARTIFACT_RESULT_VISUALIZATION_DIR, exist_ok=True) plt.savefig( os.path.join( constants.ARTIFACT_RESULT_VISUALIZATION_DIR, f"{label}_image.png" ) ) plt.show()