Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import numpy as np | |
import torch | |
from einops import rearrange | |
from PIL import Image | |
def convert_to_numpy(image): | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
elif isinstance(image, torch.Tensor): | |
image = image.detach().cpu().numpy() | |
elif isinstance(image, np.ndarray): | |
image = image.copy() | |
else: | |
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.' | |
return image | |
class DepthV2Annotator: | |
def __init__(self, cfg, device=None): | |
from .dpt import DepthAnythingV2 | |
# Model configurations for different variants | |
self.model_configs = { | |
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, | |
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, | |
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, | |
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} | |
} | |
# Get model variant from config, default to 'vitl' if not specified | |
model_variant = cfg.get('MODEL_VARIANT', 'vitl') | |
if model_variant not in self.model_configs: | |
raise ValueError(f"Invalid model variant '{model_variant}'. Must be one of: {list(self.model_configs.keys())}") | |
pretrained_model = cfg['PRETRAINED_MODEL'] | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device | |
# Get configuration for the selected model variant | |
config = self.model_configs[model_variant] | |
# Initialize model with the appropriate configuration | |
self.model = DepthAnythingV2( | |
encoder=config['encoder'], | |
features=config['features'], | |
out_channels=config['out_channels'] | |
).to(self.device) | |
self.model.load_state_dict( | |
torch.load( | |
pretrained_model, | |
map_location=self.device, | |
weights_only=True | |
) | |
) | |
self.model.eval() | |
def forward(self, image): | |
image = convert_to_numpy(image) | |
depth = self.model.infer_image(image) | |
depth_pt = depth.copy() | |
depth_pt -= np.min(depth_pt) | |
depth_pt /= np.max(depth_pt) | |
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) | |
depth_image = depth_image[..., np.newaxis] | |
depth_image = np.repeat(depth_image, 3, axis=2) | |
return depth_image | |
class DepthV2VideoAnnotator(DepthV2Annotator): | |
def forward(self, frames): | |
ret_frames = [] | |
for frame in frames: | |
anno_frame = super().forward(np.array(frame)) | |
ret_frames.append(anno_frame) | |
return ret_frames |