File size: 1,288 Bytes
d7e8ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd1cf2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import logging

# Configure Logger
logger = logging.getLogger(__name__)

def get_model(task: str, model_key: str, device="cpu"):
    """
    Dynamically retrieves the model instance based on the task and model_key.

    Args:
        task (str): One of "detection", "segmentation", or "depth".
        model_key (str): Model identifier or variant.
        device (str): Device to run inference on ("cpu" or "cuda").

    Returns:
        object: Uninitialized model instance.
    """
    logger.info(f"Preparing model wrapper '{model_key}' for task '{task}' on device '{device}'")

    try:
        if task == "detection":
            from models.detection.detector import ObjectDetector
            return ObjectDetector(model_key=model_key, device=device)
        elif task == "segmentation":
            from models.segmentation.segmenter import Segmenter
            return Segmenter(model_key=model_key, device=device)
        elif task == "depth":
            from models.depth.depth_estimator import DepthEstimator
            return DepthEstimator(model_key=model_key, device=device)
        else:
            raise ValueError(f"Unsupported task '{task}'")
    except Exception as e:
        logger.error(f"Error loading model '{model_key}' for task '{task}': {e}")
        raise