File size: 4,891 Bytes
07a6bfd
 
 
 
f5dc979
 
 
 
 
b031819
 
f5dc979
 
 
 
 
 
 
07a6bfd
b031819
07a6bfd
f5dc979
b031819
 
 
 
 
 
 
 
 
 
 
 
f5dc979
 
 
ead87b5
b031819
 
f5dc979
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07a6bfd
 
 
 
f5dc979
 
 
07a6bfd
ead87b5
f5dc979
07a6bfd
f5dc979
 
f40c282
f5dc979
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ead87b5
f5dc979
ead87b5
f5dc979
 
 
 
 
 
 
 
 
 
 
 
07a6bfd
 
 
 
 
 
 
 
 
 
 
f5dc979
8f189bc
07a6bfd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import torch
from PIL import Image
from transformers import LlavaForConditionalGeneration, AutoProcessor
import logging
import json
import os
from datetime import datetime
import uuid
from huggingface_hub import snapshot_download
import shutil

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Define output JSON file path
OUTPUT_JSON_PATH = "captions.json"

# Clear Hugging Face cache and download model
MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
try:
    # Clear cache to avoid corrupted files
    cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
    model_cache = os.path.join(cache_dir, f"models--{MODEL_PATH.replace('/', '--')}")
    if os.path.exists(model_cache):
        shutil.rmtree(model_cache)
        logger.info(f"Cleared cache for {MODEL_PATH}")

    # Pre-download model to ensure integrity
    snapshot_download(repo_id=MODEL_PATH)
    logger.info(f"Downloaded model {MODEL_PATH}")

    # Load processor and model
    processor = AutoProcessor.from_pretrained(MODEL_PATH)
    model = LlavaForConditionalGeneration.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.float32,  # CPU-compatible dtype
        low_cpu_mem_usage=True,    # Minimize memory usage
        use_safetensors=True       # Force safetensors
    ).to("cpu")
    model.eval()
    logger.info("Model and processor loaded successfully.")
except Exception as e:
    logger.error(f"Error loading model: {str(e)}")
    raise

# Function to append results to JSON file
def save_to_json(image_name, caption, caption_type, caption_length, error=None):
    result = {
        "image_name": image_name,
        "caption": caption,
        "caption_type": caption_type,
        "caption_length": caption_length,
        "timestamp": datetime.now().isoformat(),
        "error": error
    }
    try:
        if os.path.exists(OUTPUT_JSON_PATH):
            with open(OUTPUT_JSON_PATH, "r") as f:
                data = json.load(f)
        else:
            data = []
    except Exception as e:
        logger.error(f"Error reading JSON file: {str(e)}")
        data = []
    
    data.append(result)
    try:
        with open(OUTPUT_JSON_PATH, "w") as f:
            json.dump(data, f, indent=4)
        logger.info(f"Saved result to {OUTPUT_JSON_PATH}")
    except Exception as e:
        logger.error(f"Error writing to JSON file: {str(e)}")

# Define the captioning function
def generate_caption(input_image: Image.Image, caption_type: str = "descriptive", caption_length: str = "medium") -> str:
    if input_image is None:
        error_msg = "Please upload an image."
        save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg)
        return error_msg
    
    # Generate a unique image name
    image_name = f"image_{uuid.uuid4().hex}.jpg"
    
    try:
        # Resize image to reduce memory usage
        input_image = input_image.resize((256, 256))
        
        # Prepare the prompt
        prompt = f"Write a {caption_length} {caption_type} caption for this image."
        convo = [
            {
                "role": "system",
                "content": "You are a helpful assistant that generates accurate and relevant image captions."
            },
            {
                "role": "user",
                "content": prompt.strip()
            }
        ]
        
        # Process the image and prompt
        inputs = processor(images=input_image, text=convo[1]["content"], return_tensors="pt").to("cpu")
        
        # Generate the caption with reduced max tokens
        with torch.no_grad():
            output = model.generate(**inputs, max_new_tokens=50, temperature=0.7, top_p=0.9)
        
        # Decode the output
        caption = processor.decode(output[0], skip_special_tokens=True).strip()
        
        # Save to JSON
        save_to_json(image_name, caption, caption_type, caption_length, error=None)
        return caption
    except Exception as e:
        error_msg = f"Error generating caption: {str(e)}"
        logger.error(error_msg)
        save_to_json(image_name, "", caption_type, caption_length, error=error_msg)
        return error_msg

# Create the Gradio interface
interface = gr.Interface(
    fn=generate_caption,
    inputs=[
        gr.Image(label="Upload Image", type="pil"),
        gr.Dropdown(choices=["descriptive", "casual", "social media"], label="Caption Type", value="descriptive"),
        gr.Dropdown(choices=["short", "medium", "long"], label="Caption Length", value="medium")
    ],
    outputs=gr.Textbox(label="Generated Caption"),
    title="Image Captioning with JoyCaption",
    description="Upload an image to generate a caption using the fancyfeast/llama-joycaption-beta-one-hf-llava model. Results are saved to captions.json."
)

if __name__ == "__main__":
    interface.launch()