Spaces:
Running
Running
from fastapi import APIRouter, UploadFile, Form, Depends, HTTPException, Response | |
from pydantic import BaseModel | |
import io | |
from sqlalchemy.orm import Session | |
from .. import crud, schemas, storage, database | |
from ..config import settings | |
from ..services.image_preprocessor import ImagePreprocessor | |
from ..services.thumbnail_service import ImageProcessingService | |
from typing import List, Optional | |
import boto3 | |
import time | |
import base64 | |
import datetime | |
router = APIRouter() | |
class CopyImageRequest(BaseModel): | |
source_image_id: str | |
source: str | |
event_type: str | |
countries: str = "" | |
epsg: str = "" | |
image_type: str = "crisis_map" | |
# Drone-specific fields (optional) | |
center_lon: Optional[float] = None | |
center_lat: Optional[float] = None | |
amsl_m: Optional[float] = None | |
agl_m: Optional[float] = None | |
heading_deg: Optional[float] = None | |
yaw_deg: Optional[float] = None | |
pitch_deg: Optional[float] = None | |
roll_deg: Optional[float] = None | |
rtk_fix: Optional[bool] = None | |
std_h_m: Optional[float] = None | |
std_v_m: Optional[float] = None | |
def get_db(): | |
db = database.SessionLocal() | |
try: | |
yield db | |
finally: | |
db.close() | |
def convert_image_to_dict(img, image_url): | |
"""Helper function to convert SQLAlchemy image model to dict for Pydantic""" | |
countries_list = [] | |
if hasattr(img, 'countries') and img.countries is not None: | |
try: | |
countries_list = [{"c_code": c.c_code, "label": c.label, "r_code": c.r_code} for c in img.countries] | |
except Exception as e: | |
print(f"Warning: Error processing countries for image {img.image_id}: {e}") | |
countries_list = [] | |
captions_list = [] | |
if hasattr(img, 'captions') and img.captions is not None: | |
try: | |
captions_list = [ | |
{ | |
"caption_id": c.caption_id, | |
"title": c.title, | |
"prompt": c.prompt, | |
"model": c.model, | |
"schema_id": c.schema_id, | |
"raw_json": c.raw_json, | |
"generated": c.generated, | |
"edited": c.edited, | |
"accuracy": c.accuracy, | |
"context": c.context, | |
"usability": c.usability, | |
"starred": c.starred if c.starred is not None else False, | |
"created_at": c.created_at, | |
"updated_at": c.updated_at | |
} for c in img.captions | |
] | |
except Exception as e: | |
print(f"Warning: Error processing captions for image {img.image_id}: {e}") | |
captions_list = [] | |
# Get starred status and other caption fields from first caption for backward compatibility | |
starred = False | |
title = None | |
prompt = None | |
model = None | |
schema_id = None | |
raw_json = None | |
generated = None | |
edited = None | |
accuracy = None | |
context = None | |
usability = None | |
created_at = None | |
updated_at = None | |
if captions_list: | |
first_caption = captions_list[0] | |
starred = first_caption.get("starred", False) | |
title = first_caption.get("title") | |
prompt = first_caption.get("prompt") | |
model = first_caption.get("model") | |
schema_id = first_caption.get("schema_id") | |
raw_json = first_caption.get("raw_json") | |
generated = first_caption.get("generated") | |
edited = first_caption.get("edited") | |
accuracy = first_caption.get("accuracy") | |
context = first_caption.get("context") | |
usability = first_caption.get("usability") | |
created_at = first_caption.get("created_at") | |
updated_at = first_caption.get("updated_at") | |
# Generate URLs for all image versions | |
thumbnail_url = None | |
detail_url = None | |
if hasattr(img, 'thumbnail_key') and img.thumbnail_key: | |
try: | |
thumbnail_url = storage.get_object_url(img.thumbnail_key) | |
except Exception as e: | |
print(f"Warning: Error generating thumbnail URL for image {img.image_id}: {e}") | |
if hasattr(img, 'detail_key') and img.detail_key: | |
try: | |
detail_url = storage.get_object_url(img.detail_key) | |
except Exception as e: | |
print(f"Warning: Error generating detail URL for image {img.image_id}: {e}") | |
img_dict = { | |
"image_id": img.image_id, | |
"file_key": img.file_key, | |
"sha256": img.sha256, | |
"thumbnail_key": getattr(img, 'thumbnail_key', None), | |
"thumbnail_sha256": getattr(img, 'thumbnail_sha256', None), | |
"thumbnail_url": thumbnail_url, | |
"detail_key": getattr(img, 'detail_key', None), | |
"detail_sha256": getattr(img, 'detail_sha256', None), | |
"detail_url": detail_url, | |
"source": img.source, | |
"event_type": img.event_type, | |
"epsg": img.epsg, | |
"image_type": img.image_type, | |
"image_url": image_url, | |
"countries": countries_list, | |
"captions": captions_list, | |
"starred": starred, # Backward compatibility | |
"captured_at": img.captured_at, | |
# Backward compatibility fields for legacy frontend | |
"title": title, | |
"prompt": prompt, | |
"model": model, | |
"schema_id": schema_id, | |
"raw_json": raw_json, | |
"generated": generated, | |
"edited": edited, | |
"accuracy": accuracy, | |
"context": context, | |
"usability": usability, | |
"created_at": created_at, | |
"updated_at": updated_at, | |
# Drone-specific fields | |
"center_lon": getattr(img, 'center_lon', None), | |
"center_lat": getattr(img, 'center_lat', None), | |
"amsl_m": getattr(img, 'amsl_m', None), | |
"agl_m": getattr(img, 'agl_m', None), | |
"heading_deg": getattr(img, 'heading_deg', None), | |
"yaw_deg": getattr(img, 'yaw_deg', None), | |
"pitch_deg": getattr(img, 'pitch_deg', None), | |
"roll_deg": getattr(img, 'roll_deg', None), | |
"rtk_fix": getattr(img, 'rtk_fix', None), | |
"std_h_m": getattr(img, 'std_h_m', None), | |
"std_v_m": getattr(img, 'std_v_m', None) | |
} | |
return img_dict | |
def list_images(db: Session = Depends(get_db)): | |
"""Get all images with their caption data""" | |
images = crud.get_images(db) | |
result = [] | |
for img in images: | |
img_dict = convert_image_to_dict(img, f"/api/images/{img.image_id}/file") | |
result.append(schemas.ImageOut(**img_dict)) | |
return result | |
def list_images_grouped( | |
page: int = 1, | |
limit: int = 10, | |
search: str = None, | |
source: str = None, | |
event_type: str = None, | |
region: str = None, | |
country: str = None, | |
image_type: str = None, | |
upload_type: str = None, | |
starred_only: bool = False, | |
db: Session = Depends(get_db) | |
): | |
"""Get images grouped by shared captions for multi-upload items with pagination and filtering""" | |
# Validate pagination parameters | |
if page < 1: | |
page = 1 | |
if limit < 1 or limit > 100: | |
limit = 10 | |
# Get all captions with their associated images | |
captions = crud.get_all_captions_with_images(db) | |
result = [] | |
for caption in captions: | |
if not caption.images: | |
continue | |
# Determine the effective image count for this caption | |
effective_image_count = caption.image_count if caption.image_count is not None and caption.image_count > 0 else len(caption.images) | |
# Apply filters | |
if search: | |
search_lower = search.lower() | |
if not (caption.title and search_lower in caption.title.lower()) and \ | |
not (caption.generated and search_lower in caption.generated.lower()): | |
continue | |
if starred_only and not caption.starred: | |
continue | |
if effective_image_count > 1: | |
# This is a multi-upload item, group them together | |
first_img = caption.images[0] | |
# Apply source filter | |
if source: | |
if not any(source in img.source for img in caption.images if img.source): | |
continue | |
# Apply event_type filter | |
if event_type: | |
if not any(event_type in img.event_type for img in caption.images if img.event_type): | |
continue | |
# Apply image_type filter | |
if image_type: | |
if not any(img.image_type == image_type for img in caption.images): | |
continue | |
# Apply upload_type filter | |
if upload_type: | |
if upload_type == 'single' and effective_image_count > 1: | |
continue | |
if upload_type == 'multiple' and effective_image_count <= 1: | |
continue | |
# Apply region/country filter | |
if region or country: | |
has_matching_country = False | |
for img in caption.images: | |
for img_country in img.countries: | |
if region and img_country.r_code == region: | |
has_matching_country = True | |
break | |
if country and img_country.c_code == country: | |
has_matching_country = True | |
break | |
if has_matching_country: | |
break | |
if not has_matching_country: | |
continue | |
# Combine metadata from all images | |
combined_source = set() | |
combined_event_type = set() | |
combined_epsg = set() | |
for img in caption.images: | |
if img.source: | |
combined_source.add(img.source) | |
if img.event_type: | |
combined_event_type.add(img.event_type) | |
if img.epsg: | |
combined_epsg.add(img.epsg) | |
# Create a combined image dict using the first image as a template | |
img_dict = convert_image_to_dict(first_img, f"/api/images/{first_img.image_id}/file") | |
# Override with combined metadata | |
img_dict["source"] = ", ".join(sorted(list(combined_source))) if combined_source else "OTHER" | |
img_dict["event_type"] = ", ".join(sorted(list(combined_event_type))) if combined_event_type else "OTHER" | |
img_dict["epsg"] = ", ".join(sorted(list(combined_epsg))) if combined_epsg else "OTHER" | |
# Update countries to include all unique countries | |
all_countries = [] | |
for img in caption.images: | |
for country_obj in img.countries: | |
if not any(c["c_code"] == country_obj.c_code for c in all_countries): | |
all_countries.append({"c_code": country_obj.c_code, "label": country_obj.label, "r_code": country_obj.r_code}) | |
img_dict["countries"] = all_countries | |
# Add all image IDs for reference | |
img_dict["all_image_ids"] = [str(img.image_id) for img in caption.images] | |
img_dict["image_count"] = effective_image_count | |
# Set caption-level fields | |
img_dict["title"] = caption.title | |
img_dict["prompt"] = caption.prompt | |
img_dict["model"] = caption.model | |
img_dict["schema_id"] = caption.schema_id | |
img_dict["raw_json"] = caption.raw_json | |
img_dict["generated"] = caption.generated | |
img_dict["edited"] = caption.edited | |
img_dict["accuracy"] = caption.accuracy | |
img_dict["context"] = caption.context | |
img_dict["usability"] = caption.usability | |
img_dict["starred"] = caption.starred | |
img_dict["created_at"] = caption.created_at | |
img_dict["updated_at"] = caption.updated_at | |
result.append(schemas.ImageOut(**img_dict)) | |
else: | |
# For single images, apply filters | |
img = caption.images[0] | |
if source and img.source != source: | |
continue | |
if event_type and img.event_type != event_type: | |
continue | |
if image_type and img.image_type != image_type: | |
continue | |
if upload_type == 'multiple': | |
continue | |
# Apply region/country filter | |
if region or country: | |
has_matching_country = False | |
for img_country in img.countries: | |
if region and img_country.r_code == region: | |
has_matching_country = True | |
break | |
if country and img_country.c_code == country: | |
has_matching_country = True | |
break | |
if not has_matching_country: | |
continue | |
img_dict = convert_image_to_dict(img, f"/api/images/{img.image_id}/file") | |
img_dict["all_image_ids"] = [str(img.image_id)] | |
img_dict["image_count"] = 1 | |
# Set caption-level fields | |
img_dict["title"] = caption.title | |
img_dict["prompt"] = caption.prompt | |
img_dict["model"] = caption.model | |
img_dict["schema_id"] = caption.schema_id | |
img_dict["raw_json"] = caption.raw_json | |
img_dict["generated"] = caption.generated | |
img_dict["edited"] = caption.edited | |
img_dict["accuracy"] = caption.accuracy | |
img_dict["context"] = caption.context | |
img_dict["usability"] = caption.usability | |
img_dict["starred"] = caption.starred | |
img_dict["created_at"] = caption.created_at | |
img_dict["updated_at"] = caption.updated_at | |
result.append(schemas.ImageOut(**img_dict)) | |
# Apply pagination | |
total_count = len(result) | |
start_index = (page - 1) * limit | |
end_index = start_index + limit | |
paginated_result = result[start_index:end_index] | |
return paginated_result | |
def get_images_grouped_count( | |
search: str = None, | |
source: str = None, | |
event_type: str = None, | |
region: str = None, | |
country: str = None, | |
image_type: str = None, | |
upload_type: str = None, | |
starred_only: bool = False, | |
db: Session = Depends(get_db) | |
): | |
"""Get total count of images for pagination""" | |
# Get all captions with their associated images | |
captions = crud.get_all_captions_with_images(db) | |
count = 0 | |
for caption in captions: | |
if not caption.images: | |
continue | |
# Determine the effective image count for this caption | |
effective_image_count = caption.image_count if caption.image_count is not None and caption.image_count > 0 else len(caption.images) | |
# Apply filters (same logic as above) | |
if search: | |
search_lower = search.lower() | |
if not (caption.title and search_lower in caption.title.lower()) and \ | |
not (caption.generated and search_lower in caption.generated.lower()): | |
continue | |
if starred_only and not caption.starred: | |
continue | |
if effective_image_count > 1: | |
# Multi-upload item | |
first_img = caption.images[0] | |
# Apply filters | |
if source: | |
if not any(source in img.source for img in caption.images if img.source): | |
continue | |
if event_type: | |
if not any(event_type in img.event_type for img in caption.images if img.event_type): | |
continue | |
if image_type: | |
if not any(img.image_type == image_type for img in caption.images): | |
continue | |
if upload_type: | |
if upload_type == 'single' and effective_image_count > 1: | |
continue | |
if upload_type == 'multiple' and effective_image_count <= 1: | |
continue | |
if region or country: | |
has_matching_country = False | |
for img in caption.images: | |
for img_country in img.countries: | |
if region and img_country.r_code == region: | |
has_matching_country = True | |
break | |
if country and img_country.c_code == country: | |
has_matching_country = True | |
break | |
if has_matching_country: | |
break | |
if not has_matching_country: | |
continue | |
count += 1 | |
else: | |
# Single image | |
img = caption.images[0] | |
if source and img.source != source: | |
continue | |
if event_type and img.event_type != event_type: | |
continue | |
if image_type and img.image_type != image_type: | |
continue | |
if upload_type == 'multiple': | |
continue | |
if region or country: | |
has_matching_country = False | |
for img_country in img.countries: | |
if region and img_country.r_code == region: | |
has_matching_country = True | |
break | |
if country and img_country.c_code == country: | |
has_matching_country = True | |
break | |
if not has_matching_country: | |
continue | |
count += 1 | |
return {"total_count": count} | |
def get_image(image_id: str, db: Session = Depends(get_db)): | |
"""Get a single image by ID""" | |
# Validate image_id before querying database | |
if not image_id or image_id in ['undefined', 'null', '']: | |
raise HTTPException(400, "Invalid image ID") | |
# Validate UUID format | |
import re | |
uuid_pattern = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', re.IGNORECASE) | |
if not uuid_pattern.match(image_id): | |
raise HTTPException(400, "Invalid image ID format") | |
img = crud.get_image(db, image_id) # This loads captions | |
if not img: | |
raise HTTPException(404, "Image not found") | |
img_dict = convert_image_to_dict(img, f"/api/images/{img.image_id}/file") | |
# Enhance img_dict with multi-upload specific fields if applicable | |
if img.captions: | |
# Assuming an image is primarily associated with one "grouping" caption for multi-uploads | |
# We take the first caption and check its linked images | |
main_caption = img.captions[0] | |
# Refresh the caption to ensure its images relationship is loaded if not already | |
db.refresh(main_caption) | |
if main_caption.images: | |
all_linked_image_ids = [str(linked_img.image_id) for linked_img in main_caption.images] | |
effective_image_count = main_caption.image_count if main_caption.image_count is not None and main_caption.image_count > 0 else len(main_caption.images) | |
if effective_image_count > 1: | |
img_dict["all_image_ids"] = all_linked_image_ids | |
img_dict["image_count"] = effective_image_count | |
else: | |
# Even for single images, explicitly set image_count to 1 | |
img_dict["image_count"] = 1 | |
img_dict["all_image_ids"] = [str(img.image_id)] # Ensure it's an array for consistency | |
else: | |
# If caption has no linked images (shouldn't happen for valid data, but for robustness) | |
img_dict["image_count"] = 1 | |
img_dict["all_image_ids"] = [str(img.image_id)] | |
else: | |
# If image has no captions, it's a single image by default | |
img_dict["image_count"] = 1 | |
img_dict["all_image_ids"] = [str(img.image_id)] | |
return schemas.ImageOut(**img_dict) | |
async def upload_image( | |
source: Optional[str] = Form(default=None), | |
event_type: str = Form(default="OTHER"), | |
countries: str = Form(default=""), | |
epsg: str = Form(default=""), | |
image_type: str = Form(default="crisis_map"), | |
file: UploadFile = Form(...), | |
title: str = Form(default=""), | |
model_name: Optional[str] = Form(default=None), | |
# Drone-specific fields (optional) | |
center_lon: Optional[float] = Form(default=None), | |
center_lat: Optional[float] = Form(default=None), | |
amsl_m: Optional[float] = Form(default=None), | |
agl_m: Optional[float] = Form(default=None), | |
heading_deg: Optional[float] = Form(default=None), | |
yaw_deg: Optional[float] = Form(default=None), | |
pitch_deg: Optional[float] = Form(default=None), | |
roll_deg: Optional[float] = Form(default=None), | |
rtk_fix: Optional[bool] = Form(default=None), | |
std_h_m: Optional[float] = Form(default=None), | |
std_v_m: Optional[float] = Form(default=None), | |
db: Session = Depends(get_db) | |
): | |
countries_list = [c.strip() for c in countries.split(',') if c.strip()] if countries else [] | |
if image_type == "drone_image": | |
if not event_type or event_type.strip() == "": | |
event_type = "OTHER" | |
if not epsg or epsg.strip() == "": | |
epsg = "OTHER" | |
else: | |
if not source or source.strip() == "": | |
source = "OTHER" | |
if not event_type or event_type.strip() == "": | |
event_type = "OTHER" | |
if not epsg or epsg.strip() == "": | |
epsg = "OTHER" | |
if not image_type or image_type.strip() == "": | |
image_type = "crisis_map" | |
if image_type != "drone_image": | |
center_lon = None | |
center_lat = None | |
amsl_m = None | |
agl_m = None | |
heading_deg = None | |
yaw_deg = None | |
pitch_deg = None | |
roll_deg = None | |
rtk_fix = None | |
std_h_m = None | |
std_v_m = None | |
content = await file.read() | |
# Preprocess image if needed | |
try: | |
processed_content, processed_filename, mime_type = ImagePreprocessor.preprocess_image( | |
content, | |
file.filename, | |
target_format='PNG', # Default to PNG for better quality | |
quality=95 | |
) | |
# Log preprocessing info | |
preprocessing_info = None | |
if processed_filename != file.filename: | |
print(f"Image preprocessed: {file.filename} -> {processed_filename} ({mime_type})") | |
preprocessing_info = { | |
"original_filename": file.filename, | |
"processed_filename": processed_filename, | |
"original_mime_type": ImagePreprocessor.detect_mime_type(content, file.filename), | |
"processed_mime_type": mime_type, | |
"was_preprocessed": True | |
} | |
else: | |
preprocessing_info = { | |
"original_filename": file.filename, | |
"processed_filename": file.filename, | |
"original_mime_type": mime_type, | |
"processed_mime_type": mime_type, | |
"was_preprocessed": False | |
} | |
except Exception as e: | |
print(f"Image preprocessing failed: {str(e)}") | |
# Fall back to original content if preprocessing fails | |
processed_content = content | |
processed_filename = file.filename | |
mime_type = 'image/png' # Default fallback | |
preprocessing_info = { | |
"original_filename": file.filename, | |
"processed_filename": file.filename, | |
"original_mime_type": "unknown", | |
"processed_mime_type": mime_type, | |
"was_preprocessed": False, | |
"error": str(e) | |
} | |
sha = crud.hash_bytes(processed_content) | |
key = storage.upload_fileobj(io.BytesIO(processed_content), processed_filename) | |
# Generate and upload all image resolutions | |
thumbnail_key = None | |
thumbnail_sha256 = None | |
detail_key = None | |
detail_sha256 = None | |
try: | |
# Process both thumbnail and detail versions | |
thumbnail_result, detail_result = ImageProcessingService.process_all_resolutions( | |
processed_content, | |
processed_filename | |
) | |
if thumbnail_result: | |
thumbnail_key, thumbnail_sha256 = thumbnail_result | |
print(f"Thumbnail generated and uploaded: key={thumbnail_key}, sha256={thumbnail_sha256}") | |
if detail_result: | |
detail_key, detail_sha256 = detail_result | |
print(f"Detail version generated and uploaded: key={detail_key}, sha256={detail_sha256}") | |
except Exception as e: | |
print(f"Image resolution processing failed: {str(e)}") | |
# Continue without processed versions if generation fails | |
try: | |
img = crud.create_image( | |
db, source, event_type, key, sha, countries_list, epsg, image_type, | |
center_lon, center_lat, amsl_m, agl_m, heading_deg, yaw_deg, pitch_deg, roll_deg, | |
rtk_fix, std_h_m, std_v_m, | |
thumbnail_key=thumbnail_key, thumbnail_sha256=thumbnail_sha256, | |
detail_key=detail_key, detail_sha256=detail_sha256 | |
) | |
except Exception as e: | |
raise HTTPException(500, f"Failed to save image to database: {str(e)}") | |
try: | |
url = storage.get_object_url(key) | |
except Exception as e: | |
url = f"/api/images/{img.image_id}/file" | |
# Create caption using VLM | |
prompt_obj = crud.get_active_prompt_by_image_type(db, image_type) | |
if not prompt_obj: | |
raise HTTPException(400, f"No active prompt found for image type '{image_type}'") | |
prompt_text = prompt_obj.label | |
metadata_instructions = prompt_obj.metadata_instructions or "" | |
try: | |
from ..services.vlm_service import vlm_manager | |
result = await vlm_manager.generate_caption( | |
image_bytes=processed_content, | |
prompt=prompt_text, | |
metadata_instructions=metadata_instructions, | |
model_name=model_name, | |
db_session=db, | |
) | |
raw = result.get("raw_response", {}) | |
text = result.get("caption", "") | |
metadata = result.get("metadata", {}) | |
actual_model = result.get("model", model_name) | |
final_model_name = actual_model if actual_model != "random" else "STUB_MODEL" | |
caption = crud.create_caption( | |
db, | |
image_id=img.image_id, | |
title=title, | |
prompt=prompt_obj.p_code, | |
model_code=final_model_name, | |
raw_json=raw, | |
text=text, | |
metadata=metadata, | |
image_count=1 | |
) | |
except Exception as e: | |
print(f"VLM caption generation failed: {str(e)}") | |
# Continue without caption if VLM fails | |
img_dict = convert_image_to_dict(img, url) | |
# Add preprocessing info to the response | |
img_dict['preprocessing_info'] = preprocessing_info | |
result = schemas.ImageOut(**img_dict) | |
return result | |
async def upload_multiple_images( | |
files: List[UploadFile] = Form(...), | |
source: Optional[str] = Form(default=None), | |
event_type: str = Form(default="OTHER"), | |
countries: str = Form(default=""), | |
epsg: str = Form(default=""), | |
image_type: str = Form(default="crisis_map"), | |
title: str = Form(...), | |
model_name: Optional[str] = Form(default=None), | |
# Drone-specific fields (optional) | |
center_lon: Optional[float] = Form(default=None), | |
center_lat: Optional[float] = Form(default=None), | |
amsl_m: Optional[float] = Form(default=None), | |
agl_m: Optional[float] = Form(default=None), | |
heading_deg: Optional[float] = Form(default=None), | |
yaw_deg: Optional[float] = Form(default=None), | |
pitch_deg: Optional[float] = Form(default=None), | |
roll_deg: Optional[float] = Form(default=None), | |
rtk_fix: Optional[bool] = Form(default=None), | |
std_h_m: Optional[float] = Form(default=None), | |
std_v_m: Optional[float] = Form(default=None), | |
db: Session = Depends(get_db) | |
): | |
"""Upload multiple images and create a single caption for all of them""" | |
if len(files) > 5: | |
raise HTTPException(400, "Maximum 5 images allowed") | |
if len(files) < 1: | |
raise HTTPException(400, "At least one image required") | |
countries_list = [c.strip() for c in countries.split(',') if c.strip()] if countries else [] | |
if image_type == "drone_image": | |
if not event_type or event_type.strip() == "": | |
event_type = "OTHER" | |
if not epsg or epsg.strip() == "": | |
epsg = "OTHER" | |
else: | |
if not source or source.strip() == "": | |
source = "OTHER" | |
if not event_type or event_type.strip() == "": | |
event_type = "OTHER" | |
if not epsg or epsg.strip() == "": | |
epsg = "OTHER" | |
if not image_type or image_type.strip() == "": | |
image_type = "crisis_map" | |
if image_type != "drone_image": | |
center_lon = None | |
center_lat = None | |
amsl_m = None | |
agl_m = None | |
heading_deg = None | |
yaw_deg = None | |
pitch_deg = None | |
roll_deg = None | |
rtk_fix = None | |
std_h_m = None | |
std_v_m = None | |
uploaded_images = [] | |
image_bytes_list = [] | |
# Process each file | |
for file in files: | |
content = await file.read() | |
# Preprocess image if needed | |
try: | |
processed_content, processed_filename, mime_type = ImagePreprocessor.preprocess_image( | |
content, | |
file.filename, | |
target_format='PNG', | |
quality=95 | |
) | |
except Exception as e: | |
print(f"Image preprocessing failed: {str(e)}") | |
processed_content = content | |
processed_filename = file.filename | |
mime_type = 'image/png' | |
sha = crud.hash_bytes(processed_content) | |
key = storage.upload_fileobj(io.BytesIO(processed_content), processed_filename) | |
# Create image record | |
img = crud.create_image( | |
db, source, event_type, key, sha, countries_list, epsg, image_type, | |
center_lon, center_lat, amsl_m, agl_m, heading_deg, yaw_deg, pitch_deg, roll_deg, | |
rtk_fix, std_h_m, std_v_m | |
) | |
uploaded_images.append(img) | |
image_bytes_list.append(processed_content) | |
# Get the first image for URL generation (they all share the same metadata) | |
first_img = uploaded_images[0] | |
try: | |
url = storage.get_object_url(first_img.file_key) | |
except Exception as e: | |
url = f"/api/images/{first_img.image_id}/file" | |
# Create caption for all images | |
# Use the model_name parameter from the request, or let VLM manager choose the best available model | |
prompt_obj = crud.get_active_prompt_by_image_type(db, image_type) | |
if not prompt_obj: | |
raise HTTPException(400, f"No active prompt found for image type '{image_type}'") | |
prompt_text = prompt_obj.label | |
metadata_instructions = prompt_obj.metadata_instructions or "" | |
# Add system instruction for multiple images | |
multi_image_instruction = f"\n\nIMPORTANT: You are analyzing {len(image_bytes_list)} images. Please provide a combined analysis that covers all images together. In your metadata section, provide separate metadata for each image:\n- 'title': ONE shared title for all images\n- 'metadata_images': an object containing individual metadata for each image:\n - 'image1': {{ 'source': 'data source', 'type': 'event type', 'countries': ['country codes'], 'epsg': 'spatial reference' }}\n - 'image2': {{ 'source': 'data source', 'type': 'event type', 'countries': ['country codes'], 'epsg': 'spatial reference' }}\n - etc. for each image\n\nEach image should have its own source, type, countries, and epsg values based on what that specific image shows." | |
metadata_instructions += multi_image_instruction | |
try: | |
from ..services.vlm_service import vlm_manager | |
result = await vlm_manager.generate_multi_image_caption( | |
image_bytes_list=image_bytes_list, | |
prompt=prompt_text, | |
metadata_instructions=metadata_instructions, | |
model_name=model_name, | |
db_session=db, | |
) | |
raw = result.get("raw_response", {}) | |
text = result.get("caption", "") | |
metadata = result.get("metadata", {}) | |
# Use the actual model that was used by the VLM service | |
actual_model = result.get("model", model_name) | |
# Update individual image metadata if VLM provided it | |
metadata_images = metadata.get("metadata_images", {}) | |
if metadata_images and isinstance(metadata_images, dict): | |
for i, img in enumerate(uploaded_images): | |
image_key = f"image{i+1}" | |
if image_key in metadata_images: | |
img_metadata = metadata_images[image_key] | |
if isinstance(img_metadata, dict): | |
# Update image with individual metadata | |
img.source = img_metadata.get("source", img.source) | |
img.event_type = img_metadata.get("type", img.event_type) | |
img.epsg = img_metadata.get("epsg", img.epsg) | |
img.countries = img_metadata.get("countries", img.countries) | |
# Ensure we never use 'random' as the model name in the database | |
final_model_name = actual_model if actual_model != "random" else "STUB_MODEL" | |
# Create caption linked to the first image | |
caption = crud.create_caption( | |
db, | |
image_id=first_img.image_id, | |
title=title, | |
prompt=prompt_obj.p_code, | |
model_code=final_model_name, | |
raw_json=raw, | |
text=text, | |
metadata=metadata, | |
image_count=len(image_bytes_list) | |
) | |
# Link caption to all images | |
for img in uploaded_images[1:]: | |
img.captions.append(caption) | |
db.commit() | |
except Exception as e: | |
print(f"VLM error: {e}") | |
# Create fallback caption | |
fallback_text = f"Analysis of {len(image_bytes_list)} images" | |
caption = crud.create_caption( | |
db, | |
image_id=first_img.image_id, | |
title=title, | |
prompt=prompt_obj.p_code, | |
model_code="FALLBACK", | |
raw_json={"error": str(e), "fallback": True}, | |
text=fallback_text, | |
metadata={}, | |
image_count=len(image_bytes_list) | |
) | |
# Link caption to all images | |
for img in uploaded_images[1:]: | |
img.captions.append(caption) | |
db.commit() | |
img_dict = convert_image_to_dict(first_img, url) | |
# Add all image IDs to the response for multi-image uploads | |
if len(uploaded_images) > 1: | |
img_dict["all_image_ids"] = [str(img.image_id) for img in uploaded_images] | |
img_dict["image_count"] = len(uploaded_images) | |
result = schemas.ImageOut(**img_dict) | |
return result | |
async def copy_image_for_contribution( | |
request: CopyImageRequest, | |
db: Session = Depends(get_db) | |
): | |
"""Copy an existing image for contribution purposes, creating a new image_id""" | |
source_img = crud.get_image(db, request.source_image_id) | |
if not source_img: | |
raise HTTPException(404, "Source image not found") | |
try: | |
if hasattr(storage, 's3') and settings.STORAGE_PROVIDER != "local": | |
response = storage.s3.get_object( | |
Bucket=settings.S3_BUCKET, | |
Key=source_img.file_key, | |
) | |
image_content = response["Body"].read() | |
else: | |
import os | |
file_path = os.path.join(settings.STORAGE_DIR, source_img.file_key) | |
with open(file_path, 'rb') as f: | |
image_content = f.read() | |
new_filename = f"contribution_{request.source_image_id}_{int(time.time())}.jpg" | |
new_key = storage.upload_fileobj(io.BytesIO(image_content), new_filename) | |
countries_list = [c.strip() for c in request.countries.split(',') if c.strip()] if request.countries else [] | |
new_img = crud.create_image( | |
db, | |
request.source, | |
request.event_type, | |
new_key, | |
source_img.sha256, | |
countries_list, | |
request.epsg, | |
request.image_type, | |
request.center_lon, request.center_lat, request.amsl_m, request.agl_m, | |
request.heading_deg, request.yaw_deg, request.pitch_deg, request.roll_deg, | |
request.rtk_fix, request.std_h_m, request.std_v_m | |
) | |
try: | |
url = storage.get_object_url(new_key) | |
except Exception as e: | |
url = f"/api/images/{new_img.image_id}/file" | |
img_dict = convert_image_to_dict(new_img, url) | |
result = schemas.ImageOut(**img_dict) | |
return result | |
except Exception as e: | |
raise HTTPException(500, f"Failed to copy image: {str(e)}") | |
async def get_image_file(image_id: str, db: Session = Depends(get_db)): | |
"""Serve the actual image file""" | |
print(f"🔍 Serving image file for image_id: {image_id}") | |
img = crud.get_image(db, image_id) | |
if not img: | |
print(f"❌ Image not found: {image_id}") | |
raise HTTPException(404, "Image not found") | |
print(f"✅ Found image: {img.image_id}, file_key: {img.file_key}") | |
try: | |
if hasattr(storage, 's3') and settings.STORAGE_PROVIDER != "local": | |
print(f"�� Using S3 storage - serving file content directly") | |
try: | |
response = storage.s3.get_object(Bucket=settings.S3_BUCKET, Key=img.file_key) | |
content = response['Body'].read() | |
print(f"✅ Read {len(content)} bytes from S3") | |
except Exception as e: | |
print(f"❌ Failed to get S3 object: {e}") | |
raise HTTPException(500, f"Failed to retrieve image from storage: {e}") | |
else: | |
print(f"🔍 Using local storage") | |
import os | |
file_path = os.path.join(settings.STORAGE_DIR, img.file_key) | |
print(f"📁 Reading from: {file_path}") | |
print(f"📁 File exists: {os.path.exists(file_path)}") | |
if not os.path.exists(file_path): | |
print(f"❌ File not found at: {file_path}") | |
raise FileNotFoundError(f"Image file not found: {file_path}") | |
with open(file_path, 'rb') as f: | |
content = f.read() | |
print(f"✅ Read {len(content)} bytes from file") | |
import mimetypes | |
content_type, _ = mimetypes.guess_type(img.file_key) | |
if not content_type: | |
content_type = 'application/octet-stream' | |
print(f"✅ Serving image with content-type: {content_type}, size: {len(content)} bytes") | |
return Response(content=content, media_type=content_type) | |
except Exception as e: | |
print(f"❌ Error serving image: {e}") | |
import traceback | |
print(f"🔍 Full traceback: {traceback.format_exc()}") | |
raise HTTPException(500, f"Failed to serve image file: {e}") | |
def update_image_metadata( | |
image_id: str, | |
metadata: schemas.ImageMetadataUpdate, | |
db: Session = Depends(get_db) | |
): | |
"""Update image metadata (source, type, epsg, image_type, countries)""" | |
print(f"DEBUG: Updating metadata for image {image_id}") | |
print(f"DEBUG: Metadata received: {metadata}") | |
img = crud.get_image(db, image_id) | |
if not img: | |
print(f"DEBUG: Image {image_id} not found in database") | |
raise HTTPException(404, "Image not found") | |
print(f"DEBUG: Found image {image_id} in database") | |
try: | |
if metadata.source is not None: | |
img.source = metadata.source | |
if metadata.event_type is not None: | |
img.event_type = metadata.event_type | |
if metadata.epsg is not None: | |
img.epsg = metadata.epsg | |
if metadata.image_type is not None: | |
img.image_type = metadata.image_type | |
# Handle starred field - update the first caption's starred status | |
if metadata.starred is not None: | |
if img.captions: | |
# Update the first caption's starred status | |
img.captions[0].starred = metadata.starred | |
else: | |
# If no captions exist, create a minimal caption with starred status | |
from app import models | |
caption = models.Captions( | |
title="", | |
starred=metadata.starred, | |
created_at=datetime.datetime.utcnow() | |
) | |
db.add(caption) | |
img.captions.append(caption) | |
# Update drone-specific fields | |
if metadata.center_lon is not None: | |
img.center_lon = metadata.center_lon | |
if metadata.center_lat is not None: | |
img.center_lat = metadata.center_lat | |
if metadata.amsl_m is not None: | |
img.amsl_m = metadata.amsl_m | |
if metadata.agl_m is not None: | |
img.agl_m = metadata.agl_m | |
if metadata.heading_deg is not None: | |
img.heading_deg = metadata.heading_deg | |
if metadata.yaw_deg is not None: | |
img.yaw_deg = metadata.yaw_deg | |
if metadata.pitch_deg is not None: | |
img.pitch_deg = metadata.pitch_deg | |
if metadata.roll_deg is not None: | |
img.roll_deg = metadata.roll_deg | |
if metadata.rtk_fix is not None: | |
img.rtk_fix = metadata.rtk_fix | |
if metadata.std_h_m is not None: | |
img.std_h_m = metadata.std_h_m | |
if metadata.std_v_m is not None: | |
img.std_v_m = metadata.std_v_m | |
if metadata.countries is not None: | |
print(f"DEBUG: Updating countries to: {metadata.countries}") | |
img.countries.clear() | |
for country_code in metadata.countries: | |
country = crud.get_country(db, country_code) | |
if country: | |
img.countries.append(country) | |
print(f"DEBUG: Added country: {country_code}") | |
db.commit() | |
db.refresh(img) | |
print(f"DEBUG: Metadata update successful for image {image_id}") | |
try: | |
url = storage.get_object_url(img.file_key) | |
except Exception: | |
url = f"/api/images/{img.image_id}/file" | |
img_dict = convert_image_to_dict(img, url) | |
return schemas.ImageOut(**img_dict) | |
except Exception as e: | |
db.rollback() | |
print(f"DEBUG: Metadata update failed for image {image_id}: {str(e)}") | |
raise HTTPException(500, f"Failed to update image metadata: {str(e)}") | |
def delete_image(image_id: str, db: Session = Depends(get_db), content_management: bool = False): | |
"""Delete an image and its associated caption data | |
Args: | |
image_id: The ID of the image to delete | |
content_management: If True, this is a content management delete (from map details) | |
If False, this is a user dissatisfaction delete (from upload flow) | |
""" | |
img = crud.get_image(db, image_id) | |
if not img: | |
raise HTTPException(404, "Image not found") | |
# Only increment delete count if this is NOT a content management delete | |
# Content management deletes (from map details) should not count against model performance | |
if not content_management and img.captions: | |
# Get model from the first caption | |
model_name = img.captions[0].model | |
if model_name: | |
from .. import crud as crud_module | |
model = crud_module.get_model(db, model_name) | |
if model: | |
model.delete_count += 1 | |
db.commit() | |
db.delete(img) | |
db.commit() | |
return {"message": "Image deleted successfully"} | |
async def preprocess_image_only( | |
file: UploadFile = Form(...), | |
preprocess_only: bool = Form(False) | |
): | |
"""Preprocess image without storing it - returns processed file data""" | |
try: | |
# Read file content | |
file_content = await file.read() | |
# Preprocess the image | |
processed_content, processed_filename, processed_mime_type = ImagePreprocessor.preprocess_image( | |
file_content, | |
file.filename or "unknown", | |
target_format='PNG', | |
quality=95 | |
) | |
# Check if preprocessing actually occurred | |
was_preprocessed = ( | |
processed_filename != (file.filename or "unknown") or | |
processed_mime_type != file.content_type | |
) | |
# Encode processed content as base64 for JSON response | |
processed_content_b64 = base64.b64encode(processed_content).decode('utf-8') | |
# Create preprocessing info | |
preprocessing_info = { | |
"original_filename": file.filename or "unknown", | |
"processed_filename": processed_filename, | |
"original_mime_type": file.content_type or "application/octet-stream", | |
"processed_mime_type": processed_mime_type, | |
"was_preprocessed": was_preprocessed | |
} | |
# Return processed file data | |
return { | |
"processed_content": processed_content_b64, | |
"processed_filename": processed_filename, | |
"processed_mime_type": processed_mime_type, | |
"preprocessing_info": preprocessing_info, | |
"was_preprocessed": was_preprocessed | |
} | |
except Exception as e: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Preprocessing failed: {str(e)}" | |
) | |