Spaces:
Sleeping
Sleeping
import pathlib | |
import keras | |
import tensorflow as tf | |
import os | |
import numpy as np | |
import re | |
IMAGES_PATH = "Flicker8k_Dataset" | |
IMAGE_SIZE = (299, 299) | |
VOCAB_SIZE = 10000 | |
SEQ_LENGTH = 25 | |
BATCH_SIZE = 64 | |
AUTOTUNE = tf.data.AUTOTUNE | |
path = pathlib.Path(".") | |
keras.utils.get_file( | |
origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip', | |
cache_dir='.', | |
cache_subdir=path, | |
extract=True) | |
keras.utils.get_file( | |
origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip', | |
cache_dir='.', | |
cache_subdir=path, | |
extract=True) | |
dataset = pathlib.Path(path, "Flickr8k.token.txt").read_text( | |
encoding='utf-8').splitlines() | |
dataset = [line.split('\t') for line in dataset] | |
dataset = [[os.path.join(IMAGES_PATH, fname.split( | |
'#')[0].strip()), caption] for (fname, caption) in dataset] | |
caption_mapping = {} | |
text_data = [] | |
X_en_data = [] | |
X_de_data = [] | |
Y_data = [] | |
for img_name, caption in dataset: | |
if img_name.endswith("jpg"): | |
X_de_data.append("<start> " + caption.strip().replace(".", "")) | |
Y_data.append(caption.strip().replace(".", "") + " <end>") | |
text_data.append( | |
"<start> " + caption.strip().replace(".", "") + " <end>") | |
X_en_data.append(img_name) | |
if img_name in caption_mapping: | |
caption_mapping[img_name].append(caption) | |
else: | |
caption_mapping[img_name] = [caption] | |
train_size = 0.8 | |
shuffle = True | |
np.random.seed(42) | |
zipped = list(zip(X_en_data, X_de_data, Y_data)) | |
np.random.shuffle(zipped) | |
X_en_data, X_de_data, Y_data = zip(*zipped) | |
train_size = int(len(X_en_data)*train_size) | |
X_train_en = list(X_en_data[:train_size]) | |
X_train_de = list(X_de_data[:train_size]) | |
Y_train = list(Y_data[:train_size]) | |
X_valid_en = list(X_en_data[train_size:]) | |
X_valid_de = list(X_de_data[train_size:]) | |
Y_valid = list(Y_data[train_size:]) | |
strip_chars = "!\"#$%&'()*+,-./:;=?@[\]^_`{|}~" | |
def custom_standardization(input_string): | |
lowercase = tf.strings.lower(input_string) | |
return tf.strings.regex_replace(lowercase, f'{re.escape(strip_chars)}', '') | |
vectorization = keras.layers.TextVectorization( | |
max_tokens=VOCAB_SIZE, | |
output_mode="int", | |
output_sequence_length=SEQ_LENGTH, | |
standardize=custom_standardization, | |
) | |
vectorization.adapt(text_data) | |
vocab = np.array(vectorization.get_vocabulary()) | |
np.save('./artifacts/vocabulary.npy', vocab) | |
def decode_and_resize(img_path): | |
img = tf.io.read_file(img_path) | |
img = tf.image.decode_jpeg(img, channels=3) | |
img = tf.image.resize(img, IMAGE_SIZE) | |
img = tf.image.convert_image_dtype(img, tf.float32) | |
return img | |
def process_input(img_cap, y_captions): | |
img_path, x_captions = img_cap | |
return ((decode_and_resize(img_path), vectorization(x_captions)), vectorization(y_captions)) | |
def make_dataset(images, x_captions, y_captions): | |
dataset = tf.data.Dataset.from_tensor_slices( | |
((images, x_captions), y_captions)) | |
dataset = dataset.map(process_input, num_parallel_calls=AUTOTUNE) | |
dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE) | |
return dataset | |
train_dataset = make_dataset(X_train_en, X_train_de, Y_train) | |
valid_dataset = make_dataset(X_valid_en, X_valid_de, Y_valid) | |