Model Card β FloodDetNet-Prithvi v5
Physics-informed satellite flood segmentation with cross-modal SAR/optical fusion, built on the IBM/NASA Prithvi-EO-1.0-100M geospatial foundation model.
Model Details
Model Description
FloodDetNet-Prithvi v5 is a pixel-level flood segmentation model for 12-channel satellite patches (6 optical + 5 SAR-derived + 1 JRC permanent-water prior). It fine-tunes the Prithvi-EO-1.0-100M temporal Vision Transformer as the optical backbone, augmented by a parallel SAR encoder, cross-modal attention, and a physics-informed water-body gate that uses the 37-year JRC Global Surface Water dataset to suppress false positives over historically wet areas.
The model outputs binary flood masks (flood / no-flood) at full input resolution via sliding-window inference with optional 4-flip test-time augmentation. At evaluation time, waterbody pixels are reconstructed from the binary prediction using the JRC prior, enabling 3-class IoU reporting (no-flood / flood / waterbody) against the original 3-class ground truth.
- Developed by: AI-Hackers
- Funded by: ANRF (Anusandhan National Research Foundation) AISEHack Programme
- Model type: Semantic segmentation β geospatial / Earth observation
- Input modalities: Optical (HLS-style 6-band) + SAR (HH, HV, SAR_diff, log_HH, log_HV) + JRC binary water prior
- Output: Binary flood mask (0 = no-flood, 1 = flood), 2-channel softmax probability map
- License: MIT (see
checkpoint/LICENSE) - Fine-tuned from: ibm-nasa-geospatial/Prithvi-EO-1.0-100M
Model Sources
- Model weights:
prithvi_best-v7.ckptβ epoch 125, val flood IoU 0.368 (Google Drive / Zenodo link)
Uses
Direct Use
Run flood segmentation inference on any 512Γ512 (or larger) satellite patch with 6 optical bands (Green, Red, NIR, SWIR, and two additional) + SAR (HH, HV). The model handles arbitrary patch sizes via sliding-window inference (window 224Γ224, stride 112).
Intended users:
- NDRF field commanders β real-time flood extent maps for route passability decisions
- District agriculture officers β pixel-level flood vs. water-body proof for PMFBY crop insurance claims
- State Emergency Operations Centre analysts β multi-event flood extent tracking
- Remote sensing researchers β flood segmentation benchmark and foundation model fine-tuning reference
Downstream Use
The model is integrated into the FloodSense web application as the inference backend:
- Served via AWS SageMaker endpoint (primary) or HuggingFace Inference API (fallback)
- Results feed a MapLibre GL JS flood overlay, Gemini AI chatbot, and automated PDF report generation
- Outputs can be exported as RLE masks (competition format), GeoTIFF, or village-level CSV statistics
Out-of-Scope Use
- Single-modality optical-only input: The model expects SAR channels 6β10; optical-only patches will produce degraded results
- Non-flood water detection: The model is trained to detect transient flood water, not permanent water bodies, rivers, or ocean
- Sub-metre resolution imagery: Trained on 30 m HLS-resolution patches; very high resolution imagery (< 5 m) is out of distribution
- Real-time video or streaming: Designed for batch patch inference, not streaming
- Regions with no JRC coverage: The JRCGate falls back to zeros outside the downloaded tile footprint, reducing waterbody suppression accuracy
Bias, Risks, and Limitations
Geographic bias: The model was trained exclusively on a single flood event in West Bengal / Bangladesh (May 2024, patches 001β079). Performance on other geographies, flood types (flash floods, coastal inundation, urban flooding), or seasons is unknown and likely degraded.
Class imbalance: Flood pixels constitute only ~12% of the training data. The model is tuned for high recall (Tversky Ξ±=0.90) at the cost of precision β it will over-predict flood in ambiguous boundary regions.
JRC tile dependency: Waterbody suppression accuracy depends on the JRC Global Surface Water tile
covering the inference region. Images outside the occurrence_80E_30N tile footprint receive no
permanent-water prior, causing elevated false-positive rates in water-body regions (test WB FP rate
without JRC: ~0.34).
Train/val generalisation gap: Flood IoU drops from 0.34 (train) to 0.15 (val/test), indicating overfitting on the 59-image training set. The model should be fine-tuned on local data before operational deployment in a new region.
No temporal context: Despite Prithvi's native multi-temporal capability, this model uses single-frame input (T=1). Pre-flood / post-flood change detection would significantly improve accuracy.
Recommendations
- Always validate on locally labelled data before operational use
- Use the JRC waterbody mask as a post-processing filter in addition to the JRCGate
- Set the flood threshold (default 0.45) based on local precision/recall requirements β lower for higher recall (emergency response), higher for higher precision (insurance claims)
- Do not use model outputs as sole evidence for legal or financial decisions without human expert review
How to Get Started with the Model
import torch
import numpy as np
from pathlib import Path
# Clone repo and install dependencies
# pip install torch lightning timm rasterio scipy
from best import (
_load_model_from_checkpoint, cfg,
read_tif, normalise_optical, normalise_sar,
compute_sar_features, load_jrc_channel,
sliding_window_predict, postprocess_mask,
)
# Point JRC dir to your pre-generated masks
cfg.JRC_DIR = Path("data/jrc_masks")
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
model = _load_model_from_checkpoint("prithvi_best-v7.ckpt", device)
model.to(device).eval()
# Prepare input (6-band GeoTIFF)
img_path = Path("data/image/20240529_EO4_RES2_fl_pid_002_image.tif")
raw = read_tif(img_path) # (6, H, W)
opt = normalise_optical(raw) # (6, H, W)
sar = normalise_sar(compute_sar_features(raw)) # (5, H, W)
jrc = load_jrc_channel(img_path, raw.shape[1:]) # (1, H, W)
img_t = torch.from_numpy(np.concatenate([opt, sar, jrc], axis=0)).float()
# Inference
with torch.no_grad():
prob = sliding_window_predict(model, img_t, crop=224, stride=112, device=device)
# prob: (2, H, W) β channel 1 is flood probability
# Post-process
flood_mask = postprocess_mask(prob[1].numpy(), threshold=0.45)
# flood_mask: (H, W) binary uint8 β 1 = flood, 0 = no-flood
print(f"Flood pixels: {flood_mask.sum()} ({100*flood_mask.mean():.1f}%)")
For 4-flip TTA (recommended for submission):
from best import tta_predict
prob = tta_predict(model, img_t, device=device)
Training Details
Training Data
- Competition dataset: 79 labelled 512Γ512 satellite patches, West Bengal / Bangladesh flood event, May 2024. 6-band float32 GeoTIFF (HH, HV, Green, Red, NIR, SWIR). 3-class labels: 0 = no-flood, 1 = flood, 2 = waterbody. Provided by ANRF AISEHack organisers.
- Split: Stratified 59 train / 10 val / 10 test by flood pixel percentage
- Class distribution: flood ~12%, waterbody ~55%, no-flood ~33%
- External β JRC Global Surface Water v1.4 (Pekel et al., 2016): 37-year occurrence layer,
tile
occurrence_80E_30N, reprojected per-image to match patch CRS/resolution. Used as binary permanent-water prior (channel 12, threshold β₯ 75% occurrence). - External β Prithvi-EO-1.0-100M pretrained weights: IBM/NASA-IMPACT, pretrained on 1M+ HLS (Harmonized Landsat Sentinel-2) scenes at 30 m resolution.
Training Procedure
Preprocessing
Optical (channels 0β5):
hls = raw[[2,2,3,4,5,5]] * 0.0001 # scale to reflectance
opt = (hls - PRITHVI_MEANS) / PRITHVI_STDS # official Prithvi normalisation
Means: [0.0333, 0.0570, 0.0589, 0.2323, 0.1973, 0.1194]
Stds: [0.0227, 0.0268, 0.0400, 0.0779, 0.0871, 0.0724]
SAR (channels 6β10): HH, HV, HHβHV, log(HH), log(HV), z-score normalised.
Means: [786.83, 357.99, 434.18, 6.47, 5.78]
Stds: [373.43, 148.94, 305.09, 0.55, 0.62]
JRC (channel 11): Binary 0/1 (occurrence β₯ 75% β 1).
Label: Waterbody (class 2) remapped to no-flood (class 0) during training. Ignore index: β1 (nodata pixels).
Augmentation: horizontal/vertical flip, rot90, elastic deformation (p=0.3), SAR gamma speckle (looks=4), optical brightness jitter (std=0.25), cutout (48Γ48), MixUp (Ξ±=0.3), 8 random 224Γ224 crops per image per epoch.
Training Hyperparameters
- Training regime: fp16 mixed precision (PyTorch Lightning
precision="16-mixed") - Optimiser: AdamW, weight decay 0.05
- Learning rate: backbone 5e-5, head/neck/SAR 3e-4 (differential LR)
- Scheduler: OneCycleLR, cosine annealing, pct_start=0.1, div_factor=10, final_div_factor=1e4
- Batch size: 4
- Epochs: 150 (early stopping patience 40 on val/flood_iou)
- Backbone freeze: epochs 0β9 (backbone frozen, only head trained)
- Gradient clipping: 1.0
- Loss:
L = 0.15Β·CE + 0.25Β·Dice + 0.60Β·Tversky(Ξ±=0.90, Ξ²=0.10) + 0.05Β·SAR_KD + 0.20Β·FocalFlood(Ξ³=3) - Class weights: [0.20, 0.80] (no-flood, flood)
Speeds, Sizes, Times
| Metric | Value |
|---|---|
| Training hardware | NVIDIA T4 16 GB (single GPU) |
| Training time | ~8 hours (125 epochs to best checkpoint) |
| Best checkpoint epoch | 125 (step 14,868) |
| Checkpoint size (full) | 1,052 MB |
| Checkpoint size (inference-only) | ~400 MB |
| Inference time β T4 GPU | ~8 s per 512Γ512 patch (sliding window + 4-flip TTA) |
| Inference time β CPU | ~45 s per 512Γ512 patch |
| Total parameters | ~110M (ViT-Base backbone + heads) |
Evaluation
Testing Data, Factors & Metrics
Testing Data
- Val split: 10 images (pid_060β069), held out during training
- Test split: 10 images (pid_070β079), never seen during training or hyperparameter tuning
- Flood-rich subset: Top-10 training images by flood pixel count (pid_002, 010, 011, 015β017, 020, 021, 027, 028)
Factors
Evaluation is disaggregated by:
- Split (train / val / test) β measures generalisation
- Per-image β identifies high/low-performing patches
- JRC coverage β images with JRC tile coverage vs. without (WB IoU = 0 without coverage)
- Flood prevalence β flood-rich images (> 30% flood pixels) vs. flood-sparse (< 5%)
Metrics
| Metric | Description |
|---|---|
| Flood IoU | Intersection-over-Union for class 1 (flood) β primary competition metric |
| No-Flood IoU | IoU for class 0 (no-flood) |
| WaterBody IoU | IoU for class 2 (waterbody), reconstructed via JRC: pred=0 β§ JRC=1 β class 2 |
| mIoU (3-class) | Mean IoU over classes present in ground truth |
| mIoU (binary) | Mean of flood + no-flood IoU (matches training objective) |
| Overall Pixel Accuracy | Fraction of correctly classified valid pixels |
| WB FP Rate | Fraction of waterbody GT pixels predicted as flood |
Results
Best Checkpoint (prithvi_best-v7, epoch 125)
By split:
| Split | NF IoU | FL IoU | WB IoU | mIoU(3) | mIoU(2) | Px Acc | WB FP Rate |
|---|---|---|---|---|---|---|---|
| Train (59 images) | 0.722 | 0.340 | 0.505 | 0.522 | 0.531 | 0.809 | 0.173 |
| Val (10 images) | 0.739 | 0.153 | 0.264 | 0.385 | 0.446 | 0.760 | 0.120 |
| Test (10 images) | 0.684 | 0.145 | 0.204 | 0.344 | 0.414 | 0.686 | 0.344 |
Top-10 flood-rich training images (best-case performance):
| Image | FL IoU | WB IoU | mIoU(3) | Px Acc |
|---|---|---|---|---|
| pid_021 | 0.594 | 0.245 | 0.441 | 0.700 |
| pid_002 | 0.504 | 0.823 | 0.582 | 0.695 |
| pid_015 | 0.535 | 0.439 | 0.475 | 0.662 |
| pid_016 | 0.497 | 0.634 | 0.536 | 0.663 |
| pid_011 | 0.471 | 0.807 | 0.546 | 0.637 |
| Mean | 0.523 | 0.387 | 0.466 | 0.682 |
Checkpoint progression (val flood IoU):
| Checkpoint | Epoch | Val FL IoU (callback) | Val FL IoU (3-class eval) |
|---|---|---|---|
| prithvi_best | 1 | 0.147 | 0.000 |
| prithvi_best-v2 | 6 | 0.194 | 0.000 |
| prithvi_best-v1 | 19 | 0.195 | 0.001 |
| prithvi_best-v4 | 23 | 0.171 | 0.001 |
| prithvi_best-v3 | 51 | 0.201 | 0.000 |
| prithvi_best-v5/v6 | 56 | 0.323 | 0.000 |
| prithvi_best-v7 | 125 | 0.368 | 0.153 |
Summary
The model achieves val flood IoU 0.368 and test flood IoU 0.145 on the competition dataset. The train/val gap reflects overfitting on 59 training images. On flood-rich images the model demonstrates strong performance (mean flood IoU 0.52), confirming the architecture is sound. The primary bottleneck is training data volume, not model capacity.
Model Examination
PhysicsRuleLayer interpretability: The learned thresholds ndwi_flood_thresh and
sar_flood_thresh can be read directly from the checkpoint:
ckpt = torch.load("prithvi_best-v7.ckpt", map_location="cpu", weights_only=False)
print(ckpt["state_dict"]["physics.ndwi_flood_thresh"]) # learned NDWI threshold
print(ckpt["state_dict"]["physics.sar_flood_thresh"]) # learned SAR threshold
print(ckpt["state_dict"]["jrc_gate.suppress_strength"]) # JRC suppression strength
A positive suppress_strength (> 0) confirms the model learned to use the permanent-water prior.
LightMCAM attention: The cross-modal attention weights (optical Q Γ SAR K/V, pooled 8Γ8) can be extracted during forward pass to produce spatial attention maps showing which SAR regions most influence the optical flood prediction.
SARConfidenceGate: The differentiable NDWI + SAR threshold rule produces a pixel-level confidence map (0β1) that can be visualised as a physics-based flood probability, independent of the neural network output β useful for sanity-checking predictions.
Environmental Impact
Carbon emissions estimated using the ML Impact Calculator (Lacoste et al., 2019).
- Hardware type: NVIDIA T4 GPU (16 GB VRAM)
- Hours used: ~8 hours (best checkpoint); ~40 hours total across all training runs
- Cloud provider: Kaggle (Google Cloud Platform)
- Compute region: us-central1 (Iowa, USA) β grid carbon intensity ~0.385 kg COβ/kWh
- Estimated carbon emitted: ~1.2 kg COβeq (best run); ~6 kg COβeq (all runs)
T4 TDP ~70 W; 8 h Γ 0.070 kW Γ 0.385 kg/kWh Γ PUE 1.1 β 0.24 kg COβeq per run.
Technical Specifications
Model Architecture and Objective
FloodDetNet-Prithvi v5
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Input: (B, 12, H, W)
ch 0β5 Optical (HLS-normalised)
ch 6β10 SAR features (HH, HV, HHβHV, log_HH, log_HV)
ch 11 JRC binary water prior
Branch 1 β Optical:
TemporalViTEncoder (Prithvi-EO-1.0-100M, ViT-Base)
PatchEmbed3D: Conv3D(6, 768, kernel=(1,16,16))
12Γ ViT Block (768-dim, 12 heads, MLP ratio 4.0)
LayerNorm β tokens (B, 197, 768)
ConvTransformerNeck:
2Γ ConvTranspose2d (768β256, 256β256, stride 2)
β (B, 256, H, W)
Branch 2 β SAR:
SAREncoderWithEdge:
shallow: Conv(5,64,3) β BN β ReLU β (B, 64, H, W)
deep: 3Γ stride-2 Conv β (B, 256, H/8, W/8)
3Γ ConvTranspose β (B, 256, H, W)
Branch 3 β Physics:
PhysicsRuleLayer:
NDWI = (green β nir) / (green + nir + Ξ΅)
flood_score = Ο(NDWI β Ο_ndwi) Γ Ο(βlog_HH β Ο_sar)
β Conv1Γ1(2, 2) β (B, 2, H, W)
Fusion:
LightMCAM: Q=optical(256), K/V=SAR(256), pool=8Γ8
β attended optical (B, 256, H, W)
SEGM: dec_feat Γ (1 + Ο(edge_conv(sar_shallow)))
Fuse: concat(256+256+2) β Conv(256) β Conv(128) β Conv(2)
JRCGate:
gate = Ο(conv([jrc_ch, log_hh])) β [0,1]
logits[:,1] β= gate Γ suppress_strength (learned)
logits[:,0] += gate Γ suppress_strength
Output: (B, 2, H, W) logits β softmax β flood probability map
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Total parameters: ~110M
Trainable at epoch 125: all (backbone unfrozen at epoch 10)
Objective: Binary cross-entropy + Dice + Tversky (flood-recall maximisation) + SAR confidence knowledge distillation + flood focal loss.
Compute Infrastructure
Hardware
- Training: Kaggle Notebook, NVIDIA T4 16 GB GPU, 2Γ CPU cores, 13 GB RAM
- Inference (production): AWS SageMaker
ml.g4dn.xlarge(T4 GPU) - Inference (fallback): HuggingFace Inference API
Software
- Python 3.10
- PyTorch 2.x + CUDA 11.8
- Lightning 2.x (PyTorch Lightning)
- timm 0.9.x (ViT blocks, PatchEmbed)
- rasterio 1.3.x (GeoTIFF I/O)
- scipy (elastic deformation, morphological post-processing)
- numpy, pandas
Citation
If you use this model or the FloodSense system in your work, please cite:
BibTeX:
@misc{flooddetnet_prithvi_2025,
title = {FloodDetNet-Prithvi v5: Physics-Informed Satellite Flood Segmentation
with Cross-Modal SAR/Optical Fusion},
author = {AISEHack Theme 1 Team},
year = {2025},
howpublished = {ANRF AISEHack Phase 2 Submission},
note = {Fine-tuned from ibm-nasa-geospatial/Prithvi-EO-1.0-100M}
}
APA: AISEHack Theme 1 Team. (2025). FloodDetNet-Prithvi v5: Physics-Informed Satellite Flood Segmentation with Cross-Modal SAR/Optical Fusion. ANRF AISEHack Phase 2 Submission.
Please also cite the base model and key dependencies:
@article{jakubik2023prithvi,
title = {Foundation Models for Generalist Geospatial Artificial Intelligence},
author = {Jakubik, Johannes and others},
journal = {arXiv preprint arXiv:2310.18660},
year = {2023}
}
@article{pekel2016jrc,
title = {High-resolution global maps of 21st-century surface water and its changes},
author = {Pekel, Jean-Fran{\c{c}}ois and Cottam, Andrew and Gorelick, Noel
and Belward, Alan S.},
journal = {Nature},
volume = {540},
pages = {418--422},
year = {2016}
}
@inproceedings{abraham2019tversky,
title = {A Novel Focal Tversky Loss Function with Improved Attention U-Net
for Lesion Segmentation},
author = {Abraham, Nabila and Khan, Naimul Mefraz},
booktitle = {IEEE International Symposium on Biomedical Imaging (ISBI)},
year = {2019},
note = {arXiv:1810.07842}
}
@article{xu2025sgcad,
title = {SGCAD: A SAR-Guided Confidence-Gated Distillation Framework of Optical
and SAR Images for Water-Enhanced Land-Cover Semantic Segmentation},
author = {Xu, Zhenghao and others},
journal = {Remote Sensing},
volume = {18},
number = {6},
pages = {962},
year = {2025}
}
Glossary
| Term | Definition |
|---|---|
| HLS | Harmonized Landsat Sentinel-2 β NASA surface reflectance product at 30 m resolution |
| SAR | Synthetic Aperture Radar β active microwave sensor, all-weather, cloud-penetrating |
| HH / HV | SAR polarisation channels: horizontal transmit/horizontal receive, horizontal transmit/vertical receive |
| NDWI | Normalised Difference Water Index = (Green β NIR) / (Green + NIR); positive values indicate water |
| JRC GSW | JRC Global Surface Water β 37-year (1984β2021) Landsat-derived water occurrence dataset |
| JRCGate | Learned spatial gate that suppresses flood logit in pixels where JRC occurrence β₯ 75% AND SAR backscatter is low |
| LightMCAM | Lightweight Multi-modal Cross-Attention Module β SAR as K/V, optical as Q, pooled 8Γ8 attention |
| SEGM | SAR Edge Guidance Module β amplifies boundary responses in decoder using SAR shallow edge features |
| Tversky loss | Asymmetric Dice variant: 1 β (TP + Ξ΅) / (TP + Ξ±Β·FN + Ξ²Β·FP + Ξ΅); Ξ±=0.90 penalises missed floods |
| TTA | Test-Time Augmentation β 4-flip ensemble (H-flip Γ V-flip) averaged for final prediction |
| RLE | Run-Length Encoding β competition submission format for binary masks |
| mIoU | Mean Intersection-over-Union β mean of per-class IoU scores |
| WB FP Rate | Waterbody False Positive Rate β fraction of waterbody GT pixels predicted as flood |
More Information
- FloodSense web application: Full-stack offline-capable disaster response platform built on
this model β React 18 + TypeScript + MapLibre GL JS + Gemini AI chatbot + AWS SageMaker inference.
See
flood-response-app/README.md. - Evaluation pipeline:
evaluate_model.pyβ full 3-class IoU + pixel accuracy evaluation across all checkpoints and splits, with TensorBoard log parsing and 4 plot types. - Quick evaluation:
quick_eval.pyβ fast single-checkpoint evaluation on any image subset. - JRC mask generation:
fetch_jrc_masks.py,integrate_jrc.pyβ download and reproject JRC Global Surface Water tiles to match patch CRS/resolution. - Competition context: ANRF AISEHack Theme 1 Phase 2 β flood detection from satellite imagery, West Bengal / Bangladesh flood event, May 2024.
Model Card Authors
Roshan Rateria
Model tree for RoAr777/FloodSense
Base model
ibm-nasa-geospatial/Prithvi-EO-1.0-100M