Spaces:
Running
on
L40S
Running
on
L40S
import logging | |
from typing import Any, Optional, Tuple | |
import torch | |
from omegaconf import DictConfig, OmegaConf | |
from safetensors.torch import load_model | |
BOUNDING_BOX_MAX_SIZE = 1.925 | |
def normalize_bbox(bounding_box_xyz: Tuple[float]): | |
max_l = max(bounding_box_xyz) | |
return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz] | |
def load_config(cfg_path: str) -> Any: | |
""" | |
Load and resolve a configuration file. | |
Args: | |
cfg_path (str): The path to the configuration file. | |
Returns: | |
Any: The loaded and resolved configuration object. | |
Raises: | |
AssertionError: If the loaded configuration is not an instance of DictConfig. | |
""" | |
cfg = OmegaConf.load(cfg_path) | |
OmegaConf.resolve(cfg) | |
assert isinstance(cfg, DictConfig) | |
return cfg | |
def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any: | |
""" | |
Parses a configuration dictionary into a structured configuration object. | |
Args: | |
cfg_type (Any): The type of the structured configuration object. | |
cfg (DictConfig): The configuration dictionary to be parsed. | |
Returns: | |
Any: The structured configuration object created from the dictionary. | |
""" | |
scfg = OmegaConf.structured(cfg_type(**cfg)) | |
return scfg | |
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None: | |
""" | |
Load a safetensors checkpoint into a PyTorch model. | |
The model is updated in place. | |
Args: | |
model: PyTorch model to load weights into | |
ckpt_path: Path to the safetensors checkpoint file | |
Returns: | |
None | |
""" | |
assert ckpt_path.endswith( | |
".safetensors" | |
), f"Checkpoint path '{ckpt_path}' is not a safetensors file" | |
load_model(model, ckpt_path) | |
def select_device() -> Any: | |
""" | |
Selects the appropriate PyTorch device for tensor allocation. | |
Returns: | |
Any: The `torch.device` object. | |
""" | |
return torch.device( | |
"cuda" | |
if torch.cuda.is_available() | |
else "mps" | |
if torch.backends.mps.is_available() | |
else "cpu" | |
) | |