File size: 2,329 Bytes
fe64bad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import numpy as np

def lengths_to_mask(lengths):
    max_len = max(lengths)
    mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
    return mask
    

def collate_tensors(batch):
    dims = batch[0].dim()
    max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
    size = (len(batch),) + tuple(max_size)
    canvas = batch[0].new_zeros(size=size)
    for i, b in enumerate(batch):
        sub_tensor = canvas[i]
        for d in range(dims):
            sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
        sub_tensor.add_(b)
    return canvas


def collate(batch):
    notnone_batches = [b for b in batch if b is not None]
    if len(notnone_batches) == 0:
        out_batch = {"x": [], "y": [],
                     "mask": [], "lengths": [],
                     "clip_image": [], "clip_text": [],
                     "clip_path": [], "clip_images_emb": []
                     }
        return out_batch
    databatch = [b['inp'] for b in notnone_batches]
    labelbatch = [b['target'] for b in notnone_batches]
    lenbatch = [len(b['inp'][0][0]) for b in notnone_batches]


    databatchTensor = collate_tensors(databatch)
    labelbatchTensor = torch.as_tensor(labelbatch)
    lenbatchTensor = torch.as_tensor(lenbatch)
    maskbatchTensor = lengths_to_mask(lenbatchTensor)


    out_batch = {"x": databatchTensor, "y": labelbatchTensor,
             "mask": maskbatchTensor, "lengths": lenbatchTensor}
             # "y_action_names": actionlabelbatchTensor}
    if 'clip_image' in notnone_batches[0]:
        clip_image_batch = [torch.as_tensor(b['clip_image']) for b in notnone_batches]
        out_batch.update({'clip_images': collate_tensors(clip_image_batch)})

    if 'clip_text' in notnone_batches[0]:
        textbatch = [b['clip_text'] for b in notnone_batches]
        out_batch.update({'clip_text': textbatch})

    if 'clip_path' in notnone_batches[0]:
        textbatch = [b['clip_path'] for b in notnone_batches]
        out_batch.update({'clip_path': textbatch})

    if 'all_categories' in notnone_batches[0]:
        textbatch = [b['all_categories'] for b in notnone_batches]
        out_batch.update({'all_categories': textbatch})

    return out_batch