File size: 4,034 Bytes
1924502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import pathlib, urllib.request, torch, sys

ROOT = pathlib.Path(__file__).resolve().parents[3]
MODEL_DIR = ROOT / "models"
DEPTH_REPO = MODEL_DIR / "Depth-Anything-V2"
DEPTH_W = ROOT / "weights" / "depth_anything_v2" / "depth_anything_v2_vits.pth"
SAM_REPO = MODEL_DIR / "segment-anything"
SAM_W    = ROOT / "weights" / "segment-anything" / "sam_vit_b_01ec64.pth"

sys.path.insert(0, str(DEPTH_REPO))
sys.path.insert(0, str(SAM_REPO))
from depth_anything_v2.dpt import DepthAnythingV2
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

def _download(url: str, dest: pathlib.Path):
    dest.parent.mkdir(parents=True, exist_ok=True)
    if not dest.exists():
        urllib.request.urlretrieve(url, dest)

def load_depth(device):
    _download("https://huggingface.co/Depth-Anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth", DEPTH_W)
    
    # Model configuration for ViT-S
    model_configs = {
        'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
    }
    
    net = DepthAnythingV2(**model_configs['vits'])
    state_dict = torch.load(DEPTH_W, map_location=device)
    net.load_state_dict(state_dict, strict=True)
    net = net.to(device).eval()
    return net

def load_sam(device):
    _download("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", SAM_W)
    sam = sam_model_registry["vit_b"](checkpoint=str(SAM_W)).to(device)
    return SamAutomaticMaskGenerator(sam)

def predict_depth(net, img_rgb, device):
    from agent.io import pad_to_multiple_of_14, safe_resize
    import torch, numpy as np
    img = pad_to_multiple_of_14(safe_resize(img_rgb))
    ten = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float() / 255
    with torch.no_grad():
        d = net(ten.to(device)).squeeze().cpu().numpy()
    return d

def generate_masks(mask_gen, img_rgb, use_goal1_approach=True):
    from agent.io import safe_resize, pad_to_multiple_of_14
    
    if use_goal1_approach:
        # Goal1 approach: Just resize to 1024px, NO padding (this works!)
        img_processed = safe_resize(img_rgb, long_side=1024)
    else:
        # Goal2 approach: Use same preprocessing as depth (may misalign but consistent)
        img_processed = pad_to_multiple_of_14(safe_resize(img_rgb))
    
    return mask_gen.generate(img_processed)

def predict_depth_and_masks(depth_net, mask_gen, img_rgb, device, approach="goal1_sam"):
    """
    Generate depth and masks with different alignment strategies.
    
    Args:
        approach: 
            - "goal1_sam": Use Goal1's SAM approach (1024px, no padding) - better SAM results
            - "aligned": Use identical preprocessing for both - perfect alignment
    """
    from agent.io import pad_to_multiple_of_14, safe_resize
    import torch
    
    if approach == "goal1_sam":
        # Goal1 approach: Different preprocessing for better SAM results
        depth_img = pad_to_multiple_of_14(safe_resize(img_rgb))  # For depth (960px + padding)
        sam_img = safe_resize(img_rgb, long_side=1024)           # For SAM (1024px, no padding)
        
        # Generate depth map
        ten = torch.from_numpy(depth_img).permute(2,0,1).unsqueeze(0).float() / 255
        with torch.no_grad():
            depth = depth_net(ten.to(device)).squeeze().cpu().numpy()
        
        # Generate masks with Goal1 approach
        masks = mask_gen.generate(sam_img)
        
        return depth, masks, depth_img, sam_img
        
    else:  # "aligned"
        # Goal2 approach: Same preprocessing for perfect alignment
        img_processed = pad_to_multiple_of_14(safe_resize(img_rgb))
        
        # Generate depth map
        ten = torch.from_numpy(img_processed).permute(2,0,1).unsqueeze(0).float() / 255
        with torch.no_grad():
            depth = depth_net(ten.to(device)).squeeze().cpu().numpy()
        
        # Generate masks on the same processed image
        masks = mask_gen.generate(img_processed)
        
        return depth, masks, img_processed