Spaces:
Sleeping
Sleeping
""" | |
Gradio interface for the LLaVA model. | |
""" | |
import gradio as gr | |
from PIL import Image | |
import os | |
import tempfile | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import traceback | |
import sys | |
from ..configs.settings import ( | |
GRADIO_THEME, | |
GRADIO_TITLE, | |
GRADIO_DESCRIPTION, | |
DEFAULT_MAX_NEW_TOKENS, | |
DEFAULT_TEMPERATURE, | |
DEFAULT_TOP_P, | |
API_HOST, | |
API_PORT, | |
API_WORKERS, | |
API_RELOAD | |
) | |
from ..models.llava_model import LLaVAModel | |
from ..utils.logging import setup_logging, get_logger | |
# Set up logging | |
setup_logging() | |
logger = get_logger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI(title="LLaVA Web Interface") | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize model | |
model = None | |
def initialize_model(): | |
global model | |
try: | |
logger.info("Initializing LLaVA model...") | |
# Use a smaller model variant and enable memory optimizations | |
model = LLaVAModel( | |
vision_model_path="openai/clip-vit-base-patch32", # Smaller vision model | |
language_model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0", # Smaller language model | |
device="cpu", # Force CPU for Hugging Face Spaces | |
projection_hidden_dim=2048 # Reduce projection layer size | |
) | |
# Enable memory optimizations | |
torch.cuda.empty_cache() # Clear any cached memory | |
if hasattr(model, 'language_model'): | |
model.language_model.config.use_cache = False # Disable KV cache | |
logger.info(f"Model initialized on {model.device}") | |
return True | |
except Exception as e: | |
error_msg = f"Error initializing model: {str(e)}\n{traceback.format_exc()}" | |
logger.error(error_msg) | |
print(error_msg, file=sys.stderr) | |
return False | |
def process_image( | |
image: Image.Image, | |
prompt: str, | |
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, | |
temperature: float = DEFAULT_TEMPERATURE, | |
top_p: float = DEFAULT_TOP_P | |
) -> str: | |
""" | |
Process an image with the LLaVA model. | |
Args: | |
image: Input image | |
prompt: Text prompt | |
max_new_tokens: Maximum number of tokens to generate | |
temperature: Sampling temperature | |
top_p: Top-p sampling parameter | |
Returns: | |
str: Model response | |
""" | |
if not model: | |
error_msg = "Error: Model not initialized" | |
logger.error(error_msg) | |
return error_msg | |
if image is None: | |
error_msg = "Error: No image provided" | |
logger.error(error_msg) | |
return error_msg | |
if not prompt or not prompt.strip(): | |
error_msg = "Error: No prompt provided" | |
logger.error(error_msg) | |
return error_msg | |
temp_path = None | |
try: | |
logger.info(f"Processing image with prompt: {prompt[:100]}...") | |
# Save the uploaded image temporarily | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file: | |
image.save(temp_file.name) | |
temp_path = temp_file.name | |
# Clear memory before processing | |
torch.cuda.empty_cache() | |
# Generate response with reduced memory usage | |
with torch.inference_mode(): # More memory efficient than no_grad | |
response = model( | |
image=image, | |
prompt=prompt, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p | |
) | |
logger.info("Successfully generated response") | |
return response | |
except Exception as e: | |
error_msg = f"Error processing image: {str(e)}\n{traceback.format_exc()}" | |
logger.error(error_msg) | |
print(error_msg, file=sys.stderr) | |
return f"Error: {str(e)}" | |
finally: | |
# Clean up temporary file | |
if temp_path and os.path.exists(temp_path): | |
try: | |
os.unlink(temp_path) | |
except Exception as e: | |
logger.warning(f"Error cleaning up temporary file: {str(e)}") | |
# Clear memory after processing | |
try: | |
torch.cuda.empty_cache() | |
except Exception as e: | |
logger.warning(f"Error clearing CUDA cache: {str(e)}") | |
def create_interface() -> gr.Blocks: | |
"""Create and return the Gradio interface.""" | |
with gr.Blocks(theme=GRADIO_THEME) as interface: | |
gr.Markdown(f"""# {GRADIO_TITLE} | |
{GRADIO_DESCRIPTION} | |
## Example Prompts | |
Try these prompts to get started: | |
- "What can you see in this image?" | |
- "Describe this scene in detail" | |
- "What emotions does this image convey?" | |
- "What's happening in this picture?" | |
- "Can you identify any objects or people in this image?" | |
## Usage Instructions | |
1. Upload an image using the image uploader | |
2. Enter your prompt in the text box | |
3. (Optional) Adjust the generation parameters | |
4. Click "Generate Response" to get LLaVA's analysis | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
# Input components | |
image_input = gr.Image(type="pil", label="Upload Image") | |
prompt_input = gr.Textbox( | |
label="Prompt", | |
placeholder="What can you see in this image?", | |
lines=3 | |
) | |
with gr.Accordion("Generation Parameters", open=False): | |
max_tokens = gr.Slider( | |
minimum=64, | |
maximum=2048, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
step=64, | |
label="Max New Tokens" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=DEFAULT_TEMPERATURE, | |
step=0.1, | |
label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=DEFAULT_TOP_P, | |
step=0.1, | |
label="Top P" | |
) | |
generate_btn = gr.Button("Generate Response", variant="primary") | |
with gr.Column(): | |
# Output component | |
output = gr.Textbox( | |
label="Response", | |
lines=10, | |
show_copy_button=True | |
) | |
# Set up event handlers with explicit types | |
generate_btn.click( | |
fn=process_image, | |
inputs=[ | |
image_input, | |
prompt_input, | |
max_tokens, | |
temperature, | |
top_p | |
], | |
outputs=output, | |
api_name="process_image" | |
) | |
return interface | |
# Create Gradio app | |
demo = create_interface() | |
# Mount Gradio app | |
app = gr.mount_gradio_app(app, demo, path="/") | |
def main(): | |
"""Run the FastAPI application.""" | |
import uvicorn | |
# Initialize model | |
if not initialize_model(): | |
logger.error("Failed to initialize model. Exiting...") | |
sys.exit(1) | |
# Start the server | |
uvicorn.run( | |
app, | |
host=API_HOST, | |
port=API_PORT, | |
workers=API_WORKERS, | |
reload=API_RELOAD, | |
log_level="info" | |
) | |
if __name__ == "__main__": | |
main() |