|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from transformers import LlavaForConditionalGeneration, AutoProcessor |
|
import logging |
|
import json |
|
import os |
|
from datetime import datetime |
|
import uuid |
|
from huggingface_hub import snapshot_download |
|
import shutil |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
OUTPUT_JSON_PATH = "captions.json" |
|
|
|
|
|
MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava" |
|
try: |
|
|
|
cache_dir = os.path.expanduser("~/.cache/huggingface/hub") |
|
model_cache = os.path.join(cache_dir, f"models--{MODEL_PATH.replace('/', '--')}") |
|
if os.path.exists(model_cache): |
|
shutil.rmtree(model_cache) |
|
logger.info(f"Cleared cache for {MODEL_PATH}") |
|
|
|
|
|
snapshot_download(repo_id=MODEL_PATH) |
|
logger.info(f"Downloaded model {MODEL_PATH}") |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_PATH) |
|
model = LlavaForConditionalGeneration.from_pretrained( |
|
MODEL_PATH, |
|
torch_dtype=torch.float32, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True |
|
).to("cpu") |
|
model.eval() |
|
logger.info("Model and processor loaded successfully.") |
|
except Exception as e: |
|
logger.error(f"Error loading model: {str(e)}") |
|
raise |
|
|
|
|
|
def save_to_json(image_name, caption, caption_type, caption_length, error=None): |
|
result = { |
|
"image_name": image_name, |
|
"caption": caption, |
|
"caption_type": caption_type, |
|
"caption_length": caption_length, |
|
"timestamp": datetime.now().isoformat(), |
|
"error": error |
|
} |
|
try: |
|
if os.path.exists(OUTPUT_JSON_PATH): |
|
with open(OUTPUT_JSON_PATH, "r") as f: |
|
data = json.load(f) |
|
else: |
|
data = [] |
|
except Exception as e: |
|
logger.error(f"Error reading JSON file: {str(e)}") |
|
data = [] |
|
|
|
data.append(result) |
|
try: |
|
with open(OUTPUT_JSON_PATH, "w") as f: |
|
json.dump(data, f, indent=4) |
|
logger.info(f"Saved result to {OUTPUT_JSON_PATH}") |
|
except Exception as e: |
|
logger.error(f"Error writing to JSON file: {str(e)}") |
|
|
|
|
|
def generate_caption(input_image: Image.Image, caption_type: str = "descriptive", caption_length: str = "medium") -> str: |
|
if input_image is None: |
|
error_msg = "Please upload an image." |
|
save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg) |
|
return error_msg |
|
|
|
|
|
image_name = f"image_{uuid.uuid4().hex}.jpg" |
|
|
|
try: |
|
|
|
input_image = input_image.resize((256, 256)) |
|
|
|
|
|
prompt = f"Write a {caption_length} {caption_type} caption for this image." |
|
convo = [ |
|
{ |
|
"role": "system", |
|
"content": "You are a helpful assistant that generates accurate and relevant image captions." |
|
}, |
|
{ |
|
"role": "user", |
|
"content": prompt.strip() |
|
} |
|
] |
|
|
|
|
|
inputs = processor(images=input_image, text=convo[1]["content"], return_tensors="pt").to("cpu") |
|
|
|
|
|
with torch.no_grad(): |
|
output = model.generate(**inputs, max_new_tokens=50, temperature=0.7, top_p=0.9) |
|
|
|
|
|
caption = processor.decode(output[0], skip_special_tokens=True).strip() |
|
|
|
|
|
save_to_json(image_name, caption, caption_type, caption_length, error=None) |
|
return caption |
|
except Exception as e: |
|
error_msg = f"Error generating caption: {str(e)}" |
|
logger.error(error_msg) |
|
save_to_json(image_name, "", caption_type, caption_length, error=error_msg) |
|
return error_msg |
|
|
|
|
|
interface = gr.Interface( |
|
fn=generate_caption, |
|
inputs=[ |
|
gr.Image(label="Upload Image", type="pil"), |
|
gr.Dropdown(choices=["descriptive", "casual", "social media"], label="Caption Type", value="descriptive"), |
|
gr.Dropdown(choices=["short", "medium", "long"], label="Caption Length", value="medium") |
|
], |
|
outputs=gr.Textbox(label="Generated Caption"), |
|
title="Image Captioning with JoyCaption", |
|
description="Upload an image to generate a caption using the fancyfeast/llama-joycaption-beta-one-hf-llava model. Results are saved to captions.json." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |