Spaces:
Sleeping
Sleeping
from tensorflow.keras.preprocessing.image import load_img, img_to_array | |
from tensorflow.keras.preprocessing.text import Tokenizer | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.applications.xception import Xception, preprocess_input | |
import pickle | |
from tqdm import tqdm | |
import os | |
from PIL import Image | |
from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout, Flatten, concatenate | |
import numpy as np | |
import gradio as gr | |
import tensorflow as tf | |
model = Xception() | |
# Restructure model | |
model = Model(inputs = model.inputs , outputs = model.layers[-2].output) | |
# import tensorflow as tf | |
# class MyEmbedding(tf.keras.layers.Embedding): | |
# def __init__(self, *args, input_length=None, **kwargs): | |
# super().__init__(*args, **kwargs) | |
# self.input_length = input_length | |
# def from_config(cls, config): | |
# input_length = config.pop('input_length', None) | |
# instance = cls(**config) | |
# instance.input_length = input_length | |
# return instance | |
# # Load the model with custom objects | |
# caption_model = tf.keras.models.load_model('model.h5', custom_objects={'MyEmbedding': MyEmbedding}) | |
with open('captions.txt', 'r') as f: | |
next(f) | |
captions_doc = f.read() | |
# create mapping of image to captions | |
mapping = {} | |
# process lines | |
for line in tqdm(captions_doc.split('\n')): | |
# split the line by comma(,) | |
tokens = line.split(',') | |
if len(line) < 2: | |
continue | |
image_id, caption = tokens[0], tokens[1:] | |
# remove extension from image ID | |
image_id = image_id.split('.')[0] | |
# convert caption list to string | |
caption = " ".join(caption) | |
# create list if needed | |
if image_id not in mapping: | |
mapping[image_id] = [] | |
# store the caption | |
mapping[image_id].append(caption) | |
def clean(mapping): | |
for key, captions in mapping.items(): | |
for i in range(len(captions)): | |
# take one caption at a time | |
caption = captions[i] | |
# preprocessing steps | |
# convert to lowercase | |
caption = caption.lower() | |
# delete digits, special chars, etc., | |
caption = caption.replace('[^A-Za-z]', '') | |
# delete additional spaces | |
caption = caption.replace('\s+', ' ') | |
# add start and end tags to the caption | |
caption = 'startseq ' + " ".join([word for word in caption.split() if len(word)>1]) + ' endseq' | |
captions[i] = caption | |
all_captions = [] | |
for key in mapping: | |
for caption in mapping[key]: | |
all_captions.append(caption) | |
# tokenize the text | |
tokenizer = Tokenizer() | |
tokenizer.fit_on_texts(all_captions) | |
vocab_size = len(tokenizer.word_index) + 1 | |
# get maximum length of the caption available | |
max_length = max(len(caption.split()) for caption in all_captions) | |
def extract_features(image): | |
image = load_img(image, target_size=(299, 299)) | |
# convert image pixels to numpy array | |
image = img_to_array(image) | |
# reshape data for model | |
image = np.expand_dims(image, axis=0) | |
image = preprocess_input(image) | |
feature = model.predict(image, verbose=0) | |
return feature | |
def idx_to_word(integer, tokenizer): | |
for word,index, in tokenizer.word_index.items(): | |
if index == integer: | |
return word | |
return None | |
def save_image(img, save_dir="saved_images"): | |
# Create the directory if it doesn't exist | |
os.makedirs(save_dir, exist_ok=True) | |
# Save the image with a unique name | |
img_name = os.path.join(save_dir, "uploaded_image.png") | |
img.save(img_name) | |
return img_name | |
# generate caption for an image | |
def predict_caption(model, image, tokenizer, max_length=35): | |
# add start tag for generation process | |
in_text = 'startseq' | |
# iterate over the max length of sequence | |
for i in range(max_length): | |
# encode input sequence | |
sequence = tokenizer.texts_to_sequences([in_text])[0] | |
# pad the sequence | |
sequence = pad_sequences([sequence], max_length) | |
# predict next word | |
yhat = model.predict([image, sequence], verbose=0) | |
# get index with high probability | |
yhat = np.argmax(yhat) | |
# convert index to word | |
word = idx_to_word(yhat, tokenizer) | |
# stop if word not found | |
if word is None: | |
break | |
# append word as input for generating next word | |
in_text += " " + word | |
# stop if we reach end tag | |
if word == 'endseq': | |
break | |
return in_text | |
def caption_prediction(img): | |
image = Image.fromarray(img) | |
img_path = save_image(image) | |
features = extract_features(img_path) | |
y_pred = predict_caption(caption_model, features, tokenizer)[8:][:-6] | |
return y_pred | |
demo = gr.Interface(fn=caption_prediction, inputs='image',outputs='text',title='caption generator') | |
demo.launch(debug=True,share=True) | |