""" 🌪️ STORM ORACLE — Tornado Super-Predictor (training-ready, no placeholders) - RadarPatternExtractor: multi-scale CNN + spatial attention pooling - AtmosphericConditionEncoder: per-variable MLPs -> tokens -> attention -> fused vector - Heads: probability (sigmoid), EF (logits), location (reg), timing (reg), uncertainty (sigmoid) - Calibration: single temperature parameter (learnable/fittable after training) - ContinuousLearner: online fine-tuning with replay buffer and EMA weights """ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F # ----------------------------- Types --------------------------------- @dataclass class TornadoPredictionBatch: """All outputs are BATCH TENSORS (no Python scalars).""" tornado_probability: torch.Tensor # (B,) ef_scale_probs: torch.Tensor # (B,6) most_likely_ef_scale: torch.Tensor # (B,) location_offset: torch.Tensor # (B,2) timing_predictions: torch.Tensor # (B,3) uncertainty_scores: torch.Tensor # (B,4) in [0,1] radar_signatures: torch.Tensor # (B,3) [hook, meso, couplet] atmospheric_indicators: torch.Tensor # (B,3) [cape, shear_norm, instability] logits: Optional[torch.Tensor] = None # (B,) pre-sigmoid (for calibration/loss) # ---------------------- Building blocks -------------------------------- class SpatialAttentionPool(nn.Module): """ Turns a 2D feature map (B,C,H,W) into (B,C) using a learned query and MHA over H*W tokens. """ def __init__(self, channels: int, num_heads: int = 8): super().__init__() self.channels = channels self.pos_embed = nn.Parameter(torch.randn(1, channels, 1)) # simple scalar per-channel bias over tokens self.query = nn.Parameter(torch.randn(1, 1, channels)) # learned global query token self.attn = nn.MultiheadAttention(embed_dim=channels, num_heads=num_heads, batch_first=True) self.ln = nn.LayerNorm(channels) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B,C,H,W) -> tokens: (B, H*W, C) B, C, H, W = x.shape tokens = x.view(B, C, H * W).transpose(1, 2) # (B, HW, C) tokens = self.ln(tokens + self.pos_embed.expand(B, C, 1).transpose(1, 2)) # broadcast mild bias q = self.query.expand(B, -1, -1) # (B,1,C) pooled, _ = self.attn(q, tokens, tokens) # (B,1,C) return pooled.squeeze(1) # (B,C) class RadarPatternExtractor(nn.Module): """ Advanced radar pattern extraction with spatial attention pooling. Accepts variable input_channels (e.g., 3×T for T time steps). """ def __init__(self, input_channels: int = 3): super().__init__() self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, padding=3) self.conv2 = nn.Conv2d(64, 128, kernel_size=5, padding=2) self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1) self.bn4 = nn.BatchNorm2d(512) # Specialized detectors self.hook_echo_detector = nn.Conv2d(512, 64, kernel_size=3, padding=1) self.mesocyclone_detector = nn.Conv2d(512, 64, kernel_size=5, padding=2) self.velocity_couplet_detector = nn.Conv2d(512, 64, kernel_size=3, padding=1) # Attention pooling to summarize (B,512,H',W') -> (B,512) self.pool = SpatialAttentionPool(512, num_heads=8) # Combine base + specialists -> 512 + 64*3 = 704 -> project to 1024 self.proj = nn.Sequential( nn.Linear(512 + 64 * 3, 1024), nn.ReLU(), nn.Dropout(0.5), ) def forward(self, radar_data: torch.Tensor) -> Dict[str, torch.Tensor]: # radar_data: (B,C,H,W) x = F.relu(self.conv1(radar_data)); x = F.max_pool2d(x, 2) x = F.relu(self.conv2(x)); x = F.max_pool2d(x, 2) x = F.relu(self.conv3(x)); x = F.max_pool2d(x, 2) x = F.relu(self.conv4(x)); x = self.bn4(x) hook = F.relu(self.hook_echo_detector(x)) meso = F.relu(self.mesocyclone_detector(x)) vel = F.relu(self.velocity_couplet_detector(x)) base_vec = self.pool(x) # (B,512) hook_vec = hook.mean(dim=(2, 3)) # (B,64) meso_vec = meso.mean(dim=(2, 3)) # (B,64) vel_vec = vel.mean(dim=(2, 3)) # (B,64) fused = torch.cat([base_vec, hook_vec, meso_vec, vel_vec], dim=1) # (B,704) combined = self.proj(fused) # (B,1024) strengths = torch.stack([ hook_vec.mean(dim=1), # (B,) meso_vec.mean(dim=1), # (B,) vel_vec.mean(dim=1), # (B,) ], dim=1) # (B,3) return { "combined_features": combined, "signature_strengths": strengths, # hook, meso, velocity couplet } class AtmosphericConditionEncoder(nn.Module): """ Encode environmental parameters using per-variable MLPs, then treat them as tokens and apply MHA. """ def __init__(self): super().__init__() self.enc_cape = nn.Linear(1, 32) self.enc_shear = nn.Linear(4, 64) # 0–1, 0–3, 0–6, deep self.enc_helicity = nn.Linear(2, 32) # 0–1, 0–3 self.enc_temp = nn.Linear(3, 32) # sfc, 850, 500 self.enc_dewpoint = nn.Linear(2, 32) # sfc, 850 self.enc_pressure = nn.Linear(1, 16) # we will embed each of the 6 groups to dim=64 and self-attend self.to_64 = nn.ModuleDict({ "cape": nn.Linear(32, 64), "shear": nn.Identity(), # already 64 "helicity": nn.Linear(32, 64), "temp": nn.Linear(32, 64), "dewpoint": nn.Linear(32, 64), "pressure": nn.Linear(16, 64), }) self.ln = nn.LayerNorm(64) self.attn = nn.MultiheadAttention(embed_dim=64, num_heads=4, batch_first=True) self.fuse = nn.Sequential( nn.Linear(64 * 6, 256), nn.ReLU(), nn.Dropout(0.3), ) def forward(self, atmo: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def ensure_2d(t: torch.Tensor, d: int) -> torch.Tensor: # make (B,d) t = t if t.ndim == 2 else t.view(-1, d) return t cape = ensure_2d(atmo.get("cape", torch.zeros(1, 1, device=next(self.parameters()).device)), 1) shear= ensure_2d(atmo.get("wind_shear", torch.zeros(1, 4, device=next(self.parameters()).device)), 4) hel = ensure_2d(atmo.get("helicity", torch.zeros(1, 2, device=next(self.parameters()).device)), 2) temp = ensure_2d(atmo.get("temperature", torch.zeros(1, 3, device=next(self.parameters()).device)), 3) dew = ensure_2d(atmo.get("dewpoint", torch.zeros(1, 2, device=next(self.parameters()).device)), 2) pres = ensure_2d(atmo.get("pressure", torch.zeros(1, 1, device=next(self.parameters()).device)), 1) cape_e = F.relu(self.enc_cape(cape)) # (B,32) shear_e= F.relu(self.enc_shear(shear)) # (B,64) hel_e = F.relu(self.enc_helicity(hel)) # (B,32) temp_e = F.relu(self.enc_temp(temp)) # (B,32) dew_e = F.relu(self.enc_dewpoint(dew)) # (B,32) pres_e = F.relu(self.enc_pressure(pres)) # (B,16) tokens = torch.stack([ self.ln(self.to_64["cape"](cape_e)), self.ln(self.to_64["shear"](shear_e)), self.ln(self.to_64["helicity"](hel_e)), self.ln(self.to_64["temp"](temp_e)), self.ln(self.to_64["dewpoint"](dew_e)), self.ln(self.to_64["pressure"](pres_e)), ], dim=1) # (B, 6, 64) attn_out, _ = self.attn(tokens, tokens, tokens) # (B,6,64) fused = self.fuse(attn_out.reshape(attn_out.size(0), -1)) # (B,256) # easy indicators for explanations/QA shear_mag = torch.linalg.vector_norm(shear, dim=-1) # (B,) instab = cape.squeeze(-1) * shear_mag # (B,) return { "atmospheric_features": fused, # (B,256) "cape_score": cape.squeeze(-1), # (B,) "shear_magnitude": shear_mag, # (B,) "instability_index": instab, # (B,) } # -------------------------- Main model -------------------------------- class TornadoSuperPredictor(nn.Module): def __init__(self, in_channels: int = 3): super().__init__() self.radar_extractor = RadarPatternExtractor(input_channels=in_channels) self.atmo_encoder = AtmosphericConditionEncoder() fused_dim = 1024 + 256 self.prob_head = nn.Sequential( nn.Linear(fused_dim, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 1) ) self.ef_head = nn.Sequential( nn.Linear(fused_dim, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, 6) ) self.loc_head = nn.Sequential( nn.Linear(fused_dim, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, 2) ) self.time_head = nn.Sequential( nn.Linear(fused_dim, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, 3) ) self.unc_head = nn.Sequential( nn.Linear(fused_dim, 256), nn.ReLU(), nn.Linear(256, 4) ) # temperature parameter for calibration (start at 1.0) self.register_parameter("log_temperature", nn.Parameter(torch.zeros(()))) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, (nn.Linear, nn.Conv2d)): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) else: nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.zeros_(m.bias) @property def temperature(self) -> torch.Tensor: return torch.exp(self.log_temperature) # positive def forward(self, radar_x: torch.Tensor, atmo: Dict[str, torch.Tensor]) -> TornadoPredictionBatch: # radar_x: (B,C,H,W), atmo: dict of (B,dim) r = self.radar_extractor(radar_x) a = self.atmo_encoder(atmo) fused = torch.cat([r["combined_features"], a["atmospheric_features"]], dim=1) # (B,1280) logits = self.prob_head(fused).squeeze(-1) # (B,) logits = logits / self.temperature.clamp_min(1e-6) # calibrated logits probs = torch.sigmoid(logits) # (B,) ef_logits = self.ef_head(fused) # (B,6) ef_probs = F.softmax(ef_logits, dim=-1) ef_idx = ef_probs.argmax(dim=-1) loc = self.loc_head(fused) # (B,2) tim = self.time_head(fused) # (B,3) unc = torch.sigmoid(self.unc_head(fused)) # (B,4) in [0,1] return TornadoPredictionBatch( tornado_probability=probs, ef_scale_probs=ef_probs, most_likely_ef_scale=ef_idx, location_offset=loc, timing_predictions=tim, uncertainty_scores=unc, radar_signatures=r["signature_strengths"], atmospheric_indicators=torch.stack([ a["cape_score"], a["shear_magnitude"], a["instability_index"] ], dim=1), logits=logits, ) # --------------------- Continuous learning wrapper -------------------- class ContinuousLearner(nn.Module): """ Light wrapper that adds: - optimizer + (optional) pos_weight or focal loss - EMA weights for stable inference during online updates - small replay buffer to avoid catastrophic forgetting """ def __init__( self, model: TornadoSuperPredictor, lr: float = 1e-4, wd: float = 1e-4, use_focal: bool = False, pos_weight: Optional[float] = None, ema_decay: float = 0.999, replay_capacity: int = 2048, device: Optional[torch.device] = None, ): super().__init__() self.model = model self.device = device or next(model.parameters()).device self.opt = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=wd) self.use_focal = use_focal self.pos_weight = None if pos_weight is None else torch.tensor(pos_weight, device=self.device) self.ema_decay = ema_decay # EMA weights self.shadow = {k: v.detach().clone() for k, v in self.model.state_dict().items()} self.replay_capacity = replay_capacity self._replay = [] # list of tuples (radar_x, atmo_dict, y) def _bce_loss(self, logits: torch.Tensor, y: torch.Tensor) -> torch.Tensor: if self.pos_weight is not None: return F.binary_cross_entropy_with_logits(logits, y.float(), pos_weight=self.pos_weight) return F.binary_cross_entropy_with_logits(logits, y.float()) def _focal_loss(self, logits: torch.Tensor, y: torch.Tensor, gamma: float = 2.0, alpha: float = 0.5) -> torch.Tensor: p = torch.sigmoid(logits) pt = p * y + (1 - p) * (1 - y) w = (1 - pt).pow(gamma) at = alpha * y + (1 - alpha) * (1 - y) loss = -(y * torch.log(p.clamp_min(1e-9)) + (1 - y) * torch.log((1 - p).clamp_min(1e-9))) * w * at return loss.mean() @torch.no_grad() def _update_ema(self): for k, v in self.model.state_dict().items(): self.shadow[k].mul_(self.ema_decay).add_(v, alpha=(1.0 - self.ema_decay)) def train_step(self, radar_x: torch.Tensor, atmo: Dict[str, torch.Tensor], y: torch.Tensor) -> Dict[str, float]: self.model.train() out = self.model(radar_x, atmo) # contains logits & probs if self.use_focal: loss = self._focal_loss(out.logits, y) else: loss = self._bce_loss(out.logits, y) self.opt.zero_grad(set_to_none=True) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.opt.step() self._update_ema() # push to replay if self.replay_capacity > 0: with torch.no_grad(): if len(self._replay) >= self.replay_capacity: self._replay.pop(0) # store small detached copy (avoid GPU memory blowup) self._replay.append(( radar_x.detach().cpu(), {k: v.detach().cpu() for k, v in atmo.items()}, y.detach().cpu() )) with torch.no_grad(): prob = out.tornado_probability.mean().item() return {"loss": float(loss.item()), "avg_prob": prob} @torch.no_grad() def ema_state_dict(self) -> Dict[str, torch.Tensor]: return {k: v.clone() for k, v in self.shadow.items()} @torch.no_grad() def load_ema_weights(self): self.model.load_state_dict(self.ema_state_dict()) def replay_step(self, batch_size: int = 16) -> Optional[Dict[str, float]]: if not self._replay: return None import random idxs = random.sample(range(len(self._replay)), k=min(batch_size, len(self._replay))) xs = torch.cat([self._replay[i][0] for i in idxs], dim=0).to(self.device) ys = torch.cat([self._replay[i][2] for i in idxs], dim=0).to(self.device) atmo = {} # stack dict fields keys = list(self._replay[idxs[0]][1].keys()) for k in keys: atmo[k] = torch.cat([self._replay[i][1][k] for i in idxs], dim=0).to(self.device) return self.train_step(xs, atmo, ys)