import torch import torch.nn as nn from transformers import PreTrainedModel from .configuration_storm_oracle import StormOracleConfig # ---- import your actual model code ---- # If your code lives in tornado_predictor.py (as pasted), import from there: from .tornado_predictor import TornadoSuperPredictor # adjust if filename differs class StormOracleModel(PreTrainedModel): config_class = StormOracleConfig def __init__(self, config: StormOracleConfig): super().__init__(config) self.model = TornadoSuperPredictor(in_channels=config.in_channels) self.post_init() # HF bookkeeping def forward(self, radar_x: torch.Tensor, atmo: dict): """ radar_x: (B, C, H, W) atmo: dict of tensors (cape, wind_shear, helicity, temperature, dewpoint, pressure) returns TornadoPredictionBatch (your dataclass) """ return self.model(radar_x, atmo)