Spaces:
Running
Running
File size: 6,679 Bytes
2e237ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import numpy as np
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from detectron2.config import configurable
from detectron2.data.detection_utils import convert_image_to_rgb
from detectron2.structures import ImageList, Instances
from detectron2.utils.events import get_event_storage
from detectron2.utils.logger import log_first_n
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
from .Wordnn_embedding import WordnnEmbedding
__all__ = ["VGT"]
def torch_memory(device, tag=""):
# Checks and prints GPU memory
print(tag, f"{torch.cuda.memory_allocated(device)/1024/1024:.2f} MB USED")
print(tag, f"{torch.cuda.memory_reserved(device)/1024/1024:.2f} MB RESERVED")
print(tag, f"{torch.cuda.max_memory_allocated(device)/1024/1024:.2f} MB USED MAX")
print(tag, f"{torch.cuda.max_memory_reserved(device)/1024/1024:.2f} MB RESERVED MAX")
print("")
@META_ARCH_REGISTRY.register()
class VGT(GeneralizedRCNN):
@configurable
def __init__(
self,
*,
vocab_size: int = 30552,
hidden_size: int = 768,
embedding_dim: int = 64,
bros_embedding_path: str = "",
use_pretrain_weight: bool = True,
use_UNK_text: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.Wordgrid_embedding = WordnnEmbedding(
vocab_size, hidden_size, embedding_dim, bros_embedding_path, use_pretrain_weight, use_UNK_text
)
@classmethod
def from_config(cls, cfg):
ret = super().from_config(cfg)
ret.update(
{
"vocab_size": cfg.MODEL.WORDGRID.VOCAB_SIZE,
"hidden_size": cfg.MODEL.WORDGRID.HIDDEN_SIZE,
"embedding_dim": cfg.MODEL.WORDGRID.EMBEDDING_DIM,
"bros_embedding_path": cfg.MODEL.WORDGRID.MODEL_PATH,
"use_pretrain_weight": cfg.MODEL.WORDGRID.USE_PRETRAIN_WEIGHT,
"use_UNK_text": cfg.MODEL.WORDGRID.USE_UNK_TEXT,
}
)
return ret
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
* image: Tensor, image in (C, H, W) format.
* instances (optional): groundtruth :class:`Instances`
* proposals (optional): :class:`Instances`, precomputed proposals.
Other information that's included in the original dicts, such as:
* "height", "width" (int): the output resolution of the model, used in inference.
See :meth:`postprocess` for details.
Returns:
list[dict]:
Each dict is the output for one input image.
The dict contains one key "instances" whose value is a :class:`Instances`.
The :class:`Instances` object has the following keys:
"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
"""
if not self.training:
return self.inference(batched_inputs)
images = self.preprocess_image(batched_inputs)
if "instances" in batched_inputs[0]:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
else:
gt_instances = None
chargrid = self.Wordgrid_embedding(images.tensor, batched_inputs)
features = self.backbone(images.tensor, chargrid)
if self.proposal_generator is not None:
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
else:
assert "proposals" in batched_inputs[0]
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
proposal_losses = {}
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
if self.vis_period > 0:
storage = get_event_storage()
if storage.iter % self.vis_period == 0:
self.visualize_training(batched_inputs, proposals)
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
return losses
def inference(
self,
batched_inputs: List[Dict[str, torch.Tensor]],
detected_instances: Optional[List[Instances]] = None,
do_postprocess: bool = True,
):
"""
Run inference on the given inputs.
Args:
batched_inputs (list[dict]): same as in :meth:`forward`
detected_instances (None or list[Instances]): if not None, it
contains an `Instances` object per image. The `Instances`
object contains "pred_boxes" and "pred_classes" which are
known boxes in the image.
The inference will then skip the detection of bounding boxes,
and only predict other per-ROI outputs.
do_postprocess (bool): whether to apply post-processing on the outputs.
Returns:
When do_postprocess=True, same as in :meth:`forward`.
Otherwise, a list[Instances] containing raw network outputs.
"""
assert not self.training
images = self.preprocess_image(batched_inputs)
chargrid = self.Wordgrid_embedding(images.tensor, batched_inputs)
features = self.backbone(images.tensor, chargrid)
if detected_instances is None:
if self.proposal_generator is not None:
proposals, _ = self.proposal_generator(images, features, None)
else:
assert "proposals" in batched_inputs[0]
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
results, _ = self.roi_heads(images, features, proposals, None)
else:
detected_instances = [x.to(self.device) for x in detected_instances]
results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
if do_postprocess:
assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess."
return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes)
else:
return results
|