Spaces:
Running
Running
# py_backend/app/routers/caption.py | |
from fastapi import APIRouter, HTTPException, Depends, Form, Request | |
from sqlalchemy.orm import Session | |
from typing import List | |
from .. import crud, database, schemas, storage | |
from ..services.vlm_service import vlm_manager | |
from ..services.schema_validator import schema_validator | |
from ..config import settings | |
router = APIRouter() | |
def get_db(): | |
db = database.SessionLocal() | |
try: | |
yield db | |
finally: | |
db.close() | |
async def create_caption( | |
image_id: str, | |
title: str = Form(...), | |
prompt: str = Form(None), # optional; will use active prompts if not provided | |
model_name: str | None = Form(None), | |
db: Session = Depends(get_db), | |
): | |
print(f"DEBUG: Received request - image_id: {image_id}, title: {title}, prompt: {prompt}, model_name: {model_name}") | |
img = crud.get_image(db, image_id) | |
if not img: | |
raise HTTPException(404, "image not found") | |
# Get the prompt (explicit by code/label, or active for image type) | |
if prompt: | |
print(f"Looking for prompt: '{prompt}' (type: {type(prompt)})") | |
prompt_obj = crud.get_prompt(db, prompt) or crud.get_prompt_by_label(db, prompt) | |
else: | |
print(f"Looking for active prompt for image type: {img.image_type}") | |
prompt_obj = crud.get_active_prompt_by_image_type(db, img.image_type) | |
print(f"Prompt lookup result: {prompt_obj}") | |
if not prompt_obj: | |
raise HTTPException(400, f"No prompt found (requested: '{prompt}' or active for type '{img.image_type}')") | |
prompt_text = prompt_obj.label | |
metadata_instructions = prompt_obj.metadata_instructions or "" | |
print(f"Using prompt text: '{prompt_text}'") | |
print(f"Using metadata instructions: '{metadata_instructions[:100]}...'") | |
# Load image bytes (S3 or local) | |
try: | |
print(f"DEBUG: About to call VLM service with model_name: {model_name}") | |
if hasattr(storage, 's3') and settings.STORAGE_PROVIDER != "local": | |
response = storage.s3.get_object( | |
Bucket=settings.S3_BUCKET, | |
Key=img.file_key, | |
) | |
img_bytes = response["Body"].read() | |
else: | |
import os | |
file_path = os.path.join(settings.STORAGE_DIR, img.file_key) | |
with open(file_path, 'rb') as f: | |
img_bytes = f.read() | |
except Exception as e: | |
print(f"Error reading image file: {e}") | |
# fallback: try presigned/public URL | |
try: | |
url = storage.get_object_url(img.file_key) | |
if url.startswith('/') and settings.STORAGE_PROVIDER == "local": | |
url = f"http://localhost:8000{url}" | |
import requests | |
resp = requests.get(url) | |
resp.raise_for_status() | |
img_bytes = resp.content | |
except Exception as fallback_error: | |
print(f"Fallback also failed: {fallback_error}") | |
raise HTTPException(500, f"Could not read image file: {e}") | |
metadata = {} | |
try: | |
result = await vlm_manager.generate_caption( | |
image_bytes=img_bytes, | |
prompt=prompt_text, | |
metadata_instructions=metadata_instructions, | |
model_name=model_name, | |
db_session=db, | |
) | |
print(f"DEBUG: VLM service result: {result}") | |
print(f"DEBUG: Result model field: {result.get('model', 'NOT_FOUND')}") | |
raw = result.get("raw_response", {}) | |
# Validate and clean the data using schema validation | |
image_type = img.image_type | |
print(f"DEBUG: Validating data for image type: {image_type}") | |
print(f"DEBUG: Raw data structure: {list(raw.keys()) if isinstance(raw, dict) else 'Not a dict'}") | |
cleaned_data, is_valid, validation_error = schema_validator.clean_and_validate_data(raw, image_type) | |
if is_valid: | |
print(f"✓ Schema validation passed for {image_type}") | |
text = cleaned_data.get("analysis", "") | |
metadata = cleaned_data.get("metadata", {}) | |
else: | |
print(f"⚠ Schema validation failed for {image_type}: {validation_error}") | |
text = result.get("caption", "This is a fallback caption due to schema validation error.") | |
metadata = result.get("metadata", {}) | |
raw["validation_error"] = validation_error | |
raw["validation_failed"] = True | |
used_model = result.get("model", model_name) or "STUB_MODEL" | |
if used_model == "random": | |
print(f"WARNING: VLM service returned 'random' as model name, using STUB_MODEL fallback") | |
used_model = "STUB_MODEL" | |
# Fallback info (if any) | |
if result.get("fallback_used"): | |
raw["fallback_info"] = { | |
"original_model": result.get("original_model"), | |
"fallback_model": used_model, | |
"reason": result.get("fallback_reason"), | |
} | |
except Exception as e: | |
print(f"VLM error, using fallback: {e}") | |
text = "This is a fallback caption due to VLM service error." | |
used_model = "STUB_MODEL" | |
raw = {"error": str(e), "fallback": True} | |
metadata = {} | |
caption = crud.create_caption( | |
db, | |
image_id=image_id, | |
title=title, | |
prompt=prompt_obj.p_code, | |
model_code=used_model, | |
raw_json=raw, | |
text=text, | |
metadata=metadata, | |
) | |
db.refresh(caption) | |
print(f"DEBUG: Caption created, caption object: {caption}") | |
print(f"DEBUG: caption_id: {caption.caption_id}") | |
return schemas.CaptionOut.from_orm(caption) | |
def get_all_captions_legacy_format( | |
request: Request, | |
db: Session = Depends(get_db), | |
): | |
"""Get all images with captions in the old format for backward compatibility""" | |
print(f"DEBUG: Fetching all captions in legacy format...") | |
captions = crud.get_all_captions_with_images(db) | |
print(f"DEBUG: Found {len(captions)} captions") | |
result = [] | |
for caption in captions: | |
db.refresh(caption) | |
if caption.images: | |
for image in caption.images: | |
from .upload import convert_image_to_dict | |
base_url = str(request.base_url).rstrip('/') | |
url = f"{base_url}/api/images/{image.image_id}/file" | |
print(f"DEBUG: Generated image URL: {url}") | |
img_dict = convert_image_to_dict(image, url) | |
# Overlay caption fields (legacy shape) | |
img_dict.update({ | |
"title": caption.title, | |
"prompt": caption.prompt, | |
"model": caption.model, | |
"schema_id": caption.schema_id, | |
"raw_json": caption.raw_json, | |
"generated": caption.generated, | |
"edited": caption.edited, | |
"accuracy": caption.accuracy, | |
"context": caption.context, | |
"usability": caption.usability, | |
"starred": caption.starred, | |
"created_at": caption.created_at, | |
"updated_at": caption.updated_at, | |
}) | |
result.append(schemas.ImageOut(**img_dict)) | |
print(f"DEBUG: Returning {len(result)} legacy format results") | |
return result | |
def get_all_captions_with_images( | |
db: Session = Depends(get_db), | |
): | |
"""Get all captions""" | |
print(f"DEBUG: Fetching all captions...") | |
captions = crud.get_all_captions_with_images(db) | |
print(f"DEBUG: Found {len(captions)} captions") | |
result = [] | |
for caption in captions: | |
print(f"DEBUG: Processing caption {caption.caption_id}, title: {caption.title}, generated: {caption.generated}, model: {caption.model}") | |
db.refresh(caption) | |
result.append(schemas.CaptionOut.from_orm(caption)) | |
print(f"DEBUG: Returning {len(result)} formatted results") | |
return result | |
def get_captions_by_image( | |
image_id: str, | |
db: Session = Depends(get_db), | |
): | |
"""Get all captions for a specific image""" | |
captions = crud.get_captions_by_image(db, image_id) | |
result = [] | |
for caption in captions: | |
db.refresh(caption) | |
result.append(schemas.CaptionOut.from_orm(caption)) | |
return result | |
def get_caption( | |
caption_id: str, | |
db: Session = Depends(get_db), | |
): | |
caption = crud.get_caption(db, caption_id) | |
if not caption: | |
raise HTTPException(404, "caption not found") | |
db.refresh(caption) | |
return schemas.CaptionOut.from_orm(caption) | |
def update_caption( | |
caption_id: str, | |
update: schemas.CaptionUpdate, | |
db: Session = Depends(get_db), | |
): | |
caption = crud.update_caption(db, caption_id, update) | |
if not caption: | |
raise HTTPException(404, "caption not found") | |
db.refresh(caption) | |
return schemas.CaptionOut.from_orm(caption) | |
def update_caption_by_image( | |
image_id: str, | |
update: schemas.CaptionUpdate, | |
db: Session = Depends(get_db), | |
): | |
"""Update the first caption for an image (for backward compatibility)""" | |
img = crud.get_image(db, image_id) | |
if not img: | |
raise HTTPException(404, "image not found") | |
if not img.captions: | |
raise HTTPException(404, "no captions found for this image") | |
caption = crud.update_caption(db, str(img.captions[0].caption_id), update) | |
if not caption: | |
raise HTTPException(404, "caption not found") | |
db.refresh(caption) | |
return schemas.CaptionOut.from_orm(caption) | |
def delete_caption( | |
caption_id: str, | |
db: Session = Depends(get_db), | |
): | |
"""Delete caption data for a caption""" | |
success = crud.delete_caption(db, caption_id) | |
if not success: | |
raise HTTPException(404, "caption not found") | |
return {"message": "Caption deleted successfully"} | |