File size: 3,022 Bytes
841f290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import copy
from typing import Optional
from wenet.dataset.dataset import Dataset, BigDataset

from wenet.text.base_tokenizer import BaseTokenizer


def init_asr_dataset(data_type,
                     data_list_file,
                     tokenizer: Optional[BaseTokenizer] = None,
                     conf=None,
                     partition=True):
    return Dataset(data_type, data_list_file, tokenizer, conf, partition)

def init_asr_big_dataset(data_type,
                     data_list_file_s2t,
                     data_list_file_t2s,
                     data_list_file_s2s,
                     data_list_file_t2t,
                     tokenizer: Optional[BaseTokenizer] = None,
                     conf=None,
                     partition=True):
    return BigDataset(data_type, data_list_file_s2t, data_list_file_t2s, data_list_file_s2s, data_list_file_t2t, tokenizer, conf, partition)

def init_dataset(dataset_type,
                 data_type,
                 data_list_file,
                 tokenizer: Optional[BaseTokenizer] = None,
                 conf=None,
                 partition=True,
                 split='train'):
    assert dataset_type in ['asr', 'ssl']

    if split != 'train':
        cv_conf = copy.deepcopy(conf)
        cv_conf['cycle'] = 1
        cv_conf['speed_perturb'] = False
        cv_conf['spec_aug'] = False
        cv_conf['spec_sub'] = False
        cv_conf['spec_trim'] = False
        cv_conf['shuffle'] = False
        cv_conf['list_shuffle'] = False
        conf = cv_conf

    if dataset_type == 'asr':
        return init_asr_dataset(data_type, data_list_file, tokenizer, conf,
                                partition)
    else:
        from wenet.ssl.init_dataset import init_dataset as init_ssl_dataset
        return init_ssl_dataset(data_type, data_list_file, conf, partition)


def init_big_dataset(dataset_type,
                 data_type,
                 data_list_file_s2t,
                 data_list_file_t2s,
                 data_list_file_s2s,
                 data_list_file_t2t,
                 tokenizer: Optional[BaseTokenizer] = None,
                 conf=None,
                 partition=True,
                 split='train'):
    assert dataset_type in ['asr', 'ssl']

    if split != 'train':
        cv_conf = copy.deepcopy(conf)
        cv_conf['cycle'] = 1
        cv_conf['speed_perturb'] = False
        cv_conf['spec_aug'] = False
        cv_conf['spec_sub'] = False
        cv_conf['spec_trim'] = False
        cv_conf['shuffle'] = False
        cv_conf['list_shuffle'] = False
        conf = cv_conf

    if dataset_type == 'asr':
        return init_asr_big_dataset(data_type,
                                data_list_file_s2t, data_list_file_t2s, data_list_file_s2s, data_list_file_t2t,
                                tokenizer, conf, partition)
    else:
        pass
        # from wenet.ssl.init_dataset import init_dataset as init_ssl_dataset
        # return init_ssl_dataset(data_type, data_list_file, conf, partition)