File size: 713 Bytes
3beba17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
"""Model loading and caching."""

import sys
import types

from caliby import CalibyModel

MODELS: dict[str, CalibyModel] = {}


def get_model(variant: str, device: str) -> CalibyModel:
    """Load and cache a CalibyModel by variant name."""
    if variant not in MODELS:
        # ZeroGPU's @spaces.GPU decorator may remove sys.modules["__main__"].
        # Lightning's load_from_checkpoint calls inspect.stack() which
        # requires it, so ensure a placeholder exists.
        if "__main__" not in sys.modules:
            sys.modules["__main__"] = types.ModuleType("__main__")

        from caliby import load_model

        MODELS[variant] = load_model(variant, device=device)
    return MODELS[variant]