File size: 10,953 Bytes
39d2f14
 
 
 
 
 
 
 
 
 
 
 
 
597cecf
39d2f14
 
597cecf
 
 
 
 
 
 
39d2f14
 
 
 
 
 
597cecf
 
 
39d2f14
 
597cecf
 
 
 
 
39d2f14
597cecf
 
 
39d2f14
 
 
 
 
 
 
 
597cecf
 
 
39d2f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597cecf
 
 
39d2f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597cecf
 
 
 
 
 
 
 
 
39d2f14
 
 
597cecf
39d2f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597cecf
39d2f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597cecf
39d2f14
 
 
 
 
 
597cecf
39d2f14
 
597cecf
39d2f14
 
597cecf
39d2f14
 
 
 
597cecf
39d2f14
 
 
 
 
597cecf
39d2f14
 
 
 
597cecf
39d2f14
597cecf
 
 
 
 
 
 
 
 
39d2f14
 
 
 
 
 
597cecf
39d2f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597cecf
39d2f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597cecf
39d2f14
 
 
 
 
 
 
 
 
 
 
597cecf
39d2f14
 
 
 
 
 
 
 
597cecf
 
 
39d2f14
 
 
597cecf
39d2f14
 
 
 
597cecf
39d2f14
 
 
 
 
 
 
 
597cecf
39d2f14
597cecf
39d2f14
 
 
 
597cecf
39d2f14
 
 
 
 
 
597cecf
 
39d2f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597cecf
39d2f14
 
 
 
 
 
 
 
 
 
 
597cecf
39d2f14
597cecf
39d2f14
 
 
 
597cecf
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""

from __future__ import annotations

import torch
import torch.nn.functional as F
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding

from f5_tts.model.modules import (AdaLayerNormZero_Final,
                                  ConvPositionEmbedding, DiTBlock, MMDiTBlock,
                                  TimestepEmbedding, get_pos_embed_indices,
                                  precompute_freqs_cis)
from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
                                list_str_to_tensor, mask_from_frac_lengths)

# text embedding


class TextEmbedding(nn.Module):
    def __init__(self, out_dim, text_num_embeds):
        super().__init__()
        self.text_embed = nn.Embedding(
            text_num_embeds + 1, out_dim
        )  # will use 0 as filler token

        self.precompute_max_pos = 1024
        self.register_buffer(
            "freqs_cis",
            precompute_freqs_cis(out_dim, self.precompute_max_pos),
            persistent=False,
        )

    def forward(
        self, text: int["b nt"], drop_text=False
    ) -> int["b nt d"]:  # noqa: F722
        text = text + 1
        if drop_text:
            text = torch.zeros_like(text)
        text = self.text_embed(text)

        # sinus pos emb
        batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
        batch_text_len = text.shape[1]
        pos_idx = get_pos_embed_indices(
            batch_start, batch_text_len, max_pos=self.precompute_max_pos
        )
        text_pos_embed = self.freqs_cis[pos_idx]

        text = text + text_pos_embed

        return text


# noised input & masked cond audio embedding


class AudioEmbedding(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(2 * in_dim, out_dim)
        self.conv_pos_embed = ConvPositionEmbedding(out_dim)

    def forward(
        self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False
    ):  # noqa: F722
        if drop_audio_cond:
            cond = torch.zeros_like(cond)
        x = torch.cat((x, cond), dim=-1)
        x = self.linear(x)
        x = self.conv_pos_embed(x) + x
        return x


# Transformer backbone using MM-DiT blocks


class MMDiT(nn.Module):
    def __init__(
        self,
        *,
        dim,
        text_depth=4,
        depth=8,
        heads=8,
        dim_head=64,
        dropout=0.1,
        ff_mult=4,
        text_num_embeds=256,
        mel_dim=100,
        checkpoint_activations=False,
        text_encoder=True,
    ):
        super().__init__()

        self.time_embed = TimestepEmbedding(dim)
        if text_encoder:
            self.text_encoder = TextEncoder(
                text_num_embeds=text_num_embeds,
                text_dim=dim,
                depth=text_depth,
                heads=heads,
                dim_head=dim_head,
                ff_mult=ff_mult,
                dropout=dropout,
            )
        else:
            self.text_encoder = None
            self.text_embed = TextEmbedding(dim, text_num_embeds)

        self.audio_embed = AudioEmbedding(mel_dim, dim)

        self.rotary_embed = RotaryEmbedding(dim_head)

        self.dim = dim
        self.depth = depth

        self.transformer_blocks = nn.ModuleList(
            [
                MMDiTBlock(
                    dim=dim,
                    heads=heads,
                    dim_head=dim_head,
                    dropout=dropout,
                    ff_mult=ff_mult,
                    context_pre_only=i == depth - 1,
                )
                for i in range(depth)
            ]
        )
        self.norm_out = AdaLayerNormZero_Final(dim)  # final modulation
        self.proj_out = nn.Linear(dim, mel_dim)

        self.checkpoint_activations = checkpoint_activations

    def forward(
        self,
        x: float["b n d"],  # nosied input audio  # noqa: F722
        cond: float["b n d"],  # masked cond audio  # noqa: F722
        text: int["b nt"],  # text  # noqa: F722
        time: float["b"] | float[""],  # time step  # noqa: F821 F722
        drop_audio_cond,  # cfg for cond audio
        drop_text,  # cfg for text
        mask: bool["b n"] | None = None,  # noqa: F722
        text_mask: bool["b nt"] | None = None,  # noqa: F722
    ):
        batch = x.shape[0]
        if time.ndim == 0:
            time = time.repeat(batch)

        # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
        t = self.time_embed(time)
        if self.text_encoder is not None:
            c = self.text_encoder(text, t, mask=text_mask, drop_text=drop_text)
        else:
            c = self.text_embed(text, drop_text=drop_text)

        x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)

        seq_len = x.shape[1]
        text_len = text.shape[1]
        rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
        rope_text = self.rotary_embed.forward_from_seq_len(text_len)

        # if mask is not None:
        #     rope_audio = self.rotary_embed.forward_from_seq_len(seq_len + 1)

        #     dummy_token = torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device, dtype=x.dtype)
        #     x = torch.cat([x, dummy_token], dim=1)  # shape is now [b, nw+1, d]

        #     # pad the mask so that new dummy token is always masked out
        #     # mask: [b, nw] -> [b, nw+1]
        #     false_col = torch.zeros((x.shape[0], 1), dtype=torch.bool, device=x.device)
        #     mask = torch.cat([mask, false_col], dim=1)

        # if text_mask is not None:
        #     rope_text = self.rotary_embed.forward_from_seq_len(text_len + 1)

        #     dummy_token = torch.zeros((c.shape[0], 1, c.shape[-1]), device=c.device, dtype=c.dtype)
        #     c = torch.cat([c, dummy_token], dim=1)  # shape is now [b, nt+1, d]

        #     # pad the text mask so that new dummy token is always masked out
        #     # text_mask: [b, nt] -> [b, nt+1]
        #     false_col = torch.zeros((c.shape[0], 1), dtype=torch.bool, device=c.device)
        #     text_mask = torch.cat([text_mask, false_col], dim=1)

        for block in self.transformer_blocks:
            c, x = block(
                x,
                c,
                t,
                mask=mask,
                src_mask=text_mask,
                rope=rope_audio,
                c_rope=rope_text,
            )

        x = self.norm_out(x, t)
        output = self.proj_out(x)

        return output


class TextEncoder(nn.Module):
    def __init__(
        self,
        text_num_embeds: int,
        text_dim: int = 512,
        depth: int = 4,
        heads: int = 8,
        dim_head: int = 64,
        ff_mult: int = 4,
        dropout: float = 0.1,
    ):
        """
        A simple text encoder: an embedding layer + multiple DiTBlocks or any other
        transformer blocks for text-only self-attention.
        """
        super().__init__()
        # Embeddings
        self.text_embed = TextEmbedding(text_dim, text_num_embeds)
        self.rotary_embed = RotaryEmbedding(dim_head)

        # Example stack of DiTBlocks or any custom blocks
        self.transformer_blocks = nn.ModuleList(
            [
                DiTBlock(
                    dim=text_dim,
                    heads=heads,
                    dim_head=dim_head,
                    ff_mult=ff_mult,
                    dropout=dropout,
                )
                for _ in range(depth)
            ]
        )

    def forward(
        self,
        text: int["b nt"],  # noqa: F821
        time: float["b"] | float[""],  # time step  # noqa: F821 F722
        mask: bool["b nt"] | None = None,  # noqa: F821 F722
        drop_text: bool = False,
    ):
        """
        Encode text into hidden states of shape [b, nt, d].
        """
        batch, seq_len, device = text.shape[0], text.shape[1], text.device

        if drop_text:
            text = torch.zeros_like(text)

        # Basic embedding
        hidden_states = self.text_embed(text, seq_len)  # [b, nt, d]

        # lens and mask
        rope = self.rotary_embed.forward_from_seq_len(seq_len)

        # Pass through self-attention blocks
        for block in self.transformer_blocks:
            # Here, you likely want standard self-attn, so no cross-attn
            hidden_states = block(
                x=hidden_states,
                t=time,  # no time embedding for the text encoder by default
                mask=mask,  # or pass a text mask if needed
                rope=rope,  # pass a rope if you want rotary embeddings for text
            )
        return hidden_states


if __name__ == "__main__":
    from f5_tts.model.utils import get_tokenizer

    bsz = 16

    tokenizer = "pinyin"  # 'pinyin', 'char', or 'custom'
    tokenizer_path = None  # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
    dataset_name = "Emilia_ZH_EN"
    if tokenizer == "custom":
        tokenizer_path = tokenizer_path
    else:
        tokenizer_path = dataset_name
    vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)

    text = ["hello world"] * bsz
    text_lens = torch.ones((bsz,), dtype=torch.long) * len("hello world")
    text_lens[-1] = 5
    device = "cuda"
    batch = bsz
    time_embed = TimestepEmbedding(512).to(device)

    # handle text as string
    if isinstance(text, list):
        if exists(vocab_char_map):
            text = list_str_to_idx(text, vocab_char_map).to(device)
        else:
            text = list_str_to_tensor(text).to(device)
        assert text.shape[0] == batch

    time = torch.rand((batch,), device=device)
    text_mask = lens_to_mask(text_lens).to(device)

    # # test text encoder
    # text_encoder = TextEncoder(
    #     text_num_embeds=vocab_size,
    #     text_dim=512,
    #     depth=4,
    #     heads=8,
    #     dim_head=64,
    #     ff_mult=4,
    #     dropout=0.1
    # ).to('cuda')
    # hidden_states = text_encoder(text, time_embed(time), mask)
    # print(hidden_states.shape)  # [bsz, seq_len, text_dim]

    # test MMDiT
    mel_dim = 80
    model = MMDiT(
        dim=512,
        text_depth=4,
        depth=8,
        heads=8,
        dim_head=64,
        dropout=0.1,
        ff_mult=4,
        text_num_embeds=vocab_size,
        mel_dim=mel_dim,
    ).to(device)

    x = torch.rand((batch, 100, mel_dim), device=device)
    cond = torch.rand((batch, 100, mel_dim), device=device)
    lens = torch.ones((batch,), dtype=torch.long) * 100
    mask = lens_to_mask(lens).to(device)

    output = model(
        x,
        cond,
        text,
        time,
        drop_audio_cond=False,
        drop_text=False,
        mask=mask,
        text_mask=text_mask,
    )

    print(output.shape)  # [bsz, seq_len, mel_dim]