Spaces:
Running
on
Zero
Running
on
Zero
Update registry.py
Browse files- registry.py +33 -43
registry.py
CHANGED
@@ -1,43 +1,33 @@
|
|
1 |
-
import logging
|
2 |
-
|
3 |
-
# Configure Logger
|
4 |
-
logger = logging.getLogger(__name__)
|
5 |
-
|
6 |
-
def get_model(task: str, model_key: str, device="cpu"):
|
7 |
-
"""
|
8 |
-
Dynamically retrieves the model instance based on the task and model_key.
|
9 |
-
|
10 |
-
Args:
|
11 |
-
task (str): One of "detection", "segmentation", or "depth".
|
12 |
-
model_key (str): Model identifier or variant.
|
13 |
-
device (str): Device to run inference on ("cpu" or "cuda").
|
14 |
-
|
15 |
-
Returns:
|
16 |
-
object:
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
""
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
return DepthEstimator(model_key=model_key, device=device)
|
35 |
-
|
36 |
-
else:
|
37 |
-
error_msg = f"Unsupported task '{task}'. Valid options are: 'detection', 'segmentation', 'depth'."
|
38 |
-
logger.error(error_msg)
|
39 |
-
raise ValueError(error_msg)
|
40 |
-
|
41 |
-
except Exception as e:
|
42 |
-
logger.error(f"Error while loading model '{model_key}' for task '{task}': {e}")
|
43 |
-
raise
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
# Configure Logger
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
|
6 |
+
def get_model(task: str, model_key: str, device="cpu"):
|
7 |
+
"""
|
8 |
+
Dynamically retrieves the model instance based on the task and model_key.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
task (str): One of "detection", "segmentation", or "depth".
|
12 |
+
model_key (str): Model identifier or variant.
|
13 |
+
device (str): Device to run inference on ("cpu" or "cuda").
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
object: Uninitialized model instance.
|
17 |
+
"""
|
18 |
+
logger.info(f"Preparing model wrapper '{model_key}' for task '{task}' on device '{device}'")
|
19 |
+
|
20 |
+
if task == "detection":
|
21 |
+
from models.detection.detector import ObjectDetector
|
22 |
+
return ObjectDetector(model_key=model_key, device=device)
|
23 |
+
|
24 |
+
elif task == "segmentation":
|
25 |
+
from models.segmentation.segmenter import Segmenter
|
26 |
+
return Segmenter(model_key=model_key, device=device)
|
27 |
+
|
28 |
+
elif task == "depth":
|
29 |
+
from models.depth.depth_estimator import DepthEstimator
|
30 |
+
return DepthEstimator(model_key=model_key, device=device)
|
31 |
+
|
32 |
+
else:
|
33 |
+
raise ValueError(f"Unsupported task '{task}'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|