File size: 7,625 Bytes
e5d40e3
 
 
 
 
 
5b58ac7
 
 
2e2a7bf
 
 
 
e5d40e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e2a7bf
 
 
 
 
 
 
 
 
 
 
 
e5d40e3
5b58ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e2a7bf
 
 
5b58ac7
e5d40e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b58ac7
2e2a7bf
 
 
5b58ac7
2e2a7bf
 
 
 
 
 
 
 
 
 
 
e5d40e3
 
5b58ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5d40e3
 
2e2a7bf
e5d40e3
2e2a7bf
 
 
e5d40e3
2e2a7bf
 
 
 
 
 
 
 
 
 
 
 
 
 
e5d40e3
2e2a7bf
e5d40e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b58ac7
e5d40e3
 
 
2e2a7bf
 
 
 
 
e5d40e3
2e2a7bf
 
e5d40e3
 
 
 
2e2a7bf
 
 
 
 
 
e5d40e3
2e2a7bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5d40e3
 
 
2e2a7bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
"""
Gradio interface for the LLaVA model.
"""

import gradio as gr
from PIL import Image
import os
import tempfile
import torch
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import traceback
import sys

from ..configs.settings import (
    GRADIO_THEME,
    GRADIO_TITLE,
    GRADIO_DESCRIPTION,
    DEFAULT_MAX_NEW_TOKENS,
    DEFAULT_TEMPERATURE,
    DEFAULT_TOP_P,
    API_HOST,
    API_PORT,
    API_WORKERS,
    API_RELOAD
)
from ..models.llava_model import LLaVAModel
from ..utils.logging import setup_logging, get_logger

# Set up logging
setup_logging()
logger = get_logger(__name__)

# Initialize FastAPI app
app = FastAPI(title="LLaVA Web Interface")

# Configure CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize model
model = None

def initialize_model():
    global model
    try:
        logger.info("Initializing LLaVA model...")
        # Use a smaller model variant and enable memory optimizations
        model = LLaVAModel(
            vision_model_path="openai/clip-vit-base-patch32",  # Smaller vision model
            language_model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",  # Smaller language model
            device="cpu",  # Force CPU for Hugging Face Spaces
            projection_hidden_dim=2048  # Reduce projection layer size
        )
        
        # Enable memory optimizations
        torch.cuda.empty_cache()  # Clear any cached memory
        if hasattr(model, 'language_model'):
            model.language_model.config.use_cache = False  # Disable KV cache
        
        logger.info(f"Model initialized on {model.device}")
        return True
    except Exception as e:
        error_msg = f"Error initializing model: {str(e)}\n{traceback.format_exc()}"
        logger.error(error_msg)
        print(error_msg, file=sys.stderr)
        return False

def process_image(
    image: Image.Image,
    prompt: str,
    max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
    temperature: float = DEFAULT_TEMPERATURE,
    top_p: float = DEFAULT_TOP_P
) -> str:
    """
    Process an image with the LLaVA model.
    
    Args:
        image: Input image
        prompt: Text prompt
        max_new_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature
        top_p: Top-p sampling parameter
        
    Returns:
        str: Model response
    """
    if not model:
        error_msg = "Error: Model not initialized"
        logger.error(error_msg)
        return error_msg
    
    if image is None:
        error_msg = "Error: No image provided"
        logger.error(error_msg)
        return error_msg
    
    if not prompt or not prompt.strip():
        error_msg = "Error: No prompt provided"
        logger.error(error_msg)
        return error_msg
    
    temp_path = None
    try:
        logger.info(f"Processing image with prompt: {prompt[:100]}...")
        
        # Save the uploaded image temporarily
        with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
            image.save(temp_file.name)
            temp_path = temp_file.name

        # Clear memory before processing
        torch.cuda.empty_cache()
        
        # Generate response with reduced memory usage
        with torch.inference_mode():  # More memory efficient than no_grad
            response = model(
                image=image,
                prompt=prompt,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p
            )

        logger.info("Successfully generated response")
        return response

    except Exception as e:
        error_msg = f"Error processing image: {str(e)}\n{traceback.format_exc()}"
        logger.error(error_msg)
        print(error_msg, file=sys.stderr)
        return f"Error: {str(e)}"
    
    finally:
        # Clean up temporary file
        if temp_path and os.path.exists(temp_path):
            try:
                os.unlink(temp_path)
            except Exception as e:
                logger.warning(f"Error cleaning up temporary file: {str(e)}")
        
        # Clear memory after processing
        try:
            torch.cuda.empty_cache()
        except Exception as e:
            logger.warning(f"Error clearing CUDA cache: {str(e)}")

def create_interface() -> gr.Blocks:
    """Create and return the Gradio interface."""
    with gr.Blocks(theme=GRADIO_THEME) as interface:
        gr.Markdown(f"""# {GRADIO_TITLE}

{GRADIO_DESCRIPTION}

## Example Prompts

Try these prompts to get started:
- "What can you see in this image?"
- "Describe this scene in detail"
- "What emotions does this image convey?"
- "What's happening in this picture?"
- "Can you identify any objects or people in this image?"

## Usage Instructions

1. Upload an image using the image uploader
2. Enter your prompt in the text box
3. (Optional) Adjust the generation parameters
4. Click "Generate Response" to get LLaVA's analysis
""")
        
        with gr.Row():
            with gr.Column():
                # Input components
                image_input = gr.Image(type="pil", label="Upload Image")
                prompt_input = gr.Textbox(
                    label="Prompt",
                    placeholder="What can you see in this image?",
                    lines=3
                )
                
                with gr.Accordion("Generation Parameters", open=False):
                    max_tokens = gr.Slider(
                        minimum=64,
                        maximum=2048,
                        value=DEFAULT_MAX_NEW_TOKENS,
                        step=64,
                        label="Max New Tokens"
                    )
                    temperature = gr.Slider(
                        minimum=0.1,
                        maximum=1.0,
                        value=DEFAULT_TEMPERATURE,
                        step=0.1,
                        label="Temperature"
                    )
                    top_p = gr.Slider(
                        minimum=0.1,
                        maximum=1.0,
                        value=DEFAULT_TOP_P,
                        step=0.1,
                        label="Top P"
                    )
                
                generate_btn = gr.Button("Generate Response", variant="primary")
            
            with gr.Column():
                # Output component
                output = gr.Textbox(
                    label="Response",
                    lines=10,
                    show_copy_button=True
                )
        
        # Set up event handlers with explicit types
        generate_btn.click(
            fn=process_image,
            inputs=[
                image_input,
                prompt_input,
                max_tokens,
                temperature,
                top_p
            ],
            outputs=output,
            api_name="process_image"
        )
    
    return interface

# Create Gradio app
demo = create_interface()

# Mount Gradio app
app = gr.mount_gradio_app(app, demo, path="/")

def main():
    """Run the FastAPI application."""
    import uvicorn
    
    # Initialize model
    if not initialize_model():
        logger.error("Failed to initialize model. Exiting...")
        sys.exit(1)
    
    # Start the server
    uvicorn.run(
        app,
        host=API_HOST,
        port=API_PORT,
        workers=API_WORKERS,
        reload=API_RELOAD,
        log_level="info"
    )

if __name__ == "__main__":
    main()