File size: 5,766 Bytes
f499d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from copy import deepcopy
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Union, Tuple, List
import numpy as np
from numpy import ndarray
import os
from random import shuffle
from box import Box
from torch.onnx.symbolic_opset11 import index_copy

from .spec import ConfigSpec

@dataclass
class DatapathConfig(ConfigSpec):
    '''
    Config to handle input data paths.
    '''
    # root
    input_dataset_dir: str
    
    # use proportion data sampling
    use_prob: bool
    
    # cls: [(path_1, p_1), ...]
    data_path: Dict[str, List[Tuple[str, float]]]
    
    # how many files to return when using data sampling
    num_files: Union[int, None]
    
    @classmethod
    def from_args(cls, **kwargs) -> 'DatapathConfig':
        '''
        Make a temporary datapath from user inputs.
        '''
        input = kwargs.get('input', None)
        output = kwargs.get('output', None)
        recursive = kwargs.get('recursive', False)
        
    
    @classmethod
    def parse(cls, config) -> 'DatapathConfig':
        cls.check_keys(config)
        return DatapathConfig(
            input_dataset_dir=config.input_dataset_dir,
            use_prob=config.get('use_prob', True),
            data_path=config.data_path,
            num_files=config.get('num_files', None),
        )
    
    def split_by_cls(self) -> Dict[str, 'DatapathConfig']:
        res: Dict[str, DatapathConfig] = {}
        for cls in self.data_path:
            res[cls] = deepcopy(self)
            res[cls].data_path = {cls: self.data_path[cls]}
        return res

class Datapath():
    def __init__(
        self,
        config: Union[DatapathConfig, None]=None,
        files: Union[List[str], None]=None,
        cls: Union[str, None]=None,
    ):
        if config is not None:
            self.config     = config
            self.file_list  = []
            cls_probs_first = []
            cls_first       = []
            
            self.files_by_class: Dict[str, List[Dict]] = defaultdict(list)
            self.class_positions: Dict[str, List[int]] = defaultdict(list)
            self.cls_probs_second: Dict[str, ndarray] = defaultdict(List)
            
            for cls in self.config.data_path:
                prob = 0.
                probs_second = []
                for (path, p) in self.config.data_path[cls]:
                    prob += p
                    probs_second.append(p)
                    with open(path, 'r') as f:
                        file_items = []
                        missing = 0
                        for l in f.readlines():
                            raw_data_path = os.path.join(self.config.input_dataset_dir, l.strip(), 'raw_data.npz')
                            if not os.path.exists(raw_data_path):
                                missing += 1
                                continue
                            file_items.append({
                                'cls': cls,
                                'path': os.path.join(self.config.input_dataset_dir, l.strip()),
                                'prob': p
                            })
                        assert len(file_items) > 0, f"files in {path} are all missing! root: {self.config.input_dataset_dir}"
                        if missing > 0:
                            print(f"\033[31m{cls}: {missing} missing files\033[0m")
                        self.files_by_class[cls].append(file_items)
                        self.class_positions[cls].append(0)
                        self.file_list.extend(file_items)
                probs_second = np.array(probs_second)
                self.cls_probs_second[cls] = probs_second / probs_second.sum()
                cls_first.append(cls)
                cls_probs_first.append(prob)
            cls_probs_first = np.array(cls_probs_first)
            self.cls_first: List[str] = cls_first
            self.cls_probs_first: Dict[str, List[float]] = cls_probs_first / cls_probs_first.sum()
        elif files is not None:
            if cls is None:
                cls = 'inference'
            self.file_list  = [{'cls': cls, 'path': file} for file in files]
            cls_probs_first = np.array([1.])
            cls_first       = []
            
            self.files_by_class: Dict[str, List[Dict]] = {cls: self.file_list.copy()}
            self.class_positions: Dict[str, List[int]] = {cls: [0]}
            self.cls_probs_second: Dict[str, ndarray] = {cls: np.array([1.])}
            self.config = Box({'use_prob': False})
        else:
            assert(0)
    
    def __len__(self):
        if self.config.use_prob:
            assert self.config.num_files is not None, 'num_files is not specified'
            return self.config.num_files
        return len(self.file_list)
    
    def __getitem__(self, index) -> Tuple[str, str]:
        if self.config.use_prob:
            # first sample a class
            cls = np.random.choice(self.cls_first, p=self.cls_probs_first)
            
            # second sample in this class
            idx = np.random.choice(len(self.files_by_class[cls]), p=self.cls_probs_second[cls])
            
            # get the current position
            pos = self.class_positions[cls][idx]
            files = self.files_by_class[cls][idx]
            
            # get the item andd update position
            item = files[pos]
            self.class_positions[cls][idx] = (pos + 1) % len(files)
            if (pos + 1) % len(files) == 0:
                shuffle(self.files_by_class[cls][idx])
        else:
            item = self.file_list[index]
        return (item['cls'], item['path'])
    
    def get_data(self) -> List[Tuple[str, str]]:
        return [self[i] for i in range(len(self))]