aharley's picture
added basics
6d95ea1
import torch
import dataclasses
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Any, Optional, Dict
@dataclass(eq=False)
class VideoData:
"""
Dataclass for storing video tracks data.
"""
video: torch.Tensor # B, S, C, H, W
trajs: torch.Tensor # B, S, N, 2
visibs: torch.Tensor # B, S, N
# optional data
valids: Optional[torch.Tensor] = None # B, S, N
hards: Optional[torch.Tensor] = None # B, S, N
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
seq_name: Optional[str] = None
dname: Optional[str] = None
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
transforms: Optional[Dict[str, Any]] = None
aug_video: Optional[torch.Tensor] = None
def collate_fn(batch):
"""
Collate function for video tracks data.
"""
video = torch.stack([b.video for b in batch], dim=0)
trajs = torch.stack([b.trajs for b in batch], dim=0)
visibs = torch.stack([b.visibs for b in batch], dim=0)
query_points = segmentation = None
if batch[0].query_points is not None:
query_points = torch.stack([b.query_points for b in batch], dim=0)
if batch[0].segmentation is not None:
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
seq_name = [b.seq_name for b in batch]
dname = [b.dname for b in batch]
return VideoData(
video=video,
trajs=trajs,
visibs=visibs,
segmentation=segmentation,
seq_name=seq_name,
dname=dname,
query_points=query_points,
)
def collate_fn_train(batch):
"""
Collate function for video tracks data during training.
"""
gotit = [gotit for _, gotit in batch]
video = torch.stack([b.video for b, _ in batch], dim=0)
trajs = torch.stack([b.trajs for b, _ in batch], dim=0)
visibs = torch.stack([b.visibs for b, _ in batch], dim=0)
valids = torch.stack([b.valids for b, _ in batch], dim=0)
seq_name = [b.seq_name for b, _ in batch]
dname = [b.dname for b, _ in batch]
query_points = transforms = aug_video = hards = None
if batch[0][0].query_points is not None:
query_points = torch.stack([b.query_points for b, _ in batch], dim=0)
if batch[0][0].hards is not None:
hards = torch.stack([b.hards for b, _ in batch], dim=0)
if batch[0][0].transforms is not None:
transforms = [b.transforms for b, _ in batch]
if batch[0][0].aug_video is not None:
aug_video = torch.stack([b.aug_video for b, _ in batch], dim=0)
return (
VideoData(
video=video,
trajs=trajs,
visibs=visibs,
valids=valids,
hards=hards,
seq_name=seq_name,
dname=dname,
query_points=query_points,
aug_video=aug_video,
transforms=transforms,
),
gotit,
)
def try_to_cuda(t: Any) -> Any:
"""
Try to move the input variable `t` to a cuda device.
Args:
t: Input.
Returns:
t_cuda: `t` moved to a cuda device, if supported.
"""
try:
t = t.float().cuda()
except AttributeError:
pass
return t
def dataclass_to_cuda_(obj):
"""
Move all contents of a dataclass to cuda inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
return obj