Spaces:
Running
on
L40S
Running
on
L40S
File size: 2,089 Bytes
616f571 ce16420 616f571 ce16420 616f571 ce16420 616f571 ce16420 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import logging
from typing import Any, Optional, Tuple
import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_model
BOUNDING_BOX_MAX_SIZE = 1.925
def normalize_bbox(bounding_box_xyz: Tuple[float]):
max_l = max(bounding_box_xyz)
return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz]
def load_config(cfg_path: str) -> Any:
"""
Load and resolve a configuration file.
Args:
cfg_path (str): The path to the configuration file.
Returns:
Any: The loaded and resolved configuration object.
Raises:
AssertionError: If the loaded configuration is not an instance of DictConfig.
"""
cfg = OmegaConf.load(cfg_path)
OmegaConf.resolve(cfg)
assert isinstance(cfg, DictConfig)
return cfg
def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
"""
Parses a configuration dictionary into a structured configuration object.
Args:
cfg_type (Any): The type of the structured configuration object.
cfg (DictConfig): The configuration dictionary to be parsed.
Returns:
Any: The structured configuration object created from the dictionary.
"""
scfg = OmegaConf.structured(cfg_type(**cfg))
return scfg
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
"""
Load a safetensors checkpoint into a PyTorch model.
The model is updated in place.
Args:
model: PyTorch model to load weights into
ckpt_path: Path to the safetensors checkpoint file
Returns:
None
"""
assert ckpt_path.endswith(
".safetensors"
), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
load_model(model, ckpt_path)
def select_device() -> Any:
"""
Selects the appropriate PyTorch device for tensor allocation.
Returns:
Any: The `torch.device` object.
"""
return torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
|