File size: 10,555 Bytes
506a2b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# Adapted from https://github.com/CorentinJ/Real-Time-Voice-Cloning
# MIT License
from typing import List, Union, Optional

import numpy as np
from numpy.lib.stride_tricks import as_strided
import librosa
import torch
import torch.nn.functional as F
from torch import nn, Tensor

from .config import VoiceEncConfig
from .melspec import melspectrogram


def pack(arrays, seq_len: int=None, pad_value=0):
    """
    Given a list of length B of array-like objects of shapes (Ti, ...), packs them in a single tensor of
    shape (B, T, ...) by padding each individual array on the right.

    :param arrays: a list of array-like objects of matching shapes except for the first axis.
    :param seq_len: the value of T. It must be the maximum of the lengths Ti of the arrays at
    minimum. Will default to that value if None.
    :param pad_value: the value to pad the arrays with.
    :return: a (B, T, ...) tensor
    """
    if seq_len is None:
        seq_len = max(len(array) for array in arrays)
    else:
        assert seq_len >= max(len(array) for array in arrays)

    # Convert lists to np.array
    if isinstance(arrays[0], list):
        arrays = [np.array(array) for array in arrays]

    # Convert to tensor and handle device
    device = None
    if isinstance(arrays[0], torch.Tensor):
        tensors = arrays
        device = tensors[0].device
    else:
        tensors = [torch.as_tensor(array) for array in arrays]

    # Fill the packed tensor with the array data
    packed_shape = (len(tensors), seq_len, *tensors[0].shape[1:])
    packed_tensor = torch.full(packed_shape, pad_value, dtype=tensors[0].dtype, device=device)

    for i, tensor in enumerate(tensors):
        packed_tensor[i, :tensor.size(0)] = tensor

    return packed_tensor


def get_num_wins(
    n_frames: int,
    step: int,
    min_coverage: float,
    hp: VoiceEncConfig,
):
    assert n_frames > 0
    win_size = hp.ve_partial_frames
    n_wins, remainder = divmod(max(n_frames - win_size + step, 0), step)
    if n_wins == 0 or (remainder + (win_size - step)) / win_size >= min_coverage:
        n_wins += 1
    target_n = win_size + step * (n_wins - 1)
    return n_wins, target_n


def get_frame_step(
    overlap: float,
    rate: float,
    hp: VoiceEncConfig,
):
    # Compute how many frames separate two partial utterances
    assert 0 <= overlap < 1
    if rate is None:
        frame_step = int(np.round(hp.ve_partial_frames * (1 - overlap)))
    else:
        frame_step = int(np.round((hp.sample_rate / rate) / hp.ve_partial_frames))
    assert 0 < frame_step <= hp.ve_partial_frames
    return frame_step


def stride_as_partials(
    mel: np.ndarray,
    hp: VoiceEncConfig,
    overlap=0.5,
    rate: float=None,
    min_coverage=0.8,
):
    """
    Takes unscaled mels in (T, M) format
    TODO: doc
    """
    assert 0 < min_coverage <= 1
    frame_step = get_frame_step(overlap, rate, hp)

    # Compute how many partials can fit in the mel
    n_partials, target_len = get_num_wins(len(mel), frame_step, min_coverage, hp)

    # Trim or pad the mel spectrogram to match the number of partials
    if target_len > len(mel):
        mel = np.concatenate((mel, np.full((target_len - len(mel), hp.num_mels), 0)))
    elif target_len < len(mel):
        mel = mel[:target_len]

    # Ensure the numpy array data is float32 and contiguous in memory
    mel = mel.astype(np.float32, order="C")

    # Re-arrange the array in memory to be of shape (N, P, M) with partials overlapping eachother,
    # where N is the number of partials, P is the number of frames of each partial and M the
    # number of channels of the mel spectrograms.
    shape = (n_partials, hp.ve_partial_frames, hp.num_mels)
    strides = (mel.strides[0] * frame_step, mel.strides[0], mel.strides[1])
    partials = as_strided(mel, shape, strides)
    return partials


class VoiceEncoder(nn.Module):
    def __init__(self, hp=VoiceEncConfig()):
        super().__init__()

        self.hp = hp

        # Network definition
        self.lstm = nn.LSTM(self.hp.num_mels, self.hp.ve_hidden_size, num_layers=3, batch_first=True)
        if hp.flatten_lstm_params:
            self.lstm.flatten_parameters()
        self.proj = nn.Linear(self.hp.ve_hidden_size, self.hp.speaker_embed_size)

        # Cosine similarity scaling (fixed initial parameter values)
        self.similarity_weight = nn.Parameter(torch.tensor([10.]), requires_grad=True)
        self.similarity_bias = nn.Parameter(torch.tensor([-5.]), requires_grad=True)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, mels: torch.FloatTensor):
        """
        Computes the embeddings of a batch of partial utterances.

        :param mels: a batch of unscaled mel spectrograms of same duration as a float32 tensor
        of shape (B, T, M) where T is hp.ve_partial_frames
        :return: the embeddings as a float32 tensor of shape (B, E) where E is
        hp.speaker_embed_size. Embeddings are L2-normed and thus lay in the range [-1, 1].
        """
        if self.hp.normalized_mels and (mels.min() < 0 or mels.max() > 1):
            raise Exception(f"Mels outside [0, 1]. Min={mels.min()}, Max={mels.max()}")

        # Pass the input through the LSTM layers
        _, (hidden, _) = self.lstm(mels)

        # Project the final hidden state
        raw_embeds = self.proj(hidden[-1])
        if self.hp.ve_final_relu:
            raw_embeds = F.relu(raw_embeds)

        # L2 normalize the embeddings.
        return raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True)

    def inference(self, mels: torch.Tensor, mel_lens, overlap=0.5, rate: float=None, min_coverage=0.8, batch_size=None):
        """
        Computes the embeddings of a batch of full utterances with gradients.

        :param mels: (B, T, M) unscaled mels
        :return: (B, E) embeddings on CPU
        """
        mel_lens = mel_lens.tolist() if torch.is_tensor(mel_lens) else mel_lens

        # Compute where to split the utterances into partials
        frame_step = get_frame_step(overlap, rate, self.hp)
        n_partials, target_lens = zip(*(get_num_wins(l, frame_step, min_coverage, self.hp) for l in mel_lens))

        # Possibly pad the mels to reach the target lengths
        len_diff = max(target_lens) - mels.size(1)
        if len_diff > 0:
            pad = torch.full((mels.size(0), len_diff, self.hp.num_mels), 0, dtype=torch.float32)
            mels = torch.cat((mels, pad.to(mels.device)), dim=1)

        # Group all partials together so that we can batch them easily
        partials = [
            mel[i * frame_step: i * frame_step + self.hp.ve_partial_frames]
            for mel, n_partial in zip(mels, n_partials) for i in range(n_partial)
        ]
        assert all(partials[0].shape == partial.shape for partial in partials)
        partials = torch.stack(partials)

        # Forward the partials
        n_chunks = int(np.ceil(len(partials) / (batch_size or len(partials))))
        partial_embeds = torch.cat([self(batch) for batch in partials.chunk(n_chunks)], dim=0).cpu()

        # Reduce the partial embeds into full embeds and L2-normalize them
        slices = np.concatenate(([0], np.cumsum(n_partials)))
        raw_embeds = [torch.mean(partial_embeds[start:end], dim=0) for start, end in zip(slices[:-1], slices[1:])]
        raw_embeds = torch.stack(raw_embeds)
        embeds = raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True)

        return embeds

    @staticmethod
    def utt_to_spk_embed(utt_embeds: np.ndarray):
        """
        Takes an array of L2-normalized utterance embeddings, computes the mean embedding and L2-normalize it to get a
        speaker embedding.
        """
        assert utt_embeds.ndim == 2
        utt_embeds = np.mean(utt_embeds, axis=0)
        return utt_embeds / np.linalg.norm(utt_embeds, 2)

    @staticmethod
    def voice_similarity(embeds_x: np.ndarray, embeds_y: np.ndarray):
        """
        Cosine similarity for L2-normalized utterance embeddings or speaker embeddings
        """
        embeds_x = embeds_x if embeds_x.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_x)
        embeds_y = embeds_y if embeds_y.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_y)
        return embeds_x @ embeds_y

    def embeds_from_mels(
        self, mels: Union[Tensor, List[np.ndarray]], mel_lens=None, as_spk=False, batch_size=32, **kwargs
    ):
        """
        Convenience function for deriving utterance or speaker embeddings from mel spectrograms.

        :param mels: unscaled mels strictly within [0, 1] as either a (B, T, M) tensor or a list of (Ti, M) arrays.
        :param mel_lens: if passing mels as a tensor, individual mel lengths
        :param as_spk: whether to return utterance embeddings or a single speaker embedding
        :param kwargs: args for inference()

        :returns: embeds as a (B, E) float32 numpy array if <as_spk> is False, else as a (E,) array
        """
        # Load mels in memory and pack them
        if isinstance(mels, List):
            mels = [np.asarray(mel) for mel in mels]
            assert all(m.shape[1] == mels[0].shape[1] for m in mels), "Mels aren't in (B, T, M) format"
            mel_lens = [mel.shape[0] for mel in mels]
            mels = pack(mels)

        # Embed them
        with torch.inference_mode():
            utt_embeds = self.inference(mels.to(self.device), mel_lens, batch_size=batch_size, **kwargs).numpy()

        return self.utt_to_spk_embed(utt_embeds) if as_spk else utt_embeds

    def embeds_from_wavs(
        self,
        wavs: List[np.ndarray],
        sample_rate,
        as_spk=False,
        batch_size=32,
        trim_top_db: Optional[float]=20,
        **kwargs
    ):
        """
        Wrapper around embeds_from_mels

        :param trim_top_db: this argument was only added for the sake of compatibility with metavoice's implementation
        """
        if sample_rate != self.hp.sample_rate:
            wavs = [
                librosa.resample(wav, orig_sr=sample_rate, target_sr=self.hp.sample_rate, res_type="kaiser_fast")
                for wav in wavs
            ]

        if trim_top_db:
            wavs = [librosa.effects.trim(wav, top_db=trim_top_db)[0] for wav in wavs]

        if "rate" not in kwargs:
            kwargs["rate"] = 1.3  # Resemble's default value.

        mels = [melspectrogram(w, self.hp).T for w in wavs]

        return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs)