File size: 4,891 Bytes
07a6bfd f5dc979 b031819 f5dc979 07a6bfd b031819 07a6bfd f5dc979 b031819 f5dc979 ead87b5 b031819 f5dc979 07a6bfd f5dc979 07a6bfd ead87b5 f5dc979 07a6bfd f5dc979 f40c282 f5dc979 ead87b5 f5dc979 ead87b5 f5dc979 07a6bfd f5dc979 8f189bc 07a6bfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import gradio as gr
import torch
from PIL import Image
from transformers import LlavaForConditionalGeneration, AutoProcessor
import logging
import json
import os
from datetime import datetime
import uuid
from huggingface_hub import snapshot_download
import shutil
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define output JSON file path
OUTPUT_JSON_PATH = "captions.json"
# Clear Hugging Face cache and download model
MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
try:
# Clear cache to avoid corrupted files
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
model_cache = os.path.join(cache_dir, f"models--{MODEL_PATH.replace('/', '--')}")
if os.path.exists(model_cache):
shutil.rmtree(model_cache)
logger.info(f"Cleared cache for {MODEL_PATH}")
# Pre-download model to ensure integrity
snapshot_download(repo_id=MODEL_PATH)
logger.info(f"Downloaded model {MODEL_PATH}")
# Load processor and model
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = LlavaForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float32, # CPU-compatible dtype
low_cpu_mem_usage=True, # Minimize memory usage
use_safetensors=True # Force safetensors
).to("cpu")
model.eval()
logger.info("Model and processor loaded successfully.")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
# Function to append results to JSON file
def save_to_json(image_name, caption, caption_type, caption_length, error=None):
result = {
"image_name": image_name,
"caption": caption,
"caption_type": caption_type,
"caption_length": caption_length,
"timestamp": datetime.now().isoformat(),
"error": error
}
try:
if os.path.exists(OUTPUT_JSON_PATH):
with open(OUTPUT_JSON_PATH, "r") as f:
data = json.load(f)
else:
data = []
except Exception as e:
logger.error(f"Error reading JSON file: {str(e)}")
data = []
data.append(result)
try:
with open(OUTPUT_JSON_PATH, "w") as f:
json.dump(data, f, indent=4)
logger.info(f"Saved result to {OUTPUT_JSON_PATH}")
except Exception as e:
logger.error(f"Error writing to JSON file: {str(e)}")
# Define the captioning function
def generate_caption(input_image: Image.Image, caption_type: str = "descriptive", caption_length: str = "medium") -> str:
if input_image is None:
error_msg = "Please upload an image."
save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg)
return error_msg
# Generate a unique image name
image_name = f"image_{uuid.uuid4().hex}.jpg"
try:
# Resize image to reduce memory usage
input_image = input_image.resize((256, 256))
# Prepare the prompt
prompt = f"Write a {caption_length} {caption_type} caption for this image."
convo = [
{
"role": "system",
"content": "You are a helpful assistant that generates accurate and relevant image captions."
},
{
"role": "user",
"content": prompt.strip()
}
]
# Process the image and prompt
inputs = processor(images=input_image, text=convo[1]["content"], return_tensors="pt").to("cpu")
# Generate the caption with reduced max tokens
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=50, temperature=0.7, top_p=0.9)
# Decode the output
caption = processor.decode(output[0], skip_special_tokens=True).strip()
# Save to JSON
save_to_json(image_name, caption, caption_type, caption_length, error=None)
return caption
except Exception as e:
error_msg = f"Error generating caption: {str(e)}"
logger.error(error_msg)
save_to_json(image_name, "", caption_type, caption_length, error=error_msg)
return error_msg
# Create the Gradio interface
interface = gr.Interface(
fn=generate_caption,
inputs=[
gr.Image(label="Upload Image", type="pil"),
gr.Dropdown(choices=["descriptive", "casual", "social media"], label="Caption Type", value="descriptive"),
gr.Dropdown(choices=["short", "medium", "long"], label="Caption Length", value="medium")
],
outputs=gr.Textbox(label="Generated Caption"),
title="Image Captioning with JoyCaption",
description="Upload an image to generate a caption using the fancyfeast/llama-joycaption-beta-one-hf-llava model. Results are saved to captions.json."
)
if __name__ == "__main__":
interface.launch() |