Spaces:
Sleeping
Sleeping
import keras | |
import numpy as np | |
import tensorflow as tf | |
import re | |
from src.components.model import get_cnn_model, TransformerEncoderBlock, TransformerDecoderBlock, ImageCaptioningModel, image_augmentation, LRSchedule | |
SEQ_LENGTH = 25 | |
VOCAB_SIZE = 10000 | |
IMAGE_SIZE = (299, 299) | |
print("loading_model...") | |
loaded_model = keras.saving.load_model( | |
"./artifacts/caption_model.keras", compile=True) | |
print("model loaded...") | |
vocab = np.load("./artifacts/vocabulary.npy") | |
print("vocab loaded...") | |
data_txt = np.load("./artifacts/data_txt.npy").tolist() | |
print("vectorization data loaded...") | |
index_lookup = dict(zip(range(len(vocab)), vocab)) | |
print("index lookup loaded...") | |
max_decoded_sentence_length = SEQ_LENGTH - 1 | |
strip_chars = "!\"#$%&'()*+,-./:;=?@[\]^_`{|}~" | |
def custom_standardization(input_string): | |
lowercase = tf.strings.lower(input_string) | |
return tf.strings.regex_replace(lowercase, f'{re.escape(strip_chars)}', '') | |
vectorization = keras.layers.TextVectorization( | |
max_tokens=VOCAB_SIZE, | |
output_mode="int", | |
output_sequence_length=SEQ_LENGTH, | |
standardize=custom_standardization, | |
) | |
vectorization.adapt(data_txt) | |
print("vectorization adapted...") | |
def decode_and_resize(image): | |
if isinstance(image, str): | |
img = tf.io.read_file(image) | |
img = tf.image.decode_jpeg(img, channels=3) | |
elif isinstance(image, np.ndarray): | |
img = tf.constant(image) | |
img = tf.image.resize(img, IMAGE_SIZE) | |
img = tf.image.convert_image_dtype(img, tf.float32) | |
return img | |
def generate_caption(image): | |
sample_img = decode_and_resize(image) | |
# Pass the image to the CNN | |
img = tf.expand_dims(sample_img, 0) | |
img = loaded_model.cnn_model(img) | |
# Pass the image features to the Transformer encoder | |
encoded_img = loaded_model.encoder(img, training=False) | |
# Generate the caption using the Transformer decoder | |
decoded_caption = "<start> " | |
for i in range(max_decoded_sentence_length): | |
tokenized_caption = vectorization([decoded_caption]) | |
mask = tf.math.not_equal(tokenized_caption, 0) | |
predictions = loaded_model.decoder( | |
tokenized_caption, encoded_img, training=False, mask=mask | |
) | |
sampled_token_index = np.argmax(predictions[0, i, :]) | |
sampled_token = index_lookup[sampled_token_index] | |
if sampled_token == "<end>": | |
break | |
decoded_caption += " " + sampled_token | |
decoded_caption = decoded_caption.replace("<start> ", "") | |
decoded_caption = decoded_caption.replace(" <end>", "").strip() | |
return decoded_caption | |