SCGR's picture
dynamic json schema
76f5d42
import io, hashlib
from typing import Optional, List
from sqlalchemy.orm import Session, joinedload
from . import models, schemas
from fastapi import HTTPException
def hash_bytes(data: bytes) -> str:
"""Compute SHA-256 hex digest of the data."""
return hashlib.sha256(data).hexdigest()
def create_image(db: Session, src, type_code, key, sha, countries: list[str], epsg: Optional[str], image_type: str,
center_lon: Optional[float] = None, center_lat: Optional[float] = None,
amsl_m: Optional[float] = None, agl_m: Optional[float] = None,
heading_deg: Optional[float] = None, yaw_deg: Optional[float] = None,
pitch_deg: Optional[float] = None, roll_deg: Optional[float] = None,
rtk_fix: Optional[bool] = None, std_h_m: Optional[float] = None, std_v_m: Optional[float] = None,
thumbnail_key: Optional[str] = None, thumbnail_sha256: Optional[str] = None,
detail_key: Optional[str] = None, detail_sha256: Optional[str] = None):
"""Insert into images and image_countries."""
if image_type == "drone_image":
if type_code is None:
type_code = "OTHER"
if epsg is None:
epsg = "OTHER"
else:
if src is None:
src = "OTHER"
if type_code is None:
type_code = "OTHER"
if epsg is None:
epsg = "OTHER"
if image_type != "drone_image":
center_lon = None
center_lat = None
amsl_m = None
agl_m = None
heading_deg = None
yaw_deg = None
pitch_deg = None
roll_deg = None
rtk_fix = None
std_h_m = None
std_v_m = None
img = models.Images(
source=src, event_type=type_code,
file_key=key, sha256=sha, thumbnail_key=thumbnail_key, thumbnail_sha256=thumbnail_sha256,
detail_key=detail_key, detail_sha256=detail_sha256, epsg=epsg, image_type=image_type,
center_lon=center_lon, center_lat=center_lat, amsl_m=amsl_m, agl_m=agl_m,
heading_deg=heading_deg, yaw_deg=yaw_deg, pitch_deg=pitch_deg, roll_deg=roll_deg,
rtk_fix=rtk_fix, std_h_m=std_h_m, std_v_m=std_v_m
)
db.add(img)
db.flush()
for c in countries:
country = db.get(models.Country, c)
if country:
img.countries.append(country)
db.commit()
db.refresh(img)
return img
def get_images(db: Session):
"""Get all images with their countries and captions"""
return (
db.query(models.Images)
.options(
joinedload(models.Images.countries),
joinedload(models.Images.captions).joinedload(models.Captions.images),
)
.all()
)
def get_image(db: Session, image_id: str):
"""Get a single image by ID with its countries and captions"""
return (
db.query(models.Images)
.options(
joinedload(models.Images.countries),
joinedload(models.Images.captions).joinedload(models.Captions.images),
)
.filter(models.Images.image_id == image_id)
.first()
)
def create_caption(db: Session, image_id, title, prompt, model_code, raw_json, text, metadata=None, image_count=None):
print(f"Creating caption for image_id: {image_id}")
print(f"Caption data: title={title}, prompt={prompt}, model={model_code}")
print(f"Database session ID: {id(db)}")
print(f"Database session is active: {db.is_active}")
if metadata:
raw_json["extracted_metadata"] = metadata
img = db.get(models.Images, image_id)
if not img:
raise HTTPException(404, "Image not found")
# Set schema based on image type
schema_id = "default_caption@1.0.0" # default
if img.image_type == "drone_image":
schema_id = "drone_caption@1.0.0"
caption = models.Captions(
title=title,
prompt=prompt,
model=model_code,
schema_id=schema_id,
raw_json=raw_json,
generated=text,
edited=text,
image_count=image_count
)
db.add(caption)
db.flush()
# Link caption to image
img.captions.append(caption)
print(f"About to commit caption to database...")
db.commit()
print(f"Caption commit successful!")
db.refresh(caption)
print(f"Caption created successfully for image: {img.image_id}")
return caption
def get_caption(db: Session, caption_id: str):
"""Get caption data for a specific caption ID"""
return db.get(models.Captions, caption_id)
def get_captions_by_image(db: Session, image_id: str):
"""Get all captions for a specific image"""
img = db.get(models.Images, image_id)
if img:
return img.captions
return []
def get_all_captions_with_images(db: Session):
"""Get all captions with their associated images"""
return (
db.query(models.Captions)
.options(
joinedload(models.Captions.images).joinedload(models.Images.countries),
)
.all()
)
def get_prompts(db: Session):
"""Get all available prompts"""
return db.query(models.Prompts).all()
def get_prompt(db: Session, p_code: str):
"""Get a specific prompt by code"""
return db.query(models.Prompts).filter(models.Prompts.p_code == p_code).first()
def get_prompt_by_label(db: Session, label: str):
"""Get a specific prompt by label text"""
return db.query(models.Prompts).filter(models.Prompts.label == label).first()
def get_active_prompt_by_image_type(db: Session, image_type: str):
"""Get the active prompt for a specific image type"""
return db.query(models.Prompts).filter(
models.Prompts.image_type == image_type,
models.Prompts.is_active == True
).first()
def toggle_prompt_active_status(db: Session, p_code: str, image_type: str):
"""Toggle the active status of a prompt for a specific image type"""
# Validate that the image_type exists
image_type_obj = db.query(models.ImageTypes).filter(models.ImageTypes.image_type == image_type).first()
if not image_type_obj:
raise ValueError(f"Invalid image_type: {image_type}")
# Get the prompt to toggle
prompt = db.query(models.Prompts).filter(models.Prompts.p_code == p_code).first()
if not prompt:
return None
# If the prompt is already active, deactivate it
if prompt.is_active:
prompt.is_active = False
db.commit()
db.refresh(prompt)
return prompt
# If the prompt is not active, first deactivate the currently active prompt
# then activate this one
current_active = db.query(models.Prompts).filter(
models.Prompts.image_type == image_type,
models.Prompts.is_active == True
).first()
if current_active:
current_active.is_active = False
# Commit the deactivation first to avoid constraint violation
db.commit()
prompt.is_active = True
db.commit()
db.refresh(prompt)
return prompt
def create_prompt(db: Session, prompt_data: schemas.PromptCreate):
"""Create a new prompt"""
# Validate that the image_type exists
image_type_obj = db.query(models.ImageTypes).filter(models.ImageTypes.image_type == prompt_data.image_type).first()
if not image_type_obj:
raise ValueError(f"Invalid image_type: {prompt_data.image_type}")
# Check if prompt code already exists
existing_prompt = db.query(models.Prompts).filter(models.Prompts.p_code == prompt_data.p_code).first()
if existing_prompt:
raise ValueError(f"Prompt with code '{prompt_data.p_code}' already exists")
# If this prompt is set as active, deactivate the currently active prompt for this image type
if prompt_data.is_active:
current_active = db.query(models.Prompts).filter(
models.Prompts.image_type == prompt_data.image_type,
models.Prompts.is_active == True
).first()
if current_active:
current_active.is_active = False
# Commit the deactivation first to avoid constraint violation
db.commit()
# Create the new prompt
new_prompt = models.Prompts(
p_code=prompt_data.p_code,
label=prompt_data.label,
metadata_instructions=prompt_data.metadata_instructions,
image_type=prompt_data.image_type,
is_active=prompt_data.is_active
)
db.add(new_prompt)
db.commit()
db.refresh(new_prompt)
return new_prompt
def update_prompt(db: Session, p_code: str, prompt_update: schemas.PromptUpdate):
"""Update a specific prompt by code"""
prompt = db.query(models.Prompts).filter(models.Prompts.p_code == p_code).first()
if not prompt:
return None
# Handle is_active field specially to maintain unique constraint
update_data = prompt_update.dict(exclude_unset=True)
# If we're setting this prompt as active, deactivate other prompts for this image type
if 'is_active' in update_data and update_data['is_active']:
current_active = db.query(models.Prompts).filter(
models.Prompts.image_type == prompt.image_type,
models.Prompts.is_active == True,
models.Prompts.p_code != p_code # Exclude current prompt
).first()
if current_active:
current_active.is_active = False
# Commit the deactivation first to avoid constraint violation
db.commit()
# Update all fields
for field, value in update_data.items():
setattr(prompt, field, value)
db.commit()
db.refresh(prompt)
return prompt
def update_caption(db: Session, caption_id: str, update: schemas.CaptionUpdate):
"""Update caption data for a caption"""
caption = db.get(models.Captions, caption_id)
if not caption:
return None
for field, value in update.dict(exclude_unset=True).items():
setattr(caption, field, value)
db.commit()
db.refresh(caption)
return caption
def delete_caption(db: Session, caption_id: str):
"""Delete caption data for a caption"""
caption = db.get(models.Captions, caption_id)
if not caption:
return False
db.delete(caption)
db.commit()
return True
def get_sources(db: Session):
"""Get all sources for lookup"""
return db.query(models.Source).all()
def get_regions(db: Session):
"""Get all regions for lookup"""
return db.query(models.Region).all()
def get_types(db: Session):
"""Get all types for lookup"""
return db.query(models.EventType).all()
def get_spatial_references(db: Session):
"""Get all spatial references for lookup"""
return db.query(models.SpatialReference).all()
def get_image_types(db: Session):
"""Get all image types for lookup"""
return db.query(models.ImageTypes).all()
def get_countries(db: Session):
"""Get all countries for lookup"""
return db.query(models.Country).all()
def get_country(db: Session, c_code: str):
"""Get a single country by code"""
return db.get(models.Country, c_code)
def get_models(db: Session):
"""Get all models"""
return db.query(models.Models).all()
def get_model(db: Session, m_code: str):
"""Get a specific model by code"""
return db.get(models.Models, m_code)
def create_model(db: Session, m_code: str, label: str, model_type: str, provider: str, model_id: str, is_available: bool = False):
"""Create a new model"""
new_model = models.Models(
m_code=m_code,
label=label,
model_type=model_type,
provider=provider,
model_id=model_id,
is_available=is_available
)
db.add(new_model)
db.commit()
db.refresh(new_model)
return new_model
def update_model(db: Session, m_code: str, update_data: dict):
"""Update an existing model"""
model = db.get(models.Models, m_code)
if not model:
return None
for field, value in update_data.items():
if hasattr(model, field):
setattr(model, field, value)
db.commit()
db.refresh(model)
return model
def delete_model(db: Session, m_code: str):
"""Hard delete a model by removing it from the database"""
model = db.get(models.Models, m_code)
if not model:
return False
# Remove the model from the database
db.delete(model)
db.commit()
return True
def get_all_schemas(db: Session):
"""Get all JSON schemas"""
return db.query(models.JSONSchema).all()
def get_schema(db: Session, schema_id: str):
"""Get a specific JSON schema by ID"""
return db.query(models.JSONSchema).filter(models.JSONSchema.schema_id == schema_id).first()
def get_schemas_by_image_type(db: Session, image_type: str):
"""Get all JSON schemas for a specific image type"""
return db.query(models.JSONSchema).filter(models.JSONSchema.image_type == image_type).all()
def get_recent_images_with_validation(db: Session, limit: int = 100):
"""Get recent images with validation info"""
return db.query(models.Images).order_by(models.Images.captured_at.desc()).limit(limit).all()
# Fallback model CRUD operations
def get_fallback_model(db: Session) -> Optional[str]:
"""Get the configured fallback model"""
fallback_model = db.query(models.Models).filter(models.Models.is_fallback == True).first()
return fallback_model.m_code if fallback_model else None
def set_fallback_model(db: Session, model_code: str):
"""Set the fallback model - ensures only one model can be fallback"""
# First, clear any existing fallback
db.query(models.Models).filter(models.Models.is_fallback == True).update({"is_fallback": False})
# Set the new fallback model
model = db.query(models.Models).filter(models.Models.m_code == model_code).first()
if model:
model.is_fallback = True
db.commit()
db.refresh(model)
return model
return None