File size: 16,738 Bytes
249b0a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e668b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249b0a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
import tempfile
import time
from pathlib import Path
from typing import Optional, Tuple
import spaces

import gradio as gr
import numpy as np
import soundfile as sf
import torch

from dia.model import Dia

# Model selection
DIA_MODELS = {
    "Dhivehi Dia-1.6B": "alakxender/Dia-1.6B-dhivehi-ep1",
    #"Dhivehi 18k": "alakxender/Dia-1.6B-dhivehi-18k"
}

dia_models = {}

def load_dia_model(model_id):
    if model_id not in dia_models:
        print(f"Loading model {model_id}")
        dia_models[model_id] = Dia.from_pretrained(model_id)
    print(f"Loaded model {model_id}")
    return dia_models[model_id]

@spaces.GPU
def run_inference(
    text_input: str,
    audio_prompt_input: Optional[Tuple[int, np.ndarray]],
    transcription_input: Optional[str],
    max_new_tokens: int,
    cfg_scale: float,
    temperature: float,
    top_p: float,
    cfg_filter_top_k: int,
    speed_factor: float,
    model_name: str,
):
    model_id = DIA_MODELS[model_name]
    model = load_dia_model(model_id)
    if not text_input or text_input.isspace():
        raise gr.Error("Text input cannot be empty.")
    temp_txt_file_path = None
    temp_audio_prompt_path = None
    output_audio = (44100, np.zeros(1, dtype=np.float32))
    try:
        prompt_path_for_generate = None
        if audio_prompt_input is not None:
            sr, audio_data = audio_prompt_input
            duration_sec = len(audio_data) / float(sr) if sr else 0
            if duration_sec > 10.0:
                raise gr.Error("Audio prompt must be 10 seconds or shorter.")
            if (
                audio_data is None or audio_data.size == 0 or audio_data.max() == 0
            ):
                gr.Warning("Audio prompt seems empty or silent, ignoring prompt.")
            else:
                with tempfile.NamedTemporaryFile(
                    mode="wb", suffix=".wav", delete=False
                ) as f_audio:
                    temp_audio_prompt_path = f_audio.name
                    if np.issubdtype(audio_data.dtype, np.integer):
                        max_val = np.iinfo(audio_data.dtype).max
                        audio_data = audio_data.astype(np.float32) / max_val
                    elif not np.issubdtype(audio_data.dtype, np.floating):
                        gr.Warning(
                            f"Unsupported audio prompt dtype {audio_data.dtype}, attempting conversion."
                        )
                        try:
                            audio_data = audio_data.astype(np.float32)
                        except Exception as conv_e:
                            raise gr.Error(
                                f"Failed to convert audio prompt to float32: {conv_e}"
                            )
                    if audio_data.ndim > 1:
                        if audio_data.shape[0] == 2:
                            audio_data = np.mean(audio_data, axis=0)
                        elif audio_data.shape[1] == 2:
                            audio_data = np.mean(audio_data, axis=1)
                        else:
                            gr.Warning(
                                f"Audio prompt has unexpected shape {audio_data.shape}, taking first channel/axis."
                            )
                            audio_data = (
                                audio_data[0]
                                if audio_data.shape[0] < audio_data.shape[1]
                                else audio_data[:, 0]
                            )
                        audio_data = np.ascontiguousarray(audio_data)
                    try:
                        sf.write(
                            temp_audio_prompt_path, audio_data, sr, subtype="FLOAT"
                        )
                        prompt_path_for_generate = temp_audio_prompt_path
                        print(
                            f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})"
                        )
                    except Exception as write_e:
                        print(f"Error writing temporary audio file: {write_e}")
                        raise gr.Error(f"Failed to save audio prompt: {write_e}")
        start_time = time.time()
        with torch.inference_mode():
            combined_text = (
                text_input.strip() + "\n" + transcription_input.strip()
                if transcription_input and not transcription_input.isspace()
                else text_input
            )
            output_audio_np = model.generate(
                combined_text,
                max_tokens=max_new_tokens,
                cfg_scale=cfg_scale,
                temperature=temperature,
                top_p=top_p,
                cfg_filter_top_k=cfg_filter_top_k,
                use_torch_compile=False,
                audio_prompt_path=prompt_path_for_generate,
            )
        end_time = time.time()
        print(f"Generation finished in {end_time - start_time:.2f} seconds.")
        if output_audio_np is not None:
            output_sr = 44100
            original_len = len(output_audio_np)
            speed_factor = max(0.1, min(speed_factor, 5.0))
            target_len = int(original_len / speed_factor)
            if target_len != original_len and target_len > 0:
                x_original = np.arange(original_len)
                x_resampled = np.linspace(0, original_len - 1, target_len)
                resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np)
                output_audio = (
                    output_sr,
                    resampled_audio_np.astype(np.float32),
                )
                print(
                    f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed."
                )
            else:
                output_audio = (
                    output_sr,
                    output_audio_np,
                )
                print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).")
            print(
                f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
            )
            if (
                output_audio[1].dtype == np.float32
                or output_audio[1].dtype == np.float64
            ):
                audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
                audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
                output_audio = (output_sr, audio_for_gradio)
                print("Converted audio to int16 for Gradio output.")
        else:
            print("\nGeneration finished, but no valid tokens were produced.")
            gr.Warning("Generation produced no output.")
    except Exception as e:
        print(f"Error during inference: {e}")
        import traceback
        traceback.print_exc()
        raise gr.Error(f"Inference failed: {e}")
    finally:
        if temp_txt_file_path and Path(temp_txt_file_path).exists():
            try:
                Path(temp_txt_file_path).unlink()
                print(f"Deleted temporary text file: {temp_txt_file_path}")
            except OSError as e:
                print(
                    f"Warning: Error deleting temporary text file {temp_txt_file_path}: {e}"
                )
        if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists():
            try:
                Path(temp_audio_prompt_path).unlink()
                print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}")
            except OSError as e:
                print(
                    f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}"
                )
    return output_audio

def get_dia_1_6B_tab():
    css = """
    #col-container {max-width: 90%; margin-left: auto; margin-right: auto;}
    .dhivehi-text-nofont textarea {
        font-size: 18px !important;
        line-height: 1.8 !important;
        direction: rtl !important;
        text-align: right !important;
    }
    .dhivehi-text-nofont input {
        font-size: 18px !important;
        direction: rtl !important;
        text-align: right !important;
    }
    """
    default_text = ""
    example_txt_path = Path("./example.txt")
    if example_txt_path.exists():
        try:
            default_text = example_txt_path.read_text(encoding="utf-8").strip()
            if not default_text:
                default_text = "Example text file was empty."
        except Exception as e:
            print(f"Warning: Could not read example.txt: {e}")
    with gr.Tab("🎙️ Dia-1.6B"):
        gr.Markdown("# Dia Text-to-Speech Synthesis (Dia-1.6B)")
        with gr.Row(equal_height=False):
            with gr.Column(scale=1):
                model_dropdown = gr.Dropdown(
                    choices=list(DIA_MODELS.keys()),
                    value=list(DIA_MODELS.keys())[0],
                    label="Select Dia Model"
                )
                text_input = gr.Textbox(
                    label="Input Text",
                    placeholder="ލިޔެލަން",
                    value=default_text,
                    lines=5,
                    elem_classes=["dhivehi-text-nofont"]
                )
                audio_prompt_input = gr.Audio(
                    label="Audio Prompt (≤ 10 s, Optional)",
                    show_label=True,
                    sources=["upload", "microphone"],
                    type="numpy",
                )
                transcription_input = gr.Textbox(
                    label="Audio Prompt Transcription (Optional)",
                    placeholder="ޓްރާންސްކްރިޕްޓް ލިޔެލަން",
                    lines=3,
                    elem_classes=["dhivehi-text-nofont"]
                )
                with gr.Accordion("Generation Parameters", open=False):
                    default_model = load_dia_model(DIA_MODELS[list(DIA_MODELS.keys())[0]])
                    max_new_tokens = gr.Slider(
                        label="Max New Tokens (Audio Length)",
                        minimum=860,
                        maximum=3072,
                        value=getattr(getattr(default_model.config, 'data', None), 'audio_length', 1536),
                        step=50,
                        info="Controls the maximum length of the generated audio (more tokens = longer audio).",
                    )
                    cfg_scale = gr.Slider(
                        label="CFG Scale (Guidance Strength)",
                        minimum=1.0,
                        maximum=5.0,
                        value=3.0,
                        step=0.1,
                        info="Higher values increase adherence to the text prompt.",
                    )
                    temperature = gr.Slider(
                        label="Temperature (Randomness)",
                        minimum=1.0,
                        maximum=2.5,
                        value=1.8,
                        step=0.05,
                        info="Lower values make the output more deterministic, higher values increase randomness.",
                    )
                    top_p = gr.Slider(
                        label="Top P (Nucleus Sampling)",
                        minimum=0.70,
                        maximum=1.0,
                        value=0.95,
                        step=0.01,
                        info="Filters vocabulary to the most likely tokens cumulatively reaching probability P.",
                    )
                    cfg_filter_top_k = gr.Slider(
                        label="CFG Filter Top K",
                        minimum=15,
                        maximum=100,
                        value=45,
                        step=1,
                        info="Top k filter for CFG guidance.",
                    )
                    speed_factor_slider = gr.Slider(
                        label="Speed Factor",
                        minimum=0.8,
                        maximum=1.0,
                        value=1.0,
                        step=0.02,
                        info="Adjusts the speed of the generated audio (1.0 = original speed).",
                    )
                generate_btn = gr.Button("Generate Audio", variant="primary")
            with gr.Column(scale=1):
                audio_output = gr.Audio(
                    label="Generated Audio",
                    type="numpy",
                    autoplay=False,
                )
        generate_btn.click(
            run_inference,
            inputs=[
                text_input,
                audio_prompt_input,
                transcription_input,
                max_new_tokens,
                cfg_scale,
                temperature,
                top_p,
                cfg_filter_top_k,
                speed_factor_slider,
                model_dropdown,
            ],
            outputs=[audio_output],
        )
        # Examples (optional, can be extended)
        examples_list = [
            [
                """[S1] އައްސަލާމު އަލައިކުމް. (clears throat) Good morning!
[S2] How are you today?
[S1] I'm fine, thanks. ކިހިނެއް ހާލު؟
[S2] (coughs) ކުޑަކޮށް ބަލިކޮށް މިއުޅެނީ
[S1] Oh okay. Get well soon...
[S2] Thanks! See you later... ފަހުން ދިމާވެލާނީ
[S1]""",
                None,
                "",
                1536,
                5.0,
                2.5,
                0.95,
                45,
                1.0,
                list(DIA_MODELS.keys())[0],
            ],
            ["""[FEMALE-01] [S1] ގައުމަށް އައި މިނިވަން ނޫރާނީ... [S2] ދައުރުން މި ހަނދާންތައް އާކުރަނީ... [S1] އައުދާނަ އިތުރު އަބުޠާލުންނަށް... [S2] ޒިކުރާގެ މަލުން މި ވެދުން ކުރަނީ.""",
                None,
                "",
                1536,
                3.0,
                1.8,
                0.95,
                45,
                0.96,
                list(DIA_MODELS.keys())[0]
            ],
            ["""[MALE-01] [S1] މާޒީގެ އުޖާލާ މަންޒަރުތައް... [S2] މާރީތި އުފާވެރި ކުރެހުންތައް... [S2] ދާތީ އަދު ހާމަ ވަމުން ކުލަތައް... [S2] ތާރީޚު އަލުން މި އިޢާދަ ވަނީ.""",
                None,
                "",
                1536,
                3.0,
                1.8,
                0.95,
                45,
                0.96,
                list(DIA_MODELS.keys())[0]
            ],
        ]
        if examples_list:
            gr.Examples(
                examples=examples_list,
                inputs=[
                    text_input,
                    audio_prompt_input,
                    transcription_input,
                    max_new_tokens,
                    cfg_scale,
                    temperature,
                    top_p,
                    cfg_filter_top_k,
                    speed_factor_slider,
                    model_dropdown,
                ],
                outputs=[audio_output],
                fn=run_inference,
                cache_examples=False,
                label="Examples (Click to Run)",
            )
        else:
            gr.Markdown("_(No examples configured or example prompt file missing)_")
            gr.Markdown(
                "---\n"
                "**General Guidelines:**\n"
            "- Keep input text length moderate\n"
            "  - Short input (corresponding to under 5s of audio) will sound unnatural\n"
            "  - Very long input (corresponding to over 20s of audio) will make the speech unnaturally fast\n\n"
            "- Use non-verbal tags sparingly, from the list in the README. Overusing or using unlisted non-verbals may cause weird artifacts\n\n"
            "- Always begin input text with [S1], and always alternate between [S1] and [S2] (i.e. [S1]... [S1]... is not good)\n\n"
            "**When using audio prompts (voice cloning):**\n"
            "- Provide the transcript of the to-be cloned audio before the generation text\n"
            "- Transcript must use [S1], [S2] speaker tags correctly:\n"
            "  - Single speaker: [S1]...\n"
            "  - Two speakers: [S1]... [S2]...\n"
            "- Duration of the to-be cloned audio should be 5~10 seconds for the best results\n"
            "  - (Keep in mind: 1 second ≈ 86 tokens)\n"
            "- Put [S1] or [S2] (the second-to-last speaker's tag) at the end of the audio to improve audio quality at the end"
            )
    # No explicit return needed for context manager pattern