Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import logging | |
import numpy as np | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
from torch import nn | |
from detectron2.config import configurable | |
from detectron2.data.detection_utils import convert_image_to_rgb | |
from detectron2.structures import ImageList, Instances | |
from detectron2.utils.events import get_event_storage | |
from detectron2.utils.logger import log_first_n | |
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY | |
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN | |
from .Wordnn_embedding import WordnnEmbedding | |
__all__ = ["VGT"] | |
def torch_memory(device, tag=""): | |
# Checks and prints GPU memory | |
print(tag, f"{torch.cuda.memory_allocated(device)/1024/1024:.2f} MB USED") | |
print(tag, f"{torch.cuda.memory_reserved(device)/1024/1024:.2f} MB RESERVED") | |
print(tag, f"{torch.cuda.max_memory_allocated(device)/1024/1024:.2f} MB USED MAX") | |
print(tag, f"{torch.cuda.max_memory_reserved(device)/1024/1024:.2f} MB RESERVED MAX") | |
print("") | |
class VGT(GeneralizedRCNN): | |
def __init__( | |
self, | |
*, | |
vocab_size: int = 30552, | |
hidden_size: int = 768, | |
embedding_dim: int = 64, | |
bros_embedding_path: str = "", | |
use_pretrain_weight: bool = True, | |
use_UNK_text: bool = False, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.vocab_size = vocab_size | |
self.embedding_dim = embedding_dim | |
self.Wordgrid_embedding = WordnnEmbedding( | |
vocab_size, hidden_size, embedding_dim, bros_embedding_path, use_pretrain_weight, use_UNK_text | |
) | |
def from_config(cls, cfg): | |
ret = super().from_config(cfg) | |
ret.update( | |
{ | |
"vocab_size": cfg.MODEL.WORDGRID.VOCAB_SIZE, | |
"hidden_size": cfg.MODEL.WORDGRID.HIDDEN_SIZE, | |
"embedding_dim": cfg.MODEL.WORDGRID.EMBEDDING_DIM, | |
"bros_embedding_path": cfg.MODEL.WORDGRID.MODEL_PATH, | |
"use_pretrain_weight": cfg.MODEL.WORDGRID.USE_PRETRAIN_WEIGHT, | |
"use_UNK_text": cfg.MODEL.WORDGRID.USE_UNK_TEXT, | |
} | |
) | |
return ret | |
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): | |
""" | |
Args: | |
batched_inputs: a list, batched outputs of :class:`DatasetMapper` . | |
Each item in the list contains the inputs for one image. | |
For now, each item in the list is a dict that contains: | |
* image: Tensor, image in (C, H, W) format. | |
* instances (optional): groundtruth :class:`Instances` | |
* proposals (optional): :class:`Instances`, precomputed proposals. | |
Other information that's included in the original dicts, such as: | |
* "height", "width" (int): the output resolution of the model, used in inference. | |
See :meth:`postprocess` for details. | |
Returns: | |
list[dict]: | |
Each dict is the output for one input image. | |
The dict contains one key "instances" whose value is a :class:`Instances`. | |
The :class:`Instances` object has the following keys: | |
"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" | |
""" | |
if not self.training: | |
return self.inference(batched_inputs) | |
images = self.preprocess_image(batched_inputs) | |
if "instances" in batched_inputs[0]: | |
gt_instances = [x["instances"].to(self.device) for x in batched_inputs] | |
else: | |
gt_instances = None | |
chargrid = self.Wordgrid_embedding(images.tensor, batched_inputs) | |
features = self.backbone(images.tensor, chargrid) | |
if self.proposal_generator is not None: | |
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) | |
else: | |
assert "proposals" in batched_inputs[0] | |
proposals = [x["proposals"].to(self.device) for x in batched_inputs] | |
proposal_losses = {} | |
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances) | |
if self.vis_period > 0: | |
storage = get_event_storage() | |
if storage.iter % self.vis_period == 0: | |
self.visualize_training(batched_inputs, proposals) | |
losses = {} | |
losses.update(detector_losses) | |
losses.update(proposal_losses) | |
return losses | |
def inference( | |
self, | |
batched_inputs: List[Dict[str, torch.Tensor]], | |
detected_instances: Optional[List[Instances]] = None, | |
do_postprocess: bool = True, | |
): | |
""" | |
Run inference on the given inputs. | |
Args: | |
batched_inputs (list[dict]): same as in :meth:`forward` | |
detected_instances (None or list[Instances]): if not None, it | |
contains an `Instances` object per image. The `Instances` | |
object contains "pred_boxes" and "pred_classes" which are | |
known boxes in the image. | |
The inference will then skip the detection of bounding boxes, | |
and only predict other per-ROI outputs. | |
do_postprocess (bool): whether to apply post-processing on the outputs. | |
Returns: | |
When do_postprocess=True, same as in :meth:`forward`. | |
Otherwise, a list[Instances] containing raw network outputs. | |
""" | |
assert not self.training | |
images = self.preprocess_image(batched_inputs) | |
chargrid = self.Wordgrid_embedding(images.tensor, batched_inputs) | |
features = self.backbone(images.tensor, chargrid) | |
if detected_instances is None: | |
if self.proposal_generator is not None: | |
proposals, _ = self.proposal_generator(images, features, None) | |
else: | |
assert "proposals" in batched_inputs[0] | |
proposals = [x["proposals"].to(self.device) for x in batched_inputs] | |
results, _ = self.roi_heads(images, features, proposals, None) | |
else: | |
detected_instances = [x.to(self.device) for x in detected_instances] | |
results = self.roi_heads.forward_with_given_boxes(features, detected_instances) | |
if do_postprocess: | |
assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." | |
return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes) | |
else: | |
return results | |