Spaces:
Running
Running
File size: 10,487 Bytes
3cf9fa0 401b092 09ecaf7 1686de5 3cf9fa0 09ecaf7 d7291ef ba5edb0 ab3b988 d7291ef 09ecaf7 d7291ef f503159 d7291ef 3cf9fa0 1686de5 d7291ef d25db6b 5778774 fe5d98f d25db6b 5778774 3cf9fa0 4f6cbcc 3cf9fa0 4f6cbcc 3cf9fa0 5778774 fe5d98f 3cf9fa0 5778774 3cf9fa0 d25db6b 351d460 f359373 779c5c3 5778774 3cf9fa0 5778774 d7291ef 5778774 fe5d98f d7291ef ba5edb0 d7291ef 5778774 351d460 5778774 351d460 3cf9fa0 5778774 3cf9fa0 5778774 3cf9fa0 5778774 d25db6b 5778774 f503159 d7291ef fe5d98f d7291ef 65933cd f503159 1686de5 f503159 65933cd 1686de5 f503159 401b092 1686de5 f503159 65933cd c57d64b 3cf9fa0 f503159 65933cd 1686de5 f503159 1686de5 f503159 1686de5 f503159 65933cd f503159 65933cd f503159 65933cd 23d1df7 f503159 d7291ef f503159 d7291ef f503159 d7291ef f503159 1686de5 23d1df7 65933cd f503159 3cf9fa0 f503159 1686de5 f503159 1686de5 f503159 1686de5 f503159 1686de5 |
|
# py_backend/app/routers/caption.py
from fastapi import APIRouter, HTTPException, Depends, Form, Request
from sqlalchemy.orm import Session
from typing import List
from .. import crud, database, schemas, storage
from ..services.vlm_service import vlm_manager
from ..services.schema_validator import schema_validator
from ..config import settings
router = APIRouter()
def get_db():
db = database.SessionLocal()
try:
yield db
finally:
db.close()
@router.post(
"/images/{image_id}/caption",
response_model=schemas.CaptionOut,
)
async def create_caption(
image_id: str,
title: str = Form(...),
prompt: str = Form(None), # optional; will use active prompts if not provided
model_name: str | None = Form(None),
db: Session = Depends(get_db),
):
print(f"DEBUG: Received request - image_id: {image_id}, title: {title}, prompt: {prompt}, model_name: {model_name}")
img = crud.get_image(db, image_id)
if not img:
raise HTTPException(404, "image not found")
# Get the prompt (explicit by code/label, or active for image type)
if prompt:
print(f"Looking for prompt: '{prompt}' (type: {type(prompt)})")
prompt_obj = crud.get_prompt(db, prompt) or crud.get_prompt_by_label(db, prompt)
else:
print(f"Looking for active prompt for image type: {img.image_type}")
prompt_obj = crud.get_active_prompt_by_image_type(db, img.image_type)
print(f"Prompt lookup result: {prompt_obj}")
if not prompt_obj:
raise HTTPException(400, f"No prompt found (requested: '{prompt}' or active for type '{img.image_type}')")
prompt_text = prompt_obj.label
metadata_instructions = prompt_obj.metadata_instructions or ""
print(f"Using prompt text: '{prompt_text}'")
print(f"Using metadata instructions: '{metadata_instructions[:100]}...'")
# Load image bytes (S3 or local)
try:
print(f"DEBUG: About to call VLM service with model_name: {model_name}")
if hasattr(storage, 's3') and settings.STORAGE_PROVIDER != "local":
response = storage.s3.get_object(
Bucket=settings.S3_BUCKET,
Key=img.file_key,
)
img_bytes = response["Body"].read()
else:
import os
file_path = os.path.join(settings.STORAGE_DIR, img.file_key)
with open(file_path, 'rb') as f:
img_bytes = f.read()
except Exception as e:
print(f"Error reading image file: {e}")
# fallback: try presigned/public URL
try:
url = storage.get_object_url(img.file_key)
if url.startswith('/') and settings.STORAGE_PROVIDER == "local":
url = f"http://localhost:8000{url}"
import requests
resp = requests.get(url)
resp.raise_for_status()
img_bytes = resp.content
except Exception as fallback_error:
print(f"Fallback also failed: {fallback_error}")
raise HTTPException(500, f"Could not read image file: {e}")
metadata = {}
try:
result = await vlm_manager.generate_caption(
image_bytes=img_bytes,
prompt=prompt_text,
metadata_instructions=metadata_instructions,
model_name=model_name,
db_session=db,
)
print(f"DEBUG: VLM service result: {result}")
print(f"DEBUG: Result model field: {result.get('model', 'NOT_FOUND')}")
raw = result.get("raw_response", {})
# Validate and clean the data using schema validation
image_type = img.image_type
print(f"DEBUG: Validating data for image type: {image_type}")
print(f"DEBUG: Raw data structure: {list(raw.keys()) if isinstance(raw, dict) else 'Not a dict'}")
cleaned_data, is_valid, validation_error = schema_validator.clean_and_validate_data(raw, image_type)
if is_valid:
print(f"✓ Schema validation passed for {image_type}")
text = cleaned_data.get("analysis", "")
metadata = cleaned_data.get("metadata", {})
else:
print(f"⚠ Schema validation failed for {image_type}: {validation_error}")
text = result.get("caption", "This is a fallback caption due to schema validation error.")
metadata = result.get("metadata", {})
raw["validation_error"] = validation_error
raw["validation_failed"] = True
used_model = result.get("model", model_name) or "STUB_MODEL"
if used_model == "random":
print(f"WARNING: VLM service returned 'random' as model name, using STUB_MODEL fallback")
used_model = "STUB_MODEL"
# Fallback info (if any)
if result.get("fallback_used"):
raw["fallback_info"] = {
"original_model": result.get("original_model"),
"fallback_model": used_model,
"reason": result.get("fallback_reason"),
}
except Exception as e:
print(f"VLM error, using fallback: {e}")
text = "This is a fallback caption due to VLM service error."
used_model = "STUB_MODEL"
raw = {"error": str(e), "fallback": True}
metadata = {}
caption = crud.create_caption(
db,
image_id=image_id,
title=title,
prompt=prompt_obj.p_code,
model_code=used_model,
raw_json=raw,
text=text,
metadata=metadata,
)
db.refresh(caption)
print(f"DEBUG: Caption created, caption object: {caption}")
print(f"DEBUG: caption_id: {caption.caption_id}")
return schemas.CaptionOut.from_orm(caption)
@router.get(
"/captions/legacy",
response_model=List[schemas.ImageOut],
)
def get_all_captions_legacy_format(
request: Request,
db: Session = Depends(get_db),
):
"""Get all images with captions in the old format for backward compatibility"""
print(f"DEBUG: Fetching all captions in legacy format...")
captions = crud.get_all_captions_with_images(db)
print(f"DEBUG: Found {len(captions)} captions")
result = []
for caption in captions:
db.refresh(caption)
if caption.images:
for image in caption.images:
from .upload import convert_image_to_dict
base_url = str(request.base_url).rstrip('/')
url = f"{base_url}/api/images/{image.image_id}/file"
print(f"DEBUG: Generated image URL: {url}")
img_dict = convert_image_to_dict(image, url)
# Overlay caption fields (legacy shape)
img_dict.update({
"title": caption.title,
"prompt": caption.prompt,
"model": caption.model,
"schema_id": caption.schema_id,
"raw_json": caption.raw_json,
"generated": caption.generated,
"edited": caption.edited,
"accuracy": caption.accuracy,
"context": caption.context,
"usability": caption.usability,
"starred": caption.starred,
"created_at": caption.created_at,
"updated_at": caption.updated_at,
})
result.append(schemas.ImageOut(**img_dict))
print(f"DEBUG: Returning {len(result)} legacy format results")
return result
@router.get(
"/captions",
response_model=List[schemas.CaptionOut],
)
def get_all_captions_with_images(
db: Session = Depends(get_db),
):
"""Get all captions"""
print(f"DEBUG: Fetching all captions...")
captions = crud.get_all_captions_with_images(db)
print(f"DEBUG: Found {len(captions)} captions")
result = []
for caption in captions:
print(f"DEBUG: Processing caption {caption.caption_id}, title: {caption.title}, generated: {caption.generated}, model: {caption.model}")
db.refresh(caption)
result.append(schemas.CaptionOut.from_orm(caption))
print(f"DEBUG: Returning {len(result)} formatted results")
return result
@router.get(
"/images/{image_id}/captions",
response_model=List[schemas.CaptionOut],
)
def get_captions_by_image(
image_id: str,
db: Session = Depends(get_db),
):
"""Get all captions for a specific image"""
captions = crud.get_captions_by_image(db, image_id)
result = []
for caption in captions:
db.refresh(caption)
result.append(schemas.CaptionOut.from_orm(caption))
return result
@router.get(
"/captions/{caption_id}",
response_model=schemas.CaptionOut,
)
def get_caption(
caption_id: str,
db: Session = Depends(get_db),
):
caption = crud.get_caption(db, caption_id)
if not caption:
raise HTTPException(404, "caption not found")
db.refresh(caption)
return schemas.CaptionOut.from_orm(caption)
@router.put(
"/captions/{caption_id}",
response_model=schemas.CaptionOut,
)
def update_caption(
caption_id: str,
update: schemas.CaptionUpdate,
db: Session = Depends(get_db),
):
caption = crud.update_caption(db, caption_id, update)
if not caption:
raise HTTPException(404, "caption not found")
db.refresh(caption)
return schemas.CaptionOut.from_orm(caption)
@router.put(
"/images/{image_id}/caption",
response_model=schemas.CaptionOut,
)
def update_caption_by_image(
image_id: str,
update: schemas.CaptionUpdate,
db: Session = Depends(get_db),
):
"""Update the first caption for an image (for backward compatibility)"""
img = crud.get_image(db, image_id)
if not img:
raise HTTPException(404, "image not found")
if not img.captions:
raise HTTPException(404, "no captions found for this image")
caption = crud.update_caption(db, str(img.captions[0].caption_id), update)
if not caption:
raise HTTPException(404, "caption not found")
db.refresh(caption)
return schemas.CaptionOut.from_orm(caption)
@router.delete(
"/captions/{caption_id}",
)
def delete_caption(
caption_id: str,
db: Session = Depends(get_db),
):
"""Delete caption data for a caption"""
success = crud.delete_caption(db, caption_id)
if not success:
raise HTTPException(404, "caption not found")
return {"message": "Caption deleted successfully"}
|