dall-e-x / app.py
soiz1's picture
Update app.py
08aebf0 verified
import logging
import os
import pathlib
import time
import tempfile
import platform
import gc
if platform.system().lower() == 'windows':
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath
elif platform.system().lower() == 'linux':
temp = pathlib.WindowsPath
pathlib.WindowsPath = pathlib.PosixPath
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import langid
langid.set_languages(['en', 'zh', 'ja'])
import torch
import torchaudio
import numpy as np
from data.tokenizer import (
AudioTokenizer,
tokenize_audio,
)
from data.collation import get_text_token_collater
from models.vallex import VALLE
from utils.g2p import PhonemeBpeTokenizer
from descriptions import *
from macros import *
from examples import *
import gradio as gr
from vocos import Vocos
from transformers import WhisperProcessor, WhisperForConditionalGeneration
# 必要な事前設定
lang2token = {"en": "<en>", "ja": "<ja>", "zh": "<zh>"}
lang2code = {"en": 0, "ja": 1, "zh": 2}
langid = None # ここでは仮定、適切なモジュールを初期化してください
# モック用の関数(本番環境では適切に実装してください)
def clear_prompts():
try:
path = tempfile.gettempdir()
for eachfile in os.listdir(path):
filename = os.path.join(path, eachfile)
if os.path.isfile(filename) and filename.endswith(".npz"):
lastmodifytime = os.stat(filename).st_mtime
endfiletime = time.time() - 60
if endfiletime > lastmodifytime:
os.remove(filename)
del path, filename, lastmodifytime, endfiletime
gc.collect()
except:
return
def transcribe_one(wav, sr):
if sr != 16000:
wav4trans = torchaudio.transforms.Resample(sr, 16000)(wav)
else:
wav4trans = wav
input_features = whisper_processor(wav4trans.squeeze(0), sampling_rate=16000, return_tensors="pt").input_features
# generate token ids
predicted_ids = whisper.generate(input_features.to(device))
lang = whisper_processor.batch_decode(predicted_ids[:, 1])[0].strip("<|>")
# decode token ids to text
text_pr = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
# print the recognized text
print(text_pr)
if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
text_pr += "."
# delete all variables
del wav4trans, input_features, predicted_ids
gc.collect()
return lang, text_pr
from data.tokenizer import (
AudioTokenizer,
tokenize_audio,
)
def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
clear_prompts()
audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
sr, wav_pr = audio_prompt
if len(wav_pr) / sr > 15:
return "Rejected, Audio too long (should be less than 15 seconds)", None
if not isinstance(wav_pr, torch.FloatTensor):
wav_pr = torch.FloatTensor(wav_pr)
if wav_pr.abs().max() > 1:
wav_pr /= wav_pr.abs().max()
if wav_pr.size(-1) == 2:
wav_pr = wav_pr[:, 0]
if wav_pr.ndim == 1:
wav_pr = wav_pr.unsqueeze(0)
assert wav_pr.ndim and wav_pr.size(0) == 1
if transcript_content == "":
lang_pr, text_pr = transcribe_one(wav_pr, sr)
lang_token = lang2token[lang_pr]
text_pr = lang_token + text_pr + lang_token
else:
lang_pr = langid.classify(str(transcript_content))[0]
lang_token = lang2token[lang_pr]
transcript_content = transcript_content.replace("\n", "")
text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
# tokenize audio
encoded_frames = tokenize_audio(None, (wav_pr, sr))
audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
# tokenize text
text_tokens = np.random.randint(0, 100, (1, 50)) # 仮のトークン生成
message = f"Detected language: {lang_pr}\n Detected text: {text_pr}\n"
if lang_pr not in ['ja', 'zh', 'en']:
return f"Prompt can only made with one of model-supported languages, got {lang_pr} instead", None
# save as npz file
file_path = os.path.join(tempfile.gettempdir(), f"{name}.npz")
np.savez(file_path, audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr])
# delete all variables
del audio_tokens, text_tokens, lang_pr, text_pr, wav_pr, sr, uploaded_audio, recorded_audio
gc.collect()
return message, file_path
def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
if len(text) > 150:
return "Rejected, Text too long (should be less than 150 characters)", None
return f"Synthesized text: {text}", (24000, np.zeros(24000)) # 仮のオーディオ出力
def get_available_npz_files():
# 一時ディレクトリ内のすべての .npz ファイルをリストアップ
return [f for f in os.listdir(tempfile.gettempdir()) if f.endswith(".npz")]
# Gradio アプリケーション
with gr.Blocks() as app:
with gr.Tabs():
# NPZ作成タブ
with gr.Tab("NPZファイルを作成"):
gr.Markdown("### 音声とテキストから .npz ファイルを作成")
name = gr.Textbox(label="ファイル名", placeholder="保存する .npz ファイル名を入力")
uploaded_audio = gr.Audio(label="アップロード音声", type="numpy")
transcript_content = gr.Textbox(label="テキスト内容", placeholder="音声に対応する文字起こしを入力")
result_message = gr.Textbox(label="結果", interactive=False)
npz_output = gr.File(label=".npz ファイル")
save_button = gr.Button("変換して保存")
dummy_input = gr.Textbox(visible=False) # ダミーコンポーネント
save_button.click(
make_npz_prompt,
inputs=[name, uploaded_audio, dummy_input, transcript_content],
outputs=[result_message, npz_output],
)
# NPZ生成タブ
with gr.Tab("NPZファイルで生成"):
gr.Markdown("### 保存した .npz ファイルから音声を生成")
npz_files_dropdown = gr.Dropdown(
label="利用可能な .npz ファイル", choices=get_available_npz_files(), interactive=True
)
text_input = gr.Textbox(label="生成するテキスト", placeholder="150文字以内のテキストを入力")
language = gr.Radio(
label="言語選択",
choices=["auto-detect", "en", "ja", "zh"],
value="auto-detect"
)
accent = gr.Radio(
label="アクセント選択",
choices=["no-accent", "en-accent", "ja-accent", "zh-accent"],
value="no-accent"
)
preset_prompt = gr.Textbox(label="プロンプト名", placeholder="既存のプロンプトを選択")
synthesis_message = gr.Textbox(label="結果", interactive=False)
audio_output = gr.Audio(label="生成音声", type="numpy")
generate_button = gr.Button("生成開始")
generate_button.click(
infer_from_prompt,
inputs=[text_input, language, accent, preset_prompt, npz_files_dropdown],
outputs=[synthesis_message, audio_output],
)
app.launch()