Spaces:
Sleeping
Sleeping
import pathlib, urllib.request, torch, sys | |
ROOT = pathlib.Path(__file__).resolve().parents[3] | |
MODEL_DIR = ROOT / "models" | |
DEPTH_REPO = MODEL_DIR / "Depth-Anything-V2" | |
DEPTH_W = ROOT / "weights" / "depth_anything_v2" / "depth_anything_v2_vits.pth" | |
SAM_REPO = MODEL_DIR / "segment-anything" | |
SAM_W = ROOT / "weights" / "segment-anything" / "sam_vit_b_01ec64.pth" | |
sys.path.insert(0, str(DEPTH_REPO)) | |
sys.path.insert(0, str(SAM_REPO)) | |
from depth_anything_v2.dpt import DepthAnythingV2 | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
def _download(url: str, dest: pathlib.Path): | |
dest.parent.mkdir(parents=True, exist_ok=True) | |
if not dest.exists(): | |
urllib.request.urlretrieve(url, dest) | |
def load_depth(device): | |
_download("https://huggingface.co/Depth-Anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth", DEPTH_W) | |
# Model configuration for ViT-S | |
model_configs = { | |
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, | |
} | |
net = DepthAnythingV2(**model_configs['vits']) | |
state_dict = torch.load(DEPTH_W, map_location=device) | |
net.load_state_dict(state_dict, strict=True) | |
net = net.to(device).eval() | |
return net | |
def load_sam(device): | |
_download("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", SAM_W) | |
sam = sam_model_registry["vit_b"](checkpoint=str(SAM_W)).to(device) | |
return SamAutomaticMaskGenerator(sam) | |
def predict_depth(net, img_rgb, device): | |
from agent.io import pad_to_multiple_of_14, safe_resize | |
import torch, numpy as np | |
img = pad_to_multiple_of_14(safe_resize(img_rgb)) | |
ten = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float() / 255 | |
with torch.no_grad(): | |
d = net(ten.to(device)).squeeze().cpu().numpy() | |
return d | |
def generate_masks(mask_gen, img_rgb, use_goal1_approach=True): | |
from agent.io import safe_resize, pad_to_multiple_of_14 | |
if use_goal1_approach: | |
# Goal1 approach: Just resize to 1024px, NO padding (this works!) | |
img_processed = safe_resize(img_rgb, long_side=1024) | |
else: | |
# Goal2 approach: Use same preprocessing as depth (may misalign but consistent) | |
img_processed = pad_to_multiple_of_14(safe_resize(img_rgb)) | |
return mask_gen.generate(img_processed) | |
def predict_depth_and_masks(depth_net, mask_gen, img_rgb, device, approach="goal1_sam"): | |
""" | |
Generate depth and masks with different alignment strategies. | |
Args: | |
approach: | |
- "goal1_sam": Use Goal1's SAM approach (1024px, no padding) - better SAM results | |
- "aligned": Use identical preprocessing for both - perfect alignment | |
""" | |
from agent.io import pad_to_multiple_of_14, safe_resize | |
import torch | |
if approach == "goal1_sam": | |
# Goal1 approach: Different preprocessing for better SAM results | |
depth_img = pad_to_multiple_of_14(safe_resize(img_rgb)) # For depth (960px + padding) | |
sam_img = safe_resize(img_rgb, long_side=1024) # For SAM (1024px, no padding) | |
# Generate depth map | |
ten = torch.from_numpy(depth_img).permute(2,0,1).unsqueeze(0).float() / 255 | |
with torch.no_grad(): | |
depth = depth_net(ten.to(device)).squeeze().cpu().numpy() | |
# Generate masks with Goal1 approach | |
masks = mask_gen.generate(sam_img) | |
return depth, masks, depth_img, sam_img | |
else: # "aligned" | |
# Goal2 approach: Same preprocessing for perfect alignment | |
img_processed = pad_to_multiple_of_14(safe_resize(img_rgb)) | |
# Generate depth map | |
ten = torch.from_numpy(img_processed).permute(2,0,1).unsqueeze(0).float() / 255 | |
with torch.no_grad(): | |
depth = depth_net(ten.to(device)).squeeze().cpu().numpy() | |
# Generate masks on the same processed image | |
masks = mask_gen.generate(img_processed) | |
return depth, masks, img_processed | |