lj1995 commited on
Commit
48ddb2e
·
1 Parent(s): 6c168a1

Delete eres2net

Browse files
eres2net/ERes2Net.py DELETED
@@ -1,260 +0,0 @@
1
- # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
- # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
-
4
- """
5
- Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
6
- ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
7
- The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
8
- The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
9
- """
10
-
11
-
12
- import torch
13
- import math
14
- import torch.nn as nn
15
- import torch.nn.functional as F
16
- import pooling_layers as pooling_layers
17
- from fusion import AFF
18
-
19
- class ReLU(nn.Hardtanh):
20
-
21
- def __init__(self, inplace=False):
22
- super(ReLU, self).__init__(0, 20, inplace)
23
-
24
- def __repr__(self):
25
- inplace_str = 'inplace' if self.inplace else ''
26
- return self.__class__.__name__ + ' (' \
27
- + inplace_str + ')'
28
-
29
-
30
- class BasicBlockERes2Net(nn.Module):
31
- expansion = 2
32
-
33
- def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
34
- super(BasicBlockERes2Net, self).__init__()
35
- width = int(math.floor(planes*(baseWidth/64.0)))
36
- self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
37
- self.bn1 = nn.BatchNorm2d(width*scale)
38
- self.nums = scale
39
-
40
- convs=[]
41
- bns=[]
42
- for i in range(self.nums):
43
- convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
44
- bns.append(nn.BatchNorm2d(width))
45
- self.convs = nn.ModuleList(convs)
46
- self.bns = nn.ModuleList(bns)
47
- self.relu = ReLU(inplace=True)
48
-
49
- self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
50
- self.bn3 = nn.BatchNorm2d(planes*self.expansion)
51
- self.shortcut = nn.Sequential()
52
- if stride != 1 or in_planes != self.expansion * planes:
53
- self.shortcut = nn.Sequential(
54
- nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
55
- stride=stride, bias=False),
56
- nn.BatchNorm2d(self.expansion * planes))
57
- self.stride = stride
58
- self.width = width
59
- self.scale = scale
60
-
61
- def forward(self, x):
62
- residual = x
63
-
64
- out = self.conv1(x)
65
- out = self.bn1(out)
66
- out = self.relu(out)
67
- spx = torch.split(out,self.width,1)
68
- for i in range(self.nums):
69
- if i==0:
70
- sp = spx[i]
71
- else:
72
- sp = sp + spx[i]
73
- sp = self.convs[i](sp)
74
- sp = self.relu(self.bns[i](sp))
75
- if i==0:
76
- out = sp
77
- else:
78
- out = torch.cat((out,sp),1)
79
-
80
- out = self.conv3(out)
81
- out = self.bn3(out)
82
-
83
- residual = self.shortcut(x)
84
- out += residual
85
- out = self.relu(out)
86
-
87
- return out
88
-
89
- class BasicBlockERes2Net_diff_AFF(nn.Module):
90
- expansion = 2
91
-
92
- def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
93
- super(BasicBlockERes2Net_diff_AFF, self).__init__()
94
- width = int(math.floor(planes*(baseWidth/64.0)))
95
- self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
96
- self.bn1 = nn.BatchNorm2d(width*scale)
97
- self.nums = scale
98
-
99
- convs=[]
100
- fuse_models=[]
101
- bns=[]
102
- for i in range(self.nums):
103
- convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
104
- bns.append(nn.BatchNorm2d(width))
105
- for j in range(self.nums - 1):
106
- fuse_models.append(AFF(channels=width))
107
-
108
- self.convs = nn.ModuleList(convs)
109
- self.bns = nn.ModuleList(bns)
110
- self.fuse_models = nn.ModuleList(fuse_models)
111
- self.relu = ReLU(inplace=True)
112
-
113
- self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
114
- self.bn3 = nn.BatchNorm2d(planes*self.expansion)
115
- self.shortcut = nn.Sequential()
116
- if stride != 1 or in_planes != self.expansion * planes:
117
- self.shortcut = nn.Sequential(
118
- nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
119
- stride=stride, bias=False),
120
- nn.BatchNorm2d(self.expansion * planes))
121
- self.stride = stride
122
- self.width = width
123
- self.scale = scale
124
-
125
- def forward(self, x):
126
- residual = x
127
-
128
- out = self.conv1(x)
129
- out = self.bn1(out)
130
- out = self.relu(out)
131
- spx = torch.split(out,self.width,1)
132
- for i in range(self.nums):
133
- if i==0:
134
- sp = spx[i]
135
- else:
136
- sp = self.fuse_models[i-1](sp, spx[i])
137
-
138
- sp = self.convs[i](sp)
139
- sp = self.relu(self.bns[i](sp))
140
- if i==0:
141
- out = sp
142
- else:
143
- out = torch.cat((out,sp),1)
144
-
145
- out = self.conv3(out)
146
- out = self.bn3(out)
147
-
148
- residual = self.shortcut(x)
149
- out += residual
150
- out = self.relu(out)
151
-
152
- return out
153
-
154
- class ERes2Net(nn.Module):
155
- def __init__(self,
156
- block=BasicBlockERes2Net,
157
- block_fuse=BasicBlockERes2Net_diff_AFF,
158
- num_blocks=[3, 4, 6, 3],
159
- m_channels=32,
160
- feat_dim=80,
161
- embedding_size=192,
162
- pooling_func='TSTP',
163
- two_emb_layer=False):
164
- super(ERes2Net, self).__init__()
165
- self.in_planes = m_channels
166
- self.feat_dim = feat_dim
167
- self.embedding_size = embedding_size
168
- self.stats_dim = int(feat_dim / 8) * m_channels * 8
169
- self.two_emb_layer = two_emb_layer
170
-
171
- self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
172
- self.bn1 = nn.BatchNorm2d(m_channels)
173
- self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
174
- self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
175
- self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
176
- self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
177
-
178
- # Downsampling module for each layer
179
- self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False)
180
- self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
181
- self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
182
-
183
- # Bottom-up fusion module
184
- self.fuse_mode12 = AFF(channels=m_channels * 4)
185
- self.fuse_mode123 = AFF(channels=m_channels * 8)
186
- self.fuse_mode1234 = AFF(channels=m_channels * 16)
187
-
188
- self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
189
- self.pool = getattr(pooling_layers, pooling_func)(
190
- in_dim=self.stats_dim * block.expansion)
191
- self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
192
- embedding_size)
193
- if self.two_emb_layer:
194
- self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
195
- self.seg_2 = nn.Linear(embedding_size, embedding_size)
196
- else:
197
- self.seg_bn_1 = nn.Identity()
198
- self.seg_2 = nn.Identity()
199
-
200
- def _make_layer(self, block, planes, num_blocks, stride):
201
- strides = [stride] + [1] * (num_blocks - 1)
202
- layers = []
203
- for stride in strides:
204
- layers.append(block(self.in_planes, planes, stride))
205
- self.in_planes = planes * block.expansion
206
- return nn.Sequential(*layers)
207
-
208
- def forward(self, x):
209
- x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
210
- x = x.unsqueeze_(1)
211
- out = F.relu(self.bn1(self.conv1(x)))
212
- out1 = self.layer1(out)
213
- out2 = self.layer2(out1)
214
- out1_downsample = self.layer1_downsample(out1)
215
- fuse_out12 = self.fuse_mode12(out2, out1_downsample)
216
- out3 = self.layer3(out2)
217
- fuse_out12_downsample = self.layer2_downsample(fuse_out12)
218
- fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
219
- out4 = self.layer4(out3)
220
- fuse_out123_downsample = self.layer3_downsample(fuse_out123)
221
- fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
222
- stats = self.pool(fuse_out1234)
223
-
224
- embed_a = self.seg_1(stats)
225
- if self.two_emb_layer:
226
- out = F.relu(embed_a)
227
- out = self.seg_bn_1(out)
228
- embed_b = self.seg_2(out)
229
- return embed_b
230
- else:
231
- return embed_a
232
-
233
- def forward3(self, x):
234
- x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
235
- x = x.unsqueeze_(1)
236
- out = F.relu(self.bn1(self.conv1(x)))
237
- out1 = self.layer1(out)
238
- out2 = self.layer2(out1)
239
- out1_downsample = self.layer1_downsample(out1)
240
- fuse_out12 = self.fuse_mode12(out2, out1_downsample)
241
- out3 = self.layer3(out2)
242
- fuse_out12_downsample = self.layer2_downsample(fuse_out12)
243
- fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
244
- out4 = self.layer4(out3)
245
- fuse_out123_downsample = self.layer3_downsample(fuse_out123)
246
- fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
247
- return fuse_out1234
248
-
249
-
250
- if __name__ == '__main__':
251
-
252
- x = torch.zeros(10, 300, 80)
253
- model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func='TSTP')
254
- model.eval()
255
- out = model(x)
256
- print(out.shape) # torch.Size([10, 192])
257
-
258
- num_params = sum(param.numel() for param in model.parameters())
259
- print("{} M".format(num_params / 1e6)) # 6.61M
260
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eres2net/ERes2NetV2.py DELETED
@@ -1,292 +0,0 @@
1
- # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
- # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
-
4
- """
5
- To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
6
- within each stage. However, this modification also increases the number of model parameters and computational complexity.
7
- To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
8
- both the model parameters and its computational cost.
9
- """
10
-
11
-
12
-
13
- import torch
14
- import math
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
- import pooling_layers as pooling_layers
18
- from fusion import AFF
19
-
20
- class ReLU(nn.Hardtanh):
21
-
22
- def __init__(self, inplace=False):
23
- super(ReLU, self).__init__(0, 20, inplace)
24
-
25
- def __repr__(self):
26
- inplace_str = 'inplace' if self.inplace else ''
27
- return self.__class__.__name__ + ' (' \
28
- + inplace_str + ')'
29
-
30
-
31
- class BasicBlockERes2NetV2(nn.Module):
32
-
33
- def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
34
- super(BasicBlockERes2NetV2, self).__init__()
35
- width = int(math.floor(planes*(baseWidth/64.0)))
36
- self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
37
- self.bn1 = nn.BatchNorm2d(width*scale)
38
- self.nums = scale
39
- self.expansion = expansion
40
-
41
- convs=[]
42
- bns=[]
43
- for i in range(self.nums):
44
- convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
45
- bns.append(nn.BatchNorm2d(width))
46
- self.convs = nn.ModuleList(convs)
47
- self.bns = nn.ModuleList(bns)
48
- self.relu = ReLU(inplace=True)
49
-
50
- self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
51
- self.bn3 = nn.BatchNorm2d(planes*self.expansion)
52
- self.shortcut = nn.Sequential()
53
- if stride != 1 or in_planes != self.expansion * planes:
54
- self.shortcut = nn.Sequential(
55
- nn.Conv2d(in_planes,
56
- self.expansion * planes,
57
- kernel_size=1,
58
- stride=stride,
59
- bias=False),
60
- nn.BatchNorm2d(self.expansion * planes))
61
- self.stride = stride
62
- self.width = width
63
- self.scale = scale
64
-
65
- def forward(self, x):
66
- residual = x
67
-
68
- out = self.conv1(x)
69
- out = self.bn1(out)
70
- out = self.relu(out)
71
- spx = torch.split(out,self.width,1)
72
- for i in range(self.nums):
73
- if i==0:
74
- sp = spx[i]
75
- else:
76
- sp = sp + spx[i]
77
- sp = self.convs[i](sp)
78
- sp = self.relu(self.bns[i](sp))
79
- if i==0:
80
- out = sp
81
- else:
82
- out = torch.cat((out,sp),1)
83
-
84
- out = self.conv3(out)
85
- out = self.bn3(out)
86
-
87
- residual = self.shortcut(x)
88
- out += residual
89
- out = self.relu(out)
90
-
91
- return out
92
-
93
- class BasicBlockERes2NetV2AFF(nn.Module):
94
-
95
- def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
96
- super(BasicBlockERes2NetV2AFF, self).__init__()
97
- width = int(math.floor(planes*(baseWidth/64.0)))
98
- self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
99
- self.bn1 = nn.BatchNorm2d(width*scale)
100
- self.nums = scale
101
- self.expansion = expansion
102
-
103
- convs=[]
104
- fuse_models=[]
105
- bns=[]
106
- for i in range(self.nums):
107
- convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
108
- bns.append(nn.BatchNorm2d(width))
109
- for j in range(self.nums - 1):
110
- fuse_models.append(AFF(channels=width, r=4))
111
-
112
- self.convs = nn.ModuleList(convs)
113
- self.bns = nn.ModuleList(bns)
114
- self.fuse_models = nn.ModuleList(fuse_models)
115
- self.relu = ReLU(inplace=True)
116
-
117
- self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
118
- self.bn3 = nn.BatchNorm2d(planes*self.expansion)
119
- self.shortcut = nn.Sequential()
120
- if stride != 1 or in_planes != self.expansion * planes:
121
- self.shortcut = nn.Sequential(
122
- nn.Conv2d(in_planes,
123
- self.expansion * planes,
124
- kernel_size=1,
125
- stride=stride,
126
- bias=False),
127
- nn.BatchNorm2d(self.expansion * planes))
128
- self.stride = stride
129
- self.width = width
130
- self.scale = scale
131
-
132
- def forward(self, x):
133
- residual = x
134
-
135
- out = self.conv1(x)
136
- out = self.bn1(out)
137
- out = self.relu(out)
138
- spx = torch.split(out,self.width,1)
139
- for i in range(self.nums):
140
- if i==0:
141
- sp = spx[i]
142
- else:
143
- sp = self.fuse_models[i-1](sp, spx[i])
144
-
145
- sp = self.convs[i](sp)
146
- sp = self.relu(self.bns[i](sp))
147
- if i==0:
148
- out = sp
149
- else:
150
- out = torch.cat((out,sp),1)
151
-
152
- out = self.conv3(out)
153
- out = self.bn3(out)
154
-
155
- residual = self.shortcut(x)
156
- out += residual
157
- out = self.relu(out)
158
-
159
- return out
160
-
161
- class ERes2NetV2(nn.Module):
162
- def __init__(self,
163
- block=BasicBlockERes2NetV2,
164
- block_fuse=BasicBlockERes2NetV2AFF,
165
- num_blocks=[3, 4, 6, 3],
166
- m_channels=64,
167
- feat_dim=80,
168
- embedding_size=192,
169
- baseWidth=26,
170
- scale=2,
171
- expansion=2,
172
- pooling_func='TSTP',
173
- two_emb_layer=False):
174
- super(ERes2NetV2, self).__init__()
175
- self.in_planes = m_channels
176
- self.feat_dim = feat_dim
177
- self.embedding_size = embedding_size
178
- self.stats_dim = int(feat_dim / 8) * m_channels * 8
179
- self.two_emb_layer = two_emb_layer
180
- self.baseWidth = baseWidth
181
- self.scale = scale
182
- self.expansion = expansion
183
-
184
- self.conv1 = nn.Conv2d(1,
185
- m_channels,
186
- kernel_size=3,
187
- stride=1,
188
- padding=1,
189
- bias=False)
190
- self.bn1 = nn.BatchNorm2d(m_channels)
191
- self.layer1 = self._make_layer(block,
192
- m_channels,
193
- num_blocks[0],
194
- stride=1)
195
- self.layer2 = self._make_layer(block,
196
- m_channels * 2,
197
- num_blocks[1],
198
- stride=2)
199
- self.layer3 = self._make_layer(block_fuse,
200
- m_channels * 4,
201
- num_blocks[2],
202
- stride=2)
203
- self.layer4 = self._make_layer(block_fuse,
204
- m_channels * 8,
205
- num_blocks[3],
206
- stride=2)
207
-
208
- # Downsampling module
209
- self.layer3_ds = nn.Conv2d(m_channels * 4 * self.expansion, m_channels * 8 * self.expansion, kernel_size=3, \
210
- padding=1, stride=2, bias=False)
211
-
212
- # Bottom-up fusion module
213
- self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
214
-
215
- self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
216
- self.pool = getattr(pooling_layers, pooling_func)(
217
- in_dim=self.stats_dim * self.expansion)
218
- self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
219
- embedding_size)
220
- if self.two_emb_layer:
221
- self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
222
- self.seg_2 = nn.Linear(embedding_size, embedding_size)
223
- else:
224
- self.seg_bn_1 = nn.Identity()
225
- self.seg_2 = nn.Identity()
226
-
227
- def _make_layer(self, block, planes, num_blocks, stride):
228
- strides = [stride] + [1] * (num_blocks - 1)
229
- layers = []
230
- for stride in strides:
231
- layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion))
232
- self.in_planes = planes * self.expansion
233
- return nn.Sequential(*layers)
234
-
235
- def forward(self, x):
236
- x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
237
- x = x.unsqueeze_(1)
238
- out = F.relu(self.bn1(self.conv1(x)))
239
- out1 = self.layer1(out)
240
- out2 = self.layer2(out1)
241
- out3 = self.layer3(out2)
242
- out4 = self.layer4(out3)
243
- out3_ds = self.layer3_ds(out3)
244
- fuse_out34 = self.fuse34(out4, out3_ds)
245
- stats = self.pool(fuse_out34)
246
-
247
- embed_a = self.seg_1(stats)
248
- if self.two_emb_layer:
249
- out = F.relu(embed_a)
250
- out = self.seg_bn_1(out)
251
- embed_b = self.seg_2(out)
252
- return embed_b
253
- else:
254
- return embed_a
255
-
256
- def forward3(self, x):
257
- x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
258
- x = x.unsqueeze_(1)
259
- out = F.relu(self.bn1(self.conv1(x)))
260
- out1 = self.layer1(out)
261
- out2 = self.layer2(out1)
262
- out3 = self.layer3(out2)
263
- out4 = self.layer4(out3)
264
- out3_ds = self.layer3_ds(out3)
265
- fuse_out34 = self.fuse34(out4, out3_ds)
266
- # print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72])
267
- return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1)
268
- # stats = self.pool(fuse_out34)
269
- #
270
- # embed_a = self.seg_1(stats)
271
- # if self.two_emb_layer:
272
- # out = F.relu(embed_a)
273
- # out = self.seg_bn_1(out)
274
- # embed_b = self.seg_2(out)
275
- # return embed_b
276
- # else:
277
- # return embed_a
278
-
279
- if __name__ == '__main__':
280
-
281
- x = torch.randn(1, 300, 80)
282
- model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
283
- model.eval()
284
- y = model(x)
285
- print(y.size())
286
- macs, num_params = profile(model, inputs=(x, ))
287
- print("Params: {} M".format(num_params / 1e6)) # 17.86 M
288
- print("MACs: {} G".format(macs / 1e9)) # 12.69 G
289
-
290
-
291
-
292
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eres2net/ERes2Net_huge.py DELETED
@@ -1,286 +0,0 @@
1
- # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
- # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
-
4
- """ Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
5
- ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
6
- The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
7
- The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
8
- ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
9
- recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
10
- """
11
- import pdb
12
-
13
- import torch
14
- import math
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
- import pooling_layers as pooling_layers
18
- from fusion import AFF
19
-
20
- class ReLU(nn.Hardtanh):
21
-
22
- def __init__(self, inplace=False):
23
- super(ReLU, self).__init__(0, 20, inplace)
24
-
25
- def __repr__(self):
26
- inplace_str = 'inplace' if self.inplace else ''
27
- return self.__class__.__name__ + ' (' \
28
- + inplace_str + ')'
29
-
30
-
31
- class BasicBlockERes2Net(nn.Module):
32
- expansion = 4
33
-
34
- def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
35
- super(BasicBlockERes2Net, self).__init__()
36
- width = int(math.floor(planes*(baseWidth/64.0)))
37
- self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
38
- self.bn1 = nn.BatchNorm2d(width*scale)
39
- self.nums = scale
40
-
41
- convs=[]
42
- bns=[]
43
- for i in range(self.nums):
44
- convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
45
- bns.append(nn.BatchNorm2d(width))
46
- self.convs = nn.ModuleList(convs)
47
- self.bns = nn.ModuleList(bns)
48
- self.relu = ReLU(inplace=True)
49
-
50
- self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
51
- self.bn3 = nn.BatchNorm2d(planes*self.expansion)
52
- self.shortcut = nn.Sequential()
53
- if stride != 1 or in_planes != self.expansion * planes:
54
- self.shortcut = nn.Sequential(
55
- nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
56
- nn.BatchNorm2d(self.expansion * planes))
57
- self.stride = stride
58
- self.width = width
59
- self.scale = scale
60
-
61
- def forward(self, x):
62
- residual = x
63
-
64
- out = self.conv1(x)
65
- out = self.bn1(out)
66
- out = self.relu(out)
67
- spx = torch.split(out,self.width,1)
68
- for i in range(self.nums):
69
- if i==0:
70
- sp = spx[i]
71
- else:
72
- sp = sp + spx[i]
73
- sp = self.convs[i](sp)
74
- sp = self.relu(self.bns[i](sp))
75
- if i==0:
76
- out = sp
77
- else:
78
- out = torch.cat((out,sp),1)
79
-
80
- out = self.conv3(out)
81
- out = self.bn3(out)
82
-
83
- residual = self.shortcut(x)
84
- out += residual
85
- out = self.relu(out)
86
-
87
- return out
88
-
89
- class BasicBlockERes2Net_diff_AFF(nn.Module):
90
- expansion = 4
91
-
92
- def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
93
- super(BasicBlockERes2Net_diff_AFF, self).__init__()
94
- width = int(math.floor(planes*(baseWidth/64.0)))
95
- self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
96
- self.bn1 = nn.BatchNorm2d(width*scale)
97
- self.nums = scale
98
-
99
- convs=[]
100
- fuse_models=[]
101
- bns=[]
102
- for i in range(self.nums):
103
- convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
104
- bns.append(nn.BatchNorm2d(width))
105
- for j in range(self.nums - 1):
106
- fuse_models.append(AFF(channels=width))
107
-
108
- self.convs = nn.ModuleList(convs)
109
- self.bns = nn.ModuleList(bns)
110
- self.fuse_models = nn.ModuleList(fuse_models)
111
- self.relu = ReLU(inplace=True)
112
-
113
- self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
114
- self.bn3 = nn.BatchNorm2d(planes*self.expansion)
115
- self.shortcut = nn.Sequential()
116
- if stride != 1 or in_planes != self.expansion * planes:
117
- self.shortcut = nn.Sequential(
118
- nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
119
- nn.BatchNorm2d(self.expansion * planes))
120
- self.stride = stride
121
- self.width = width
122
- self.scale = scale
123
-
124
- def forward(self, x):
125
- residual = x
126
-
127
- out = self.conv1(x)
128
- out = self.bn1(out)
129
- out = self.relu(out)
130
- spx = torch.split(out,self.width,1)
131
- for i in range(self.nums):
132
- if i==0:
133
- sp = spx[i]
134
- else:
135
- sp = self.fuse_models[i-1](sp, spx[i])
136
-
137
- sp = self.convs[i](sp)
138
- sp = self.relu(self.bns[i](sp))
139
- if i==0:
140
- out = sp
141
- else:
142
- out = torch.cat((out,sp),1)
143
-
144
-
145
- out = self.conv3(out)
146
- out = self.bn3(out)
147
-
148
- residual = self.shortcut(x)
149
- out += residual
150
- out = self.relu(out)
151
-
152
- return out
153
-
154
- class ERes2Net(nn.Module):
155
- def __init__(self,
156
- block=BasicBlockERes2Net,
157
- block_fuse=BasicBlockERes2Net_diff_AFF,
158
- num_blocks=[3, 4, 6, 3],
159
- m_channels=64,
160
- feat_dim=80,
161
- embedding_size=192,
162
- pooling_func='TSTP',
163
- two_emb_layer=False):
164
- super(ERes2Net, self).__init__()
165
- self.in_planes = m_channels
166
- self.feat_dim = feat_dim
167
- self.embedding_size = embedding_size
168
- self.stats_dim = int(feat_dim / 8) * m_channels * 8
169
- self.two_emb_layer = two_emb_layer
170
-
171
- self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
172
- self.bn1 = nn.BatchNorm2d(m_channels)
173
-
174
- self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
175
- self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
176
- self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
177
- self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
178
-
179
- self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
180
- self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
181
- self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False)
182
-
183
- self.fuse_mode12 = AFF(channels=m_channels * 8)
184
- self.fuse_mode123 = AFF(channels=m_channels * 16)
185
- self.fuse_mode1234 = AFF(channels=m_channels * 32)
186
-
187
- self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
188
- self.pool = getattr(pooling_layers, pooling_func)(
189
- in_dim=self.stats_dim * block.expansion)
190
- self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
191
- if self.two_emb_layer:
192
- self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
193
- self.seg_2 = nn.Linear(embedding_size, embedding_size)
194
- else:
195
- self.seg_bn_1 = nn.Identity()
196
- self.seg_2 = nn.Identity()
197
-
198
- def _make_layer(self, block, planes, num_blocks, stride):
199
- strides = [stride] + [1] * (num_blocks - 1)
200
- layers = []
201
- for stride in strides:
202
- layers.append(block(self.in_planes, planes, stride))
203
- self.in_planes = planes * block.expansion
204
- return nn.Sequential(*layers)
205
-
206
- def forward(self, x):
207
- x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
208
-
209
- x = x.unsqueeze_(1)
210
- out = F.relu(self.bn1(self.conv1(x)))
211
- out1 = self.layer1(out)
212
- out2 = self.layer2(out1)
213
- out1_downsample = self.layer1_downsample(out1)
214
- fuse_out12 = self.fuse_mode12(out2, out1_downsample)
215
- out3 = self.layer3(out2)
216
- fuse_out12_downsample = self.layer2_downsample(fuse_out12)
217
- fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
218
- out4 = self.layer4(out3)
219
- fuse_out123_downsample = self.layer3_downsample(fuse_out123)
220
- fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
221
- stats = self.pool(fuse_out1234)
222
-
223
- embed_a = self.seg_1(stats)
224
- if self.two_emb_layer:
225
- out = F.relu(embed_a)
226
- out = self.seg_bn_1(out)
227
- embed_b = self.seg_2(out)
228
- return embed_b
229
- else:
230
- return embed_a
231
-
232
- def forward2(self, x,if_mean):
233
- x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
234
-
235
- x = x.unsqueeze_(1)
236
- out = F.relu(self.bn1(self.conv1(x)))
237
- out1 = self.layer1(out)
238
- out2 = self.layer2(out1)
239
- out1_downsample = self.layer1_downsample(out1)
240
- fuse_out12 = self.fuse_mode12(out2, out1_downsample)
241
- out3 = self.layer3(out2)
242
- fuse_out12_downsample = self.layer2_downsample(fuse_out12)
243
- fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
244
- out4 = self.layer4(out3)
245
- fuse_out123_downsample = self.layer3_downsample(fuse_out123)
246
- fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2)#bs,20480,T
247
- if(if_mean==False):
248
- mean=fuse_out1234[0].transpose(1,0)#(T,20480),bs=T
249
- else:
250
- mean = fuse_out1234.mean(2)#bs,20480
251
- mean_std=torch.cat([mean,torch.zeros_like(mean)],1)
252
- return self.seg_1(mean_std)#(T,192)
253
-
254
-
255
- # stats = self.pool(fuse_out1234)
256
- # if self.two_emb_layer:
257
- # out = F.relu(embed_a)
258
- # out = self.seg_bn_1(out)
259
- # embed_b = self.seg_2(out)
260
- # return embed_b
261
- # else:
262
- # return embed_a
263
-
264
- def forward3(self, x):
265
- x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
266
-
267
- x = x.unsqueeze_(1)
268
- out = F.relu(self.bn1(self.conv1(x)))
269
- out1 = self.layer1(out)
270
- out2 = self.layer2(out1)
271
- out1_downsample = self.layer1_downsample(out1)
272
- fuse_out12 = self.fuse_mode12(out2, out1_downsample)
273
- out3 = self.layer3(out2)
274
- fuse_out12_downsample = self.layer2_downsample(fuse_out12)
275
- fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
276
- out4 = self.layer4(out3)
277
- fuse_out123_downsample = self.layer3_downsample(fuse_out123)
278
- fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
279
- return fuse_out1234
280
- # print(fuse_out1234.shape)
281
- # print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
282
- # pdb.set_trace()
283
-
284
-
285
-
286
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eres2net/fusion.py DELETED
@@ -1,29 +0,0 @@
1
- # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
- # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
-
8
- class AFF(nn.Module):
9
-
10
- def __init__(self, channels=64, r=4):
11
- super(AFF, self).__init__()
12
- inter_channels = int(channels // r)
13
-
14
- self.local_att = nn.Sequential(
15
- nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
16
- nn.BatchNorm2d(inter_channels),
17
- nn.SiLU(inplace=True),
18
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
19
- nn.BatchNorm2d(channels),
20
- )
21
-
22
- def forward(self, x, ds_y):
23
- xa = torch.cat((x, ds_y), dim=1)
24
- x_att = self.local_att(xa)
25
- x_att = 1.0 + torch.tanh(x_att)
26
- xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att)
27
-
28
- return xo
29
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eres2net/kaldi.py DELETED
@@ -1,819 +0,0 @@
1
- import math
2
- from typing import Tuple
3
-
4
- import torch
5
- import torchaudio
6
- from torch import Tensor
7
-
8
- __all__ = [
9
- "get_mel_banks",
10
- "inverse_mel_scale",
11
- "inverse_mel_scale_scalar",
12
- "mel_scale",
13
- "mel_scale_scalar",
14
- "spectrogram",
15
- "fbank",
16
- "mfcc",
17
- "vtln_warp_freq",
18
- "vtln_warp_mel_freq",
19
- ]
20
-
21
- # numeric_limits<float>::epsilon() 1.1920928955078125e-07
22
- EPSILON = torch.tensor(torch.finfo(torch.float).eps)
23
- # 1 milliseconds = 0.001 seconds
24
- MILLISECONDS_TO_SECONDS = 0.001
25
-
26
- # window types
27
- HAMMING = "hamming"
28
- HANNING = "hanning"
29
- POVEY = "povey"
30
- RECTANGULAR = "rectangular"
31
- BLACKMAN = "blackman"
32
- WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
33
-
34
-
35
- def _get_epsilon(device, dtype):
36
- return EPSILON.to(device=device, dtype=dtype)
37
-
38
-
39
- def _next_power_of_2(x: int) -> int:
40
- r"""Returns the smallest power of 2 that is greater than x"""
41
- return 1 if x == 0 else 2 ** (x - 1).bit_length()
42
-
43
-
44
- def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
45
- r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
46
- representing how the window is shifted along the waveform. Each row is a frame.
47
-
48
- Args:
49
- waveform (Tensor): Tensor of size ``num_samples``
50
- window_size (int): Frame length
51
- window_shift (int): Frame shift
52
- snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
53
- in the file, and the number of frames depends on the frame_length. If False, the number of frames
54
- depends only on the frame_shift, and we reflect the data at the ends.
55
-
56
- Returns:
57
- Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
58
- """
59
- assert waveform.dim() == 1
60
- num_samples = waveform.size(0)
61
- strides = (window_shift * waveform.stride(0), waveform.stride(0))
62
-
63
- if snip_edges:
64
- if num_samples < window_size:
65
- return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
66
- else:
67
- m = 1 + (num_samples - window_size) // window_shift
68
- else:
69
- reversed_waveform = torch.flip(waveform, [0])
70
- m = (num_samples + (window_shift // 2)) // window_shift
71
- pad = window_size // 2 - window_shift // 2
72
- pad_right = reversed_waveform
73
- if pad > 0:
74
- # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
75
- # but we want [2, 1, 0, 0, 1, 2]
76
- pad_left = reversed_waveform[-pad:]
77
- waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
78
- else:
79
- # pad is negative so we want to trim the waveform at the front
80
- waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
81
-
82
- sizes = (m, window_size)
83
- return waveform.as_strided(sizes, strides)
84
-
85
-
86
- def _feature_window_function(
87
- window_type: str,
88
- window_size: int,
89
- blackman_coeff: float,
90
- device: torch.device,
91
- dtype: int,
92
- ) -> Tensor:
93
- r"""Returns a window function with the given type and size"""
94
- if window_type == HANNING:
95
- return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
96
- elif window_type == HAMMING:
97
- return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
98
- elif window_type == POVEY:
99
- # like hanning but goes to zero at edges
100
- return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
101
- elif window_type == RECTANGULAR:
102
- return torch.ones(window_size, device=device, dtype=dtype)
103
- elif window_type == BLACKMAN:
104
- a = 2 * math.pi / (window_size - 1)
105
- window_function = torch.arange(window_size, device=device, dtype=dtype)
106
- # can't use torch.blackman_window as they use different coefficients
107
- return (
108
- blackman_coeff
109
- - 0.5 * torch.cos(a * window_function)
110
- + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
111
- ).to(device=device, dtype=dtype)
112
- else:
113
- raise Exception("Invalid window type " + window_type)
114
-
115
-
116
- def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
117
- r"""Returns the log energy of size (m) for a strided_input (m,*)"""
118
- device, dtype = strided_input.device, strided_input.dtype
119
- log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
120
- if energy_floor == 0.0:
121
- return log_energy
122
- return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
123
-
124
-
125
- def _get_waveform_and_window_properties(
126
- waveform: Tensor,
127
- channel: int,
128
- sample_frequency: float,
129
- frame_shift: float,
130
- frame_length: float,
131
- round_to_power_of_two: bool,
132
- preemphasis_coefficient: float,
133
- ) -> Tuple[Tensor, int, int, int]:
134
- r"""Gets the waveform and window properties"""
135
- channel = max(channel, 0)
136
- assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
137
- waveform = waveform[channel, :] # size (n)
138
- window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
139
- window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
140
- padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
141
-
142
- assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
143
- window_size, len(waveform)
144
- )
145
- assert 0 < window_shift, "`window_shift` must be greater than 0"
146
- assert padded_window_size % 2 == 0, (
147
- "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
148
- )
149
- assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
150
- assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
151
- return waveform, window_shift, window_size, padded_window_size
152
-
153
-
154
- def _get_window(
155
- waveform: Tensor,
156
- padded_window_size: int,
157
- window_size: int,
158
- window_shift: int,
159
- window_type: str,
160
- blackman_coeff: float,
161
- snip_edges: bool,
162
- raw_energy: bool,
163
- energy_floor: float,
164
- dither: float,
165
- remove_dc_offset: bool,
166
- preemphasis_coefficient: float,
167
- ) -> Tuple[Tensor, Tensor]:
168
- r"""Gets a window and its log energy
169
-
170
- Returns:
171
- (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
172
- """
173
- device, dtype = waveform.device, waveform.dtype
174
- epsilon = _get_epsilon(device, dtype)
175
-
176
- # size (m, window_size)
177
- strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
178
-
179
- if dither != 0.0:
180
- rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
181
- strided_input = strided_input + rand_gauss * dither
182
-
183
- if remove_dc_offset:
184
- # Subtract each row/frame by its mean
185
- row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
186
- strided_input = strided_input - row_means
187
-
188
- if raw_energy:
189
- # Compute the log energy of each row/frame before applying preemphasis and
190
- # window function
191
- signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
192
-
193
- if preemphasis_coefficient != 0.0:
194
- # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
195
- offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
196
- 0
197
- ) # size (m, window_size + 1)
198
- strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
199
-
200
- # Apply window_function to each row/frame
201
- window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
202
- 0
203
- ) # size (1, window_size)
204
- strided_input = strided_input * window_function # size (m, window_size)
205
-
206
- # Pad columns with zero until we reach size (m, padded_window_size)
207
- if padded_window_size != window_size:
208
- padding_right = padded_window_size - window_size
209
- strided_input = torch.nn.functional.pad(
210
- strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
211
- ).squeeze(0)
212
-
213
- # Compute energy after window function (not the raw one)
214
- if not raw_energy:
215
- signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
216
-
217
- return strided_input, signal_log_energy
218
-
219
-
220
- def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
221
- # subtracts the column mean of the tensor size (m, n) if subtract_mean=True
222
- # it returns size (m, n)
223
- if subtract_mean:
224
- col_means = torch.mean(tensor, dim=0).unsqueeze(0)
225
- tensor = tensor - col_means
226
- return tensor
227
-
228
-
229
- def spectrogram(
230
- waveform: Tensor,
231
- blackman_coeff: float = 0.42,
232
- channel: int = -1,
233
- dither: float = 0.0,
234
- energy_floor: float = 1.0,
235
- frame_length: float = 25.0,
236
- frame_shift: float = 10.0,
237
- min_duration: float = 0.0,
238
- preemphasis_coefficient: float = 0.97,
239
- raw_energy: bool = True,
240
- remove_dc_offset: bool = True,
241
- round_to_power_of_two: bool = True,
242
- sample_frequency: float = 16000.0,
243
- snip_edges: bool = True,
244
- subtract_mean: bool = False,
245
- window_type: str = POVEY,
246
- ) -> Tensor:
247
- r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
248
- compute-spectrogram-feats.
249
-
250
- Args:
251
- waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
252
- blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
253
- channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
254
- dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
255
- the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
256
- energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
257
- this floor is applied to the zeroth component, representing the total signal energy. The floor on the
258
- individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
259
- frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
260
- frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
261
- min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
262
- preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
263
- raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
264
- remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
265
- round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
266
- to FFT. (Default: ``True``)
267
- sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
268
- specified there) (Default: ``16000.0``)
269
- snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
270
- in the file, and the number of frames depends on the frame_length. If False, the number of frames
271
- depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
272
- subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
273
- it this way. (Default: ``False``)
274
- window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
275
- (Default: ``'povey'``)
276
-
277
- Returns:
278
- Tensor: A spectrogram identical to what Kaldi would output. The shape is
279
- (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
280
- """
281
- device, dtype = waveform.device, waveform.dtype
282
- epsilon = _get_epsilon(device, dtype)
283
-
284
- waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
285
- waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
286
- )
287
-
288
- if len(waveform) < min_duration * sample_frequency:
289
- # signal is too short
290
- return torch.empty(0)
291
-
292
- strided_input, signal_log_energy = _get_window(
293
- waveform,
294
- padded_window_size,
295
- window_size,
296
- window_shift,
297
- window_type,
298
- blackman_coeff,
299
- snip_edges,
300
- raw_energy,
301
- energy_floor,
302
- dither,
303
- remove_dc_offset,
304
- preemphasis_coefficient,
305
- )
306
-
307
- # size (m, padded_window_size // 2 + 1, 2)
308
- fft = torch.fft.rfft(strided_input)
309
-
310
- # Convert the FFT into a power spectrum
311
- power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
312
- power_spectrum[:, 0] = signal_log_energy
313
-
314
- power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
315
- return power_spectrum
316
-
317
-
318
- def inverse_mel_scale_scalar(mel_freq: float) -> float:
319
- return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
320
-
321
-
322
- def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
323
- return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
324
-
325
-
326
- def mel_scale_scalar(freq: float) -> float:
327
- return 1127.0 * math.log(1.0 + freq / 700.0)
328
-
329
-
330
- def mel_scale(freq: Tensor) -> Tensor:
331
- return 1127.0 * (1.0 + freq / 700.0).log()
332
-
333
-
334
- def vtln_warp_freq(
335
- vtln_low_cutoff: float,
336
- vtln_high_cutoff: float,
337
- low_freq: float,
338
- high_freq: float,
339
- vtln_warp_factor: float,
340
- freq: Tensor,
341
- ) -> Tensor:
342
- r"""This computes a VTLN warping function that is not the same as HTK's one,
343
- but has similar inputs (this function has the advantage of never producing
344
- empty bins).
345
-
346
- This function computes a warp function F(freq), defined between low_freq
347
- and high_freq inclusive, with the following properties:
348
- F(low_freq) == low_freq
349
- F(high_freq) == high_freq
350
- The function is continuous and piecewise linear with two inflection
351
- points.
352
- The lower inflection point (measured in terms of the unwarped
353
- frequency) is at frequency l, determined as described below.
354
- The higher inflection point is at a frequency h, determined as
355
- described below.
356
- If l <= f <= h, then F(f) = f/vtln_warp_factor.
357
- If the higher inflection point (measured in terms of the unwarped
358
- frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
359
- Since (by the last point) F(h) == h/vtln_warp_factor, then
360
- max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
361
- h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
362
- = vtln_high_cutoff * min(1, vtln_warp_factor).
363
- If the lower inflection point (measured in terms of the unwarped
364
- frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
365
- This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
366
- = vtln_low_cutoff * max(1, vtln_warp_factor)
367
- Args:
368
- vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
369
- vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
370
- low_freq (float): Lower frequency cutoffs in mel computation
371
- high_freq (float): Upper frequency cutoffs in mel computation
372
- vtln_warp_factor (float): Vtln warp factor
373
- freq (Tensor): given frequency in Hz
374
-
375
- Returns:
376
- Tensor: Freq after vtln warp
377
- """
378
- assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
379
- assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
380
- l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
381
- h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
382
- scale = 1.0 / vtln_warp_factor
383
- Fl = scale * l # F(l)
384
- Fh = scale * h # F(h)
385
- assert l > low_freq and h < high_freq
386
- # slope of left part of the 3-piece linear function
387
- scale_left = (Fl - low_freq) / (l - low_freq)
388
- # [slope of center part is just "scale"]
389
-
390
- # slope of right part of the 3-piece linear function
391
- scale_right = (high_freq - Fh) / (high_freq - h)
392
-
393
- res = torch.empty_like(freq)
394
-
395
- outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
396
- before_l = torch.lt(freq, l) # freq < l
397
- before_h = torch.lt(freq, h) # freq < h
398
- after_h = torch.ge(freq, h) # freq >= h
399
-
400
- # order of operations matter here (since there is overlapping frequency regions)
401
- res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
402
- res[before_h] = scale * freq[before_h]
403
- res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
404
- res[outside_low_high_freq] = freq[outside_low_high_freq]
405
-
406
- return res
407
-
408
-
409
- def vtln_warp_mel_freq(
410
- vtln_low_cutoff: float,
411
- vtln_high_cutoff: float,
412
- low_freq,
413
- high_freq: float,
414
- vtln_warp_factor: float,
415
- mel_freq: Tensor,
416
- ) -> Tensor:
417
- r"""
418
- Args:
419
- vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
420
- vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
421
- low_freq (float): Lower frequency cutoffs in mel computation
422
- high_freq (float): Upper frequency cutoffs in mel computation
423
- vtln_warp_factor (float): Vtln warp factor
424
- mel_freq (Tensor): Given frequency in Mel
425
-
426
- Returns:
427
- Tensor: ``mel_freq`` after vtln warp
428
- """
429
- return mel_scale(
430
- vtln_warp_freq(
431
- vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
432
- )
433
- )
434
-
435
-
436
- def get_mel_banks(
437
- num_bins: int,
438
- window_length_padded: int,
439
- sample_freq: float,
440
- low_freq: float,
441
- high_freq: float,
442
- vtln_low: float,
443
- vtln_high: float,
444
- vtln_warp_factor: float,device=None,dtype=None
445
- ) -> Tuple[Tensor, Tensor]:
446
- """
447
- Returns:
448
- (Tensor, Tensor): The tuple consists of ``bins`` (which is
449
- melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
450
- center frequencies of bins of size (``num_bins``)).
451
- """
452
- assert num_bins > 3, "Must have at least 3 mel bins"
453
- assert window_length_padded % 2 == 0
454
- num_fft_bins = window_length_padded / 2
455
- nyquist = 0.5 * sample_freq
456
-
457
- if high_freq <= 0.0:
458
- high_freq += nyquist
459
-
460
- assert (
461
- (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
462
- ), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
463
-
464
- # fft-bin width [think of it as Nyquist-freq / half-window-length]
465
- fft_bin_width = sample_freq / window_length_padded
466
- mel_low_freq = mel_scale_scalar(low_freq)
467
- mel_high_freq = mel_scale_scalar(high_freq)
468
-
469
- # divide by num_bins+1 in next line because of end-effects where the bins
470
- # spread out to the sides.
471
- mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
472
-
473
- if vtln_high < 0.0:
474
- vtln_high += nyquist
475
-
476
- assert vtln_warp_factor == 1.0 or (
477
- (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
478
- ), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
479
- vtln_low, vtln_high, low_freq, high_freq
480
- )
481
-
482
- bin = torch.arange(num_bins).unsqueeze(1)
483
- left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
484
- center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
485
- right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
486
-
487
- if vtln_warp_factor != 1.0:
488
- left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
489
- center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
490
- right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
491
-
492
- # center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
493
- # size(1, num_fft_bins)
494
- mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
495
-
496
- # size (num_bins, num_fft_bins)
497
- up_slope = (mel - left_mel) / (center_mel - left_mel)
498
- down_slope = (right_mel - mel) / (right_mel - center_mel)
499
-
500
- if vtln_warp_factor == 1.0:
501
- # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
502
- bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
503
- else:
504
- # warping can move the order of left_mel, center_mel, right_mel anywhere
505
- bins = torch.zeros_like(up_slope)
506
- up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
507
- down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
508
- bins[up_idx] = up_slope[up_idx]
509
- bins[down_idx] = down_slope[down_idx]
510
-
511
- return bins.to(device=device,dtype=dtype)#, center_freqs
512
-
513
- cache={}
514
- def fbank(
515
- waveform: Tensor,
516
- blackman_coeff: float = 0.42,
517
- channel: int = -1,
518
- dither: float = 0.0,
519
- energy_floor: float = 1.0,
520
- frame_length: float = 25.0,
521
- frame_shift: float = 10.0,
522
- high_freq: float = 0.0,
523
- htk_compat: bool = False,
524
- low_freq: float = 20.0,
525
- min_duration: float = 0.0,
526
- num_mel_bins: int = 23,
527
- preemphasis_coefficient: float = 0.97,
528
- raw_energy: bool = True,
529
- remove_dc_offset: bool = True,
530
- round_to_power_of_two: bool = True,
531
- sample_frequency: float = 16000.0,
532
- snip_edges: bool = True,
533
- subtract_mean: bool = False,
534
- use_energy: bool = False,
535
- use_log_fbank: bool = True,
536
- use_power: bool = True,
537
- vtln_high: float = -500.0,
538
- vtln_low: float = 100.0,
539
- vtln_warp: float = 1.0,
540
- window_type: str = POVEY,
541
- ) -> Tensor:
542
- r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
543
- compute-fbank-feats.
544
-
545
- Args:
546
- waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
547
- blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
548
- channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
549
- dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
550
- the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
551
- energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
552
- this floor is applied to the zeroth component, representing the total signal energy. The floor on the
553
- individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
554
- frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
555
- frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
556
- high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
557
- (Default: ``0.0``)
558
- htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
559
- (need to change other parameters). (Default: ``False``)
560
- low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
561
- min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
562
- num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
563
- preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
564
- raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
565
- remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
566
- round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
567
- to FFT. (Default: ``True``)
568
- sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
569
- specified there) (Default: ``16000.0``)
570
- snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
571
- in the file, and the number of frames depends on the frame_length. If False, the number of frames
572
- depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
573
- subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
574
- it this way. (Default: ``False``)
575
- use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
576
- use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
577
- use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
578
- vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
579
- negative, offset from high-mel-freq (Default: ``-500.0``)
580
- vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
581
- vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
582
- window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
583
- (Default: ``'povey'``)
584
-
585
- Returns:
586
- Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
587
- where m is calculated in _get_strided
588
- """
589
- device, dtype = waveform.device, waveform.dtype
590
-
591
- waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
592
- waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
593
- )
594
-
595
- if len(waveform) < min_duration * sample_frequency:
596
- # signal is too short
597
- return torch.empty(0, device=device, dtype=dtype)
598
-
599
- # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
600
- strided_input, signal_log_energy = _get_window(
601
- waveform,
602
- padded_window_size,
603
- window_size,
604
- window_shift,
605
- window_type,
606
- blackman_coeff,
607
- snip_edges,
608
- raw_energy,
609
- energy_floor,
610
- dither,
611
- remove_dc_offset,
612
- preemphasis_coefficient,
613
- )
614
-
615
- # size (m, padded_window_size // 2 + 1)
616
- spectrum = torch.fft.rfft(strided_input).abs()
617
- if use_power:
618
- spectrum = spectrum.pow(2.0)
619
-
620
- # size (num_mel_bins, padded_window_size // 2)
621
- # print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
622
-
623
- cache_key="%s-%s-%s-%s-%s-%s-%s-%s-%s-%s"%(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype)
624
- if cache_key not in cache:
625
- mel_energies = get_mel_banks(
626
- num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype
627
- )
628
- cache[cache_key]=mel_energies
629
- else:
630
- mel_energies=cache[cache_key]
631
-
632
- # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
633
- mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
634
-
635
- # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
636
- mel_energies = torch.mm(spectrum, mel_energies.T)
637
- if use_log_fbank:
638
- # avoid log of zero (which should be prevented anyway by dithering)
639
- mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
640
-
641
- # if use_energy then add it as the last column for htk_compat == true else first column
642
- if use_energy:
643
- signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
644
- # returns size (m, num_mel_bins + 1)
645
- if htk_compat:
646
- mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
647
- else:
648
- mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
649
-
650
- mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
651
- return mel_energies
652
-
653
-
654
- def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
655
- # returns a dct matrix of size (num_mel_bins, num_ceps)
656
- # size (num_mel_bins, num_mel_bins)
657
- dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
658
- # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
659
- # this would be the first column in the dct_matrix for torchaudio as it expects a
660
- # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
661
- # expects a left multiply e.g. dct_matrix * vector).
662
- dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
663
- dct_matrix = dct_matrix[:, :num_ceps]
664
- return dct_matrix
665
-
666
-
667
- def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
668
- # returns size (num_ceps)
669
- # Compute liftering coefficients (scaling on cepstral coeffs)
670
- # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
671
- i = torch.arange(num_ceps)
672
- return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
673
-
674
-
675
- def mfcc(
676
- waveform: Tensor,
677
- blackman_coeff: float = 0.42,
678
- cepstral_lifter: float = 22.0,
679
- channel: int = -1,
680
- dither: float = 0.0,
681
- energy_floor: float = 1.0,
682
- frame_length: float = 25.0,
683
- frame_shift: float = 10.0,
684
- high_freq: float = 0.0,
685
- htk_compat: bool = False,
686
- low_freq: float = 20.0,
687
- num_ceps: int = 13,
688
- min_duration: float = 0.0,
689
- num_mel_bins: int = 23,
690
- preemphasis_coefficient: float = 0.97,
691
- raw_energy: bool = True,
692
- remove_dc_offset: bool = True,
693
- round_to_power_of_two: bool = True,
694
- sample_frequency: float = 16000.0,
695
- snip_edges: bool = True,
696
- subtract_mean: bool = False,
697
- use_energy: bool = False,
698
- vtln_high: float = -500.0,
699
- vtln_low: float = 100.0,
700
- vtln_warp: float = 1.0,
701
- window_type: str = POVEY,
702
- ) -> Tensor:
703
- r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
704
- compute-mfcc-feats.
705
-
706
- Args:
707
- waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
708
- blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
709
- cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
710
- channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
711
- dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
712
- the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
713
- energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
714
- this floor is applied to the zeroth component, representing the total signal energy. The floor on the
715
- individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
716
- frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
717
- frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
718
- high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
719
- (Default: ``0.0``)
720
- htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
721
- features (need to change other parameters). (Default: ``False``)
722
- low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
723
- num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
724
- min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
725
- num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
726
- preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
727
- raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
728
- remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
729
- round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
730
- to FFT. (Default: ``True``)
731
- sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
732
- specified there) (Default: ``16000.0``)
733
- snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
734
- in the file, and the number of frames depends on the frame_length. If False, the number of frames
735
- depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
736
- subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
737
- it this way. (Default: ``False``)
738
- use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
739
- vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
740
- negative, offset from high-mel-freq (Default: ``-500.0``)
741
- vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
742
- vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
743
- window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
744
- (Default: ``"povey"``)
745
-
746
- Returns:
747
- Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
748
- where m is calculated in _get_strided
749
- """
750
- assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
751
-
752
- device, dtype = waveform.device, waveform.dtype
753
-
754
- # The mel_energies should not be squared (use_power=True), not have mean subtracted
755
- # (subtract_mean=False), and use log (use_log_fbank=True).
756
- # size (m, num_mel_bins + use_energy)
757
- feature = fbank(
758
- waveform=waveform,
759
- blackman_coeff=blackman_coeff,
760
- channel=channel,
761
- dither=dither,
762
- energy_floor=energy_floor,
763
- frame_length=frame_length,
764
- frame_shift=frame_shift,
765
- high_freq=high_freq,
766
- htk_compat=htk_compat,
767
- low_freq=low_freq,
768
- min_duration=min_duration,
769
- num_mel_bins=num_mel_bins,
770
- preemphasis_coefficient=preemphasis_coefficient,
771
- raw_energy=raw_energy,
772
- remove_dc_offset=remove_dc_offset,
773
- round_to_power_of_two=round_to_power_of_two,
774
- sample_frequency=sample_frequency,
775
- snip_edges=snip_edges,
776
- subtract_mean=False,
777
- use_energy=use_energy,
778
- use_log_fbank=True,
779
- use_power=True,
780
- vtln_high=vtln_high,
781
- vtln_low=vtln_low,
782
- vtln_warp=vtln_warp,
783
- window_type=window_type,
784
- )
785
-
786
- if use_energy:
787
- # size (m)
788
- signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
789
- # offset is 0 if htk_compat==True else 1
790
- mel_offset = int(not htk_compat)
791
- feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
792
-
793
- # size (num_mel_bins, num_ceps)
794
- dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
795
-
796
- # size (m, num_ceps)
797
- feature = feature.matmul(dct_matrix)
798
-
799
- if cepstral_lifter != 0.0:
800
- # size (1, num_ceps)
801
- lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
802
- feature *= lifter_coeffs.to(device=device, dtype=dtype)
803
-
804
- # if use_energy then replace the last column for htk_compat == true else first column
805
- if use_energy:
806
- feature[:, 0] = signal_log_energy
807
-
808
- if htk_compat:
809
- energy = feature[:, 0].unsqueeze(1) # size (m, 1)
810
- feature = feature[:, 1:] # size (m, num_ceps - 1)
811
- if not use_energy:
812
- # scale on C0 (actually removing a scale we previously added that's
813
- # part of one common definition of the cosine transform.)
814
- energy *= math.sqrt(2)
815
-
816
- feature = torch.cat((feature, energy), dim=1)
817
-
818
- feature = _subtract_column_mean(feature, subtract_mean)
819
- return feature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eres2net/pooling_layers.py DELETED
@@ -1,104 +0,0 @@
1
- # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
- # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
-
4
- """ This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
5
-
6
- import torch
7
- import torch.nn as nn
8
-
9
-
10
- class TAP(nn.Module):
11
- """
12
- Temporal average pooling, only first-order mean is considered
13
- """
14
- def __init__(self, **kwargs):
15
- super(TAP, self).__init__()
16
-
17
- def forward(self, x):
18
- pooling_mean = x.mean(dim=-1)
19
- # To be compatable with 2D input
20
- pooling_mean = pooling_mean.flatten(start_dim=1)
21
- return pooling_mean
22
-
23
-
24
- class TSDP(nn.Module):
25
- """
26
- Temporal standard deviation pooling, only second-order std is considered
27
- """
28
- def __init__(self, **kwargs):
29
- super(TSDP, self).__init__()
30
-
31
- def forward(self, x):
32
- # The last dimension is the temporal axis
33
- pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
34
- pooling_std = pooling_std.flatten(start_dim=1)
35
- return pooling_std
36
-
37
-
38
- class TSTP(nn.Module):
39
- """
40
- Temporal statistics pooling, concatenate mean and std, which is used in
41
- x-vector
42
- Comment: simple concatenation can not make full use of both statistics
43
- """
44
- def __init__(self, **kwargs):
45
- super(TSTP, self).__init__()
46
-
47
- def forward(self, x):
48
- # The last dimension is the temporal axis
49
- pooling_mean = x.mean(dim=-1)
50
- pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
51
- pooling_mean = pooling_mean.flatten(start_dim=1)
52
- pooling_std = pooling_std.flatten(start_dim=1)
53
-
54
- stats = torch.cat((pooling_mean, pooling_std), 1)
55
- return stats
56
-
57
-
58
- class ASTP(nn.Module):
59
- """ Attentive statistics pooling: Channel- and context-dependent
60
- statistics pooling, first used in ECAPA_TDNN.
61
- """
62
- def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
63
- super(ASTP, self).__init__()
64
- self.global_context_att = global_context_att
65
-
66
- # Use Conv1d with stride == 1 rather than Linear, then we don't
67
- # need to transpose inputs.
68
- if global_context_att:
69
- self.linear1 = nn.Conv1d(
70
- in_dim * 3, bottleneck_dim,
71
- kernel_size=1) # equals W and b in the paper
72
- else:
73
- self.linear1 = nn.Conv1d(
74
- in_dim, bottleneck_dim,
75
- kernel_size=1) # equals W and b in the paper
76
- self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
77
- kernel_size=1) # equals V and k in the paper
78
-
79
- def forward(self, x):
80
- """
81
- x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
82
- or a 4-dimensional tensor in resnet architecture (B,C,F,T)
83
- 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
84
- """
85
- if len(x.shape) == 4:
86
- x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
87
- assert len(x.shape) == 3
88
-
89
- if self.global_context_att:
90
- context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
91
- context_std = torch.sqrt(
92
- torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
93
- x_in = torch.cat((x, context_mean, context_std), dim=1)
94
- else:
95
- x_in = x
96
-
97
- # DON'T use ReLU here! ReLU may be hard to converge.
98
- alpha = torch.tanh(
99
- self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
100
- alpha = torch.softmax(self.linear2(alpha), dim=2)
101
- mean = torch.sum(alpha * x, dim=2)
102
- var = torch.sum(alpha * (x**2), dim=2) - mean**2
103
- std = torch.sqrt(var.clamp(min=1e-10))
104
- return torch.cat([mean, std], dim=1)