Model Card β€” FloodDetNet-Prithvi v5

Physics-informed satellite flood segmentation with cross-modal SAR/optical fusion, built on the IBM/NASA Prithvi-EO-1.0-100M geospatial foundation model.

Model Details

Model Description

FloodDetNet-Prithvi v5 is a pixel-level flood segmentation model for 12-channel satellite patches (6 optical + 5 SAR-derived + 1 JRC permanent-water prior). It fine-tunes the Prithvi-EO-1.0-100M temporal Vision Transformer as the optical backbone, augmented by a parallel SAR encoder, cross-modal attention, and a physics-informed water-body gate that uses the 37-year JRC Global Surface Water dataset to suppress false positives over historically wet areas.

The model outputs binary flood masks (flood / no-flood) at full input resolution via sliding-window inference with optional 4-flip test-time augmentation. At evaluation time, waterbody pixels are reconstructed from the binary prediction using the JRC prior, enabling 3-class IoU reporting (no-flood / flood / waterbody) against the original 3-class ground truth.

  • Developed by: AI-Hackers
  • Funded by: ANRF (Anusandhan National Research Foundation) AISEHack Programme
  • Model type: Semantic segmentation β€” geospatial / Earth observation
  • Input modalities: Optical (HLS-style 6-band) + SAR (HH, HV, SAR_diff, log_HH, log_HV) + JRC binary water prior
  • Output: Binary flood mask (0 = no-flood, 1 = flood), 2-channel softmax probability map
  • License: MIT (see checkpoint/LICENSE)
  • Fine-tuned from: ibm-nasa-geospatial/Prithvi-EO-1.0-100M

Model Sources

  • Model weights: prithvi_best-v7.ckpt β€” epoch 125, val flood IoU 0.368 (Google Drive / Zenodo link)

Uses

Direct Use

Run flood segmentation inference on any 512Γ—512 (or larger) satellite patch with 6 optical bands (Green, Red, NIR, SWIR, and two additional) + SAR (HH, HV). The model handles arbitrary patch sizes via sliding-window inference (window 224Γ—224, stride 112).

Intended users:

  • NDRF field commanders β€” real-time flood extent maps for route passability decisions
  • District agriculture officers β€” pixel-level flood vs. water-body proof for PMFBY crop insurance claims
  • State Emergency Operations Centre analysts β€” multi-event flood extent tracking
  • Remote sensing researchers β€” flood segmentation benchmark and foundation model fine-tuning reference

Downstream Use

The model is integrated into the FloodSense web application as the inference backend:

  • Served via AWS SageMaker endpoint (primary) or HuggingFace Inference API (fallback)
  • Results feed a MapLibre GL JS flood overlay, Gemini AI chatbot, and automated PDF report generation
  • Outputs can be exported as RLE masks (competition format), GeoTIFF, or village-level CSV statistics

Out-of-Scope Use

  • Single-modality optical-only input: The model expects SAR channels 6–10; optical-only patches will produce degraded results
  • Non-flood water detection: The model is trained to detect transient flood water, not permanent water bodies, rivers, or ocean
  • Sub-metre resolution imagery: Trained on 30 m HLS-resolution patches; very high resolution imagery (< 5 m) is out of distribution
  • Real-time video or streaming: Designed for batch patch inference, not streaming
  • Regions with no JRC coverage: The JRCGate falls back to zeros outside the downloaded tile footprint, reducing waterbody suppression accuracy

Bias, Risks, and Limitations

Geographic bias: The model was trained exclusively on a single flood event in West Bengal / Bangladesh (May 2024, patches 001–079). Performance on other geographies, flood types (flash floods, coastal inundation, urban flooding), or seasons is unknown and likely degraded.

Class imbalance: Flood pixels constitute only ~12% of the training data. The model is tuned for high recall (Tversky Ξ±=0.90) at the cost of precision β€” it will over-predict flood in ambiguous boundary regions.

JRC tile dependency: Waterbody suppression accuracy depends on the JRC Global Surface Water tile covering the inference region. Images outside the occurrence_80E_30N tile footprint receive no permanent-water prior, causing elevated false-positive rates in water-body regions (test WB FP rate without JRC: ~0.34).

Train/val generalisation gap: Flood IoU drops from 0.34 (train) to 0.15 (val/test), indicating overfitting on the 59-image training set. The model should be fine-tuned on local data before operational deployment in a new region.

No temporal context: Despite Prithvi's native multi-temporal capability, this model uses single-frame input (T=1). Pre-flood / post-flood change detection would significantly improve accuracy.

Recommendations

  • Always validate on locally labelled data before operational use
  • Use the JRC waterbody mask as a post-processing filter in addition to the JRCGate
  • Set the flood threshold (default 0.45) based on local precision/recall requirements β€” lower for higher recall (emergency response), higher for higher precision (insurance claims)
  • Do not use model outputs as sole evidence for legal or financial decisions without human expert review

How to Get Started with the Model

import torch
import numpy as np
from pathlib import Path

# Clone repo and install dependencies
# pip install torch lightning timm rasterio scipy

from best import (
    _load_model_from_checkpoint, cfg,
    read_tif, normalise_optical, normalise_sar,
    compute_sar_features, load_jrc_channel,
    sliding_window_predict, postprocess_mask,
)

# Point JRC dir to your pre-generated masks
cfg.JRC_DIR = Path("data/jrc_masks")

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model
model = _load_model_from_checkpoint("prithvi_best-v7.ckpt", device)
model.to(device).eval()

# Prepare input (6-band GeoTIFF)
img_path = Path("data/image/20240529_EO4_RES2_fl_pid_002_image.tif")
raw = read_tif(img_path)                                    # (6, H, W)
opt = normalise_optical(raw)                                # (6, H, W)
sar = normalise_sar(compute_sar_features(raw))              # (5, H, W)
jrc = load_jrc_channel(img_path, raw.shape[1:])             # (1, H, W)
img_t = torch.from_numpy(np.concatenate([opt, sar, jrc], axis=0)).float()

# Inference
with torch.no_grad():
    prob = sliding_window_predict(model, img_t, crop=224, stride=112, device=device)
    # prob: (2, H, W) β€” channel 1 is flood probability

# Post-process
flood_mask = postprocess_mask(prob[1].numpy(), threshold=0.45)
# flood_mask: (H, W) binary uint8 β€” 1 = flood, 0 = no-flood

print(f"Flood pixels: {flood_mask.sum()} ({100*flood_mask.mean():.1f}%)")

For 4-flip TTA (recommended for submission):

from best import tta_predict
prob = tta_predict(model, img_t, device=device)

Training Details

Training Data

  • Competition dataset: 79 labelled 512Γ—512 satellite patches, West Bengal / Bangladesh flood event, May 2024. 6-band float32 GeoTIFF (HH, HV, Green, Red, NIR, SWIR). 3-class labels: 0 = no-flood, 1 = flood, 2 = waterbody. Provided by ANRF AISEHack organisers.
  • Split: Stratified 59 train / 10 val / 10 test by flood pixel percentage
  • Class distribution: flood ~12%, waterbody ~55%, no-flood ~33%
  • External β€” JRC Global Surface Water v1.4 (Pekel et al., 2016): 37-year occurrence layer, tile occurrence_80E_30N, reprojected per-image to match patch CRS/resolution. Used as binary permanent-water prior (channel 12, threshold β‰₯ 75% occurrence).
  • External β€” Prithvi-EO-1.0-100M pretrained weights: IBM/NASA-IMPACT, pretrained on 1M+ HLS (Harmonized Landsat Sentinel-2) scenes at 30 m resolution.

Training Procedure

Preprocessing

Optical (channels 0–5):

hls = raw[[2,2,3,4,5,5]] * 0.0001          # scale to reflectance
opt = (hls - PRITHVI_MEANS) / PRITHVI_STDS  # official Prithvi normalisation

Means: [0.0333, 0.0570, 0.0589, 0.2323, 0.1973, 0.1194] Stds: [0.0227, 0.0268, 0.0400, 0.0779, 0.0871, 0.0724]

SAR (channels 6–10): HH, HV, HHβˆ’HV, log(HH), log(HV), z-score normalised. Means: [786.83, 357.99, 434.18, 6.47, 5.78] Stds: [373.43, 148.94, 305.09, 0.55, 0.62]

JRC (channel 11): Binary 0/1 (occurrence β‰₯ 75% β†’ 1).

Label: Waterbody (class 2) remapped to no-flood (class 0) during training. Ignore index: βˆ’1 (nodata pixels).

Augmentation: horizontal/vertical flip, rot90, elastic deformation (p=0.3), SAR gamma speckle (looks=4), optical brightness jitter (std=0.25), cutout (48Γ—48), MixUp (Ξ±=0.3), 8 random 224Γ—224 crops per image per epoch.

Training Hyperparameters

  • Training regime: fp16 mixed precision (PyTorch Lightning precision="16-mixed")
  • Optimiser: AdamW, weight decay 0.05
  • Learning rate: backbone 5e-5, head/neck/SAR 3e-4 (differential LR)
  • Scheduler: OneCycleLR, cosine annealing, pct_start=0.1, div_factor=10, final_div_factor=1e4
  • Batch size: 4
  • Epochs: 150 (early stopping patience 40 on val/flood_iou)
  • Backbone freeze: epochs 0–9 (backbone frozen, only head trained)
  • Gradient clipping: 1.0
  • Loss: L = 0.15Β·CE + 0.25Β·Dice + 0.60Β·Tversky(Ξ±=0.90, Ξ²=0.10) + 0.05Β·SAR_KD + 0.20Β·FocalFlood(Ξ³=3)
  • Class weights: [0.20, 0.80] (no-flood, flood)

Speeds, Sizes, Times

Metric Value
Training hardware NVIDIA T4 16 GB (single GPU)
Training time ~8 hours (125 epochs to best checkpoint)
Best checkpoint epoch 125 (step 14,868)
Checkpoint size (full) 1,052 MB
Checkpoint size (inference-only) ~400 MB
Inference time β€” T4 GPU ~8 s per 512Γ—512 patch (sliding window + 4-flip TTA)
Inference time β€” CPU ~45 s per 512Γ—512 patch
Total parameters ~110M (ViT-Base backbone + heads)

Evaluation

Testing Data, Factors & Metrics

Testing Data

  • Val split: 10 images (pid_060–069), held out during training
  • Test split: 10 images (pid_070–079), never seen during training or hyperparameter tuning
  • Flood-rich subset: Top-10 training images by flood pixel count (pid_002, 010, 011, 015–017, 020, 021, 027, 028)

Factors

Evaluation is disaggregated by:

  • Split (train / val / test) β€” measures generalisation
  • Per-image β€” identifies high/low-performing patches
  • JRC coverage β€” images with JRC tile coverage vs. without (WB IoU = 0 without coverage)
  • Flood prevalence β€” flood-rich images (> 30% flood pixels) vs. flood-sparse (< 5%)

Metrics

Metric Description
Flood IoU Intersection-over-Union for class 1 (flood) β€” primary competition metric
No-Flood IoU IoU for class 0 (no-flood)
WaterBody IoU IoU for class 2 (waterbody), reconstructed via JRC: pred=0 ∧ JRC=1 β†’ class 2
mIoU (3-class) Mean IoU over classes present in ground truth
mIoU (binary) Mean of flood + no-flood IoU (matches training objective)
Overall Pixel Accuracy Fraction of correctly classified valid pixels
WB FP Rate Fraction of waterbody GT pixels predicted as flood

Results

Best Checkpoint (prithvi_best-v7, epoch 125)

By split:

Split NF IoU FL IoU WB IoU mIoU(3) mIoU(2) Px Acc WB FP Rate
Train (59 images) 0.722 0.340 0.505 0.522 0.531 0.809 0.173
Val (10 images) 0.739 0.153 0.264 0.385 0.446 0.760 0.120
Test (10 images) 0.684 0.145 0.204 0.344 0.414 0.686 0.344

Top-10 flood-rich training images (best-case performance):

Image FL IoU WB IoU mIoU(3) Px Acc
pid_021 0.594 0.245 0.441 0.700
pid_002 0.504 0.823 0.582 0.695
pid_015 0.535 0.439 0.475 0.662
pid_016 0.497 0.634 0.536 0.663
pid_011 0.471 0.807 0.546 0.637
Mean 0.523 0.387 0.466 0.682

Checkpoint progression (val flood IoU):

Checkpoint Epoch Val FL IoU (callback) Val FL IoU (3-class eval)
prithvi_best 1 0.147 0.000
prithvi_best-v2 6 0.194 0.000
prithvi_best-v1 19 0.195 0.001
prithvi_best-v4 23 0.171 0.001
prithvi_best-v3 51 0.201 0.000
prithvi_best-v5/v6 56 0.323 0.000
prithvi_best-v7 125 0.368 0.153

Summary

The model achieves val flood IoU 0.368 and test flood IoU 0.145 on the competition dataset. The train/val gap reflects overfitting on 59 training images. On flood-rich images the model demonstrates strong performance (mean flood IoU 0.52), confirming the architecture is sound. The primary bottleneck is training data volume, not model capacity.


Model Examination

PhysicsRuleLayer interpretability: The learned thresholds ndwi_flood_thresh and sar_flood_thresh can be read directly from the checkpoint:

ckpt = torch.load("prithvi_best-v7.ckpt", map_location="cpu", weights_only=False)
print(ckpt["state_dict"]["physics.ndwi_flood_thresh"])  # learned NDWI threshold
print(ckpt["state_dict"]["physics.sar_flood_thresh"])   # learned SAR threshold
print(ckpt["state_dict"]["jrc_gate.suppress_strength"]) # JRC suppression strength

A positive suppress_strength (> 0) confirms the model learned to use the permanent-water prior.

LightMCAM attention: The cross-modal attention weights (optical Q Γ— SAR K/V, pooled 8Γ—8) can be extracted during forward pass to produce spatial attention maps showing which SAR regions most influence the optical flood prediction.

SARConfidenceGate: The differentiable NDWI + SAR threshold rule produces a pixel-level confidence map (0–1) that can be visualised as a physics-based flood probability, independent of the neural network output β€” useful for sanity-checking predictions.


Environmental Impact

Carbon emissions estimated using the ML Impact Calculator (Lacoste et al., 2019).

  • Hardware type: NVIDIA T4 GPU (16 GB VRAM)
  • Hours used: ~8 hours (best checkpoint); ~40 hours total across all training runs
  • Cloud provider: Kaggle (Google Cloud Platform)
  • Compute region: us-central1 (Iowa, USA) β€” grid carbon intensity ~0.385 kg COβ‚‚/kWh
  • Estimated carbon emitted: ~1.2 kg COβ‚‚eq (best run); ~6 kg COβ‚‚eq (all runs)

T4 TDP ~70 W; 8 h Γ— 0.070 kW Γ— 0.385 kg/kWh Γ— PUE 1.1 β‰ˆ 0.24 kg COβ‚‚eq per run.


Technical Specifications

Model Architecture and Objective

FloodDetNet-Prithvi v5
─────────────────────────────────────────────────────────────────
Input:  (B, 12, H, W)
  ch 0–5   Optical (HLS-normalised)
  ch 6–10  SAR features (HH, HV, HHβˆ’HV, log_HH, log_HV)
  ch 11    JRC binary water prior

Branch 1 β€” Optical:
  TemporalViTEncoder (Prithvi-EO-1.0-100M, ViT-Base)
    PatchEmbed3D: Conv3D(6, 768, kernel=(1,16,16))
    12Γ— ViT Block (768-dim, 12 heads, MLP ratio 4.0)
    LayerNorm β†’ tokens (B, 197, 768)
  ConvTransformerNeck:
    2Γ— ConvTranspose2d (768β†’256, 256β†’256, stride 2)
    β†’ (B, 256, H, W)

Branch 2 β€” SAR:
  SAREncoderWithEdge:
    shallow: Conv(5,64,3) β†’ BN β†’ ReLU β†’ (B, 64, H, W)
    deep:    3Γ— stride-2 Conv β†’ (B, 256, H/8, W/8)
             3Γ— ConvTranspose β†’ (B, 256, H, W)

Branch 3 β€” Physics:
  PhysicsRuleLayer:
    NDWI = (green βˆ’ nir) / (green + nir + Ξ΅)
    flood_score = Οƒ(NDWI βˆ’ Ο„_ndwi) Γ— Οƒ(βˆ’log_HH βˆ’ Ο„_sar)
    β†’ Conv1Γ—1(2, 2) β†’ (B, 2, H, W)

Fusion:
  LightMCAM: Q=optical(256), K/V=SAR(256), pool=8Γ—8
             β†’ attended optical (B, 256, H, W)
  SEGM:      dec_feat Γ— (1 + Οƒ(edge_conv(sar_shallow)))
  Fuse:      concat(256+256+2) β†’ Conv(256) β†’ Conv(128) β†’ Conv(2)

JRCGate:
  gate = Οƒ(conv([jrc_ch, log_hh]))  ∈ [0,1]
  logits[:,1] βˆ’= gate Γ— suppress_strength  (learned)
  logits[:,0] += gate Γ— suppress_strength

Output: (B, 2, H, W) logits β†’ softmax β†’ flood probability map
─────────────────────────────────────────────────────────────────
Total parameters: ~110M
Trainable at epoch 125: all (backbone unfrozen at epoch 10)

Objective: Binary cross-entropy + Dice + Tversky (flood-recall maximisation) + SAR confidence knowledge distillation + flood focal loss.

Compute Infrastructure

Hardware

  • Training: Kaggle Notebook, NVIDIA T4 16 GB GPU, 2Γ— CPU cores, 13 GB RAM
  • Inference (production): AWS SageMaker ml.g4dn.xlarge (T4 GPU)
  • Inference (fallback): HuggingFace Inference API

Software

  • Python 3.10
  • PyTorch 2.x + CUDA 11.8
  • Lightning 2.x (PyTorch Lightning)
  • timm 0.9.x (ViT blocks, PatchEmbed)
  • rasterio 1.3.x (GeoTIFF I/O)
  • scipy (elastic deformation, morphological post-processing)
  • numpy, pandas

Citation

If you use this model or the FloodSense system in your work, please cite:

BibTeX:

@misc{flooddetnet_prithvi_2025,
  title        = {FloodDetNet-Prithvi v5: Physics-Informed Satellite Flood Segmentation
                  with Cross-Modal SAR/Optical Fusion},
  author       = {AISEHack Theme 1 Team},
  year         = {2025},
  howpublished = {ANRF AISEHack Phase 2 Submission},
  note         = {Fine-tuned from ibm-nasa-geospatial/Prithvi-EO-1.0-100M}
}

APA: AISEHack Theme 1 Team. (2025). FloodDetNet-Prithvi v5: Physics-Informed Satellite Flood Segmentation with Cross-Modal SAR/Optical Fusion. ANRF AISEHack Phase 2 Submission.

Please also cite the base model and key dependencies:

@article{jakubik2023prithvi,
  title   = {Foundation Models for Generalist Geospatial Artificial Intelligence},
  author  = {Jakubik, Johannes and others},
  journal = {arXiv preprint arXiv:2310.18660},
  year    = {2023}
}

@article{pekel2016jrc,
  title   = {High-resolution global maps of 21st-century surface water and its changes},
  author  = {Pekel, Jean-Fran{\c{c}}ois and Cottam, Andrew and Gorelick, Noel
             and Belward, Alan S.},
  journal = {Nature},
  volume  = {540},
  pages   = {418--422},
  year    = {2016}
}

@inproceedings{abraham2019tversky,
  title     = {A Novel Focal Tversky Loss Function with Improved Attention U-Net
               for Lesion Segmentation},
  author    = {Abraham, Nabila and Khan, Naimul Mefraz},
  booktitle = {IEEE International Symposium on Biomedical Imaging (ISBI)},
  year      = {2019},
  note      = {arXiv:1810.07842}
}

@article{xu2025sgcad,
  title   = {SGCAD: A SAR-Guided Confidence-Gated Distillation Framework of Optical
             and SAR Images for Water-Enhanced Land-Cover Semantic Segmentation},
  author  = {Xu, Zhenghao and others},
  journal = {Remote Sensing},
  volume  = {18},
  number  = {6},
  pages   = {962},
  year    = {2025}
}

Glossary

Term Definition
HLS Harmonized Landsat Sentinel-2 β€” NASA surface reflectance product at 30 m resolution
SAR Synthetic Aperture Radar β€” active microwave sensor, all-weather, cloud-penetrating
HH / HV SAR polarisation channels: horizontal transmit/horizontal receive, horizontal transmit/vertical receive
NDWI Normalised Difference Water Index = (Green βˆ’ NIR) / (Green + NIR); positive values indicate water
JRC GSW JRC Global Surface Water β€” 37-year (1984–2021) Landsat-derived water occurrence dataset
JRCGate Learned spatial gate that suppresses flood logit in pixels where JRC occurrence β‰₯ 75% AND SAR backscatter is low
LightMCAM Lightweight Multi-modal Cross-Attention Module β€” SAR as K/V, optical as Q, pooled 8Γ—8 attention
SEGM SAR Edge Guidance Module β€” amplifies boundary responses in decoder using SAR shallow edge features
Tversky loss Asymmetric Dice variant: 1 βˆ’ (TP + Ξ΅) / (TP + Ξ±Β·FN + Ξ²Β·FP + Ξ΅); Ξ±=0.90 penalises missed floods
TTA Test-Time Augmentation β€” 4-flip ensemble (H-flip Γ— V-flip) averaged for final prediction
RLE Run-Length Encoding β€” competition submission format for binary masks
mIoU Mean Intersection-over-Union β€” mean of per-class IoU scores
WB FP Rate Waterbody False Positive Rate β€” fraction of waterbody GT pixels predicted as flood

More Information

  • FloodSense web application: Full-stack offline-capable disaster response platform built on this model β€” React 18 + TypeScript + MapLibre GL JS + Gemini AI chatbot + AWS SageMaker inference. See flood-response-app/README.md.
  • Evaluation pipeline: evaluate_model.py β€” full 3-class IoU + pixel accuracy evaluation across all checkpoints and splits, with TensorBoard log parsing and 4 plot types.
  • Quick evaluation: quick_eval.py β€” fast single-checkpoint evaluation on any image subset.
  • JRC mask generation: fetch_jrc_masks.py, integrate_jrc.py β€” download and reproject JRC Global Surface Water tiles to match patch CRS/resolution.
  • Competition context: ANRF AISEHack Theme 1 Phase 2 β€” flood detection from satellite imagery, West Bengal / Bangladesh flood event, May 2024.

Model Card Authors

Roshan Rateria

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for RoAr777/FloodSense

Finetuned
(2)
this model

Space using RoAr777/FloodSense 1

Papers for RoAr777/FloodSense