Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Dict, List, Tuple | |
| from torch import Tensor | |
| from mmdet.registry import MODELS | |
| from mmdet.structures import SampleList | |
| from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig | |
| from .single_stage import SingleStageDetector | |
| class MaskFormer(SingleStageDetector): | |
| r"""Implementation of `Per-Pixel Classification is | |
| NOT All You Need for Semantic Segmentation | |
| <https://arxiv.org/pdf/2107.06278>`_.""" | |
| def __init__(self, | |
| backbone: ConfigType, | |
| neck: OptConfigType = None, | |
| panoptic_head: OptConfigType = None, | |
| panoptic_fusion_head: OptConfigType = None, | |
| train_cfg: OptConfigType = None, | |
| test_cfg: OptConfigType = None, | |
| data_preprocessor: OptConfigType = None, | |
| init_cfg: OptMultiConfig = None): | |
| super(SingleStageDetector, self).__init__( | |
| data_preprocessor=data_preprocessor, init_cfg=init_cfg) | |
| self.backbone = MODELS.build(backbone) | |
| if neck is not None: | |
| self.neck = MODELS.build(neck) | |
| panoptic_head_ = panoptic_head.deepcopy() | |
| panoptic_head_.update(train_cfg=train_cfg) | |
| panoptic_head_.update(test_cfg=test_cfg) | |
| self.panoptic_head = MODELS.build(panoptic_head_) | |
| panoptic_fusion_head_ = panoptic_fusion_head.deepcopy() | |
| panoptic_fusion_head_.update(test_cfg=test_cfg) | |
| self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_) | |
| self.num_things_classes = self.panoptic_head.num_things_classes | |
| self.num_stuff_classes = self.panoptic_head.num_stuff_classes | |
| self.num_classes = self.panoptic_head.num_classes | |
| self.train_cfg = train_cfg | |
| self.test_cfg = test_cfg | |
| def loss(self, batch_inputs: Tensor, | |
| batch_data_samples: SampleList) -> Dict[str, Tensor]: | |
| """ | |
| Args: | |
| batch_inputs (Tensor): Input images of shape (N, C, H, W). | |
| These should usually be mean centered and std scaled. | |
| batch_data_samples (list[:obj:`DetDataSample`]): The batch | |
| data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
| Returns: | |
| dict[str, Tensor]: a dictionary of loss components | |
| """ | |
| x = self.extract_feat(batch_inputs) | |
| losses = self.panoptic_head.loss(x, batch_data_samples) | |
| return losses | |
| def predict(self, | |
| batch_inputs: Tensor, | |
| batch_data_samples: SampleList, | |
| rescale: bool = True) -> SampleList: | |
| """Predict results from a batch of inputs and data samples with post- | |
| processing. | |
| Args: | |
| batch_inputs (Tensor): Inputs with shape (N, C, H, W). | |
| batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
| Samples. It usually includes information such as | |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
| rescale (bool): Whether to rescale the results. | |
| Defaults to True. | |
| Returns: | |
| list[:obj:`DetDataSample`]: Detection results of the | |
| input images. Each DetDataSample usually contain | |
| 'pred_instances' and `pred_panoptic_seg`. And the | |
| ``pred_instances`` usually contains following keys. | |
| - scores (Tensor): Classification scores, has a shape | |
| (num_instance, ) | |
| - labels (Tensor): Labels of bboxes, has a shape | |
| (num_instances, ). | |
| - bboxes (Tensor): Has a shape (num_instances, 4), | |
| the last dimension 4 arrange as (x1, y1, x2, y2). | |
| - masks (Tensor): Has a shape (num_instances, H, W). | |
| And the ``pred_panoptic_seg`` contains the following key | |
| - sem_seg (Tensor): panoptic segmentation mask, has a | |
| shape (1, h, w). | |
| """ | |
| feats = self.extract_feat(batch_inputs) | |
| mask_cls_results, mask_pred_results = self.panoptic_head.predict( | |
| feats, batch_data_samples) | |
| results_list = self.panoptic_fusion_head.predict( | |
| mask_cls_results, | |
| mask_pred_results, | |
| batch_data_samples, | |
| rescale=rescale) | |
| results = self.add_pred_to_datasample(batch_data_samples, results_list) | |
| return results | |
| def add_pred_to_datasample(self, data_samples: SampleList, | |
| results_list: List[dict]) -> SampleList: | |
| """Add predictions to `DetDataSample`. | |
| Args: | |
| data_samples (list[:obj:`DetDataSample`], optional): A batch of | |
| data samples that contain annotations and predictions. | |
| results_list (List[dict]): Instance segmentation, segmantic | |
| segmentation and panoptic segmentation results. | |
| Returns: | |
| list[:obj:`DetDataSample`]: Detection results of the | |
| input images. Each DetDataSample usually contain | |
| 'pred_instances' and `pred_panoptic_seg`. And the | |
| ``pred_instances`` usually contains following keys. | |
| - scores (Tensor): Classification scores, has a shape | |
| (num_instance, ) | |
| - labels (Tensor): Labels of bboxes, has a shape | |
| (num_instances, ). | |
| - bboxes (Tensor): Has a shape (num_instances, 4), | |
| the last dimension 4 arrange as (x1, y1, x2, y2). | |
| - masks (Tensor): Has a shape (num_instances, H, W). | |
| And the ``pred_panoptic_seg`` contains the following key | |
| - sem_seg (Tensor): panoptic segmentation mask, has a | |
| shape (1, h, w). | |
| """ | |
| for data_sample, pred_results in zip(data_samples, results_list): | |
| if 'pan_results' in pred_results: | |
| data_sample.pred_panoptic_seg = pred_results['pan_results'] | |
| if 'ins_results' in pred_results: | |
| data_sample.pred_instances = pred_results['ins_results'] | |
| assert 'sem_results' not in pred_results, 'segmantic ' \ | |
| 'segmentation results are not supported yet.' | |
| return data_samples | |
| def _forward(self, batch_inputs: Tensor, | |
| batch_data_samples: SampleList) -> Tuple[List[Tensor]]: | |
| """Network forward process. Usually includes backbone, neck and head | |
| forward without any post-processing. | |
| Args: | |
| batch_inputs (Tensor): Inputs with shape (N, C, H, W). | |
| batch_data_samples (list[:obj:`DetDataSample`]): The batch | |
| data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
| Returns: | |
| tuple[List[Tensor]]: A tuple of features from ``panoptic_head`` | |
| forward. | |
| """ | |
| feats = self.extract_feat(batch_inputs) | |
| results = self.panoptic_head.forward(feats, batch_data_samples) | |
| return results | |