Spaces:
Runtime error
Runtime error
| import warnings | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| # from mmcv.ops import RoIPool | |
| from mmcv.parallel import collate, scatter | |
| from mmcv.runner import load_checkpoint | |
| from mmdet.core import get_classes | |
| from mmdet.datasets import replace_ImageToTensor | |
| from mmdet.datasets.pipelines import Compose | |
| from mmdet.models import build_detector | |
| def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None): | |
| """Initialize a detector from config file. | |
| Args: | |
| config (str or :obj:`mmcv.Config`): Config file path or the config | |
| object. | |
| checkpoint (str, optional): Checkpoint path. If left as None, the model | |
| will not load any weights. | |
| cfg_options (dict): Options to override some settings in the used | |
| config. | |
| Returns: | |
| nn.Module: The constructed detector. | |
| """ | |
| if isinstance(config, str): | |
| config = mmcv.Config.fromfile(config) | |
| elif not isinstance(config, mmcv.Config): | |
| raise TypeError('config must be a filename or Config object, ' | |
| f'but got {type(config)}') | |
| if cfg_options is not None: | |
| config.merge_from_dict(cfg_options) | |
| config.model.pretrained = None | |
| config.model.train_cfg = None | |
| model = build_detector(config.model, test_cfg=config.get('test_cfg')) | |
| if checkpoint is not None: | |
| map_loc = 'cpu' if device == 'cpu' else None | |
| checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc) | |
| if 'CLASSES' in checkpoint.get('meta', {}): | |
| model.CLASSES = checkpoint['meta']['CLASSES'] | |
| else: | |
| warnings.simplefilter('once') | |
| warnings.warn('Class names are not saved in the checkpoint\'s ' | |
| 'meta data, use COCO classes by default.') | |
| model.CLASSES = get_classes('coco') | |
| model.cfg = config # save the config in the model for convenience | |
| model.to(device) | |
| model.eval() | |
| return model | |
| class LoadImage(object): | |
| """Deprecated. | |
| A simple pipeline to load image. | |
| """ | |
| def __call__(self, results): | |
| """Call function to load images into results. | |
| Args: | |
| results (dict): A result dict contains the file name | |
| of the image to be read. | |
| Returns: | |
| dict: ``results`` will be returned containing loaded image. | |
| """ | |
| warnings.simplefilter('once') | |
| warnings.warn('`LoadImage` is deprecated and will be removed in ' | |
| 'future releases. You may use `LoadImageFromWebcam` ' | |
| 'from `mmdet.datasets.pipelines.` instead.') | |
| if isinstance(results['img'], str): | |
| results['filename'] = results['img'] | |
| results['ori_filename'] = results['img'] | |
| else: | |
| results['filename'] = None | |
| results['ori_filename'] = None | |
| img = mmcv.imread(results['img']) | |
| results['img'] = img | |
| results['img_fields'] = ['img'] | |
| results['img_shape'] = img.shape | |
| results['ori_shape'] = img.shape | |
| return results | |
| def inference_detector(model, imgs): | |
| """Inference image(s) with the detector. | |
| Args: | |
| model (nn.Module): The loaded detector. | |
| imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): | |
| Either image files or loaded images. | |
| Returns: | |
| If imgs is a list or tuple, the same length list type results | |
| will be returned, otherwise return the detection results directly. | |
| """ | |
| if isinstance(imgs, (list, tuple)): | |
| is_batch = True | |
| else: | |
| imgs = [imgs] | |
| is_batch = False | |
| cfg = model.cfg | |
| device = next(model.parameters()).device # model device | |
| if isinstance(imgs[0], np.ndarray): | |
| cfg = cfg.copy() | |
| # set loading pipeline type | |
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |
| cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | |
| test_pipeline = Compose(cfg.data.test.pipeline) | |
| datas = [] | |
| for img in imgs: | |
| # prepare data | |
| if isinstance(img, np.ndarray): | |
| # directly add img | |
| data = dict(img=img) | |
| else: | |
| # add information into dict | |
| data = dict(img_info=dict(filename=img), img_prefix=None) | |
| # build the data pipeline | |
| data = test_pipeline(data) | |
| datas.append(data) | |
| data = collate(datas, samples_per_gpu=len(imgs)) | |
| # just get the actual data from DataContainer | |
| data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']] | |
| data['img'] = [img.data[0] for img in data['img']] | |
| if next(model.parameters()).is_cuda: | |
| # scatter to specified GPU | |
| data = scatter(data, [device])[0] | |
| else: | |
| for m in model.modules(): | |
| assert not isinstance( | |
| m, RoIPool | |
| ), 'CPU inference with RoIPool is not supported currently.' | |
| # forward the model | |
| with torch.no_grad(): | |
| results = model(return_loss=False, rescale=True, **data) | |
| if not is_batch: | |
| return results[0] | |
| else: | |
| return results | |
| async def async_inference_detector(model, img): | |
| """Async inference image(s) with the detector. | |
| Args: | |
| model (nn.Module): The loaded detector. | |
| img (str | ndarray): Either image files or loaded images. | |
| Returns: | |
| Awaitable detection results. | |
| """ | |
| cfg = model.cfg | |
| device = next(model.parameters()).device # model device | |
| # prepare data | |
| if isinstance(img, np.ndarray): | |
| # directly add img | |
| data = dict(img=img) | |
| cfg = cfg.copy() | |
| # set loading pipeline type | |
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |
| else: | |
| # add information into dict | |
| data = dict(img_info=dict(filename=img), img_prefix=None) | |
| # build the data pipeline | |
| test_pipeline = Compose(cfg.data.test.pipeline) | |
| data = test_pipeline(data) | |
| data = scatter(collate([data], samples_per_gpu=1), [device])[0] | |
| # We don't restore `torch.is_grad_enabled()` value during concurrent | |
| # inference since execution can overlap | |
| torch.set_grad_enabled(False) | |
| result = await model.aforward_test(rescale=True, **data) | |
| return result | |
| def show_result_pyplot(model, | |
| img, | |
| result, | |
| score_thr=0.3, | |
| title='result', | |
| wait_time=0): | |
| """Visualize the detection results on the image. | |
| Args: | |
| model (nn.Module): The loaded detector. | |
| img (str or np.ndarray): Image filename or loaded image. | |
| result (tuple[list] or list): The detection result, can be either | |
| (bbox, segm) or just bbox. | |
| score_thr (float): The threshold to visualize the bboxes and masks. | |
| title (str): Title of the pyplot figure. | |
| wait_time (float): Value of waitKey param. | |
| Default: 0. | |
| """ | |
| if hasattr(model, 'module'): | |
| model = model.module | |
| model.show_result( | |
| img, | |
| result, | |
| score_thr=score_thr, | |
| show=True, | |
| wait_time=wait_time, | |
| win_name=title, | |
| bbox_color=(72, 101, 241), | |
| text_color=(72, 101, 241)) | |