# ---------------------------------------------------------------------------- # Adapted from https://github.com/IDEA-Research/ED-Pose/ \ # tree/master/models/edpose # Original licence: IDEA License 1.0 # ---------------------------------------------------------------------------- import copy import math from typing import Dict, List, Tuple import numpy as np import torch import torch.nn.functional as F from mmcv.ops import MultiScaleDeformableAttention from mmengine.model import BaseModule, ModuleList, constant_init from mmengine.structures import InstanceData from torch import Tensor, nn from mmpose.models.utils import inverse_sigmoid from mmpose.registry import KEYPOINT_CODECS, MODELS from mmpose.utils.tensor_utils import to_numpy from mmpose.utils.typing import (ConfigType, Features, OptConfigType, OptSampleList, Predictions) from .base_transformer_head import TransformerHead from .transformers.deformable_detr_layers import ( DeformableDetrTransformerDecoderLayer, DeformableDetrTransformerEncoder) from .transformers.utils import FFN, PositionEmbeddingSineHW class EDPoseDecoder(BaseModule): """Transformer decoder of EDPose: `Explicit Box Detection Unifies End-to- End Multi-Person Pose Estimation. Args: layer_cfg (ConfigDict): the config of each encoder layer. All the layers will share the same config. num_layers (int): Number of decoder layers. return_intermediate (bool, optional): Whether to return outputs of intermediate layers. Defaults to `True`. embed_dims (int): Dims of embed. query_dim (int): Dims of queries. num_feature_levels (int): Number of feature levels. num_box_decoder_layers (int): Number of box decoder layers. num_keypoints (int): Number of datasets' body keypoints. num_dn (int): Number of denosing points. num_group (int): Number of decoder layers. """ def __init__(self, layer_cfg, num_layers, return_intermediate, embed_dims: int = 256, query_dim=4, num_feature_levels=1, num_box_decoder_layers=2, num_keypoints=17, num_dn=100, num_group=100): super().__init__() self.layer_cfg = layer_cfg self.num_layers = num_layers self.embed_dims = embed_dims assert return_intermediate, 'support return_intermediate only' self.return_intermediate = return_intermediate assert query_dim in [ 2, 4 ], 'query_dim should be 2/4 but {}'.format(query_dim) self.query_dim = query_dim self.num_feature_levels = num_feature_levels self.layers = ModuleList([ DeformableDetrTransformerDecoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) self.norm = nn.LayerNorm(self.embed_dims) self.ref_point_head = FFN(self.query_dim // 2 * self.embed_dims, self.embed_dims, self.embed_dims, 2) self.num_keypoints = num_keypoints self.query_scale = None self.bbox_embed = None self.class_embed = None self.pose_embed = None self.pose_hw_embed = None self.num_box_decoder_layers = num_box_decoder_layers self.box_pred_damping = None self.num_group = num_group self.rm_detach = None self.num_dn = num_dn self.hw = nn.Embedding(self.num_keypoints, 2) self.keypoint_embed = nn.Embedding(self.num_keypoints, embed_dims) self.kpt_index = [ x for x in range(self.num_group * (self.num_keypoints + 1)) if x % (self.num_keypoints + 1) != 0 ] def forward(self, query: Tensor, value: Tensor, key_padding_mask: Tensor, reference_points: Tensor, spatial_shapes: Tensor, level_start_index: Tensor, valid_ratios: Tensor, humandet_attn_mask: Tensor, human2pose_attn_mask: Tensor, **kwargs) -> Tuple[Tensor]: """Forward function of decoder Args: query (Tensor): The input queries, has shape (bs, num_queries, dim). value (Tensor): The input values, has shape (bs, num_value, dim). key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` input. ByteTensor, has shape (bs, num_value). reference_points (Tensor): The initial reference, has shape (bs, num_queries, 4) with the last dimension arranged as (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has shape (bs, num_queries, 2) with the last dimension arranged as (cx, cy). spatial_shapes (Tensor): Spatial shapes of features in all levels, has shape (num_levels, 2), last dimension represents (h, w). level_start_index (Tensor): The start index of each level. A tensor has shape (num_levels, ) and can be represented as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. valid_ratios (Tensor): The ratios of the valid width and the valid height relative to the width and the height of features in all levels, has shape (bs, num_levels, 2). reg_branches: (obj:`nn.ModuleList`, optional): Used for refining the regression results. Returns: Tuple[Tuple[Tensor]]: Outputs of Deformable Transformer Decoder. - output (Tuple[Tensor]): Output embeddings of the last decoder, each has shape (num_decoder_layers, num_queries, bs, embed_dims) - reference_points (Tensor): The reference of the last decoder layer, each has shape (num_decoder_layers, bs, num_queries, 4). The coordinates are arranged as (cx, cy, w, h) """ output = query attn_mask = humandet_attn_mask intermediate = [] intermediate_reference_points = [reference_points] effect_num_dn = self.num_dn if self.training else 0 inter_select_number = self.num_group for layer_id, layer in enumerate(self.layers): if reference_points.shape[-1] == 4: reference_points_input = \ reference_points[:, :, None] * \ torch.cat([valid_ratios, valid_ratios], -1)[None, :] else: assert reference_points.shape[-1] == 2 reference_points_input = \ reference_points[:, :, None] * \ valid_ratios[None, :] query_sine_embed = self.get_proposal_pos_embed( reference_points_input[:, :, 0, :]) # nq, bs, 256*2 query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 output = layer( output.transpose(0, 1), query_pos=query_pos.transpose(0, 1), value=value.transpose(0, 1), key_padding_mask=key_padding_mask, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reference_points=reference_points_input.transpose( 0, 1).contiguous(), self_attn_mask=attn_mask, **kwargs) output = output.transpose(0, 1) intermediate.append(self.norm(output)) # human update if layer_id < self.num_box_decoder_layers: delta_unsig = self.bbox_embed[layer_id](output) new_reference_points = delta_unsig + inverse_sigmoid( reference_points) new_reference_points = new_reference_points.sigmoid() # query expansion if layer_id == self.num_box_decoder_layers - 1: dn_output = output[:effect_num_dn] dn_new_reference_points = new_reference_points[:effect_num_dn] class_unselected = self.class_embed[layer_id]( output)[effect_num_dn:] topk_proposals = torch.topk( class_unselected.max(-1)[0], inter_select_number, dim=0)[1] new_reference_points_for_box = torch.gather( new_reference_points[effect_num_dn:], 0, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) new_output_for_box = torch.gather( output[effect_num_dn:], 0, topk_proposals.unsqueeze(-1).repeat(1, 1, self.embed_dims)) bs = new_output_for_box.shape[1] new_output_for_keypoint = new_output_for_box[:, None, :, :] \ + self.keypoint_embed.weight[None, :, None, :] if self.num_keypoints == 17: delta_xy = self.pose_embed[-1](new_output_for_keypoint)[ ..., :2] else: delta_xy = self.pose_embed[0](new_output_for_keypoint)[ ..., :2] keypoint_xy = (inverse_sigmoid( new_reference_points_for_box[..., :2][:, None]) + delta_xy).sigmoid() num_queries, _, bs, _ = keypoint_xy.shape keypoint_wh_weight = self.hw.weight.unsqueeze(0).unsqueeze( -2).repeat(num_queries, 1, bs, 1).sigmoid() keypoint_wh = keypoint_wh_weight * \ new_reference_points_for_box[..., 2:][:, None] new_reference_points_for_keypoint = torch.cat( (keypoint_xy, keypoint_wh), dim=-1) new_reference_points = torch.cat( (new_reference_points_for_box.unsqueeze(1), new_reference_points_for_keypoint), dim=1).flatten(0, 1) output = torch.cat( (new_output_for_box.unsqueeze(1), new_output_for_keypoint), dim=1).flatten(0, 1) new_reference_points = torch.cat( (dn_new_reference_points, new_reference_points), dim=0) output = torch.cat((dn_output, output), dim=0) attn_mask = human2pose_attn_mask # human-to-keypoints update if layer_id >= self.num_box_decoder_layers: effect_num_dn = self.num_dn if self.training else 0 inter_select_number = self.num_group ref_before_sigmoid = inverse_sigmoid(reference_points) output_bbox_dn = output[:effect_num_dn] output_bbox_norm = output[effect_num_dn:][0::( self.num_keypoints + 1)] ref_before_sigmoid_bbox_dn = \ ref_before_sigmoid[:effect_num_dn] ref_before_sigmoid_bbox_norm = \ ref_before_sigmoid[effect_num_dn:][0::( self.num_keypoints + 1)] delta_unsig_dn = self.bbox_embed[layer_id](output_bbox_dn) delta_unsig_norm = self.bbox_embed[layer_id](output_bbox_norm) outputs_unsig_dn = delta_unsig_dn + ref_before_sigmoid_bbox_dn outputs_unsig_norm = delta_unsig_norm + \ ref_before_sigmoid_bbox_norm new_reference_points_for_box_dn = outputs_unsig_dn.sigmoid() new_reference_points_for_box_norm = outputs_unsig_norm.sigmoid( ) output_kpt = output[effect_num_dn:].index_select( 0, torch.tensor(self.kpt_index, device=output.device)) delta_xy_unsig = self.pose_embed[layer_id - self.num_box_decoder_layers]( output_kpt) outputs_unsig = ref_before_sigmoid[ effect_num_dn:].index_select( 0, torch.tensor(self.kpt_index, device=output.device)).clone() delta_hw_unsig = self.pose_hw_embed[ layer_id - self.num_box_decoder_layers]( output_kpt) outputs_unsig[..., :2] += delta_xy_unsig[..., :2] outputs_unsig[..., 2:] += delta_hw_unsig new_reference_points_for_keypoint = outputs_unsig.sigmoid() bs = new_reference_points_for_box_norm.shape[1] new_reference_points_norm = torch.cat( (new_reference_points_for_box_norm.unsqueeze(1), new_reference_points_for_keypoint.view( -1, self.num_keypoints, bs, 4)), dim=1).flatten(0, 1) new_reference_points = torch.cat( (new_reference_points_for_box_dn, new_reference_points_norm), dim=0) reference_points = new_reference_points.detach() intermediate_reference_points.append(reference_points) decoder_outputs = [itm_out.transpose(0, 1) for itm_out in intermediate] reference_points = [ itm_refpoint.transpose(0, 1) for itm_refpoint in intermediate_reference_points ] return decoder_outputs, reference_points @staticmethod def get_proposal_pos_embed(pos_tensor: Tensor, temperature: int = 10000, num_pos_feats: int = 128) -> Tensor: """Get the position embedding of the proposal. Args: pos_tensor (Tensor): Not normalized proposals, has shape (bs, num_queries, 4) with the last dimension arranged as (cx, cy, w, h). temperature (int, optional): The temperature used for scaling the position embedding. Defaults to 10000. num_pos_feats (int, optional): The feature dimension for each position along x, y, w, and h-axis. Note the final returned dimension for each position is 4 times of num_pos_feats. Default to 128. Returns: Tensor: The position embedding of proposal, has shape (bs, num_queries, num_pos_feats * 4), with the last dimension arranged as (cx, cy, w, h) """ scale = 2 * math.pi dim_t = torch.arange( num_pos_feats, dtype=torch.float32, device=pos_tensor.device) dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) x_embed = pos_tensor[:, :, 0] * scale y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) if pos_tensor.size(-1) == 2: pos = torch.cat((pos_y, pos_x), dim=2) elif pos_tensor.size(-1) == 4: w_embed = pos_tensor[:, :, 2] * scale pos_w = w_embed[:, :, None] / dim_t pos_w = torch.stack( (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) h_embed = pos_tensor[:, :, 3] * scale pos_h = h_embed[:, :, None] / dim_t pos_h = torch.stack( (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError('Unknown pos_tensor shape(-1):{}'.format( pos_tensor.size(-1))) return pos class EDPoseOutHead(BaseModule): """Final Head of EDPose: `Explicit Box Detection Unifies End-to-End Multi- Person Pose Estimation. Args: num_classes (int): The number of classes. num_keypoints (int): The number of datasets' body keypoints. num_queries (int): The number of queries. cls_no_bias (bool): Weather add the bias to class embed. embed_dims (int): The dims of embed. as_two_stage (bool, optional): Whether to generate the proposal from the outputs of encoder. Defaults to `False`. refine_queries_num (int): The number of refines queries after decoders. num_box_decoder_layers (int): The number of bbox decoder layer. num_group (int): The number of groups. num_pred_layer (int): The number of the prediction layers. Defaults to 6. dec_pred_class_embed_share (bool): Whether to share parameters for all the class prediction layers. Defaults to `False`. dec_pred_bbox_embed_share (bool): Whether to share parameters for all the bbox prediction layers. Defaults to `False`. dec_pred_pose_embed_share (bool): Whether to share parameters for all the pose prediction layers. Defaults to `False`. """ def __init__(self, num_classes, num_keypoints: int = 17, num_queries: int = 900, cls_no_bias: bool = False, embed_dims: int = 256, as_two_stage: bool = False, refine_queries_num: int = 100, num_box_decoder_layers: int = 2, num_group: int = 100, num_pred_layer: int = 6, dec_pred_class_embed_share: bool = False, dec_pred_bbox_embed_share: bool = False, dec_pred_pose_embed_share: bool = False, **kwargs): super().__init__() self.embed_dims = embed_dims self.as_two_stage = as_two_stage self.num_classes = num_classes self.refine_queries_num = refine_queries_num self.num_box_decoder_layers = num_box_decoder_layers self.num_keypoints = num_keypoints self.num_queries = num_queries # prepare pred layers self.dec_pred_class_embed_share = dec_pred_class_embed_share self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share self.dec_pred_pose_embed_share = dec_pred_pose_embed_share # prepare class & box embed _class_embed = nn.Linear( self.embed_dims, self.num_classes, bias=(not cls_no_bias)) if not cls_no_bias: prior_prob = 0.01 bias_value = -math.log((1 - prior_prob) / prior_prob) _class_embed.bias.data = torch.ones(self.num_classes) * bias_value _bbox_embed = FFN(self.embed_dims, self.embed_dims, 4, 3) _pose_embed = FFN(self.embed_dims, self.embed_dims, 2, 3) _pose_hw_embed = FFN(self.embed_dims, self.embed_dims, 2, 3) self.num_group = num_group if dec_pred_bbox_embed_share: box_embed_layerlist = [_bbox_embed for i in range(num_pred_layer)] else: box_embed_layerlist = [ copy.deepcopy(_bbox_embed) for i in range(num_pred_layer) ] if dec_pred_class_embed_share: class_embed_layerlist = [ _class_embed for i in range(num_pred_layer) ] else: class_embed_layerlist = [ copy.deepcopy(_class_embed) for i in range(num_pred_layer) ] if num_keypoints == 17: if dec_pred_pose_embed_share: pose_embed_layerlist = [ _pose_embed for i in range(num_pred_layer - num_box_decoder_layers + 1) ] else: pose_embed_layerlist = [ copy.deepcopy(_pose_embed) for i in range(num_pred_layer - num_box_decoder_layers + 1) ] else: if dec_pred_pose_embed_share: pose_embed_layerlist = [ _pose_embed for i in range(num_pred_layer - num_box_decoder_layers) ] else: pose_embed_layerlist = [ copy.deepcopy(_pose_embed) for i in range(num_pred_layer - num_box_decoder_layers) ] pose_hw_embed_layerlist = [ _pose_hw_embed for i in range(num_pred_layer - num_box_decoder_layers) ] self.bbox_embed = nn.ModuleList(box_embed_layerlist) self.class_embed = nn.ModuleList(class_embed_layerlist) self.pose_embed = nn.ModuleList(pose_embed_layerlist) self.pose_hw_embed = nn.ModuleList(pose_hw_embed_layerlist) def init_weights(self) -> None: """Initialize weights of the Deformable DETR head.""" for m in self.bbox_embed: constant_init(m[-1], 0, bias=0) for m in self.pose_embed: constant_init(m[-1], 0, bias=0) def forward(self, hidden_states: List[Tensor], references: List[Tensor], mask_dict: Dict, hidden_states_enc: Tensor, referens_enc: Tensor, batch_data_samples) -> Dict: """Forward function. Args: hidden_states (Tensor): Hidden states output from each decoder layer, has shape (num_decoder_layers, bs, num_queries, dim). references (list[Tensor]): List of the reference from the decoder. Returns: tuple[Tensor]: results of head containing the following tensor. - pred_logits (Tensor): Outputs from the classification head, the socres of every bboxes. - pred_boxes (Tensor): The output boxes. - pred_keypoints (Tensor): The output keypoints. """ # update human boxes effec_dn_num = self.refine_queries_num if self.training else 0 outputs_coord_list = [] outputs_class = [] for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_cls_embed, layer_hs) in enumerate( zip(references[:-1], self.bbox_embed, self.class_embed, hidden_states)): if dec_lid < self.num_box_decoder_layers: layer_delta_unsig = layer_bbox_embed(layer_hs) layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid( layer_ref_sig) layer_outputs_unsig = layer_outputs_unsig.sigmoid() layer_cls = layer_cls_embed(layer_hs) outputs_coord_list.append(layer_outputs_unsig) outputs_class.append(layer_cls) else: layer_hs_bbox_dn = layer_hs[:, :effec_dn_num, :] layer_hs_bbox_norm = \ layer_hs[:, effec_dn_num:, :][:, 0::( self.num_keypoints + 1), :] bs = layer_ref_sig.shape[0] ref_before_sigmoid_bbox_dn = \ layer_ref_sig[:, : effec_dn_num, :] ref_before_sigmoid_bbox_norm = \ layer_ref_sig[:, effec_dn_num:, :][:, 0::( self.num_keypoints + 1), :] layer_delta_unsig_dn = layer_bbox_embed(layer_hs_bbox_dn) layer_delta_unsig_norm = layer_bbox_embed(layer_hs_bbox_norm) layer_outputs_unsig_dn = layer_delta_unsig_dn + \ inverse_sigmoid(ref_before_sigmoid_bbox_dn) layer_outputs_unsig_dn = layer_outputs_unsig_dn.sigmoid() layer_outputs_unsig_norm = layer_delta_unsig_norm + \ inverse_sigmoid(ref_before_sigmoid_bbox_norm) layer_outputs_unsig_norm = layer_outputs_unsig_norm.sigmoid() layer_outputs_unsig = torch.cat( (layer_outputs_unsig_dn, layer_outputs_unsig_norm), dim=1) layer_cls_dn = layer_cls_embed(layer_hs_bbox_dn) layer_cls_norm = layer_cls_embed(layer_hs_bbox_norm) layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) outputs_class.append(layer_cls) outputs_coord_list.append(layer_outputs_unsig) # update keypoints boxes outputs_keypoints_list = [] kpt_index = [ x for x in range(self.num_group * (self.num_keypoints + 1)) if x % (self.num_keypoints + 1) != 0 ] for dec_lid, (layer_ref_sig, layer_hs) in enumerate( zip(references[:-1], hidden_states)): if dec_lid < self.num_box_decoder_layers: assert isinstance(layer_hs, torch.Tensor) bs = layer_hs.shape[0] layer_res = layer_hs.new_zeros( (bs, self.num_queries, self.num_keypoints * 3)) outputs_keypoints_list.append(layer_res) else: bs = layer_ref_sig.shape[0] layer_hs_kpt = \ layer_hs[:, effec_dn_num:, :].index_select( 1, torch.tensor(kpt_index, device=layer_hs.device)) delta_xy_unsig = self.pose_embed[dec_lid - self.num_box_decoder_layers]( layer_hs_kpt) layer_ref_sig_kpt = \ layer_ref_sig[:, effec_dn_num:, :].index_select( 1, torch.tensor(kpt_index, device=layer_hs.device)) layer_outputs_unsig_keypoints = delta_xy_unsig + \ inverse_sigmoid(layer_ref_sig_kpt[..., :2]) vis_xy_unsig = torch.ones_like( layer_outputs_unsig_keypoints, device=layer_outputs_unsig_keypoints.device) xyv = torch.cat((layer_outputs_unsig_keypoints, vis_xy_unsig[:, :, 0].unsqueeze(-1)), dim=-1) xyv = xyv.sigmoid() layer_res = xyv.reshape( (bs, self.num_group, self.num_keypoints, 3)).flatten(2, 3) layer_res = self.keypoint_xyzxyz_to_xyxyzz(layer_res) outputs_keypoints_list.append(layer_res) dn_mask_dict = mask_dict if self.refine_queries_num > 0 and dn_mask_dict is not None: outputs_class, outputs_coord_list, outputs_keypoints_list = \ self.dn_post_process2( outputs_class, outputs_coord_list, outputs_keypoints_list, dn_mask_dict ) for _out_class, _out_bbox, _out_keypoint in zip( outputs_class, outputs_coord_list, outputs_keypoints_list): assert _out_class.shape[1] == \ _out_bbox.shape[1] == _out_keypoint.shape[1] return outputs_class[-1], outputs_coord_list[ -1], outputs_keypoints_list[-1] def keypoint_xyzxyz_to_xyxyzz(self, keypoints: torch.Tensor): """ Args: keypoints (torch.Tensor): ..., 51 """ res = torch.zeros_like(keypoints) num_points = keypoints.shape[-1] // 3 res[..., 0:2 * num_points:2] = keypoints[..., 0::3] res[..., 1:2 * num_points:2] = keypoints[..., 1::3] res[..., 2 * num_points:] = keypoints[..., 2::3] return res @MODELS.register_module() class EDPoseHead(TransformerHead): """Head introduced in `Explicit Box Detection Unifies End-to-End Multi- Person Pose Estimation`_ by J Yang1 et al (2023). The head is composed of Encoder, Decoder and Out_head. Code is modified from the `official github repo `_. More details can be found in the `paper `_ . Args: num_queries (int): Number of query in Transformer. num_feature_levels (int): Number of feature levels. Defaults to 4. num_keypoints (int): Number of keypoints. Defaults to 4. as_two_stage (bool, optional): Whether to generate the proposal from the outputs of encoder. Defaults to `False`. encoder (:obj:`ConfigDict` or dict, optional): Config of the Transformer encoder. Defaults to None. decoder (:obj:`ConfigDict` or dict, optional): Config of the Transformer decoder. Defaults to None. out_head (:obj:`ConfigDict` or dict, optional): Config for the bounding final out head module. Defaults to None. positional_encoding (:obj:`ConfigDict` or dict): Config for transformer position encoding. Defaults None. denosing_cfg (:obj:`ConfigDict` or dict, optional): Config of the human query denoising training strategy. data_decoder (:obj:`ConfigDict` or dict, optional): Config of the data decoder which transform the results from output space to input space. dec_pred_class_embed_share (bool): Whether to share the class embed layer. Default False. dec_pred_bbox_embed_share (bool): Whether to share the bbox embed layer. Default False. refine_queries_num (int): Number of refined human content queries and their position queries . two_stage_keep_all_tokens (bool): Whether to keep all tokens. """ def __init__(self, num_queries: int = 100, num_feature_levels: int = 4, num_keypoints: int = 17, as_two_stage: bool = False, encoder: OptConfigType = None, decoder: OptConfigType = None, out_head: OptConfigType = None, positional_encoding: OptConfigType = None, data_decoder: OptConfigType = None, denosing_cfg: OptConfigType = None, dec_pred_class_embed_share: bool = False, dec_pred_bbox_embed_share: bool = False, refine_queries_num: int = 100, two_stage_keep_all_tokens: bool = False) -> None: self.as_two_stage = as_two_stage self.num_feature_levels = num_feature_levels self.refine_queries_num = refine_queries_num self.dec_pred_class_embed_share = dec_pred_class_embed_share self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share self.two_stage_keep_all_tokens = two_stage_keep_all_tokens self.num_heads = decoder['layer_cfg']['self_attn_cfg']['num_heads'] self.num_group = decoder['num_group'] self.num_keypoints = num_keypoints self.denosing_cfg = denosing_cfg if data_decoder is not None: self.data_decoder = KEYPOINT_CODECS.build(data_decoder) else: self.data_decoder = None super().__init__( encoder=encoder, decoder=decoder, out_head=out_head, positional_encoding=positional_encoding, num_queries=num_queries) self.positional_encoding = PositionEmbeddingSineHW( **self.positional_encoding_cfg) self.encoder = DeformableDetrTransformerEncoder(**self.encoder_cfg) self.decoder = EDPoseDecoder( num_keypoints=num_keypoints, **self.decoder_cfg) self.out_head = EDPoseOutHead( num_keypoints=num_keypoints, as_two_stage=as_two_stage, refine_queries_num=refine_queries_num, **self.out_head_cfg, **self.decoder_cfg) self.embed_dims = self.encoder.embed_dims self.label_enc = nn.Embedding( self.denosing_cfg['dn_labelbook_size'] + 1, self.embed_dims) if not self.as_two_stage: self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) self.refpoint_embedding = nn.Embedding(self.num_queries, 4) self.level_embed = nn.Parameter( torch.Tensor(self.num_feature_levels, self.embed_dims)) self.decoder.bbox_embed = self.out_head.bbox_embed self.decoder.pose_embed = self.out_head.pose_embed self.decoder.pose_hw_embed = self.out_head.pose_hw_embed self.decoder.class_embed = self.out_head.class_embed if self.as_two_stage: self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) self.memory_trans_norm = nn.LayerNorm(self.embed_dims) if dec_pred_class_embed_share and dec_pred_bbox_embed_share: self.enc_out_bbox_embed = self.out_head.bbox_embed[0] else: self.enc_out_bbox_embed = copy.deepcopy( self.out_head.bbox_embed[0]) if dec_pred_class_embed_share and dec_pred_bbox_embed_share: self.enc_out_class_embed = self.out_head.class_embed[0] else: self.enc_out_class_embed = copy.deepcopy( self.out_head.class_embed[0]) def init_weights(self) -> None: """Initialize weights for Transformer and other components.""" super().init_weights() for coder in self.encoder, self.decoder: for p in coder.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MultiScaleDeformableAttention): m.init_weights() if self.as_two_stage: nn.init.xavier_uniform_(self.memory_trans_fc.weight) nn.init.normal_(self.level_embed) def pre_transformer(self, img_feats: Tuple[Tensor], batch_data_samples: OptSampleList = None ) -> Tuple[Dict]: """Process image features before feeding them to the transformer. Args: img_feats (tuple[Tensor]): Multi-level features that may have different resolutions, output from neck. Each feature has shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. batch_data_samples (list[:obj:`DetDataSample`], optional): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Defaults to None. Returns: tuple[dict]: The first dict contains the inputs of encoder and the second dict contains the inputs of decoder. - encoder_inputs_dict (dict): The keyword args dictionary of `self.encoder()`. - decoder_inputs_dict (dict): The keyword args dictionary of `self.forward_decoder()`, which includes 'memory_mask'. """ batch_size = img_feats[0].size(0) # construct binary masks for the transformer. assert batch_data_samples is not None batch_input_shape = batch_data_samples[0].batch_input_shape img_shape_list = [sample.img_shape for sample in batch_data_samples] input_img_h, input_img_w = batch_input_shape masks = img_feats[0].new_ones((batch_size, input_img_h, input_img_w)) for img_id in range(batch_size): img_h, img_w = img_shape_list[img_id] masks[img_id, :img_h, :img_w] = 0 # NOTE following the official DETR repo, non-zero values representing # ignored positions, while zero values means valid positions. mlvl_masks = [] mlvl_pos_embeds = [] for feat in img_feats: mlvl_masks.append( F.interpolate(masks[None], size=feat.shape[-2:]).to(torch.bool).squeeze(0)) mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1])) feat_flatten = [] lvl_pos_embed_flatten = [] mask_flatten = [] spatial_shapes = [] for lvl, (feat, mask, pos_embed) in enumerate( zip(img_feats, mlvl_masks, mlvl_pos_embeds)): batch_size, c, h, w = feat.shape # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c] feat = feat.view(batch_size, c, -1).permute(0, 2, 1) pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1) lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl] mask = mask.flatten(1) spatial_shape = (h, w) feat_flatten.append(feat) lvl_pos_embed_flatten.append(lvl_pos_embed) mask_flatten.append(mask) spatial_shapes.append(spatial_shape) # (bs, num_feat_points, dim) feat_flatten = torch.cat(feat_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl) mask_flatten = torch.cat(mask_flatten, 1) spatial_shapes = torch.as_tensor( # (num_level, 2) spatial_shapes, dtype=torch.long, device=feat_flatten.device) level_start_index = torch.cat(( spatial_shapes.new_zeros((1, )), # (num_level) spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack( # (bs, num_level, 2) [self.get_valid_ratio(m) for m in mlvl_masks], 1) if self.refine_queries_num > 0 or batch_data_samples is not None: input_query_label, input_query_bbox, humandet_attn_mask, \ human2pose_attn_mask, mask_dict =\ self.prepare_for_denosing( batch_data_samples, device=img_feats[0].device) else: assert batch_data_samples is None input_query_bbox = input_query_label = \ humandet_attn_mask = human2pose_attn_mask = mask_dict = None encoder_inputs_dict = dict( query=feat_flatten, query_pos=lvl_pos_embed_flatten, key_padding_mask=mask_flatten, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios) decoder_inputs_dict = dict( memory_mask=mask_flatten, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, humandet_attn_mask=humandet_attn_mask, human2pose_attn_mask=human2pose_attn_mask, input_query_bbox=input_query_bbox, input_query_label=input_query_label, mask_dict=mask_dict) return encoder_inputs_dict, decoder_inputs_dict def forward_encoder(self, img_feats: Tuple[Tensor], batch_data_samples: OptSampleList = None) -> Dict: """Forward with Transformer encoder. The forward procedure is defined as: 'pre_transformer' -> 'encoder' Args: img_feats (tuple[Tensor]): Multi-level features that may have different resolutions, output from neck. Each feature has shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. batch_data_samples (list[:obj:`DetDataSample`], optional): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Defaults to None. Returns: dict: The dictionary of encoder outputs, which includes the `memory` of the encoder output. """ encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( img_feats, batch_data_samples) memory = self.encoder(**encoder_inputs_dict) encoder_outputs_dict = dict(memory=memory, **decoder_inputs_dict) return encoder_outputs_dict def pre_decoder(self, memory: Tensor, memory_mask: Tensor, spatial_shapes: Tensor, input_query_bbox: Tensor, input_query_label: Tensor) -> Tuple[Dict, Dict]: """Prepare intermediate variables before entering Transformer decoder, such as `query` and `reference_points`. Args: memory (Tensor): The output embeddings of the Transformer encoder, has shape (bs, num_feat_points, dim). memory_mask (Tensor): ByteTensor, the padding mask of the memory, has shape (bs, num_feat_points). It will only be used when `as_two_stage` is `True`. spatial_shapes (Tensor): Spatial shapes of features in all levels, has shape (num_levels, 2), last dimension represents (h, w). It will only be used when `as_two_stage` is `True`. input_query_bbox (Tensor): Denosing bbox query for training. input_query_label (Tensor): Denosing label query for training. Returns: tuple[dict, dict]: The decoder_inputs_dict and head_inputs_dict. - decoder_inputs_dict (dict): The keyword dictionary args of `self.decoder()`. - head_inputs_dict (dict): The keyword dictionary args of the bbox_head functions. """ bs, _, c = memory.shape if self.as_two_stage: output_memory, output_proposals = \ self.gen_encoder_output_proposals( memory, memory_mask, spatial_shapes) enc_outputs_class = self.enc_out_class_embed(output_memory) enc_outputs_coord_unact = self.enc_out_bbox_embed( output_memory) + output_proposals topk_proposals = torch.topk( enc_outputs_class.max(-1)[0], self.num_queries, dim=1)[1] topk_coords_undetach = torch.gather( enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) topk_coords_unact = topk_coords_undetach.detach() reference_points = topk_coords_unact.sigmoid() query_undetach = torch.gather( output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.embed_dims)) query = query_undetach.detach() if input_query_bbox is not None: reference_points = torch.cat( [input_query_bbox, topk_coords_unact], dim=1).sigmoid() query = torch.cat([input_query_label, query], dim=1) if self.two_stage_keep_all_tokens: hidden_states_enc = output_memory.unsqueeze(0) referens_enc = enc_outputs_coord_unact.unsqueeze(0) else: hidden_states_enc = query_undetach.unsqueeze(0) referens_enc = topk_coords_undetach.sigmoid().unsqueeze(0) else: hidden_states_enc, referens_enc = None, None query = self.query_embedding.weight[:, None, :].repeat( 1, bs, 1).transpose(0, 1) reference_points = \ self.refpoint_embedding.weight[:, None, :].repeat(1, bs, 1) if input_query_bbox is not None: reference_points = torch.cat( [input_query_bbox, reference_points], dim=1) query = torch.cat([input_query_label, query], dim=1) reference_points = reference_points.sigmoid() decoder_inputs_dict = dict( query=query, reference_points=reference_points) head_inputs_dict = dict( hidden_states_enc=hidden_states_enc, referens_enc=referens_enc) return decoder_inputs_dict, head_inputs_dict def forward_decoder(self, memory: Tensor, memory_mask: Tensor, spatial_shapes: Tensor, level_start_index: Tensor, valid_ratios: Tensor, humandet_attn_mask: Tensor, human2pose_attn_mask: Tensor, input_query_bbox: Tensor, input_query_label: Tensor, mask_dict: Dict) -> Dict: """Forward with Transformer decoder. The forward procedure is defined as: 'pre_decoder' -> 'decoder' Args: memory (Tensor): The output embeddings of the Transformer encoder, has shape (bs, num_feat_points, dim). memory_mask (Tensor): ByteTensor, the padding mask of the memory, has shape (bs, num_feat_points). spatial_shapes (Tensor): Spatial shapes of features in all levels, has shape (num_levels, 2), last dimension represents (h, w). level_start_index (Tensor): The start index of each level. A tensor has shape (num_levels, ) and can be represented as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. valid_ratios (Tensor): The ratios of the valid width and the valid height relative to the width and the height of features in all levels, has shape (bs, num_levels, 2). humandet_attn_mask (Tensor): Human attention mask. human2pose_attn_mask (Tensor): Human to pose attention mask. input_query_bbox (Tensor): Denosing bbox query for training. input_query_label (Tensor): Denosing label query for training. Returns: dict: The dictionary of decoder outputs, which includes the `hidden_states` of the decoder output and `references` including the initial and intermediate reference_points. """ decoder_in, head_in = self.pre_decoder(memory, memory_mask, spatial_shapes, input_query_bbox, input_query_label) inter_states, inter_references = self.decoder( query=decoder_in['query'].transpose(0, 1), value=memory.transpose(0, 1), key_padding_mask=memory_mask, # for cross_attn reference_points=decoder_in['reference_points'].transpose(0, 1), spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, humandet_attn_mask=humandet_attn_mask, human2pose_attn_mask=human2pose_attn_mask) references = inter_references decoder_outputs_dict = dict( hidden_states=inter_states, references=references, mask_dict=mask_dict) decoder_outputs_dict.update(head_in) return decoder_outputs_dict def forward_out_head(self, batch_data_samples: OptSampleList, hidden_states: List[Tensor], references: List[Tensor], mask_dict: Dict, hidden_states_enc: Tensor, referens_enc: Tensor) -> Tuple[Tensor]: """Forward function.""" out = self.out_head(hidden_states, references, mask_dict, hidden_states_enc, referens_enc, batch_data_samples) return out def predict(self, feats: Features, batch_data_samples: OptSampleList, test_cfg: ConfigType = {}) -> Predictions: """Predict results from features.""" input_shapes = np.array( [d.metainfo['input_size'] for d in batch_data_samples]) if test_cfg.get('flip_test', False): assert NotImplementedError( 'flip_test is currently not supported ' 'for EDPose. Please set `model.test_cfg.flip_test=False`') else: pred_logits, pred_boxes, pred_keypoints = self.forward( feats, batch_data_samples) # (B, K, D) pred = self.decode( input_shapes, pred_logits=pred_logits, pred_boxes=pred_boxes, pred_keypoints=pred_keypoints) return pred def decode(self, input_shapes: np.ndarray, pred_logits: Tensor, pred_boxes: Tensor, pred_keypoints: Tensor): """Select the final top-k keypoints, and decode the results from normalize size to origin input size. Args: input_shapes (Tensor): The size of input image. pred_logits (Tensor): The result of score. pred_boxes (Tensor): The result of bbox. pred_keypoints (Tensor): The result of keypoints. Returns: """ if self.data_decoder is None: raise RuntimeError(f'The data decoder has not been set in \ {self.__class__.__name__}. ' 'Please set the data decoder configs in \ the init parameters to ' 'enable head methods `head.predict()` and \ `head.decode()`') preds = [] pred_logits = pred_logits.sigmoid() pred_logits, pred_boxes, pred_keypoints = to_numpy( [pred_logits, pred_boxes, pred_keypoints]) for input_shape, pred_logit, pred_bbox, pred_kpts in zip( input_shapes, pred_logits, pred_boxes, pred_keypoints): bboxes, keypoints, keypoint_scores = self.data_decoder.decode( input_shape, pred_logit, pred_bbox, pred_kpts) # pack outputs preds.append( InstanceData( keypoints=keypoints, keypoint_scores=keypoint_scores, bboxes=bboxes)) return preds def gen_encoder_output_proposals(self, memory: Tensor, memory_mask: Tensor, spatial_shapes: Tensor ) -> Tuple[Tensor, Tensor]: """Generate proposals from encoded memory. The function will only be used when `as_two_stage` is `True`. Args: memory (Tensor): The output embeddings of the Transformer encoder, has shape (bs, num_feat_points, dim). memory_mask (Tensor): ByteTensor, the padding mask of the memory, has shape (bs, num_feat_points). spatial_shapes (Tensor): Spatial shapes of features in all levels, has shape (num_levels, 2), last dimension represents (h, w). Returns: tuple: A tuple of transformed memory and proposals. - output_memory (Tensor): The transformed memory for obtaining top-k proposals, has shape (bs, num_feat_points, dim). - output_proposals (Tensor): The inverse-normalized proposal, has shape (batch_size, num_keys, 4) with the last dimension arranged as (cx, cy, w, h). """ bs = memory.size(0) proposals = [] _cur = 0 # start index in the sequence of the current level for lvl, (H, W) in enumerate(spatial_shapes): mask_flatten_ = memory_mask[:, _cur:(_cur + H * W)].view(bs, H, W, 1) valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(-1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(-1) grid_y, grid_x = torch.meshgrid( torch.linspace( 0, H - 1, H, dtype=torch.float32, device=memory.device), torch.linspace( 0, W - 1, W, dtype=torch.float32, device=memory.device)) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2) grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) proposal = torch.cat((grid, wh), -1).view(bs, -1, 4) proposals.append(proposal) _cur += (H * W) output_proposals = torch.cat(proposals, 1) output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all( -1, keepdim=True) output_proposals = inverse_sigmoid(output_proposals) output_proposals = output_proposals.masked_fill( memory_mask.unsqueeze(-1), float('inf')) output_proposals = output_proposals.masked_fill( ~output_proposals_valid, float('inf')) output_memory = memory output_memory = output_memory.masked_fill( memory_mask.unsqueeze(-1), float(0)) output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) output_memory = self.memory_trans_fc(output_memory) output_memory = self.memory_trans_norm(output_memory) # [bs, sum(hw), 2] return output_memory, output_proposals @property def default_init_cfg(self): init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)] return init_cfg def prepare_for_denosing(self, targets: OptSampleList, device): """prepare for dn components in forward function.""" if not self.training: bs = len(targets) attn_mask_infere = torch.zeros( bs, self.num_heads, self.num_group * (self.num_keypoints + 1), self.num_group * (self.num_keypoints + 1), device=device, dtype=torch.bool) group_bbox_kpt = (self.num_keypoints + 1) kpt_index = [ x for x in range(self.num_group * (self.num_keypoints + 1)) if x % (self.num_keypoints + 1) == 0 ] for matchj in range(self.num_group * (self.num_keypoints + 1)): sj = (matchj // group_bbox_kpt) * group_bbox_kpt ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt if sj > 0: attn_mask_infere[:, :, matchj, :sj] = True if ej < self.num_group * (self.num_keypoints + 1): attn_mask_infere[:, :, matchj, ej:] = True for match_x in range(self.num_group * (self.num_keypoints + 1)): if match_x % group_bbox_kpt == 0: attn_mask_infere[:, :, match_x, kpt_index] = False attn_mask_infere = attn_mask_infere.flatten(0, 1) return None, None, None, attn_mask_infere, None # targets, dn_scalar, noise_scale = dn_args device = targets[0]['boxes'].device bs = len(targets) refine_queries_num = self.refine_queries_num # gather gt boxes and labels gt_boxes = [t['boxes'] for t in targets] gt_labels = [t['labels'] for t in targets] gt_keypoints = [t['keypoints'] for t in targets] # repeat them def get_indices_for_repeat(now_num, target_num, device='cuda'): """ Input: - now_num: int - target_num: int Output: - indices: tensor[target_num] """ out_indice = [] base_indice = torch.arange(now_num).to(device) multiplier = target_num // now_num out_indice.append(base_indice.repeat(multiplier)) residue = target_num % now_num out_indice.append(base_indice[torch.randint( 0, now_num, (residue, ), device=device)]) return torch.cat(out_indice) gt_boxes_expand = [] gt_labels_expand = [] gt_keypoints_expand = [] for idx, (gt_boxes_i, gt_labels_i, gt_keypoint_i) in enumerate( zip(gt_boxes, gt_labels, gt_keypoints)): num_gt_i = gt_boxes_i.shape[0] if num_gt_i > 0: indices = get_indices_for_repeat(num_gt_i, refine_queries_num, device) gt_boxes_expand_i = gt_boxes_i[indices] # num_dn, 4 gt_labels_expand_i = gt_labels_i[indices] gt_keypoints_expand_i = gt_keypoint_i[indices] else: # all negative samples when no gt boxes gt_boxes_expand_i = torch.rand( refine_queries_num, 4, device=device) gt_labels_expand_i = torch.ones( refine_queries_num, dtype=torch.int64, device=device) * int(self.num_classes) gt_keypoints_expand_i = torch.rand( refine_queries_num, self.num_keypoints * 3, device=device) gt_boxes_expand.append(gt_boxes_expand_i) gt_labels_expand.append(gt_labels_expand_i) gt_keypoints_expand.append(gt_keypoints_expand_i) gt_boxes_expand = torch.stack(gt_boxes_expand) gt_labels_expand = torch.stack(gt_labels_expand) gt_keypoints_expand = torch.stack(gt_keypoints_expand) knwon_boxes_expand = gt_boxes_expand.clone() knwon_labels_expand = gt_labels_expand.clone() # add noise if self.denosing_cfg['dn_label_noise_ratio'] > 0: prob = torch.rand_like(knwon_labels_expand.float()) chosen_indice = prob < self.denosing_cfg['dn_label_noise_ratio'] new_label = torch.randint_like( knwon_labels_expand[chosen_indice], 0, self.dn_labelbook_size) # randomly put a new one here knwon_labels_expand[chosen_indice] = new_label if self.denosing_cfg['dn_box_noise_scale'] > 0: diff = torch.zeros_like(knwon_boxes_expand) diff[..., :2] = knwon_boxes_expand[..., 2:] / 2 diff[..., 2:] = knwon_boxes_expand[..., 2:] knwon_boxes_expand += torch.mul( (torch.rand_like(knwon_boxes_expand) * 2 - 1.0), diff) * self.denosing_cfg['dn_box_noise_scale'] knwon_boxes_expand = knwon_boxes_expand.clamp(min=0.0, max=1.0) input_query_label = self.label_enc(knwon_labels_expand) input_query_bbox = inverse_sigmoid(knwon_boxes_expand) # prepare mask if 'group2group' in self.denosing_cfg['dn_attn_mask_type_list']: attn_mask = torch.zeros( bs, self.num_heads, refine_queries_num + self.num_queries, refine_queries_num + self.num_queries, device=device, dtype=torch.bool) attn_mask[:, :, refine_queries_num:, :refine_queries_num] = True for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): num_gt_i = gt_boxes_i.shape[0] if num_gt_i == 0: continue for matchi in range(refine_queries_num): si = (matchi // num_gt_i) * num_gt_i ei = (matchi // num_gt_i + 1) * num_gt_i if si > 0: attn_mask[idx, :, matchi, :si] = True if ei < refine_queries_num: attn_mask[idx, :, matchi, ei:refine_queries_num] = True attn_mask = attn_mask.flatten(0, 1) if 'group2group' in self.denosing_cfg['dn_attn_mask_type_list']: attn_mask2 = torch.zeros( bs, self.num_heads, refine_queries_num + self.num_group * (self.num_keypoints + 1), refine_queries_num + self.num_group * (self.num_keypoints + 1), device=device, dtype=torch.bool) attn_mask2[:, :, refine_queries_num:, :refine_queries_num] = True group_bbox_kpt = (self.num_keypoints + 1) kpt_index = [ x for x in range(self.num_group * (self.num_keypoints + 1)) if x % (self.num_keypoints + 1) == 0 ] for matchj in range(self.num_group * (self.num_keypoints + 1)): sj = (matchj // group_bbox_kpt) * group_bbox_kpt ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt if sj > 0: attn_mask2[:, :, refine_queries_num:, refine_queries_num:][:, :, matchj, :sj] = True if ej < self.num_group * (self.num_keypoints + 1): attn_mask2[:, :, refine_queries_num:, refine_queries_num:][:, :, matchj, ej:] = True for match_x in range(self.num_group * (self.num_keypoints + 1)): if match_x % group_bbox_kpt == 0: attn_mask2[:, :, refine_queries_num:, refine_queries_num:][:, :, match_x, kpt_index] = False for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): num_gt_i = gt_boxes_i.shape[0] if num_gt_i == 0: continue for matchi in range(refine_queries_num): si = (matchi // num_gt_i) * num_gt_i ei = (matchi // num_gt_i + 1) * num_gt_i if si > 0: attn_mask2[idx, :, matchi, :si] = True if ei < refine_queries_num: attn_mask2[idx, :, matchi, ei:refine_queries_num] = True attn_mask2 = attn_mask2.flatten(0, 1) mask_dict = { 'pad_size': refine_queries_num, 'known_bboxs': gt_boxes_expand, 'known_labels': gt_labels_expand, 'known_keypoints': gt_keypoints_expand } return input_query_label, input_query_bbox, \ attn_mask, attn_mask2, mask_dict def loss(self, feats: Tuple[Tensor], batch_data_samples: OptSampleList, train_cfg: OptConfigType = {}) -> dict: """Calculate losses from a batch of inputs and data samples.""" assert NotImplementedError( 'the training of EDPose has not been ' 'supported. Please stay tuned for further update.')