|
from dataclasses import dataclass |
|
|
|
import numpy as np |
|
import timm |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
from segmentation_models_pytorch.base import SegmentationHead |
|
from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder |
|
from timm.layers.create_act import create_act_layer |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
from transformers.modeling_outputs import SemanticSegmenterOutput |
|
|
|
from .convlstm import ConvLSTM |
|
|
|
|
|
class ACTUConfig(PretrainedConfig): |
|
model_type = "actu" |
|
|
|
def __init__( |
|
self, |
|
|
|
in_channels: int = 3, |
|
kernel_size: tuple[int, int] = (3, 3), |
|
padding="same", |
|
stride=(1, 1), |
|
backbone="resnet34", |
|
bias=True, |
|
batch_first=True, |
|
bidirectional=False, |
|
original_resolution=(256, 256), |
|
act_layer="sigmoid", |
|
n_classes=1, |
|
|
|
use_dem_input: bool = False, |
|
use_climate_branch: bool = False, |
|
|
|
climate_seq_len=5, |
|
climate_input_dim=6, |
|
lstm_hidden_dim=128, |
|
num_lstm_layers=1, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.in_channels = in_channels |
|
self.kernel_size = kernel_size |
|
self.padding = padding |
|
self.stride = stride |
|
self.backbone = backbone |
|
self.bias = bias |
|
self.batch_first = batch_first |
|
self.bidirectional = bidirectional |
|
self.original_resolution = original_resolution |
|
self.act_layer = act_layer |
|
self.n_classes = n_classes |
|
|
|
|
|
self.use_dem_input = use_dem_input |
|
self.use_climate_branch = use_climate_branch |
|
self.climate_seq_len = climate_seq_len |
|
self.climate_input_dim = climate_input_dim |
|
self.lstm_hidden_dim = lstm_hidden_dim |
|
self.num_lstm_layers = num_lstm_layers |
|
|
|
|
|
if self.use_dem_input: |
|
self.in_channels += 1 |
|
|
|
|
|
class ACTUForImageSegmentation(PreTrainedModel): |
|
config_class = ACTUConfig |
|
|
|
def __init__(self, config: ACTUConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.encoder: nn.Module = timm.create_model( |
|
config.backbone, features_only=True, in_chans=config.in_channels |
|
) |
|
|
|
with torch.no_grad(): |
|
dummy_input_channels = config.in_channels |
|
dummy_input = torch.randn( |
|
1, dummy_input_channels, *config.original_resolution, device=self.device |
|
) |
|
embs = self.encoder(dummy_input) |
|
self.embs_shape = [e.shape for e in embs] |
|
self.encoder_channels = [e[1] for e in self.embs_shape] |
|
|
|
self.convlstm = nn.ModuleList( |
|
[ |
|
ConvLSTM( |
|
in_channels=shape[1], |
|
hidden_channels=shape[1], |
|
kernel_size=config.kernel_size, |
|
padding=config.padding, |
|
stride=config.stride, |
|
bias=config.bias, |
|
batch_first=config.batch_first, |
|
bidirectional=config.bidirectional, |
|
) |
|
for shape in self.embs_shape |
|
] |
|
) |
|
|
|
if self.config.use_climate_branch: |
|
self.climate_branch = ClimateBranchLSTM( |
|
output_shapes=[e[1:] for e in self.embs_shape], |
|
lstm_hidden_dim=config.lstm_hidden_dim, |
|
climate_seq_len=config.climate_seq_len, |
|
climate_input_dim=config.climate_input_dim, |
|
num_lstm_layers=config.num_lstm_layers, |
|
) |
|
self.fusers = nn.ModuleList( |
|
GatedFusion(enc, enc) for enc in self.encoder_channels |
|
) |
|
|
|
self.decoder = UnetDecoder( |
|
encoder_channels=[1] + self.encoder_channels, |
|
decoder_channels=self.encoder_channels[::-1], |
|
n_blocks=len(self.encoder_channels), |
|
) |
|
|
|
self.seg_head = nn.Sequential( |
|
SegmentationHead( |
|
in_channels=self.encoder_channels[0], |
|
out_channels=config.n_classes, |
|
), |
|
create_act_layer(config.act_layer, inplace=True), |
|
) |
|
|
|
def forward( |
|
self, |
|
pixel_values: torch.Tensor, |
|
climate: torch.Tensor = None, |
|
dem: torch.Tensor = None, |
|
labels: torch.Tensor = None, |
|
**kwargs, |
|
) -> SemanticSegmenterOutput: |
|
b, t = pixel_values.shape[:2] |
|
original_size = pixel_values.shape[-2:] |
|
|
|
|
|
if self.config.use_dem_input: |
|
if dem is None: |
|
raise ValueError( |
|
"DEM tensor must be provided when use_dem_input is True." |
|
) |
|
dem_repeated = repeat(dem, "b c h w -> b t c h w", t=t) |
|
pixel_values = torch.cat([pixel_values, dem_repeated], dim=2) |
|
|
|
|
|
encoded_sequence = self._encode_images(pixel_values) |
|
|
|
|
|
if self.config.use_climate_branch: |
|
if climate is None: |
|
raise ValueError( |
|
"Climate tensor must be provided when use_climate_branch is True." |
|
) |
|
|
|
climate_features = self.climate_branch(climate) |
|
|
|
|
|
encoded_sequence_reshaped = [ |
|
rearrange(f, "b t c h w -> (b t) c h w") for f in encoded_sequence |
|
] |
|
climate_features_reshaped = [ |
|
rearrange(f, "b t c h w -> (b t) c h w") for f in climate_features |
|
] |
|
|
|
|
|
fused_features = [ |
|
fuser(img, clim) |
|
for fuser, img, clim in zip( |
|
self.fusers, encoded_sequence_reshaped, climate_features_reshaped |
|
) |
|
] |
|
|
|
|
|
encoded_sequence = [ |
|
rearrange(f, "(b t) c h w -> b t c h w", b=b) for f in fused_features |
|
] |
|
|
|
|
|
temporal_features = self._encode_timeseries(encoded_sequence) |
|
|
|
|
|
logits = self._decode(temporal_features, size=original_size) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits, labels.float().unsqueeze(1)) |
|
|
|
return SemanticSegmenterOutput( |
|
loss=loss, |
|
logits=logits, |
|
) |
|
|
|
def _encode_images(self, x: torch.Tensor) -> list[torch.Tensor]: |
|
B = x.size(0) |
|
encoded_frames = self.encoder(rearrange(x, "b t c h w -> (b t) c h w")) |
|
return [ |
|
rearrange(frames, "(b t) c h w -> b t c h w", b=B) |
|
for frames in encoded_frames |
|
] |
|
|
|
def _encode_timeseries(self, timeseries: torch.Tensor) -> list[torch.Tensor]: |
|
outs = [] |
|
for convlstm, encoded in reversed(list(zip(self.convlstm, timeseries))): |
|
lstm_out, (_, _) = convlstm(encoded) |
|
outs.append(lstm_out[:, -1, :, :, :]) |
|
return outs |
|
|
|
def _decode(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor: |
|
trend_map = self.decoder(*[None] + x[::-1]) |
|
trend_map = self.seg_head(trend_map) |
|
trend_map = F.interpolate( |
|
trend_map, size=size, mode="bilinear", align_corners=False |
|
) |
|
return trend_map |
|
|
|
|
|
class ClimateBranchLSTM(nn.Module): |
|
""" |
|
Processes climate time series data using an LSTM. |
|
Input shape: (B, T, T_1, C_clim) -> e.g., (B, 5, 6, 5) |
|
Output shape: (B, T, output_dim) -> e.g., (B, 5, 128) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
output_shapes: list[tuple[int, int, int]], |
|
climate_input_dim=5, |
|
climate_seq_len=6, |
|
lstm_hidden_dim=64, |
|
num_lstm_layers=1, |
|
): |
|
super().__init__() |
|
self.climate_seq_len = climate_seq_len |
|
self.climate_input_dim = climate_input_dim |
|
self.lstm_hidden_dim = lstm_hidden_dim |
|
self.num_lstm_layers = num_lstm_layers |
|
self.proj_dim = 128 |
|
self.output_shapes = output_shapes |
|
|
|
self.lstm = nn.LSTM( |
|
input_size=climate_input_dim, |
|
hidden_size=lstm_hidden_dim, |
|
num_layers=num_lstm_layers, |
|
batch_first=True, |
|
dropout=0.3 if num_lstm_layers > 1 else 0, |
|
bidirectional=False, |
|
) |
|
|
|
|
|
self.fc = nn.Linear(lstm_hidden_dim, self.proj_dim) |
|
|
|
self.upsamples = nn.ModuleList( |
|
_build_upsampler(self.proj_dim, *shape[:2]) for shape in output_shapes |
|
) |
|
|
|
def forward(self, climate_data: torch.Tensor) -> list[torch.Tensor]: |
|
|
|
B_img, B_cli, T, C = climate_data.shape |
|
|
|
|
|
lstm_input = rearrange(climate_data, "Bi Bc T C -> (Bi Bc) T C") |
|
|
|
|
|
_, (hidden, _) = self.lstm.forward(lstm_input) |
|
|
|
last_hidden = ( |
|
hidden[[hidden.size(0) // 2, -1]] if self.lstm.bidirectional else hidden[-1] |
|
) |
|
if last_hidden.ndim == 3: |
|
last_hidden = hidden.mean(dim=0) |
|
|
|
|
|
climate_features = self.fc(last_hidden) |
|
climate_features = rearrange(climate_features, "b c -> b c 1 1") |
|
climate_features = [ |
|
rearrange( |
|
u(climate_features), "(Bi Bc) C H W -> Bi Bc C H W", Bi=B_img, Bc=B_cli |
|
) |
|
for u in self.upsamples |
|
] |
|
|
|
return climate_features |
|
|
|
|
|
class GatedFusion(nn.Module): |
|
def __init__(self, img_channels, clim_channels): |
|
super().__init__() |
|
self.gate = nn.Sequential( |
|
nn.Sequential( |
|
nn.Conv2d( |
|
img_channels + clim_channels, img_channels, kernel_size=3, padding=1 |
|
), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(img_channels, img_channels, kernel_size=1), |
|
nn.Sigmoid(), |
|
) |
|
) |
|
|
|
def forward(self, img_feat, clim_feat): |
|
gate = self.gate(torch.cat([img_feat, clim_feat], dim=1)) |
|
return gate * img_feat + (1 - gate) * clim_feat |
|
|
|
|
|
def _build_upsampler( |
|
in_channels: int, target_channels: int, target_h: int |
|
) -> nn.Sequential: |
|
layers = [] |
|
current_h = 1 |
|
|
|
|
|
layers += [nn.Conv2d(in_channels, target_channels, kernel_size=1), nn.GELU()] |
|
|
|
|
|
while current_h < target_h: |
|
next_h = min(current_h * 2, target_h) |
|
layers += [ |
|
nn.Upsample(scale_factor=2, mode="nearest"), |
|
nn.Conv2d(target_channels, target_channels, kernel_size=3, padding=1), |
|
nn.GELU(), |
|
] |
|
current_h = next_h |
|
|
|
return nn.Sequential(*layers) |
|
|