TorNet-Oracle / modeling_storm_oracle.py
Wonder-Griffin's picture
Create modeling_storm_oracle.py
5376ec3 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_storm_oracle import StormOracleConfig
# ---- import your actual model code ----
# If your code lives in tornado_predictor.py (as pasted), import from there:
from .tornado_predictor import TornadoSuperPredictor # adjust if filename differs
class StormOracleModel(PreTrainedModel):
config_class = StormOracleConfig
def __init__(self, config: StormOracleConfig):
super().__init__(config)
self.model = TornadoSuperPredictor(in_channels=config.in_channels)
self.post_init() # HF bookkeeping
def forward(self, radar_x: torch.Tensor, atmo: dict):
"""
radar_x: (B, C, H, W)
atmo: dict of tensors (cape, wind_shear, helicity, temperature, dewpoint, pressure)
returns TornadoPredictionBatch (your dataclass)
"""
return self.model(radar_x, atmo)