# model.py import torch from torch import nn import torch.nn.functional as F from functools import partial import math import copy from typing import Optional, Tuple, Union, List from torch import Tensor from collections import OrderedDict import numpy as np from timm.models.resnet import resnet50d, resnet101d, resnet26d, resnet18d from timm.models.registry import register_model # --- Helper Functions for Model Definition --- def to_2tuple(x): if isinstance(x, tuple): return x return (x, x) def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu, not {activation}.") def build_attn_mask(mask_type): mask = torch.ones((151, 151), dtype=torch.bool).cuda() if mask_type == "seperate_all": mask[:50, :50] = False mask[50:67, 50:67] = False mask[67:84, 67:84] = False mask[84:101, 84:101] = False mask[101:151, 101:151] = False elif mask_type == "seperate_view": mask[:50, :50] = False mask[50:67, 50:67] = False mask[67:84, 67:84] = False mask[84:101, 84:101] = False mask[101:151, :] = False mask[:, 101:151] = False return mask # --- Model Components --- class HybridEmbed(nn.Module): def __init__( self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768, ): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.backbone = backbone if feature_size is None: with torch.no_grad(): training = backbone.training if training: backbone.eval() o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) if isinstance(o, (list, tuple)): o = o[-1] # last feature if backbone outputs list/tuple of features feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) if hasattr(self.backbone, "feature_info"): feature_dim = self.backbone.feature_info.channels()[-1] else: feature_dim = self.backbone.num_features self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1) def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): x = x[-1] # last feature if backbone outputs list/tuple of features x = self.proj(x) global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None] return x, global_x class PositionEmbeddingSine(nn.Module): def __init__( self, num_pos_feats=64, temperature=10000, normalize=False, scale=None ): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, tensor): x = tensor bs, _, h, w = x.shape not_mask = torch.ones((bs, h, w), device=x.device) y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack( (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos class TransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, norm=None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): output = src for layer in self.layers: output = layer( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos, ) if self.norm is not None: output = self.norm(output) return output class SpatialSoftmax(nn.Module): def __init__(self, height, width, channel, temperature=None, data_format="NCHW"): super().__init__() self.data_format = data_format self.height = height self.width = width self.channel = channel if temperature: self.temperature = nn.Parameter(torch.ones(1) * temperature) else: self.temperature = 1.0 pos_x, pos_y = np.meshgrid( np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width) ) pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float() pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float() self.register_buffer("pos_x", pos_x) self.register_buffer("pos_y", pos_y) def forward(self, feature): if self.data_format == "NHWC": feature = ( feature.transpose(1, 3) .tranpose(2, 3) .view(-1, self.height * self.width) ) else: feature = feature.view(-1, self.height * self.width) weight = F.softmax(feature / self.temperature, dim=-1) expected_x = torch.sum( torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True ) expected_y = torch.sum( torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True ) expected_xy = torch.cat([expected_x, expected_y], 1) feature_keypoints = expected_xy.view(-1, self.channel, 2) feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12 feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12 return feature_keypoints class MultiPath_Generator(nn.Module): def __init__(self, in_channel, embed_dim, out_channel): super().__init__() self.spatial_softmax = SpatialSoftmax(100, 100, out_channel) self.tconv0 = nn.Sequential( nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), ) self.tconv1 = nn.Sequential( nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), ) self.tconv2 = nn.Sequential( nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False), nn.BatchNorm2d(192), nn.ReLU(True), ) self.tconv3 = nn.Sequential( nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), ) self.tconv4_list = torch.nn.ModuleList( [ nn.Sequential( nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False), nn.Tanh(), ) for _ in range(6) ] ) self.upsample = nn.Upsample(size=(50, 50), mode="bilinear") def forward(self, x, measurements): mask = measurements[:, :6] mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100) velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1) velocity = velocity.repeat(1, 32, 2, 2) n, d, c = x.shape x = x.transpose(1, 2) x = x.view(n, -1, 2, 2) x = torch.cat([x, velocity], dim=1) x = self.tconv0(x) x = self.tconv1(x) x = self.tconv2(x) x = self.tconv3(x) x = self.upsample(x) xs = [] for i in range(6): xt = self.tconv4_list[i](x) xs.append(xt) xs = torch.stack(xs, dim=1) x = torch.sum(xs * mask, dim=1) x = self.spatial_softmax(x) return x class LinearWaypointsPredictor(nn.Module): def __init__(self, input_dim, cumsum=True): super().__init__() self.cumsum = cumsum self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim)) self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)]) self.head_relu = nn.ReLU(inplace=True) self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) def forward(self, x, measurements): # input shape: n 10 embed_dim bs, n, dim = x.shape x = x + self.rank_embed x = x.reshape(-1, dim) mask = measurements[:, :6] mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2) rs = [] for i in range(6): res = self.head_fc1_list[i](x) res = self.head_relu(res) res = self.head_fc2_list[i](res) rs.append(res) rs = torch.stack(rs, 1) x = torch.sum(rs * mask, dim=1) x = x.view(bs, n, 2) if self.cumsum: x = torch.cumsum(x, 1) return x class GRUWaypointsPredictor(nn.Module): def __init__(self, input_dim, waypoints=10): super().__init__() self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) self.encoder = nn.Linear(2, 64) self.decoder = nn.Linear(64, 2) self.waypoints = waypoints def forward(self, x, target_point): bs = x.shape[0] z = self.encoder(target_point).unsqueeze(0) output, _ = self.gru(x, z) output = output.reshape(bs * self.waypoints, -1) output = self.decoder(output).reshape(bs, self.waypoints, 2) output = torch.cumsum(output, 1) return output class GRUWaypointsPredictorWithCommand(nn.Module): def __init__(self, input_dim, waypoints=10): super().__init__() self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)]) self.encoder = nn.Linear(2, 64) self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) self.waypoints = waypoints def forward(self, x, target_point, measurements): bs, n, dim = x.shape mask = measurements[:, :6, None, None] mask = mask.repeat(1, 1, self.waypoints, 2) z = self.encoder(target_point).unsqueeze(0) outputs = [] for i in range(6): output, _ = self.grus[i](x, z) output = output.reshape(bs * self.waypoints, -1) output = self.decoders[i](output).reshape(bs, self.waypoints, 2) output = torch.cumsum(output, 1) outputs.append(output) outputs = torch.stack(outputs, 1) output = torch.sum(outputs * mask, dim=1) return output class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): output = tgt intermediate = [] for layer in self.layers: output = layer( output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos, ) if self.return_intermediate: intermediate.append(self.norm(output)) if self.norm is not None: output = self.norm(output) if self.return_intermediate: intermediate.pop() intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate) return output.unsqueeze(0) class TransformerEncoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU(), normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = activation() self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(src, pos) src2 = self.self_attn( q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src def forward_pre( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) src2 = self.self_attn( q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src = src + self.dropout2(src2) return src def forward( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre(src, src_mask, src_key_padding_mask, pos) return self.forward_post(src, src_mask, src_key_padding_mask, pos) class TransformerDecoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU(), normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = activation() self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn( q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt def forward_pre( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn( q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt2 = self.norm2(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) return self.forward_post( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) class Interfuser(nn.Module): def __init__( self, img_size=224, multi_view_img_size=112, patch_size=8, in_chans=3, embed_dim=768, enc_depth=6, dec_depth=6, dim_feedforward=2048, normalize_before=False, rgb_backbone_name="r26", lidar_backbone_name="r26", num_heads=8, norm_layer=None, dropout=0.1, end2end=False, direct_concat=True, separate_view_attention=False, separate_all_attention=False, act_layer=None, weight_init="", freeze_num=-1, with_lidar=False, with_right_left_sensors=True, with_center_sensor=False, traffic_pred_head_type="det", waypoints_pred_head="heatmap", reverse_pos=True, use_different_backbone=False, use_view_embed=True, use_mmad_pretrain=None, ): super().__init__() self.traffic_pred_head_type = traffic_pred_head_type self.num_features = ( self.embed_dim ) = embed_dim norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.reverse_pos = reverse_pos self.waypoints_pred_head = waypoints_pred_head self.with_lidar = with_lidar self.with_right_left_sensors = with_right_left_sensors self.with_center_sensor = with_center_sensor self.direct_concat = direct_concat self.separate_view_attention = separate_view_attention self.separate_all_attention = separate_all_attention self.end2end = end2end self.use_view_embed = use_view_embed if self.direct_concat: in_chans = in_chans * 4 self.with_center_sensor = False self.with_right_left_sensors = False if self.separate_view_attention: self.attn_mask = build_attn_mask("seperate_view") elif self.separate_all_attention: self.attn_mask = build_attn_mask("seperate_all") else: self.attn_mask = None if use_different_backbone: if rgb_backbone_name == "r50": self.rgb_backbone = resnet50d( pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4], ) elif rgb_backbone_name == "r26": self.rgb_backbone = resnet26d( pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4], ) elif rgb_backbone_name == "r18": self.rgb_backbone = resnet18d( pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4], ) if lidar_backbone_name == "r50": self.lidar_backbone = resnet50d( pretrained=False, in_chans=in_chans, features_only=True, out_indices=[4], ) elif lidar_backbone_name == "r26": self.lidar_backbone = resnet26d( pretrained=False, in_chans=in_chans, features_only=True, out_indices=[4], ) elif lidar_backbone_name == "r18": self.lidar_backbone = resnet18d( pretrained=False, in_chans=3, features_only=True, out_indices=[4] ) rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone) if use_mmad_pretrain: params = torch.load(use_mmad_pretrain)["state_dict"] updated_params = OrderedDict() for key in params: if "backbone" in key: updated_params[key.replace("backbone.", "")] = params[key] self.rgb_backbone.load_state_dict(updated_params) self.rgb_patch_embed = rgb_embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) self.lidar_patch_embed = lidar_embed_layer( img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim, ) else: if rgb_backbone_name == "r50": self.rgb_backbone = resnet50d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) elif rgb_backbone_name == "r101": self.rgb_backbone = resnet101d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) elif rgb_backbone_name == "r26": self.rgb_backbone = resnet26d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) elif rgb_backbone_name == "r18": self.rgb_backbone = resnet18d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) self.rgb_patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) self.lidar_patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1)) if self.end2end: self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4)) self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim)) elif self.waypoints_pred_head == "heatmap": self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim)) else: self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11)) self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim)) if self.end2end: self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4) elif self.waypoints_pred_head == "heatmap": self.waypoints_generator = MultiPath_Generator( embed_dim + 32, embed_dim, 10 ) elif self.waypoints_pred_head == "gru": self.waypoints_generator = GRUWaypointsPredictor(embed_dim) elif self.waypoints_pred_head == "gru-command": self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim) elif self.waypoints_pred_head == "linear": self.waypoints_generator = LinearWaypointsPredictor(embed_dim) elif self.waypoints_pred_head == "linear-sum": self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True) self.junction_pred_head = nn.Linear(embed_dim, 2) self.traffic_light_pred_head = nn.Linear(embed_dim, 2) self.stop_sign_head = nn.Linear(embed_dim, 2) if self.traffic_pred_head_type == "det": self.traffic_pred_head = nn.Sequential( *[ nn.Linear(embed_dim + 32, 64), nn.ReLU(), nn.Linear(64, 7), nn.Sigmoid(), ] ) elif self.traffic_pred_head_type == "seg": self.traffic_pred_head = nn.Sequential( *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()] ) self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True) encoder_layer = TransformerEncoderLayer( embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before ) self.encoder = TransformerEncoder(encoder_layer, enc_depth, None) decoder_layer = TransformerDecoderLayer( embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before ) decoder_norm = nn.LayerNorm(embed_dim) self.decoder = TransformerDecoder( decoder_layer, dec_depth, decoder_norm, return_intermediate=False ) self.reset_parameters() def reset_parameters(self): nn.init.uniform_(self.global_embed) nn.init.uniform_(self.view_embed) nn.init.uniform_(self.query_embed) nn.init.uniform_(self.query_pos_embed) def forward_features( self, front_image, left_image, right_image, front_center_image, lidar, measurements, ): features = [] front_image_token, front_image_token_global = self.rgb_patch_embed(front_image) if self.use_view_embed: front_image_token = ( front_image_token + self.view_embed[:, :, 0:1, :] + self.position_encoding(front_image_token) ) else: front_image_token = front_image_token + self.position_encoding( front_image_token ) front_image_token = front_image_token.flatten(2).permute(2, 0, 1) front_image_token_global = ( front_image_token_global + self.view_embed[:, :, 0, :] + self.global_embed[:, :, 0:1] ) front_image_token_global = front_image_token_global.permute(2, 0, 1) features.extend([front_image_token, front_image_token_global]) if self.with_right_left_sensors: left_image_token, left_image_token_global = self.rgb_patch_embed(left_image) if self.use_view_embed: left_image_token = ( left_image_token + self.view_embed[:, :, 1:2, :] + self.position_encoding(left_image_token) ) else: left_image_token = left_image_token + self.position_encoding( left_image_token ) left_image_token = left_image_token.flatten(2).permute(2, 0, 1) left_image_token_global = ( left_image_token_global + self.view_embed[:, :, 1, :] + self.global_embed[:, :, 1:2] ) left_image_token_global = left_image_token_global.permute(2, 0, 1) right_image_token, right_image_token_global = self.rgb_patch_embed( right_image ) if self.use_view_embed: right_image_token = ( right_image_token + self.view_embed[:, :, 2:3, :] + self.position_encoding(right_image_token) ) else: right_image_token = right_image_token + self.position_encoding( right_image_token ) right_image_token = right_image_token.flatten(2).permute(2, 0, 1) right_image_token_global = ( right_image_token_global + self.view_embed[:, :, 2, :] + self.global_embed[:, :, 2:3] ) right_image_token_global = right_image_token_global.permute(2, 0, 1) features.extend( [ left_image_token, left_image_token_global, right_image_token, right_image_token_global, ] ) if self.with_center_sensor: ( front_center_image_token, front_center_image_token_global, ) = self.rgb_patch_embed(front_center_image) if self.use_view_embed: front_center_image_token = ( front_center_image_token + self.view_embed[:, :, 3:4, :] + self.position_encoding(front_center_image_token) ) else: front_center_image_token = ( front_center_image_token + self.position_encoding(front_center_image_token) ) front_center_image_token = front_center_image_token.flatten(2).permute( 2, 0, 1 ) front_center_image_token_global = ( front_center_image_token_global + self.view_embed[:, :, 3, :] + self.global_embed[:, :, 3:4] ) front_center_image_token_global = front_center_image_token_global.permute( 2, 0, 1 ) features.extend([front_center_image_token, front_center_image_token_global]) if self.with_lidar: lidar_token, lidar_token_global = self.lidar_patch_embed(lidar) if self.use_view_embed: lidar_token = ( lidar_token + self.view_embed[:, :, 4:5, :] + self.position_encoding(lidar_token) ) else: lidar_token = lidar_token + self.position_encoding(lidar_token) lidar_token = lidar_token.flatten(2).permute(2, 0, 1) lidar_token_global = ( lidar_token_global + self.view_embed[:, :, 4, :] + self.global_embed[:, :, 4:5] ) lidar_token_global = lidar_token_global.permute(2, 0, 1) features.extend([lidar_token, lidar_token_global]) features = torch.cat(features, 0) return features def forward(self, x): front_image = x["rgb"] left_image = x["rgb_left"] right_image = x["rgb_right"] front_center_image = x["rgb_center"] measurements = x["measurements"] target_point = x["target_point"] lidar = x["lidar"] if self.direct_concat: img_size = front_image.shape[-1] left_image = torch.nn.functional.interpolate( left_image, size=(img_size, img_size) ) right_image = torch.nn.functional.interpolate( right_image, size=(img_size, img_size) ) front_center_image = torch.nn.functional.interpolate( front_center_image, size=(img_size, img_size) ) front_image = torch.cat( [front_image, left_image, right_image, front_center_image], dim=1 ) features = self.forward_features( front_image, left_image, right_image, front_center_image, lidar, measurements, ) bs = front_image.shape[0] if self.end2end: tgt = self.query_pos_embed.repeat(bs, 1, 1) else: tgt = self.position_encoding( torch.ones((bs, 1, 20, 20), device=x["rgb"].device) ) tgt = tgt.flatten(2) tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2) tgt = tgt.permute(2, 0, 1) memory = self.encoder(features, mask=self.attn_mask) hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0] hs = hs.permute(1, 0, 2) if self.end2end: waypoints = self.waypoints_generator(hs, target_point) return waypoints if self.waypoints_pred_head != "heatmap": traffic_feature = hs[:, :400] is_junction_feature = hs[:, 400] traffic_light_state_feature = hs[:, 400] stop_sign_feature = hs[:, 400] waypoints_feature = hs[:, 401:411] else: traffic_feature = hs[:, :400] is_junction_feature = hs[:, 400] traffic_light_state_feature = hs[:, 400] stop_sign_feature = hs[:, 400] waypoints_feature = hs[:, 401:405] if self.waypoints_pred_head == "heatmap": waypoints = self.waypoints_generator(waypoints_feature, measurements) elif self.waypoints_pred_head == "gru": waypoints = self.waypoints_generator(waypoints_feature, target_point) elif self.waypoints_pred_head == "gru-command": waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements) elif self.waypoints_pred_head == "linear": waypoints = self.waypoints_generator(waypoints_feature, measurements) elif self.waypoints_pred_head == "linear-sum": waypoints = self.waypoints_generator(waypoints_feature, measurements) is_junction = self.junction_pred_head(is_junction_feature) traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature) stop_sign = self.stop_sign_head(stop_sign_feature) velocity = measurements[:, 6:7].unsqueeze(-1) velocity = velocity.repeat(1, 400, 32) traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2) traffic = self.traffic_pred_head(traffic_feature_with_vel) return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature # --- Model Builder Function --- # model.py # ... (كل الكود الآخر في الملف يبقى كما هو) ... # --- Model Builder Function --- def build_interfuser_model(config): """ تبني نموذج Interfuser بناءً على قاموس إعدادات. """ model = Interfuser( enc_depth=config.get("enc_depth", 6), dec_depth=config.get("dec_depth", 6), embed_dim=config.get("embed_dim", 256), rgb_backbone_name=config.get("rgb_backbone_name", "r50"), lidar_backbone_name=config.get("lidar_backbone_name", "r18"), waypoints_pred_head=config.get("waypoints_pred_head", "gru"), use_different_backbone=config.get("use_different_backbone", True), direct_concat=config.get("direct_concat", True), ) return model