import gradio as gr import torch from PIL import Image from transformers import LlavaForConditionalGeneration, AutoProcessor import logging import json import os from datetime import datetime import uuid import spacy from spacy.cli import download import zipfile import shutil # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Define paths OUTPUT_JSON_PATH = "captions.json" UPLOAD_DIR = "uploads" os.makedirs(UPLOAD_DIR, exist_ok=True) # Load SpaCy model for keyword extraction try: try: nlp = spacy.load("en_core_web_sm") except OSError: logger.info("Downloading en_core_web_sm model...") download("en_core_web_sm") nlp = spacy.load("en_core_web_sm") except Exception as e: logger.error(f"Error loading SpaCy model: {str(e)}") raise # Load LLAVA model and processor MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava" try: processor = AutoProcessor.from_pretrained(MODEL_PATH) model = LlavaForConditionalGeneration.from_pretrained( MODEL_PATH, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto" ) model.eval() logger.info("Model and processor loaded successfully.") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise # Function to extract keywords def extract_keywords(text): try: doc = nlp(text) keywords = [token.text.lower() for token in doc if token.pos_ in ["NOUN", "ADJ"] and not token.is_stop] return list(set(keywords))[:5] except Exception as e: logger.error(f"Error extracting keywords: {str(e)}") return [] # Function to save metadata to JSON def save_to_json(image_name, caption, caption_type, custom_prompt, keywords, error=None): result = { "image_name": image_name, "caption": caption, "caption_type": caption_type, "custom_prompt": custom_prompt, "keywords": keywords, "timestamp": datetime.now().isoformat(), "error": error } try: if os.path.exists(OUTPUT_JSON_PATH): with open(OUTPUT_JSON_PATH, "r") as f: data = json.load(f) else: data = [] except Exception as e: logger.error(f"Error reading JSON file: {str(e)}") data = [] data.append(result) try: with open(OUTPUT_JSON_PATH, "w") as f: json.dump(data, f, indent=4) logger.info(f"Saved result to {OUTPUT_JSON_PATH}") except Exception as e: logger.error(f"Error writing to JSON file: {str(e)}") # Function to process single image def process_single_image(image, caption_type, custom_prompt): if image is None: error_msg = "Please upload an image." save_to_json("unknown", error_msg, caption_type, custom_prompt, [], error=error_msg) return error_msg image_name = os.path.join(UPLOAD_DIR, f"image_{uuid.uuid4().hex}.jpg") image.save(image_name) try: image = image.resize((256, 256)) prompt = custom_prompt.strip() if custom_prompt.strip() else f"Write a {caption_type} caption for this image." convo = [ {"role": "system", "content": "You are a helpful assistant that generates accurate and relevant image captions."}, {"role": "user", "content": prompt.strip()} ] inputs = processor(images=image, text=convo[1]["content"], return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=50, temperature=0.7, top_p=0.9) caption = processor.decode(output[0], skip_special_tokens=True).strip() keywords = extract_keywords(caption) save_to_json(image_name, caption, caption_type, custom_prompt, keywords, error=None) return f"Caption: {caption}\nKeywords: {', '.join(keywords)}" except Exception as e: error_msg = f"Error generating caption: {str(e)}" logger.error(error_msg) save_to_json(image_name, "", caption_type, custom_prompt, [], error=error_msg) return error_msg # Function to process batch images def process_batch_images(zip_file, caption_type, custom_prompt): if zip_file is None: return "Please upload a zip file." temp_dir = "temp_upload" os.makedirs(temp_dir, exist_ok=True) results = [] try: with zipfile.ZipFile(zip_file.name, "r") as zip_ref: zip_ref.extractall(temp_dir) for root, _, files in os.walk(temp_dir): for file in files: if file.lower().endswith((".jpg", ".jpeg", ".png")): image_path = os.path.join(root, file) image_name = os.path.join(UPLOAD_DIR, f"image_{uuid.uuid4().hex}.jpg") shutil.copy(image_path, image_name) try: image = Image.open(image_path).convert("RGB").resize((256, 256)) prompt = custom_prompt.strip() if custom_prompt.strip() else f"Write a {caption_type} caption for this image." convo = [ {"role": "system", "content": "You are a helpful assistant that generates accurate and relevant image captions."}, {"role": "user", "content": prompt.strip()} ] inputs = processor(images=image, text=convo[1]["content"], return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=50, temperature=0.7, top_p=0.9) caption = processor.decode(output[0], skip_special_tokens=True).strip() keywords = extract_keywords(caption) save_to_json(image_name, caption, caption_type, custom_prompt, keywords, error=None) results.append(f"Image: {image_name}\nCaption: {caption}\nKeywords: {', '.join(keywords)}") except Exception as e: error_msg = f"Error processing {image_path}: {str(e)}" logger.error(error_msg) save_to_json(image_name, "", caption_type, custom_prompt, [], error=error_msg) results.append(error_msg) shutil.rmtree(temp_dir) return "\n\n".join(results) except Exception as e: error_msg = f"Error processing batch: {str(e)}" logger.error(error_msg) return error_msg # Function to search images def search_images(query): try: if not os.path.exists(OUTPUT_JSON_PATH): return "No captions available." with open(OUTPUT_JSON_PATH, "r") as f: data = json.load(f) results = [] for entry in data: if query.lower() in entry["caption"].lower() or any(query.lower() in kw.lower() for kw in entry["keywords"]): results.append((entry["image_name"], f"Caption: {entry['caption']}\nKeywords: {', '.join(entry['keywords'])}")) return results if results else "No matches found." except Exception as e: logger.error(f"Error searching images: {str(e)}") return f"Error searching images: {str(e)}" # Gradio interface interface = gr.Interface( fn=[process_single_image, process_batch_images, search_images], inputs=[ [gr.Image(label="Upload Single Image", type="pil"), gr.Dropdown(choices=["descriptive", "poetic", "humorous"], label="Caption Style", value="descriptive"), gr.Textbox(label="Custom Prompt (optional)", placeholder="e.g., 'Write a poetic caption'")], [gr.File(label="Upload Zip File for Batch Processing", file_types=[".zip"]), gr.Dropdown(choices=["descriptive", "poetic", "humorous"], label="Caption Style", value="descriptive"), gr.Textbox(label="Custom Prompt (optional)", placeholder="e.g., 'Write a poetic caption'")], [gr.Textbox(label="Search Query", placeholder="e.g., 'beach'")] ], outputs=[ gr.Textbox(label="Single Image Result"), gr.Textbox(label="Batch Processing Results"), gr.Gallery(label="Search Results") ], title="Image Captioning with LLAVA", description="Upload single or batch images, generate captions with custom styles, and search by captions or keywords. Results are saved to captions.json." ) if __name__ == "__main__": interface.launch()