# -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. import logging from typing import Dict, List import torch from torch import nn from custom_detectron2.config import configurable from custom_detectron2.structures import ImageList from ..postprocessing import detector_postprocess, sem_seg_postprocess from .build import META_ARCH_REGISTRY from .rcnn import GeneralizedRCNN from .semantic_seg import build_sem_seg_head __all__ = ["PanopticFPN"] @META_ARCH_REGISTRY.register() class PanopticFPN(GeneralizedRCNN): """ Implement the paper :paper:`PanopticFPN`. """ @configurable def __init__( self, *, sem_seg_head: nn.Module, combine_overlap_thresh: float = 0.5, combine_stuff_area_thresh: float = 4096, combine_instances_score_thresh: float = 0.5, **kwargs, ): """ NOTE: this interface is experimental. Args: sem_seg_head: a module for the semantic segmentation head. combine_overlap_thresh: combine masks into one instances if they have enough overlap combine_stuff_area_thresh: ignore stuff areas smaller than this threshold combine_instances_score_thresh: ignore instances whose score is smaller than this threshold Other arguments are the same as :class:`GeneralizedRCNN`. """ super().__init__(**kwargs) self.sem_seg_head = sem_seg_head # options when combining instance & semantic outputs self.combine_overlap_thresh = combine_overlap_thresh self.combine_stuff_area_thresh = combine_stuff_area_thresh self.combine_instances_score_thresh = combine_instances_score_thresh @classmethod def from_config(cls, cfg): ret = super().from_config(cfg) ret.update( { "combine_overlap_thresh": cfg.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH, "combine_stuff_area_thresh": cfg.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT, "combine_instances_score_thresh": cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH, # noqa } ) ret["sem_seg_head"] = build_sem_seg_head(cfg, ret["backbone"].output_shape()) logger = logging.getLogger(__name__) if not cfg.MODEL.PANOPTIC_FPN.COMBINE.ENABLED: logger.warning( "PANOPTIC_FPN.COMBINED.ENABLED is no longer used. " " model.inference(do_postprocess=) should be used to toggle postprocessing." ) if cfg.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT != 1.0: w = cfg.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT logger.warning( "PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT should be replaced by weights on each ROI head." ) def update_weight(x): if isinstance(x, dict): return {k: v * w for k, v in x.items()} else: return x * w roi_heads = ret["roi_heads"] roi_heads.box_predictor.loss_weight = update_weight(roi_heads.box_predictor.loss_weight) roi_heads.mask_head.loss_weight = update_weight(roi_heads.mask_head.loss_weight) return ret def forward(self, batched_inputs): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * "image": Tensor, image in (C, H, W) format. * "instances": Instances * "sem_seg": semantic segmentation ground truth. * Other information that's included in the original dicts, such as: "height", "width" (int): the output resolution of the model, used in inference. See :meth:`postprocess` for details. Returns: list[dict]: each dict has the results for one image. The dict contains the following keys: * "instances": see :meth:`GeneralizedRCNN.forward` for its format. * "sem_seg": see :meth:`SemanticSegmentor.forward` for its format. * "panoptic_seg": See the return value of :func:`combine_semantic_and_instance_outputs` for its format. """ if not self.training: return self.inference(batched_inputs) images = self.preprocess_image(batched_inputs) features = self.backbone(images.tensor) assert "sem_seg" in batched_inputs[0] gt_sem_seg = [x["sem_seg"].to(self.device) for x in batched_inputs] gt_sem_seg = ImageList.from_tensors( gt_sem_seg, self.backbone.size_divisibility, self.sem_seg_head.ignore_value, self.backbone.padding_constraints, ).tensor sem_seg_results, sem_seg_losses = self.sem_seg_head(features, gt_sem_seg) gt_instances = [x["instances"].to(self.device) for x in batched_inputs] proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) detector_results, detector_losses = self.roi_heads( images, features, proposals, gt_instances ) losses = sem_seg_losses losses.update(proposal_losses) losses.update(detector_losses) return losses def inference(self, batched_inputs: List[Dict[str, torch.Tensor]], do_postprocess: bool = True): """ Run inference on the given inputs. Args: batched_inputs (list[dict]): same as in :meth:`forward` do_postprocess (bool): whether to apply post-processing on the outputs. Returns: When do_postprocess=True, see docs in :meth:`forward`. Otherwise, returns a (list[Instances], list[Tensor]) that contains the raw detector outputs, and raw semantic segmentation outputs. """ images = self.preprocess_image(batched_inputs) features = self.backbone(images.tensor) sem_seg_results, sem_seg_losses = self.sem_seg_head(features, None) proposals, _ = self.proposal_generator(images, features, None) detector_results, _ = self.roi_heads(images, features, proposals, None) if do_postprocess: processed_results = [] for sem_seg_result, detector_result, input_per_image, image_size in zip( sem_seg_results, detector_results, batched_inputs, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height, width) detector_r = detector_postprocess(detector_result, height, width) processed_results.append({"sem_seg": sem_seg_r, "instances": detector_r}) panoptic_r = combine_semantic_and_instance_outputs( detector_r, sem_seg_r.argmax(dim=0), self.combine_overlap_thresh, self.combine_stuff_area_thresh, self.combine_instances_score_thresh, ) processed_results[-1]["panoptic_seg"] = panoptic_r return processed_results else: return detector_results, sem_seg_results def combine_semantic_and_instance_outputs( instance_results, semantic_results, overlap_threshold, stuff_area_thresh, instances_score_thresh, ): """ Implement a simple combining logic following "combine_semantic_and_instance_predictions.py" in panopticapi to produce panoptic segmentation outputs. Args: instance_results: output of :func:`detector_postprocess`. semantic_results: an (H, W) tensor, each element is the contiguous semantic category id Returns: panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. segments_info (list[dict]): Describe each segment in `panoptic_seg`. Each dict contains keys "id", "category_id", "isthing". """ panoptic_seg = torch.zeros_like(semantic_results, dtype=torch.int32) # sort instance outputs by scores sorted_inds = torch.argsort(-instance_results.scores) current_segment_id = 0 segments_info = [] instance_masks = instance_results.pred_masks.to(dtype=torch.bool, device=panoptic_seg.device) # Add instances one-by-one, check for overlaps with existing ones for inst_id in sorted_inds: score = instance_results.scores[inst_id].item() if score < instances_score_thresh: break mask = instance_masks[inst_id] # H,W mask_area = mask.sum().item() if mask_area == 0: continue intersect = (mask > 0) & (panoptic_seg > 0) intersect_area = intersect.sum().item() if intersect_area * 1.0 / mask_area > overlap_threshold: continue if intersect_area > 0: mask = mask & (panoptic_seg == 0) current_segment_id += 1 panoptic_seg[mask] = current_segment_id segments_info.append( { "id": current_segment_id, "isthing": True, "score": score, "category_id": instance_results.pred_classes[inst_id].item(), "instance_id": inst_id.item(), } ) # Add semantic results to remaining empty areas semantic_labels = torch.unique(semantic_results).cpu().tolist() for semantic_label in semantic_labels: if semantic_label == 0: # 0 is a special "thing" class continue mask = (semantic_results == semantic_label) & (panoptic_seg == 0) mask_area = mask.sum().item() if mask_area < stuff_area_thresh: continue current_segment_id += 1 panoptic_seg[mask] = current_segment_id segments_info.append( { "id": current_segment_id, "isthing": False, "category_id": semantic_label, "area": mask_area, } ) return panoptic_seg, segments_info