DurgaDeepak commited on
Commit
d7e8ce4
·
verified ·
1 Parent(s): 974d3d5

Update registry.py

Browse files
Files changed (1) hide show
  1. registry.py +33 -43
registry.py CHANGED
@@ -1,43 +1,33 @@
1
- import logging
2
-
3
- # Configure Logger
4
- logger = logging.getLogger(__name__)
5
-
6
- def get_model(task: str, model_key: str, device="cpu"):
7
- """
8
- Dynamically retrieves the model instance based on the task and model_key.
9
-
10
- Args:
11
- task (str): One of "detection", "segmentation", or "depth".
12
- model_key (str): Model identifier or variant.
13
- device (str): Device to run inference on ("cpu" or "cuda").
14
-
15
- Returns:
16
- object: Initialized model ready for inference.
17
-
18
- Raises:
19
- ValueError: If task is unsupported or model loading fails.
20
- """
21
- logger.info(f"Request received to load model '{model_key}' for task '{task}' on device '{device}'")
22
-
23
- try:
24
- if task == "detection":
25
- from models.detection.detector import ObjectDetector
26
- return ObjectDetector(model_key=model_key, device=device)
27
-
28
- elif task == "segmentation":
29
- from models.segmentation.segmenter import Segmenter
30
- return Segmenter(model_key=model_key, device=device)
31
-
32
- elif task == "depth":
33
- from models.depth.depth_estimator import DepthEstimator
34
- return DepthEstimator(model_key=model_key, device=device)
35
-
36
- else:
37
- error_msg = f"Unsupported task '{task}'. Valid options are: 'detection', 'segmentation', 'depth'."
38
- logger.error(error_msg)
39
- raise ValueError(error_msg)
40
-
41
- except Exception as e:
42
- logger.error(f"Error while loading model '{model_key}' for task '{task}': {e}")
43
- raise
 
1
+ import logging
2
+
3
+ # Configure Logger
4
+ logger = logging.getLogger(__name__)
5
+
6
+ def get_model(task: str, model_key: str, device="cpu"):
7
+ """
8
+ Dynamically retrieves the model instance based on the task and model_key.
9
+
10
+ Args:
11
+ task (str): One of "detection", "segmentation", or "depth".
12
+ model_key (str): Model identifier or variant.
13
+ device (str): Device to run inference on ("cpu" or "cuda").
14
+
15
+ Returns:
16
+ object: Uninitialized model instance.
17
+ """
18
+ logger.info(f"Preparing model wrapper '{model_key}' for task '{task}' on device '{device}'")
19
+
20
+ if task == "detection":
21
+ from models.detection.detector import ObjectDetector
22
+ return ObjectDetector(model_key=model_key, device=device)
23
+
24
+ elif task == "segmentation":
25
+ from models.segmentation.segmenter import Segmenter
26
+ return Segmenter(model_key=model_key, device=device)
27
+
28
+ elif task == "depth":
29
+ from models.depth.depth_estimator import DepthEstimator
30
+ return DepthEstimator(model_key=model_key, device=device)
31
+
32
+ else:
33
+ raise ValueError(f"Unsupported task '{task}'")