suriya7's picture
Update app.py
0eb573f verified
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Model
from tensorflow.keras.applications.xception import Xception, preprocess_input
import pickle
from tqdm import tqdm
import os
from PIL import Image
from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout, Flatten, concatenate
import numpy as np
import gradio as gr
import tensorflow as tf
model = Xception()
# Restructure model
model = Model(inputs = model.inputs , outputs = model.layers[-2].output)
# import tensorflow as tf
# class MyEmbedding(tf.keras.layers.Embedding):
# def __init__(self, *args, input_length=None, **kwargs):
# super().__init__(*args, **kwargs)
# self.input_length = input_length
# def from_config(cls, config):
# input_length = config.pop('input_length', None)
# instance = cls(**config)
# instance.input_length = input_length
# return instance
# # Load the model with custom objects
# caption_model = tf.keras.models.load_model('model.h5', custom_objects={'MyEmbedding': MyEmbedding})
with open('captions.txt', 'r') as f:
next(f)
captions_doc = f.read()
# create mapping of image to captions
mapping = {}
# process lines
for line in tqdm(captions_doc.split('\n')):
# split the line by comma(,)
tokens = line.split(',')
if len(line) < 2:
continue
image_id, caption = tokens[0], tokens[1:]
# remove extension from image ID
image_id = image_id.split('.')[0]
# convert caption list to string
caption = " ".join(caption)
# create list if needed
if image_id not in mapping:
mapping[image_id] = []
# store the caption
mapping[image_id].append(caption)
def clean(mapping):
for key, captions in mapping.items():
for i in range(len(captions)):
# take one caption at a time
caption = captions[i]
# preprocessing steps
# convert to lowercase
caption = caption.lower()
# delete digits, special chars, etc.,
caption = caption.replace('[^A-Za-z]', '')
# delete additional spaces
caption = caption.replace('\s+', ' ')
# add start and end tags to the caption
caption = 'startseq ' + " ".join([word for word in caption.split() if len(word)>1]) + ' endseq'
captions[i] = caption
all_captions = []
for key in mapping:
for caption in mapping[key]:
all_captions.append(caption)
# tokenize the text
tokenizer = Tokenizer()
tokenizer.fit_on_texts(all_captions)
vocab_size = len(tokenizer.word_index) + 1
# get maximum length of the caption available
max_length = max(len(caption.split()) for caption in all_captions)
def extract_features(image):
image = load_img(image, target_size=(299, 299))
# convert image pixels to numpy array
image = img_to_array(image)
# reshape data for model
image = np.expand_dims(image, axis=0)
image = preprocess_input(image)
feature = model.predict(image, verbose=0)
return feature
def idx_to_word(integer, tokenizer):
for word,index, in tokenizer.word_index.items():
if index == integer:
return word
return None
def save_image(img, save_dir="saved_images"):
# Create the directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Save the image with a unique name
img_name = os.path.join(save_dir, "uploaded_image.png")
img.save(img_name)
return img_name
# generate caption for an image
def predict_caption(model, image, tokenizer, max_length=35):
# add start tag for generation process
in_text = 'startseq'
# iterate over the max length of sequence
for i in range(max_length):
# encode input sequence
sequence = tokenizer.texts_to_sequences([in_text])[0]
# pad the sequence
sequence = pad_sequences([sequence], max_length)
# predict next word
yhat = model.predict([image, sequence], verbose=0)
# get index with high probability
yhat = np.argmax(yhat)
# convert index to word
word = idx_to_word(yhat, tokenizer)
# stop if word not found
if word is None:
break
# append word as input for generating next word
in_text += " " + word
# stop if we reach end tag
if word == 'endseq':
break
return in_text
def caption_prediction(img):
image = Image.fromarray(img)
img_path = save_image(image)
features = extract_features(img_path)
y_pred = predict_caption(caption_model, features, tokenizer)[8:][:-6]
return y_pred
demo = gr.Interface(fn=caption_prediction, inputs='image',outputs='text',title='caption generator')
demo.launch(debug=True,share=True)