Spaces:
Runtime error
Runtime error
from src.utils.config_loader import constants | |
from huggingface_hub import snapshot_download | |
from zipfile import ZipFile | |
import numpy as np | |
import os, shutil | |
import matplotlib.pyplot as plt | |
import cv2 | |
import math | |
def download_hf_dataset(repo_id, allow_patterns=None): | |
"""Used to download dataset from any public hugging face dataset""" | |
snapshot_download( | |
repo_id=repo_id, | |
repo_type="dataset", | |
local_dir=constants.RAW_DATASET_DIR, | |
allow_patterns=allow_patterns, | |
) | |
def download_personal_hf_dataset(name): | |
"""Used to download dataset from a specific hugging face dataset""" | |
download_hf_dataset( | |
repo_id="Anuj-Panthri/Image-Colorization-Datasets", allow_patterns=f"{name}/*" | |
) | |
def unzip_file(file_path, destination_dir): | |
"""unzips file to destination_dir""" | |
if os.path.exists(destination_dir): | |
shutil.rmtree(destination_dir) | |
os.makedirs(destination_dir) | |
with ZipFile(file_path, "r") as zip: | |
zip.extractall(destination_dir) | |
def is_bw(img: np.ndarray): | |
"""checks if RGB image is black and white""" | |
rg, gb, rb = ( | |
img[:, :, 0] - img[:, :, 1], | |
img[:, :, 1] - img[:, :, 2], | |
img[:, :, 0] - img[:, :, 2], | |
) | |
rg, gb, rb = np.abs(rg).sum(), np.abs(gb).sum(), np.abs(rb).sum() | |
avg = np.mean([rg, gb, rb]) | |
return avg < 10 | |
def print_title(msg: str, max_chars=105): | |
n = (max_chars - len(msg)) // 2 | |
print("=" * n, msg.upper(), "=" * n, sep="") | |
def scale_L(L): | |
return L / 100 | |
def rescale_L(L): | |
return L * 100 | |
def scale_AB(AB): | |
return AB / 128 | |
def rescale_AB(AB): | |
return AB * 128 | |
def show_images_from_paths( | |
image_paths: list[str], | |
image_size=64, | |
cols=4, | |
row_size=5, | |
col_size=5, | |
show_BW=False, | |
title=None, | |
save=False, | |
label="", | |
): | |
n = len(image_paths) | |
rows = math.ceil(n / cols) | |
fig = plt.figure(figsize=(col_size * cols, row_size * rows)) | |
if title: | |
plt.title(title) | |
plt.axis("off") | |
for i in range(n): | |
fig.add_subplot(rows, cols, i + 1) | |
img = cv2.imread(image_paths[i])[:, :, ::-1] | |
img = cv2.resize(img, [image_size, image_size]) | |
if show_BW: | |
BW = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
BW = np.tile(BW, (1, 1, 3)) | |
img = np.concatenate([BW, img], axis=1) | |
plt.imshow(img.astype("uint8")) | |
if save: | |
os.makedirs(constants.ARTIFACT_DATASET_VISUALIZATION_DIR, exist_ok=True) | |
plt.savefig( | |
os.path.join( | |
constants.ARTIFACT_DATASET_VISUALIZATION_DIR, f"{label}_image.png" | |
) | |
) | |
plt.show() | |
def see_batch( | |
L_batch, | |
AB_batch, | |
show_L=False, | |
cols=4, | |
row_size=5, | |
col_size=5, | |
title=None, | |
save=False, | |
label="", | |
): | |
n = L_batch.shape[0] | |
rows = math.ceil(n / cols) | |
fig = plt.figure(figsize=(col_size * cols, row_size * rows)) | |
if title: | |
plt.title(title) | |
plt.axis("off") | |
for i in range(n): | |
fig.add_subplot(rows, cols, i + 1) | |
L, AB = L_batch[i], AB_batch[i] | |
L, AB = rescale_L(L), rescale_AB(AB) | |
# print(L.shape,AB.shape) | |
img = np.concatenate([L, AB], axis=-1) | |
img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB) * 255 | |
# print(img.min(),img.max()) | |
if show_L: | |
L = np.tile(L, (1, 1, 3)) / 100 * 255 | |
img = np.concatenate([L, img], axis=1) | |
plt.imshow(img.astype("uint8")) | |
if save: | |
os.makedirs(constants.ARTIFACT_RESULT_VISUALIZATION_DIR, exist_ok=True) | |
plt.savefig( | |
os.path.join( | |
constants.ARTIFACT_RESULT_VISUALIZATION_DIR, f"{label}_image.png" | |
) | |
) | |
plt.show() | |