import io, hashlib from typing import Optional, List from sqlalchemy.orm import Session, joinedload from . import models, schemas from fastapi import HTTPException def hash_bytes(data: bytes) -> str: """Compute SHA-256 hex digest of the data.""" return hashlib.sha256(data).hexdigest() def create_image(db: Session, src, type_code, key, sha, countries: list[str], epsg: Optional[str], image_type: str, center_lon: Optional[float] = None, center_lat: Optional[float] = None, amsl_m: Optional[float] = None, agl_m: Optional[float] = None, heading_deg: Optional[float] = None, yaw_deg: Optional[float] = None, pitch_deg: Optional[float] = None, roll_deg: Optional[float] = None, rtk_fix: Optional[bool] = None, std_h_m: Optional[float] = None, std_v_m: Optional[float] = None, thumbnail_key: Optional[str] = None, thumbnail_sha256: Optional[str] = None, detail_key: Optional[str] = None, detail_sha256: Optional[str] = None): """Insert into images and image_countries.""" if image_type == "drone_image": if type_code is None: type_code = "OTHER" if epsg is None: epsg = "OTHER" else: if src is None: src = "OTHER" if type_code is None: type_code = "OTHER" if epsg is None: epsg = "OTHER" if image_type != "drone_image": center_lon = None center_lat = None amsl_m = None agl_m = None heading_deg = None yaw_deg = None pitch_deg = None roll_deg = None rtk_fix = None std_h_m = None std_v_m = None img = models.Images( source=src, event_type=type_code, file_key=key, sha256=sha, thumbnail_key=thumbnail_key, thumbnail_sha256=thumbnail_sha256, detail_key=detail_key, detail_sha256=detail_sha256, epsg=epsg, image_type=image_type, center_lon=center_lon, center_lat=center_lat, amsl_m=amsl_m, agl_m=agl_m, heading_deg=heading_deg, yaw_deg=yaw_deg, pitch_deg=pitch_deg, roll_deg=roll_deg, rtk_fix=rtk_fix, std_h_m=std_h_m, std_v_m=std_v_m ) db.add(img) db.flush() for c in countries: country = db.get(models.Country, c) if country: img.countries.append(country) db.commit() db.refresh(img) return img def get_images(db: Session): """Get all images with their countries and captions""" return ( db.query(models.Images) .options( joinedload(models.Images.countries), joinedload(models.Images.captions).joinedload(models.Captions.images), ) .all() ) def get_image(db: Session, image_id: str): """Get a single image by ID with its countries and captions""" return ( db.query(models.Images) .options( joinedload(models.Images.countries), joinedload(models.Images.captions).joinedload(models.Captions.images), ) .filter(models.Images.image_id == image_id) .first() ) def create_caption(db: Session, image_id, title, prompt, model_code, raw_json, text, metadata=None, image_count=None): print(f"Creating caption for image_id: {image_id}") print(f"Caption data: title={title}, prompt={prompt}, model={model_code}") print(f"Database session ID: {id(db)}") print(f"Database session is active: {db.is_active}") if metadata: raw_json["extracted_metadata"] = metadata img = db.get(models.Images, image_id) if not img: raise HTTPException(404, "Image not found") # Set schema based on image type schema_id = "default_caption@1.0.0" # default if img.image_type == "drone_image": schema_id = "drone_caption@1.0.0" caption = models.Captions( title=title, prompt=prompt, model=model_code, schema_id=schema_id, raw_json=raw_json, generated=text, edited=text, image_count=image_count ) db.add(caption) db.flush() # Link caption to image img.captions.append(caption) print(f"About to commit caption to database...") db.commit() print(f"Caption commit successful!") db.refresh(caption) print(f"Caption created successfully for image: {img.image_id}") return caption def get_caption(db: Session, caption_id: str): """Get caption data for a specific caption ID""" return db.get(models.Captions, caption_id) def get_captions_by_image(db: Session, image_id: str): """Get all captions for a specific image""" img = db.get(models.Images, image_id) if img: return img.captions return [] def get_all_captions_with_images(db: Session): """Get all captions with their associated images""" return ( db.query(models.Captions) .options( joinedload(models.Captions.images).joinedload(models.Images.countries), ) .all() ) def get_prompts(db: Session): """Get all available prompts""" return db.query(models.Prompts).all() def get_prompt(db: Session, p_code: str): """Get a specific prompt by code""" return db.query(models.Prompts).filter(models.Prompts.p_code == p_code).first() def get_prompt_by_label(db: Session, label: str): """Get a specific prompt by label text""" return db.query(models.Prompts).filter(models.Prompts.label == label).first() def get_active_prompt_by_image_type(db: Session, image_type: str): """Get the active prompt for a specific image type""" return db.query(models.Prompts).filter( models.Prompts.image_type == image_type, models.Prompts.is_active == True ).first() def toggle_prompt_active_status(db: Session, p_code: str, image_type: str): """Toggle the active status of a prompt for a specific image type""" # Validate that the image_type exists image_type_obj = db.query(models.ImageTypes).filter(models.ImageTypes.image_type == image_type).first() if not image_type_obj: raise ValueError(f"Invalid image_type: {image_type}") # Get the prompt to toggle prompt = db.query(models.Prompts).filter(models.Prompts.p_code == p_code).first() if not prompt: return None # If the prompt is already active, deactivate it if prompt.is_active: prompt.is_active = False db.commit() db.refresh(prompt) return prompt # If the prompt is not active, first deactivate the currently active prompt # then activate this one current_active = db.query(models.Prompts).filter( models.Prompts.image_type == image_type, models.Prompts.is_active == True ).first() if current_active: current_active.is_active = False # Commit the deactivation first to avoid constraint violation db.commit() prompt.is_active = True db.commit() db.refresh(prompt) return prompt def create_prompt(db: Session, prompt_data: schemas.PromptCreate): """Create a new prompt""" # Validate that the image_type exists image_type_obj = db.query(models.ImageTypes).filter(models.ImageTypes.image_type == prompt_data.image_type).first() if not image_type_obj: raise ValueError(f"Invalid image_type: {prompt_data.image_type}") # Check if prompt code already exists existing_prompt = db.query(models.Prompts).filter(models.Prompts.p_code == prompt_data.p_code).first() if existing_prompt: raise ValueError(f"Prompt with code '{prompt_data.p_code}' already exists") # If this prompt is set as active, deactivate the currently active prompt for this image type if prompt_data.is_active: current_active = db.query(models.Prompts).filter( models.Prompts.image_type == prompt_data.image_type, models.Prompts.is_active == True ).first() if current_active: current_active.is_active = False # Commit the deactivation first to avoid constraint violation db.commit() # Create the new prompt new_prompt = models.Prompts( p_code=prompt_data.p_code, label=prompt_data.label, metadata_instructions=prompt_data.metadata_instructions, image_type=prompt_data.image_type, is_active=prompt_data.is_active ) db.add(new_prompt) db.commit() db.refresh(new_prompt) return new_prompt def update_prompt(db: Session, p_code: str, prompt_update: schemas.PromptUpdate): """Update a specific prompt by code""" prompt = db.query(models.Prompts).filter(models.Prompts.p_code == p_code).first() if not prompt: return None # Handle is_active field specially to maintain unique constraint update_data = prompt_update.dict(exclude_unset=True) # If we're setting this prompt as active, deactivate other prompts for this image type if 'is_active' in update_data and update_data['is_active']: current_active = db.query(models.Prompts).filter( models.Prompts.image_type == prompt.image_type, models.Prompts.is_active == True, models.Prompts.p_code != p_code # Exclude current prompt ).first() if current_active: current_active.is_active = False # Commit the deactivation first to avoid constraint violation db.commit() # Update all fields for field, value in update_data.items(): setattr(prompt, field, value) db.commit() db.refresh(prompt) return prompt def update_caption(db: Session, caption_id: str, update: schemas.CaptionUpdate): """Update caption data for a caption""" caption = db.get(models.Captions, caption_id) if not caption: return None for field, value in update.dict(exclude_unset=True).items(): setattr(caption, field, value) db.commit() db.refresh(caption) return caption def delete_caption(db: Session, caption_id: str): """Delete caption data for a caption""" caption = db.get(models.Captions, caption_id) if not caption: return False db.delete(caption) db.commit() return True def get_sources(db: Session): """Get all sources for lookup""" return db.query(models.Source).all() def get_regions(db: Session): """Get all regions for lookup""" return db.query(models.Region).all() def get_types(db: Session): """Get all types for lookup""" return db.query(models.EventType).all() def get_spatial_references(db: Session): """Get all spatial references for lookup""" return db.query(models.SpatialReference).all() def get_image_types(db: Session): """Get all image types for lookup""" return db.query(models.ImageTypes).all() def get_countries(db: Session): """Get all countries for lookup""" return db.query(models.Country).all() def get_country(db: Session, c_code: str): """Get a single country by code""" return db.get(models.Country, c_code) def get_models(db: Session): """Get all models""" return db.query(models.Models).all() def get_model(db: Session, m_code: str): """Get a specific model by code""" return db.get(models.Models, m_code) def create_model(db: Session, m_code: str, label: str, model_type: str, provider: str, model_id: str, is_available: bool = False): """Create a new model""" new_model = models.Models( m_code=m_code, label=label, model_type=model_type, provider=provider, model_id=model_id, is_available=is_available ) db.add(new_model) db.commit() db.refresh(new_model) return new_model def update_model(db: Session, m_code: str, update_data: dict): """Update an existing model""" model = db.get(models.Models, m_code) if not model: return None for field, value in update_data.items(): if hasattr(model, field): setattr(model, field, value) db.commit() db.refresh(model) return model def delete_model(db: Session, m_code: str): """Hard delete a model by removing it from the database""" model = db.get(models.Models, m_code) if not model: return False # Remove the model from the database db.delete(model) db.commit() return True def get_all_schemas(db: Session): """Get all JSON schemas""" return db.query(models.JSONSchema).all() def get_schema(db: Session, schema_id: str): """Get a specific JSON schema by ID""" return db.query(models.JSONSchema).filter(models.JSONSchema.schema_id == schema_id).first() def get_recent_images_with_validation(db: Session, limit: int = 100): """Get recent images with validation info""" return db.query(models.Images).order_by(models.Images.captured_at.desc()).limit(limit).all()