actu-direction-classification / modeling_actu.py
DarthReca's picture
Upload 4 files
e3f3842 verified
from dataclasses import dataclass
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from segmentation_models_pytorch.base import SegmentationHead
from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder
from timm.layers.create_act import create_act_layer
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput
from .convlstm import ConvLSTM
class ACTUConfig(PretrainedConfig):
model_type = "actu"
def __init__(
self,
# Base ACTU parameters
in_channels: int = 3,
kernel_size: tuple[int, int] = (3, 3),
padding="same",
stride=(1, 1),
backbone="resnet34",
bias=True,
batch_first=True,
bidirectional=False,
original_resolution=(256, 256),
act_layer="sigmoid",
n_classes=1,
# Variant control parameters
use_dem_input: bool = False,
use_climate_branch: bool = False,
# Climate branch parameters
climate_seq_len=5,
climate_input_dim=6,
lstm_hidden_dim=128,
num_lstm_layers=1,
**kwargs,
):
super().__init__(**kwargs)
self.in_channels = in_channels
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.backbone = backbone
self.bias = bias
self.batch_first = batch_first
self.bidirectional = bidirectional
self.original_resolution = original_resolution
self.act_layer = act_layer
self.n_classes = n_classes
# Parameters to control variants
self.use_dem_input = use_dem_input
self.use_climate_branch = use_climate_branch
self.climate_seq_len = climate_seq_len
self.climate_input_dim = climate_input_dim
self.lstm_hidden_dim = lstm_hidden_dim
self.num_lstm_layers = num_lstm_layers
# Adjust in_channels if DEM is used
if self.use_dem_input:
self.in_channels += 1
class ACTUForImageSegmentation(PreTrainedModel):
config_class = ACTUConfig
def __init__(self, config: ACTUConfig):
super().__init__(config)
self.config = config
self.encoder: nn.Module = timm.create_model(
config.backbone, features_only=True, in_chans=config.in_channels
)
with torch.no_grad():
dummy_input_channels = config.in_channels
dummy_input = torch.randn(
1, dummy_input_channels, *config.original_resolution, device=self.device
)
embs = self.encoder(dummy_input)
self.embs_shape = [e.shape for e in embs]
self.encoder_channels = [e[1] for e in self.embs_shape]
self.convlstm = nn.ModuleList(
[
ConvLSTM(
in_channels=shape[1],
hidden_channels=shape[1],
kernel_size=config.kernel_size,
padding=config.padding,
stride=config.stride,
bias=config.bias,
batch_first=config.batch_first,
bidirectional=config.bidirectional,
)
for shape in self.embs_shape
]
)
if self.config.use_climate_branch:
self.climate_branch = ClimateBranchLSTM(
output_shapes=[e[1:] for e in self.embs_shape],
lstm_hidden_dim=config.lstm_hidden_dim,
climate_seq_len=config.climate_seq_len,
climate_input_dim=config.climate_input_dim,
num_lstm_layers=config.num_lstm_layers,
)
self.fusers = nn.ModuleList(
GatedFusion(enc, enc) for enc in self.encoder_channels
)
self.decoder = UnetDecoder(
encoder_channels=[1] + self.encoder_channels,
decoder_channels=self.encoder_channels[::-1],
n_blocks=len(self.encoder_channels),
)
self.seg_head = nn.Sequential(
SegmentationHead(
in_channels=self.encoder_channels[0],
out_channels=config.n_classes,
),
create_act_layer(config.act_layer, inplace=True),
)
def forward(
self,
pixel_values: torch.Tensor,
climate: torch.Tensor = None,
dem: torch.Tensor = None,
labels: torch.Tensor = None,
**kwargs,
) -> SemanticSegmenterOutput:
b, t = pixel_values.shape[:2]
original_size = pixel_values.shape[-2:]
# Handle DEM input
if self.config.use_dem_input:
if dem is None:
raise ValueError(
"DEM tensor must be provided when use_dem_input is True."
)
dem_repeated = repeat(dem, "b c h w -> b t c h w", t=t)
pixel_values = torch.cat([pixel_values, dem_repeated], dim=2)
# 1. Encode images per time step
encoded_sequence = self._encode_images(pixel_values)
# 2. Handle Climate Branch Fusion
if self.config.use_climate_branch:
if climate is None:
raise ValueError(
"Climate tensor must be provided when use_climate_branch is True."
)
climate_features = self.climate_branch(climate)
# Reshape for fusion
encoded_sequence_reshaped = [
rearrange(f, "b t c h w -> (b t) c h w") for f in encoded_sequence
]
climate_features_reshaped = [
rearrange(f, "b t c h w -> (b t) c h w") for f in climate_features
]
# Fuse features
fused_features = [
fuser(img, clim)
for fuser, img, clim in zip(
self.fusers, encoded_sequence_reshaped, climate_features_reshaped
)
]
# Reshape back to sequence
encoded_sequence = [
rearrange(f, "(b t) c h w -> b t c h w", b=b) for f in fused_features
]
# 3. Process sequence with ConvLSTM
temporal_features = self._encode_timeseries(encoded_sequence)
# 4. Decode to get the segmentation map
logits = self._decode(temporal_features, size=original_size)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels.float().unsqueeze(1))
return SemanticSegmenterOutput(
loss=loss,
logits=logits,
)
def _encode_images(self, x: torch.Tensor) -> list[torch.Tensor]:
B = x.size(0)
encoded_frames = self.encoder(rearrange(x, "b t c h w -> (b t) c h w"))
return [
rearrange(frames, "(b t) c h w -> b t c h w", b=B)
for frames in encoded_frames
]
def _encode_timeseries(self, timeseries: torch.Tensor) -> list[torch.Tensor]:
outs = []
for convlstm, encoded in reversed(list(zip(self.convlstm, timeseries))):
lstm_out, (_, _) = convlstm(encoded)
outs.append(lstm_out[:, -1, :, :, :])
return outs
def _decode(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
trend_map = self.decoder(*[None] + x[::-1])
trend_map = self.seg_head(trend_map)
trend_map = F.interpolate(
trend_map, size=size, mode="bilinear", align_corners=False
)
return trend_map
class ClimateBranchLSTM(nn.Module):
"""
Processes climate time series data using an LSTM.
Input shape: (B, T, T_1, C_clim) -> e.g., (B, 5, 6, 5)
Output shape: (B, T, output_dim) -> e.g., (B, 5, 128)
"""
def __init__(
self,
output_shapes: list[tuple[int, int, int]],
climate_input_dim=5,
climate_seq_len=6,
lstm_hidden_dim=64,
num_lstm_layers=1,
):
super().__init__()
self.climate_seq_len = climate_seq_len
self.climate_input_dim = climate_input_dim
self.lstm_hidden_dim = lstm_hidden_dim
self.num_lstm_layers = num_lstm_layers
self.proj_dim = 128
self.output_shapes = output_shapes
self.lstm = nn.LSTM(
input_size=climate_input_dim,
hidden_size=lstm_hidden_dim,
num_layers=num_lstm_layers,
batch_first=True, # Crucial: expects input shape (batch, seq_len, features)
dropout=0.3 if num_lstm_layers > 1 else 0,
bidirectional=False,
)
# Linear layer to project LSTM output to the desired final dimension
self.fc = nn.Linear(lstm_hidden_dim, self.proj_dim)
self.upsamples = nn.ModuleList(
_build_upsampler(self.proj_dim, *shape[:2]) for shape in output_shapes
)
def forward(self, climate_data: torch.Tensor) -> list[torch.Tensor]:
# climate_data shape: (B, T, T_1, C_clim), e.g., (B, 5, 6, 5)
B_img, B_cli, T, C = climate_data.shape
# Reshape for LSTM: Treat each sequence independently
lstm_input = rearrange(climate_data, "Bi Bc T C -> (Bi Bc) T C")
# Pass through LSTM
_, (hidden, _) = self.lstm.forward(lstm_input)
# Get the last layer's hidden state
last_hidden = (
hidden[[hidden.size(0) // 2, -1]] if self.lstm.bidirectional else hidden[-1]
)
if last_hidden.ndim == 3:
last_hidden = hidden.mean(dim=0)
# Pass the final hidden state through the fully connected layer(s) and upsample
climate_features = self.fc(last_hidden)
climate_features = rearrange(climate_features, "b c -> b c 1 1")
climate_features = [
rearrange(
u(climate_features), "(Bi Bc) C H W -> Bi Bc C H W", Bi=B_img, Bc=B_cli
)
for u in self.upsamples
]
return climate_features
class GatedFusion(nn.Module):
def __init__(self, img_channels, clim_channels):
super().__init__()
self.gate = nn.Sequential(
nn.Sequential(
nn.Conv2d(
img_channels + clim_channels, img_channels, kernel_size=3, padding=1
),
nn.ReLU(inplace=True),
nn.Conv2d(img_channels, img_channels, kernel_size=1),
nn.Sigmoid(), # Gate values between 0 and 1
)
)
def forward(self, img_feat, clim_feat):
gate = self.gate(torch.cat([img_feat, clim_feat], dim=1))
return gate * img_feat + (1 - gate) * clim_feat
def _build_upsampler(
in_channels: int, target_channels: int, target_h: int
) -> nn.Sequential:
layers = []
current_h = 1
# Expand to target channels early (e.g., 1x1 → 1x1 with target_channels)
layers += [nn.Conv2d(in_channels, target_channels, kernel_size=1), nn.GELU()]
# Upsample spatially to target_h
while current_h < target_h:
next_h = min(current_h * 2, target_h)
layers += [
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(target_channels, target_channels, kernel_size=3, padding=1),
nn.GELU(),
]
current_h = next_h
return nn.Sequential(*layers)