Spaces:
Runtime error
Runtime error
File size: 3,763 Bytes
6216ecd edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd 8a74fe1 87c4a7b edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd edb1d95 6216ecd 34eb6c0 edb1d95 34eb6c0 edb1d95 34eb6c0 edb1d95 34eb6c0 edb1d95 34eb6c0 edb1d95 34eb6c0 edb1d95 34eb6c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from src.utils.config_loader import constants
from huggingface_hub import snapshot_download
from zipfile import ZipFile
import numpy as np
import os, shutil
import matplotlib.pyplot as plt
import cv2
import math
def download_hf_dataset(repo_id, allow_patterns=None):
"""Used to download dataset from any public hugging face dataset"""
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=constants.RAW_DATASET_DIR,
allow_patterns=allow_patterns,
)
def download_personal_hf_dataset(name):
"""Used to download dataset from a specific hugging face dataset"""
download_hf_dataset(
repo_id="Anuj-Panthri/Image-Colorization-Datasets", allow_patterns=f"{name}/*"
)
def unzip_file(file_path, destination_dir):
"""unzips file to destination_dir"""
if os.path.exists(destination_dir):
shutil.rmtree(destination_dir)
os.makedirs(destination_dir)
with ZipFile(file_path, "r") as zip:
zip.extractall(destination_dir)
def is_bw(img: np.ndarray):
"""checks if RGB image is black and white"""
rg, gb, rb = (
img[:, :, 0] - img[:, :, 1],
img[:, :, 1] - img[:, :, 2],
img[:, :, 0] - img[:, :, 2],
)
rg, gb, rb = np.abs(rg).sum(), np.abs(gb).sum(), np.abs(rb).sum()
avg = np.mean([rg, gb, rb])
return avg < 10
def print_title(msg: str, max_chars=105):
n = (max_chars - len(msg)) // 2
print("=" * n, msg.upper(), "=" * n, sep="")
def scale_L(L):
return L / 100
def rescale_L(L):
return L * 100
def scale_AB(AB):
return AB / 128
def rescale_AB(AB):
return AB * 128
def show_images_from_paths(
image_paths: list[str],
image_size=64,
cols=4,
row_size=5,
col_size=5,
show_BW=False,
title=None,
save=False,
label="",
):
n = len(image_paths)
rows = math.ceil(n / cols)
fig = plt.figure(figsize=(col_size * cols, row_size * rows))
if title:
plt.title(title)
plt.axis("off")
for i in range(n):
fig.add_subplot(rows, cols, i + 1)
img = cv2.imread(image_paths[i])[:, :, ::-1]
img = cv2.resize(img, [image_size, image_size])
if show_BW:
BW = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
BW = np.tile(BW, (1, 1, 3))
img = np.concatenate([BW, img], axis=1)
plt.imshow(img.astype("uint8"))
if save:
os.makedirs(constants.ARTIFACT_DATASET_VISUALIZATION_DIR, exist_ok=True)
plt.savefig(
os.path.join(
constants.ARTIFACT_DATASET_VISUALIZATION_DIR, f"{label}_image.png"
)
)
plt.show()
def see_batch(
L_batch,
AB_batch,
show_L=False,
cols=4,
row_size=5,
col_size=5,
title=None,
save=False,
label="",
):
n = L_batch.shape[0]
rows = math.ceil(n / cols)
fig = plt.figure(figsize=(col_size * cols, row_size * rows))
if title:
plt.title(title)
plt.axis("off")
for i in range(n):
fig.add_subplot(rows, cols, i + 1)
L, AB = L_batch[i], AB_batch[i]
L, AB = rescale_L(L), rescale_AB(AB)
# print(L.shape,AB.shape)
img = np.concatenate([L, AB], axis=-1)
img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB) * 255
# print(img.min(),img.max())
if show_L:
L = np.tile(L, (1, 1, 3)) / 100 * 255
img = np.concatenate([L, img], axis=1)
plt.imshow(img.astype("uint8"))
if save:
os.makedirs(constants.ARTIFACT_RESULT_VISUALIZATION_DIR, exist_ok=True)
plt.savefig(
os.path.join(
constants.ARTIFACT_RESULT_VISUALIZATION_DIR, f"{label}_image.png"
)
)
plt.show()
|