Spaces:
Sleeping
Sleeping
import torch | |
import dataclasses | |
import torch.nn.functional as F | |
from dataclasses import dataclass | |
from typing import Any, Optional, Dict | |
class VideoData: | |
""" | |
Dataclass for storing video tracks data. | |
""" | |
video: torch.Tensor # B, S, C, H, W | |
trajs: torch.Tensor # B, S, N, 2 | |
visibs: torch.Tensor # B, S, N | |
# optional data | |
valids: Optional[torch.Tensor] = None # B, S, N | |
hards: Optional[torch.Tensor] = None # B, S, N | |
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W | |
seq_name: Optional[str] = None | |
dname: Optional[str] = None | |
query_points: Optional[torch.Tensor] = None # TapVID evaluation format | |
transforms: Optional[Dict[str, Any]] = None | |
aug_video: Optional[torch.Tensor] = None | |
def collate_fn(batch): | |
""" | |
Collate function for video tracks data. | |
""" | |
video = torch.stack([b.video for b in batch], dim=0) | |
trajs = torch.stack([b.trajs for b in batch], dim=0) | |
visibs = torch.stack([b.visibs for b in batch], dim=0) | |
query_points = segmentation = None | |
if batch[0].query_points is not None: | |
query_points = torch.stack([b.query_points for b in batch], dim=0) | |
if batch[0].segmentation is not None: | |
segmentation = torch.stack([b.segmentation for b in batch], dim=0) | |
seq_name = [b.seq_name for b in batch] | |
dname = [b.dname for b in batch] | |
return VideoData( | |
video=video, | |
trajs=trajs, | |
visibs=visibs, | |
segmentation=segmentation, | |
seq_name=seq_name, | |
dname=dname, | |
query_points=query_points, | |
) | |
def collate_fn_train(batch): | |
""" | |
Collate function for video tracks data during training. | |
""" | |
gotit = [gotit for _, gotit in batch] | |
video = torch.stack([b.video for b, _ in batch], dim=0) | |
trajs = torch.stack([b.trajs for b, _ in batch], dim=0) | |
visibs = torch.stack([b.visibs for b, _ in batch], dim=0) | |
valids = torch.stack([b.valids for b, _ in batch], dim=0) | |
seq_name = [b.seq_name for b, _ in batch] | |
dname = [b.dname for b, _ in batch] | |
query_points = transforms = aug_video = hards = None | |
if batch[0][0].query_points is not None: | |
query_points = torch.stack([b.query_points for b, _ in batch], dim=0) | |
if batch[0][0].hards is not None: | |
hards = torch.stack([b.hards for b, _ in batch], dim=0) | |
if batch[0][0].transforms is not None: | |
transforms = [b.transforms for b, _ in batch] | |
if batch[0][0].aug_video is not None: | |
aug_video = torch.stack([b.aug_video for b, _ in batch], dim=0) | |
return ( | |
VideoData( | |
video=video, | |
trajs=trajs, | |
visibs=visibs, | |
valids=valids, | |
hards=hards, | |
seq_name=seq_name, | |
dname=dname, | |
query_points=query_points, | |
aug_video=aug_video, | |
transforms=transforms, | |
), | |
gotit, | |
) | |
def try_to_cuda(t: Any) -> Any: | |
""" | |
Try to move the input variable `t` to a cuda device. | |
Args: | |
t: Input. | |
Returns: | |
t_cuda: `t` moved to a cuda device, if supported. | |
""" | |
try: | |
t = t.float().cuda() | |
except AttributeError: | |
pass | |
return t | |
def dataclass_to_cuda_(obj): | |
""" | |
Move all contents of a dataclass to cuda inplace if supported. | |
Args: | |
batch: Input dataclass. | |
Returns: | |
batch_cuda: `batch` moved to a cuda device, if supported. | |
""" | |
for f in dataclasses.fields(obj): | |
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) | |
return obj | |