File size: 2,003 Bytes
e7bac26
 
 
 
 
 
 
deabc90
e7bac26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deabc90
e7bac26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import cv2
from detectron2 import model_zoo
from detectron2.config import get_cfg, CfgNode
from detectron2.engine import DefaultPredictor
from detectron2.structures import Instances
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2.data.datasets import load_coco_json

DEVICE = 'cpu'

class Predictor():
    config: CfgNode
    
    def __init__(self) -> None:
        # 設定を初期化
        self.config = self._init_custom_config()
    
    def _init_custom_config(self):
        cfg = get_cfg()
        
        # 設定ファイルを取得
        cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"))
        cfg.MODEL.DEVICE = DEVICE

        load_coco_json('./test/_annotations.coco.json', './test', 'my_dataset_test')
        test_metadata = MetadataCatalog.get("my_dataset_test")
        print(test_metadata)
        cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(test_metadata.thing_classes)
        cfg.TEST.DETECTIONS_PER_IMAGE = 1000 # default 100
        
        return cfg

    def predict(self, model: str, img_path: str, score_min_percent: int):
        # 設定の変更
        # ※設定の変更は予測器の生成の前に実施する必要がある
        self.config.MODEL.WEIGHTS = f"models/{model}"
        self.config.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_min_percent / 100

        # 予測器を作成
        predictor = DefaultPredictor(self.config)
        
        # 推論する対象の画像を読み込み
        img = cv2.imread(img_path)

        outputs: Instances = predictor(img)["instances"]
        test_metadata = MetadataCatalog.get("my_dataset_test")
        # 推論した結果を画像に書き込む
        v = Visualizer(img[:, :, ::-1], test_metadata, scale=1.0)
        out = v.draw_instance_predictions(outputs.to(DEVICE))
        count = len(outputs)
        
        # 画像を返却する
        return out.get_image(), count