Spaces:
Sleeping
Sleeping
File size: 3,586 Bytes
6d95ea1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import torch
import dataclasses
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Any, Optional, Dict
@dataclass(eq=False)
class VideoData:
"""
Dataclass for storing video tracks data.
"""
video: torch.Tensor # B, S, C, H, W
trajs: torch.Tensor # B, S, N, 2
visibs: torch.Tensor # B, S, N
# optional data
valids: Optional[torch.Tensor] = None # B, S, N
hards: Optional[torch.Tensor] = None # B, S, N
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
seq_name: Optional[str] = None
dname: Optional[str] = None
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
transforms: Optional[Dict[str, Any]] = None
aug_video: Optional[torch.Tensor] = None
def collate_fn(batch):
"""
Collate function for video tracks data.
"""
video = torch.stack([b.video for b in batch], dim=0)
trajs = torch.stack([b.trajs for b in batch], dim=0)
visibs = torch.stack([b.visibs for b in batch], dim=0)
query_points = segmentation = None
if batch[0].query_points is not None:
query_points = torch.stack([b.query_points for b in batch], dim=0)
if batch[0].segmentation is not None:
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
seq_name = [b.seq_name for b in batch]
dname = [b.dname for b in batch]
return VideoData(
video=video,
trajs=trajs,
visibs=visibs,
segmentation=segmentation,
seq_name=seq_name,
dname=dname,
query_points=query_points,
)
def collate_fn_train(batch):
"""
Collate function for video tracks data during training.
"""
gotit = [gotit for _, gotit in batch]
video = torch.stack([b.video for b, _ in batch], dim=0)
trajs = torch.stack([b.trajs for b, _ in batch], dim=0)
visibs = torch.stack([b.visibs for b, _ in batch], dim=0)
valids = torch.stack([b.valids for b, _ in batch], dim=0)
seq_name = [b.seq_name for b, _ in batch]
dname = [b.dname for b, _ in batch]
query_points = transforms = aug_video = hards = None
if batch[0][0].query_points is not None:
query_points = torch.stack([b.query_points for b, _ in batch], dim=0)
if batch[0][0].hards is not None:
hards = torch.stack([b.hards for b, _ in batch], dim=0)
if batch[0][0].transforms is not None:
transforms = [b.transforms for b, _ in batch]
if batch[0][0].aug_video is not None:
aug_video = torch.stack([b.aug_video for b, _ in batch], dim=0)
return (
VideoData(
video=video,
trajs=trajs,
visibs=visibs,
valids=valids,
hards=hards,
seq_name=seq_name,
dname=dname,
query_points=query_points,
aug_video=aug_video,
transforms=transforms,
),
gotit,
)
def try_to_cuda(t: Any) -> Any:
"""
Try to move the input variable `t` to a cuda device.
Args:
t: Input.
Returns:
t_cuda: `t` moved to a cuda device, if supported.
"""
try:
t = t.float().cuda()
except AttributeError:
pass
return t
def dataclass_to_cuda_(obj):
"""
Move all contents of a dataclass to cuda inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
return obj
|