eadali commited on
Commit
5f1587b
·
1 Parent(s): 42f668e

Test yolo onnx model

Browse files
app.py CHANGED
@@ -19,6 +19,8 @@ from typing import List, Optional, Tuple
19
  from PIL import Image
20
  from transformers import AutoModelForObjectDetection, AutoImageProcessor
21
  from transformers.image_utils import load_image
 
 
22
 
23
 
24
  # Configuration constants
@@ -564,4 +566,7 @@ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
564
  )
565
 
566
  if __name__ == "__main__":
 
 
 
567
  demo.queue(max_size=20).launch()
 
19
  from PIL import Image
20
  from transformers import AutoModelForObjectDetection, AutoImageProcessor
21
  from transformers.image_utils import load_image
22
+ from pipeline import build_pipeline
23
+ from utils import cfg, load_config, load_onnx_model
24
 
25
 
26
  # Configuration constants
 
566
  )
567
 
568
  if __name__ == "__main__":
569
+ load_config(cfg, 'configs/yolo8n-bytetrack-cpu')
570
+ pipeline = build_pipeline(cfg.pipeline)
571
+ load_onnx_model(pipeline.detector, 'downloads/yolo8n-416.onnx')
572
  demo.queue(max_size=20).launch()
configs/yolo8n-bytetrack-cpu.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv8n + ByteTrack Configuration
2
+ pipeline:
3
+ detector:
4
+ model: yolov8n
5
+ categories: ['LightVehicle', 'Person', 'Building', 'UPole', 'Boat', 'Bike', 'Container', 'Truck', 'Gastank', 'Digger', 'Solarpanels', 'Bus']
6
+ thresholds:
7
+ confidence: 0.6
8
+ iou: 0.4
9
+ slicing:
10
+ overlap: 0.2
11
+ device: cpu
12
+
13
+ tracker:
14
+ algorithm: bytetrack
downloads/yolo8n-416.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afbf81f217973b0c9e0b256c0b156d2f08b55169b7ac3f9e2e52004a2449fad1
3
+ size 12162747
pipeline/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pipeline import Pipeline
2
+ from .detectors import build_detector
3
+ from .trackers import build_tracker
4
+
5
+ def build_pipeline(config):
6
+ """
7
+ Build and return a pipeline based on the provided configuration.
8
+ """
9
+ # Build detector and tracker using the config
10
+ detector = build_detector(config.detector)
11
+ tracker = build_tracker(config.tracker)
12
+
13
+ # Create and return a Pipeline object with detector and tracker
14
+ return Pipeline(detector=detector, tracker=tracker)
pipeline/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (679 Bytes). View file
 
pipeline/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (1.38 kB). View file
 
pipeline/detectors/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .yolo import YOLO
2
+
3
+
4
+ def build_detector(config):
5
+ """
6
+ Build the detection model based on the provided configuration.
7
+ """
8
+ # Initialize the YOLO object detection model
9
+ return YOLO(config.thresholds.confidence,
10
+ config.thresholds.iou,
11
+ config.slicing.overlap,
12
+ config.categories,
13
+ config.device)
14
+
pipeline/detectors/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (705 Bytes). View file
 
pipeline/detectors/__pycache__/yolo.cpython-312.pyc ADDED
Binary file (4 kB). View file
 
pipeline/detectors/yolo.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import supervision as sv
3
+ from sahi import AutoDetectionModel
4
+ from sahi.predict import get_sliced_prediction
5
+
6
+ class YOLO:
7
+ def __init__(self, confidence_threshold, iou_threshold, slicing_overlap, categories, device):
8
+ """
9
+ YOLO detector wrapper using SAHI for sliced prediction.
10
+
11
+ Args:
12
+ confidence_threshold (float): Minimum confidence for detections.
13
+ iou_threshold (float): IoU threshold for NMS (not used directly here).
14
+ slicing_overlap (float): Overlap ratio for slicing.
15
+ categories (list): List of class names.
16
+ device (str): Device to run the model on ('cpu' or 'cuda').
17
+ """
18
+ self.model = None
19
+ self.confidence_threshold = confidence_threshold
20
+ self.iou_threshold = iou_threshold
21
+ self.slicing_overlap = slicing_overlap
22
+ self.categories = categories
23
+ self.category_mapping = {str(i): category for i, category in enumerate(categories)}
24
+ self.device = device
25
+
26
+ def load_onnx_model(self, path):
27
+ """
28
+ Loads the ONNX model using SAHI's AutoDetectionModel.
29
+ """
30
+ self.model = AutoDetectionModel.from_pretrained(
31
+ model_type='yolov8onnx',
32
+ model_path=path,
33
+ confidence_threshold=self.confidence_threshold,
34
+ category_mapping=self.category_mapping,
35
+ device=self.device
36
+ )
37
+
38
+ def __call__(self, frame):
39
+ """
40
+ Runs sliced prediction on the input frame and returns a supervision.Detections object.
41
+ """
42
+ # Get input shape from ONNX model
43
+ input_shape = self.model.model.get_inputs()[0].shape[2]
44
+ result = get_sliced_prediction(
45
+ frame,
46
+ self.model,
47
+ slice_height=input_shape,
48
+ slice_width=input_shape,
49
+ overlap_height_ratio=self.slicing_overlap,
50
+ overlap_width_ratio=self.slicing_overlap,
51
+ verbose=False,
52
+ )
53
+ boxes = []
54
+ confidences = []
55
+ class_ids = []
56
+ for det in result.object_prediction_list:
57
+ boxes.append(det.bbox.to_xyxy())
58
+ confidences.append(det.score.value)
59
+ class_ids.append(det.category.id)
60
+ if boxes:
61
+ boxes = np.array(boxes)
62
+ confidences = np.array(confidences)
63
+ class_ids = np.array(class_ids)
64
+ else:
65
+ boxes = np.zeros((0, 4))
66
+ confidences = np.zeros((0,))
67
+ class_ids = np.zeros((0,))
68
+ detections = sv.Detections(
69
+ xyxy=boxes,
70
+ confidence=confidences,
71
+ class_id=class_ids,
72
+ )
73
+ return detections
74
+
75
+ def get_category_mapping(self):
76
+ """
77
+ Returns the category mapping.
78
+ """
79
+ # Convert string keys to integers
80
+ return {int(k): v for k, v in self.category_mapping.items()}
pipeline/pipeline.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Pipeline:
2
+ def __init__(self, detector, tracker):
3
+ """
4
+ Initialize the Pipeline class with a detector and tracker.
5
+
6
+ Args:
7
+ detector (object): The object detection model.
8
+ tracker (object): The object tracking model.
9
+ """
10
+ self.detector = detector
11
+ self.tracker = tracker
12
+
13
+ def load_state_dict(self, onnx_path):
14
+ self.detector.load_state_dict(onnx_path)
15
+
16
+ def __call__(self, frame):
17
+ """
18
+ Run the detection and tracking on the input image.
19
+
20
+ Args:
21
+ frame (np.ndarray): The input image to process.
22
+
23
+ Returns:
24
+ supervision.Detections: Detections object after tracking.
25
+ """
26
+ return self.tracker(self.detector(frame))
pipeline/trackers/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .bytetrack import ByteTrack
2
+
3
+ def build_tracker(cfg):
4
+ """
5
+ Build the tracking model based on the provided configuration.
6
+ """
7
+ # Initialize the tracker
8
+ return ByteTrack()
pipeline/trackers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (407 Bytes). View file
 
pipeline/trackers/__pycache__/bytetrack.cpython-312.pyc ADDED
Binary file (1.05 kB). View file
 
pipeline/trackers/bytetrack.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import supervision as sv
2
+
3
+
4
+ class ByteTrack:
5
+ def __init__(self):
6
+ self.tracker = sv.ByteTrack()
7
+ self.smoother = sv.DetectionsSmoother()
8
+
9
+ def __call__(self, detections):
10
+ """Process detections using ByteTrack."""
11
+ # supervision_detections = self._convert_pytorch_to_supervision(detections)
12
+ tracked_detections = self.tracker.update_with_detections(detections)
13
+ smoothed_detections = self.smoother.update_with_detections(tracked_detections)
14
+ return smoothed_detections
requirements.txt CHANGED
@@ -10,3 +10,5 @@ supervision
10
  trackers[deepsort] @ git+https://github.com/roboflow/trackers
11
  spaces
12
  imageio[pyav]
 
 
 
10
  trackers[deepsort] @ git+https://github.com/roboflow/trackers
11
  spaces
12
  imageio[pyav]
13
+ onnxruntime
14
+ sahi
utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from .config import cfg, load_config
3
+ from .model_loader import load_onnx_model
4
+
5
+
6
+ __all__ = ["cfg", "load_config", "load_onnx_model"]
utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (293 Bytes). View file
 
utils/__pycache__/check_point.cpython-312.pyc ADDED
Binary file (310 Bytes). View file
 
utils/__pycache__/config.cpython-312.pyc ADDED
Binary file (1.72 kB). View file
 
utils/__pycache__/model_loader.cpython-312.pyc ADDED
Binary file (311 Bytes). View file
 
utils/__pycache__/yacs.cpython-312.pyc ADDED
Binary file (21.7 kB). View file
 
utils/config.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .yacs import CfgNode
2
+
3
+ cfg = CfgNode(new_allowed=True)
4
+ cfg.save_dir = "./"
5
+ # common params for NETWORK
6
+ cfg.model = CfgNode(new_allowed=True)
7
+ cfg.model.arch = CfgNode(new_allowed=True)
8
+ cfg.model.arch.backbone = CfgNode(new_allowed=True)
9
+ cfg.model.arch.fpn = CfgNode(new_allowed=True)
10
+ cfg.model.arch.head = CfgNode(new_allowed=True)
11
+
12
+ # DATASET related params
13
+ cfg.data = CfgNode(new_allowed=True)
14
+ cfg.data.train = CfgNode(new_allowed=True)
15
+ cfg.data.val = CfgNode(new_allowed=True)
16
+ cfg.device = CfgNode(new_allowed=True)
17
+ cfg.device.precision = 32
18
+ # train
19
+ cfg.schedule = CfgNode(new_allowed=True)
20
+
21
+ # logger
22
+ cfg.log = CfgNode()
23
+ cfg.log.interval = 50
24
+
25
+ # testing
26
+ cfg.test = CfgNode()
27
+ # size of images for each device
28
+
29
+
30
+ def load_config(cfg, args_cfg):
31
+ cfg.defrost()
32
+ cfg.merge_from_file(args_cfg)
33
+ cfg.freeze()
34
+
35
+
36
+ if __name__ == "__main__":
37
+ import sys
38
+
39
+ with open(sys.argv[1], "w") as f:
40
+ print(cfg, file=f)
utils/model_loader.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def load_onnx_model(detector, path):
2
+ detector.load_onnx_model(path)
utils/yacs.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ ##############################################################################
15
+ """YACS -- Yet Another Configuration System is designed to be a simple
16
+ configuration management system for academic and industrial research
17
+ projects.
18
+
19
+ See README.md for usage and examples.
20
+ """
21
+
22
+ import copy
23
+ import io
24
+ import logging
25
+ import os
26
+ import sys
27
+ from ast import literal_eval
28
+
29
+ import yaml
30
+
31
+ # Flag for py2 and py3 compatibility to use when separate code paths are necessary
32
+ # When _PY2 is False, we assume Python 3 is in use
33
+ _PY2 = sys.version_info.major == 2
34
+
35
+ # Filename extensions for loading configs from files
36
+ _YAML_EXTS = {"", ".yaml", ".yml"}
37
+ _PY_EXTS = {".py"}
38
+
39
+ _FILE_TYPES = (io.IOBase,)
40
+
41
+ # CfgNodes can only contain a limited set of valid types
42
+ _VALID_TYPES = {tuple, list, str, int, float, bool, type(None)}
43
+ # py2 allow for str and unicode
44
+ if _PY2:
45
+ _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821
46
+
47
+ # Utilities for importing modules from file paths
48
+ if _PY2:
49
+ # imp is available in both py2 and py3 for now, but is deprecated in py3
50
+ import imp
51
+ else:
52
+ import importlib.util
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ class CfgNode(dict):
58
+ """
59
+ CfgNode represents an internal node in the configuration tree. It's a simple
60
+ dict-like container that allows for attribute-based access to keys.
61
+ """
62
+
63
+ IMMUTABLE = "__immutable__"
64
+ DEPRECATED_KEYS = "__deprecated_keys__"
65
+ RENAMED_KEYS = "__renamed_keys__"
66
+ NEW_ALLOWED = "__new_allowed__"
67
+
68
+ def __init__(self, init_dict=None, key_list=None, new_allowed=False):
69
+ """
70
+ Args:
71
+ init_dict (dict): the possibly-nested dictionary to initailize the
72
+ CfgNode.
73
+ key_list (list[str]): a list of names which index this CfgNode from
74
+ the root.
75
+ Currently only used for logging purposes.
76
+ new_allowed (bool): whether adding new key is allowed when merging with
77
+ other configs.
78
+ """
79
+ # Recursively convert nested dictionaries in init_dict into CfgNodes
80
+ init_dict = {} if init_dict is None else init_dict
81
+ key_list = [] if key_list is None else key_list
82
+ init_dict = self._create_config_tree_from_dict(init_dict, key_list)
83
+ super(CfgNode, self).__init__(init_dict)
84
+ # Manage if the CfgNode is frozen or not
85
+ self.__dict__[CfgNode.IMMUTABLE] = False
86
+ # Deprecated options
87
+ # If an option is removed from the code and you don't want to break existing
88
+ # yaml configs, you can add the full config key as a string to the set below.
89
+ self.__dict__[CfgNode.DEPRECATED_KEYS] = set()
90
+ # Renamed options
91
+ # If you rename a config option, record the mapping from the old name to the
92
+ # new name in the dictionary below. Optionally, if the type also changed, you
93
+ # can make the value a tuple that specifies first the renamed key and then
94
+ # instructions for how to edit the config file.
95
+ self.__dict__[CfgNode.RENAMED_KEYS] = {
96
+ # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow
97
+ # 'EXAMPLE.OLD.KEY': ( # A more complex example to follow
98
+ # 'EXAMPLE.NEW.KEY',
99
+ # "Also convert to a tuple, e.g., 'foo' -> ('foo',) or "
100
+ # + "'foo:bar' -> ('foo', 'bar')"
101
+ # ),
102
+ }
103
+
104
+ # Allow new attributes after initialisation
105
+ self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed
106
+
107
+ @classmethod
108
+ def _create_config_tree_from_dict(cls, dic, key_list):
109
+ """
110
+ Create a configuration tree using the given dict.
111
+ Any dict-like objects inside dict will be treated as a new CfgNode.
112
+
113
+ Args:
114
+ dic (dict):
115
+ key_list (list[str]): a list of names which index this CfgNode from
116
+ the root. Currently only used for logging purposes.
117
+ """
118
+ dic = copy.deepcopy(dic)
119
+ for k, v in dic.items():
120
+ if isinstance(v, dict):
121
+ # Convert dict to CfgNode
122
+ dic[k] = cls(v, key_list=key_list + [k])
123
+ else:
124
+ # Check for valid leaf type or nested CfgNode
125
+ _assert_with_logging(
126
+ _valid_type(v, allow_cfg_node=False),
127
+ "Key {} with value {} is not a valid type; valid types: {}".format(
128
+ ".".join(key_list + [k]), type(v), _VALID_TYPES
129
+ ),
130
+ )
131
+ return dic
132
+
133
+ def __getattr__(self, name):
134
+ if name in self:
135
+ return self[name]
136
+ else:
137
+ raise AttributeError(name)
138
+
139
+ def __setattr__(self, name, value):
140
+ if self.is_frozen():
141
+ raise AttributeError(
142
+ "Attempted to set {} to {}, but CfgNode is immutable".format(
143
+ name, value
144
+ )
145
+ )
146
+
147
+ _assert_with_logging(
148
+ name not in self.__dict__,
149
+ "Invalid attempt to modify internal CfgNode state: {}".format(name),
150
+ )
151
+ _assert_with_logging(
152
+ _valid_type(value, allow_cfg_node=True),
153
+ "Invalid type {} for key {}; valid types = {}".format(
154
+ type(value), name, _VALID_TYPES
155
+ ),
156
+ )
157
+
158
+ self[name] = value
159
+
160
+ def __str__(self):
161
+ def _indent(s_, num_spaces):
162
+ s = s_.split("\n")
163
+ if len(s) == 1:
164
+ return s_
165
+ first = s.pop(0)
166
+ s = [(num_spaces * " ") + line for line in s]
167
+ s = "\n".join(s)
168
+ s = first + "\n" + s
169
+ return s
170
+
171
+ r = ""
172
+ s = []
173
+ for k, v in sorted(self.items()):
174
+ seperator = "\n" if isinstance(v, CfgNode) else " "
175
+ attr_str = "{}:{}{}".format(str(k), seperator, str(v))
176
+ attr_str = _indent(attr_str, 2)
177
+ s.append(attr_str)
178
+ r += "\n".join(s)
179
+ return r
180
+
181
+ def __repr__(self):
182
+ return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__())
183
+
184
+ def dump(self, **kwargs):
185
+ """Dump to a string."""
186
+
187
+ def convert_to_dict(cfg_node, key_list):
188
+ if not isinstance(cfg_node, CfgNode):
189
+ _assert_with_logging(
190
+ _valid_type(cfg_node),
191
+ "Key {} with value {} is not a valid type; valid types: {}".format(
192
+ ".".join(key_list), type(cfg_node), _VALID_TYPES
193
+ ),
194
+ )
195
+ return cfg_node
196
+ else:
197
+ cfg_dict = dict(cfg_node)
198
+ for k, v in cfg_dict.items():
199
+ cfg_dict[k] = convert_to_dict(v, key_list + [k])
200
+ return cfg_dict
201
+
202
+ self_as_dict = convert_to_dict(self, [])
203
+ return yaml.safe_dump(self_as_dict, **kwargs)
204
+
205
+ def merge_from_file(self, cfg_filename):
206
+ """Load a yaml config file and merge it this CfgNode."""
207
+ with open(cfg_filename, "r", encoding="utf-8") as f:
208
+ cfg = self.load_cfg(f)
209
+ self.merge_from_other_cfg(cfg)
210
+
211
+ def merge_from_other_cfg(self, cfg_other):
212
+ """Merge `cfg_other` into this CfgNode."""
213
+ _merge_a_into_b(cfg_other, self, self, [])
214
+
215
+ def merge_from_list(self, cfg_list):
216
+ """Merge config (keys, values) in a list (e.g., from command line) into
217
+ this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`.
218
+ """
219
+ _assert_with_logging(
220
+ len(cfg_list) % 2 == 0,
221
+ "Override list has odd length: {}; it must be a list of pairs".format(
222
+ cfg_list
223
+ ),
224
+ )
225
+ root = self
226
+ for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
227
+ if root.key_is_deprecated(full_key):
228
+ continue
229
+ if root.key_is_renamed(full_key):
230
+ root.raise_key_rename_error(full_key)
231
+ key_list = full_key.split(".")
232
+ d = self
233
+ for subkey in key_list[:-1]:
234
+ _assert_with_logging(
235
+ subkey in d, "Non-existent key: {}".format(full_key)
236
+ )
237
+ d = d[subkey]
238
+ subkey = key_list[-1]
239
+ _assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key))
240
+ value = self._decode_cfg_value(v)
241
+ value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)
242
+ d[subkey] = value
243
+
244
+ def freeze(self):
245
+ """Make this CfgNode and all of its children immutable."""
246
+ self._immutable(True)
247
+
248
+ def defrost(self):
249
+ """Make this CfgNode and all of its children mutable."""
250
+ self._immutable(False)
251
+
252
+ def is_frozen(self):
253
+ """Return mutability."""
254
+ return self.__dict__[CfgNode.IMMUTABLE]
255
+
256
+ def _immutable(self, is_immutable):
257
+ """Set immutability to is_immutable and recursively apply the setting
258
+ to all nested CfgNodes.
259
+ """
260
+ self.__dict__[CfgNode.IMMUTABLE] = is_immutable
261
+ # Recursively set immutable state
262
+ for v in self.__dict__.values():
263
+ if isinstance(v, CfgNode):
264
+ v._immutable(is_immutable)
265
+ for v in self.values():
266
+ if isinstance(v, CfgNode):
267
+ v._immutable(is_immutable)
268
+
269
+ def clone(self):
270
+ """Recursively copy this CfgNode."""
271
+ return copy.deepcopy(self)
272
+
273
+ def register_deprecated_key(self, key):
274
+ """Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated
275
+ keys a warning is generated and the key is ignored.
276
+ """
277
+ _assert_with_logging(
278
+ key not in self.__dict__[CfgNode.DEPRECATED_KEYS],
279
+ "key {} is already registered as a deprecated key".format(key),
280
+ )
281
+ self.__dict__[CfgNode.DEPRECATED_KEYS].add(key)
282
+
283
+ def register_renamed_key(self, old_name, new_name, message=None):
284
+ """Register a key as having been renamed from `old_name` to `new_name`.
285
+ When merging a renamed key, an exception is thrown alerting to user to
286
+ the fact that the key has been renamed.
287
+ """
288
+ _assert_with_logging(
289
+ old_name not in self.__dict__[CfgNode.RENAMED_KEYS],
290
+ "key {} is already registered as a renamed cfg key".format(old_name),
291
+ )
292
+ value = new_name
293
+ if message:
294
+ value = (new_name, message)
295
+ self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value
296
+
297
+ def key_is_deprecated(self, full_key):
298
+ """Test if a key is deprecated."""
299
+ if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]:
300
+ logger.warning("Deprecated config key (ignoring): {}".format(full_key))
301
+ return True
302
+ return False
303
+
304
+ def key_is_renamed(self, full_key):
305
+ """Test if a key is renamed."""
306
+ return full_key in self.__dict__[CfgNode.RENAMED_KEYS]
307
+
308
+ def raise_key_rename_error(self, full_key):
309
+ new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key]
310
+ if isinstance(new_key, tuple):
311
+ msg = " Note: " + new_key[1]
312
+ new_key = new_key[0]
313
+ else:
314
+ msg = ""
315
+ raise KeyError(
316
+ "Key {} was renamed to {}; please update your config.{}".format(
317
+ full_key, new_key, msg
318
+ )
319
+ )
320
+
321
+ def is_new_allowed(self):
322
+ return self.__dict__[CfgNode.NEW_ALLOWED]
323
+
324
+ @classmethod
325
+ def load_cfg(cls, cfg_file_obj_or_str):
326
+ """
327
+ Load a cfg.
328
+ Args:
329
+ cfg_file_obj_or_str (str or file):
330
+ Supports loading from:
331
+ - A file object backed by a YAML file
332
+ - A file object backed by a Python source file that exports an attribute
333
+ "cfg" that is either a dict or a CfgNode
334
+ - A string that can be parsed as valid YAML
335
+ """
336
+ _assert_with_logging(
337
+ isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)),
338
+ "Expected first argument to be of type {} or {}, but it was {}".format(
339
+ _FILE_TYPES, str, type(cfg_file_obj_or_str)
340
+ ),
341
+ )
342
+ if isinstance(cfg_file_obj_or_str, str):
343
+ return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str)
344
+ elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):
345
+ return cls._load_cfg_from_file(cfg_file_obj_or_str)
346
+ else:
347
+ raise NotImplementedError("Impossible to reach here (unless there's a bug)")
348
+
349
+ @classmethod
350
+ def _load_cfg_from_file(cls, file_obj):
351
+ """Load a config from a YAML file or a Python source file."""
352
+ _, file_extension = os.path.splitext(file_obj.name)
353
+ if file_extension in _YAML_EXTS:
354
+ return cls._load_cfg_from_yaml_str(file_obj.read())
355
+ elif file_extension in _PY_EXTS:
356
+ return cls._load_cfg_py_source(file_obj.name)
357
+ else:
358
+ raise Exception(
359
+ "Attempt to load from an unsupported file type {}; "
360
+ "only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS))
361
+ )
362
+
363
+ @classmethod
364
+ def _load_cfg_from_yaml_str(cls, str_obj):
365
+ """Load a config from a YAML string encoding."""
366
+ cfg_as_dict = yaml.safe_load(str_obj)
367
+ return cls(cfg_as_dict)
368
+
369
+ @classmethod
370
+ def _load_cfg_py_source(cls, filename):
371
+ """Load a config from a Python source file."""
372
+ module = _load_module_from_file("yacs.config.override", filename)
373
+ _assert_with_logging(
374
+ hasattr(module, "cfg"),
375
+ "Python module from file {} must have 'cfg' attr".format(filename),
376
+ )
377
+ VALID_ATTR_TYPES = {dict, CfgNode}
378
+ _assert_with_logging(
379
+ type(module.cfg) in VALID_ATTR_TYPES,
380
+ "Imported module 'cfg' attr must be in {} but is {} instead".format(
381
+ VALID_ATTR_TYPES, type(module.cfg)
382
+ ),
383
+ )
384
+ return cls(module.cfg)
385
+
386
+ @classmethod
387
+ def _decode_cfg_value(cls, value):
388
+ """
389
+ Decodes a raw config value (e.g., from a yaml config files or command
390
+ line argument) into a Python object.
391
+
392
+ If the value is a dict, it will be interpreted as a new CfgNode.
393
+ If the value is a str, it will be evaluated as literals.
394
+ Otherwise it is returned as-is.
395
+ """
396
+ # Configs parsed from raw yaml will contain dictionary keys that need to be
397
+ # converted to CfgNode objects
398
+ if isinstance(value, dict):
399
+ return cls(value)
400
+ # All remaining processing is only applied to strings
401
+ if not isinstance(value, str):
402
+ return value
403
+ # Try to interpret `value` as a:
404
+ # string, number, tuple, list, dict, boolean, or None
405
+ try:
406
+ value = literal_eval(value)
407
+ # The following two excepts allow v to pass through when it represents a
408
+ # string.
409
+ #
410
+ # Longer explanation:
411
+ # The type of v is always a string (before calling literal_eval), but
412
+ # sometimes it *represents* a string and other times a data structure, like
413
+ # a list. In the case that v represents a string, what we got back from the
414
+ # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
415
+ # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
416
+ # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
417
+ # will raise a SyntaxError.
418
+ except ValueError:
419
+ pass
420
+ except SyntaxError:
421
+ pass
422
+ return value
423
+
424
+
425
+ load_cfg = (
426
+ CfgNode.load_cfg
427
+ ) # keep this function in global scope for backward compatibility
428
+
429
+
430
+ def _valid_type(value, allow_cfg_node=False):
431
+ return (type(value) in _VALID_TYPES) or (
432
+ allow_cfg_node and isinstance(value, CfgNode)
433
+ )
434
+
435
+
436
+ def _merge_a_into_b(a, b, root, key_list):
437
+ """Merge config dictionary a into config dictionary b, clobbering the
438
+ options in b whenever they are also specified in a.
439
+ """
440
+ _assert_with_logging(
441
+ isinstance(a, CfgNode),
442
+ "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode),
443
+ )
444
+ _assert_with_logging(
445
+ isinstance(b, CfgNode),
446
+ "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode),
447
+ )
448
+
449
+ for k, v_ in a.items():
450
+ full_key = ".".join(key_list + [k])
451
+
452
+ v = copy.deepcopy(v_)
453
+ v = b._decode_cfg_value(v)
454
+
455
+ if k in b:
456
+ v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
457
+ # Recursively merge dicts
458
+ if isinstance(v, CfgNode):
459
+ try:
460
+ _merge_a_into_b(v, b[k], root, key_list + [k])
461
+ except BaseException:
462
+ raise
463
+ else:
464
+ b[k] = v
465
+ elif b.is_new_allowed():
466
+ b[k] = v
467
+ else:
468
+ if root.key_is_deprecated(full_key):
469
+ continue
470
+ elif root.key_is_renamed(full_key):
471
+ root.raise_key_rename_error(full_key)
472
+ else:
473
+ raise KeyError("Non-existent config key: {}".format(full_key))
474
+
475
+
476
+ def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
477
+ """Checks that `replacement`, which is intended to replace `original` is of
478
+ the right type. The type is correct if it matches exactly or is one of a few
479
+ cases in which the type can be easily coerced.
480
+ """
481
+ original_type = type(original)
482
+ replacement_type = type(replacement)
483
+
484
+ # The types must match (with some exceptions)
485
+ if replacement_type == original_type:
486
+ return replacement
487
+
488
+ # Cast replacement from from_type to to_type if the replacement and original
489
+ # types match from_type and to_type
490
+ def conditional_cast(from_type, to_type):
491
+ if replacement_type == from_type and original_type == to_type:
492
+ return True, to_type(replacement)
493
+ else:
494
+ return False, None
495
+
496
+ # Conditionally casts
497
+ # list <-> tuple
498
+ casts = [(tuple, list), (list, tuple)]
499
+ # For py2: allow converting from str (bytes) to a unicode string
500
+ try:
501
+ casts.append((str, unicode)) # noqa: F821
502
+ except Exception:
503
+ pass
504
+
505
+ for (from_type, to_type) in casts:
506
+ converted, converted_value = conditional_cast(from_type, to_type)
507
+ if converted:
508
+ return converted_value
509
+
510
+ raise ValueError(
511
+ "Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
512
+ "key: {}".format(
513
+ original_type, replacement_type, original, replacement, full_key
514
+ )
515
+ )
516
+
517
+
518
+ def _assert_with_logging(cond, msg):
519
+ if not cond:
520
+ logger.debug(msg)
521
+ assert cond, msg
522
+
523
+
524
+ def _load_module_from_file(name, filename):
525
+ if _PY2:
526
+ module = imp.load_source(name, filename)
527
+ else:
528
+ spec = importlib.util.spec_from_file_location(name, filename)
529
+ module = importlib.util.module_from_spec(spec)
530
+ spec.loader.exec_module(module)
531
+ return module