from tensorflow.keras.preprocessing.image import load_img, img_to_array from tensorflow.keras.preprocessing.text import Tokenizer from tensorflow.keras.preprocessing.sequence import pad_sequences from tensorflow.keras.models import Model from tensorflow.keras.applications.xception import Xception, preprocess_input import pickle from tqdm import tqdm import os from PIL import Image from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout, Flatten, concatenate import numpy as np import gradio as gr import tensorflow as tf model = Xception() # Restructure model model = Model(inputs = model.inputs , outputs = model.layers[-2].output) # import tensorflow as tf # class MyEmbedding(tf.keras.layers.Embedding): # def __init__(self, *args, input_length=None, **kwargs): # super().__init__(*args, **kwargs) # self.input_length = input_length # def from_config(cls, config): # input_length = config.pop('input_length', None) # instance = cls(**config) # instance.input_length = input_length # return instance # # Load the model with custom objects # caption_model = tf.keras.models.load_model('model.h5', custom_objects={'MyEmbedding': MyEmbedding}) with open('captions.txt', 'r') as f: next(f) captions_doc = f.read() # create mapping of image to captions mapping = {} # process lines for line in tqdm(captions_doc.split('\n')): # split the line by comma(,) tokens = line.split(',') if len(line) < 2: continue image_id, caption = tokens[0], tokens[1:] # remove extension from image ID image_id = image_id.split('.')[0] # convert caption list to string caption = " ".join(caption) # create list if needed if image_id not in mapping: mapping[image_id] = [] # store the caption mapping[image_id].append(caption) def clean(mapping): for key, captions in mapping.items(): for i in range(len(captions)): # take one caption at a time caption = captions[i] # preprocessing steps # convert to lowercase caption = caption.lower() # delete digits, special chars, etc., caption = caption.replace('[^A-Za-z]', '') # delete additional spaces caption = caption.replace('\s+', ' ') # add start and end tags to the caption caption = 'startseq ' + " ".join([word for word in caption.split() if len(word)>1]) + ' endseq' captions[i] = caption all_captions = [] for key in mapping: for caption in mapping[key]: all_captions.append(caption) # tokenize the text tokenizer = Tokenizer() tokenizer.fit_on_texts(all_captions) vocab_size = len(tokenizer.word_index) + 1 # get maximum length of the caption available max_length = max(len(caption.split()) for caption in all_captions) def extract_features(image): image = load_img(image, target_size=(299, 299)) # convert image pixels to numpy array image = img_to_array(image) # reshape data for model image = np.expand_dims(image, axis=0) image = preprocess_input(image) feature = model.predict(image, verbose=0) return feature def idx_to_word(integer, tokenizer): for word,index, in tokenizer.word_index.items(): if index == integer: return word return None def save_image(img, save_dir="saved_images"): # Create the directory if it doesn't exist os.makedirs(save_dir, exist_ok=True) # Save the image with a unique name img_name = os.path.join(save_dir, "uploaded_image.png") img.save(img_name) return img_name # generate caption for an image def predict_caption(model, image, tokenizer, max_length=35): # add start tag for generation process in_text = 'startseq' # iterate over the max length of sequence for i in range(max_length): # encode input sequence sequence = tokenizer.texts_to_sequences([in_text])[0] # pad the sequence sequence = pad_sequences([sequence], max_length) # predict next word yhat = model.predict([image, sequence], verbose=0) # get index with high probability yhat = np.argmax(yhat) # convert index to word word = idx_to_word(yhat, tokenizer) # stop if word not found if word is None: break # append word as input for generating next word in_text += " " + word # stop if we reach end tag if word == 'endseq': break return in_text def caption_prediction(img): image = Image.fromarray(img) img_path = save_image(image) features = extract_features(img_path) y_pred = predict_caption(caption_model, features, tokenizer)[8:][:-6] return y_pred demo = gr.Interface(fn=caption_prediction, inputs='image',outputs='text',title='caption generator') demo.launch(debug=True,share=True)