File size: 7,501 Bytes
08aebf0
61d7bec
08aebf0
 
a6b1b80
08aebf0
a6b1b80
08aebf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6b1b80
 
 
 
 
 
 
 
be961e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6b1b80
be961e5
61d7bec
be961e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61d7bec
a6b1b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61d7bec
a6b1b80
 
 
 
61d7bec
a6b1b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61d7bec
a6b1b80
 
 
 
61d7bec
 
 
a6b1b80
61d7bec
a6b1b80
61d7bec
 
a6b1b80
61d7bec
 
 
a6b1b80
61d7bec
 
 
 
a63dad4
 
61d7bec
a6b1b80
a63dad4
61d7bec
 
 
a6b1b80
61d7bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338c48e
61d7bec
a6b1b80
 
61d7bec
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import logging
import os
import pathlib
import time
import tempfile
import platform
import gc
if platform.system().lower() == 'windows':
    temp = pathlib.PosixPath
    pathlib.PosixPath = pathlib.WindowsPath
elif platform.system().lower() == 'linux':
    temp = pathlib.WindowsPath
    pathlib.WindowsPath = pathlib.PosixPath
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

import langid
langid.set_languages(['en', 'zh', 'ja'])

import torch
import torchaudio

import numpy as np

from data.tokenizer import (
    AudioTokenizer,
    tokenize_audio,
)
from data.collation import get_text_token_collater
from models.vallex import VALLE
from utils.g2p import PhonemeBpeTokenizer
from descriptions import *
from macros import *
from examples import *

import gradio as gr
from vocos import Vocos
from transformers import WhisperProcessor, WhisperForConditionalGeneration


# 必要な事前設定
lang2token = {"en": "<en>", "ja": "<ja>", "zh": "<zh>"}
lang2code = {"en": 0, "ja": 1, "zh": 2}
langid = None  # ここでは仮定、適切なモジュールを初期化してください

# モック用の関数(本番環境では適切に実装してください)
def clear_prompts():
    try:
        path = tempfile.gettempdir()
        for eachfile in os.listdir(path):
            filename = os.path.join(path, eachfile)
            if os.path.isfile(filename) and filename.endswith(".npz"):
                lastmodifytime = os.stat(filename).st_mtime
                endfiletime = time.time() - 60
                if endfiletime > lastmodifytime:
                    os.remove(filename)
        del path, filename, lastmodifytime, endfiletime
        gc.collect()
    except:
        return
def transcribe_one(wav, sr):
    if sr != 16000:
        wav4trans = torchaudio.transforms.Resample(sr, 16000)(wav)
    else:
        wav4trans = wav

    input_features = whisper_processor(wav4trans.squeeze(0), sampling_rate=16000, return_tensors="pt").input_features

    # generate token ids
    predicted_ids = whisper.generate(input_features.to(device))
    lang = whisper_processor.batch_decode(predicted_ids[:, 1])[0].strip("<|>")
    # decode token ids to text
    text_pr = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

    # print the recognized text
    print(text_pr)

    if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
        text_pr += "."

    # delete all variables
    del wav4trans, input_features, predicted_ids
    gc.collect()
    return lang, text_pr
    
from data.tokenizer import (
    AudioTokenizer,
    tokenize_audio,
)

def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
    clear_prompts()
    audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
    sr, wav_pr = audio_prompt
    if len(wav_pr) / sr > 15:
        return "Rejected, Audio too long (should be less than 15 seconds)", None
    if not isinstance(wav_pr, torch.FloatTensor):
        wav_pr = torch.FloatTensor(wav_pr)
    if wav_pr.abs().max() > 1:
        wav_pr /= wav_pr.abs().max()
    if wav_pr.size(-1) == 2:
        wav_pr = wav_pr[:, 0]
    if wav_pr.ndim == 1:
        wav_pr = wav_pr.unsqueeze(0)
    assert wav_pr.ndim and wav_pr.size(0) == 1

    if transcript_content == "":
        lang_pr, text_pr = transcribe_one(wav_pr, sr)
        lang_token = lang2token[lang_pr]
        text_pr = lang_token + text_pr + lang_token
    else:
        lang_pr = langid.classify(str(transcript_content))[0]
        lang_token = lang2token[lang_pr]
        transcript_content = transcript_content.replace("\n", "")
        text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
    # tokenize audio
    encoded_frames = tokenize_audio(None, (wav_pr, sr))
    audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()

    # tokenize text
    text_tokens = np.random.randint(0, 100, (1, 50))  # 仮のトークン生成

    message = f"Detected language: {lang_pr}\n Detected text: {text_pr}\n"
    if lang_pr not in ['ja', 'zh', 'en']:
        return f"Prompt can only made with one of model-supported languages, got {lang_pr} instead", None

    # save as npz file
    file_path = os.path.join(tempfile.gettempdir(), f"{name}.npz")
    np.savez(file_path, audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr])

    # delete all variables
    del audio_tokens, text_tokens, lang_pr, text_pr, wav_pr, sr, uploaded_audio, recorded_audio
    gc.collect()
    return message, file_path

def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
    if len(text) > 150:
        return "Rejected, Text too long (should be less than 150 characters)", None
    return f"Synthesized text: {text}", (24000, np.zeros(24000))  # 仮のオーディオ出力

def get_available_npz_files():
    # 一時ディレクトリ内のすべての .npz ファイルをリストアップ
    return [f for f in os.listdir(tempfile.gettempdir()) if f.endswith(".npz")]

# Gradio アプリケーション
with gr.Blocks() as app:
    with gr.Tabs():
        # NPZ作成タブ
        with gr.Tab("NPZファイルを作成"):
            gr.Markdown("### 音声とテキストから .npz ファイルを作成")
            name = gr.Textbox(label="ファイル名", placeholder="保存する .npz ファイル名を入力")
            uploaded_audio = gr.Audio(label="アップロード音声", type="numpy")
            transcript_content = gr.Textbox(label="テキスト内容", placeholder="音声に対応する文字起こしを入力")
            result_message = gr.Textbox(label="結果", interactive=False)
            npz_output = gr.File(label=".npz ファイル")
            save_button = gr.Button("変換して保存")
            dummy_input = gr.Textbox(visible=False)  # ダミーコンポーネント
            
            save_button.click(
                make_npz_prompt,
                inputs=[name, uploaded_audio, dummy_input, transcript_content],
                outputs=[result_message, npz_output],
            )

        # NPZ生成タブ
        with gr.Tab("NPZファイルで生成"):
            gr.Markdown("### 保存した .npz ファイルから音声を生成")
            npz_files_dropdown = gr.Dropdown(
                label="利用可能な .npz ファイル", choices=get_available_npz_files(), interactive=True
            )
            text_input = gr.Textbox(label="生成するテキスト", placeholder="150文字以内のテキストを入力")
            language = gr.Radio(
                label="言語選択",
                choices=["auto-detect", "en", "ja", "zh"],
                value="auto-detect"
            )
            accent = gr.Radio(
                label="アクセント選択",
                choices=["no-accent", "en-accent", "ja-accent", "zh-accent"],
                value="no-accent"
            )
            preset_prompt = gr.Textbox(label="プロンプト名", placeholder="既存のプロンプトを選択")
            synthesis_message = gr.Textbox(label="結果", interactive=False)
            audio_output = gr.Audio(label="生成音声", type="numpy")
            generate_button = gr.Button("生成開始")

            generate_button.click(
                infer_from_prompt,
                inputs=[text_input, language, accent, preset_prompt, npz_files_dropdown],
                outputs=[synthesis_message, audio_output],
            )

app.launch()