File size: 3,586 Bytes
6d95ea1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import dataclasses
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Any, Optional, Dict


@dataclass(eq=False)
class VideoData:
    """
    Dataclass for storing video tracks data.
    """

    video: torch.Tensor  # B, S, C, H, W
    trajs: torch.Tensor  # B, S, N, 2
    visibs: torch.Tensor  # B, S, N
    # optional data
    valids: Optional[torch.Tensor] = None  # B, S, N
    hards: Optional[torch.Tensor] = None  # B, S, N
    segmentation: Optional[torch.Tensor] = None  # B, S, 1, H, W
    seq_name: Optional[str] = None
    dname: Optional[str] = None
    query_points: Optional[torch.Tensor] = None  # TapVID evaluation format
    transforms: Optional[Dict[str, Any]] = None
    aug_video: Optional[torch.Tensor] = None


def collate_fn(batch):
    """
    Collate function for video tracks data.
    """
    video = torch.stack([b.video for b in batch], dim=0)
    trajs = torch.stack([b.trajs for b in batch], dim=0)
    visibs = torch.stack([b.visibs for b in batch], dim=0)
    query_points = segmentation = None
    if batch[0].query_points is not None:
        query_points = torch.stack([b.query_points for b in batch], dim=0)
    if batch[0].segmentation is not None:
        segmentation = torch.stack([b.segmentation for b in batch], dim=0)
    seq_name = [b.seq_name for b in batch]
    dname = [b.dname for b in batch]

    return VideoData(
        video=video,
        trajs=trajs,
        visibs=visibs,
        segmentation=segmentation,
        seq_name=seq_name,
        dname=dname,
        query_points=query_points,
    )


def collate_fn_train(batch):
    """
    Collate function for video tracks data during training.
    """
    gotit = [gotit for _, gotit in batch]
    video = torch.stack([b.video for b, _ in batch], dim=0)
    trajs = torch.stack([b.trajs for b, _ in batch], dim=0)
    visibs = torch.stack([b.visibs for b, _ in batch], dim=0)
    valids = torch.stack([b.valids for b, _ in batch], dim=0)
    seq_name = [b.seq_name for b, _ in batch]
    dname = [b.dname for b, _ in batch]
    query_points = transforms = aug_video = hards = None
    if batch[0][0].query_points is not None:
        query_points = torch.stack([b.query_points for b, _ in batch], dim=0)
    if batch[0][0].hards is not None:
        hards = torch.stack([b.hards for b, _ in batch], dim=0)

    if batch[0][0].transforms is not None:
        transforms = [b.transforms for b, _ in batch]

    if batch[0][0].aug_video is not None:
        aug_video = torch.stack([b.aug_video for b, _ in batch], dim=0)
    return (
        VideoData(
            video=video,
            trajs=trajs,
            visibs=visibs,
            valids=valids,
            hards=hards,
            seq_name=seq_name,
            dname=dname,
            query_points=query_points,
            aug_video=aug_video,
            transforms=transforms,
        ),
        gotit,
    )


def try_to_cuda(t: Any) -> Any:
    """
    Try to move the input variable `t` to a cuda device.

    Args:
        t: Input.

    Returns:
        t_cuda: `t` moved to a cuda device, if supported.
    """
    try:
        t = t.float().cuda()
    except AttributeError:
        pass
    return t


def dataclass_to_cuda_(obj):
    """
    Move all contents of a dataclass to cuda inplace if supported.

    Args:
        batch: Input dataclass.

    Returns:
        batch_cuda: `batch` moved to a cuda device, if supported.
    """
    for f in dataclasses.fields(obj):
        setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
    return obj