MPCIRCLE commited on
Commit
c21ab36
·
verified ·
1 Parent(s): 61222b7

Update webui.py

Browse files

changed to streamlit

Files changed (1) hide show
  1. webui.py +113 -61
webui.py CHANGED
@@ -1,81 +1,133 @@
1
- import spaces
2
  import os
3
- import shutil
4
- import threading
5
  import time
6
  import sys
7
-
8
  from huggingface_hub import snapshot_download
9
 
10
  current_dir = os.path.dirname(os.path.abspath(__file__))
11
  sys.path.append(current_dir)
12
  sys.path.append(os.path.join(current_dir, "indextts"))
13
 
14
- import gradio as gr
15
  from indextts.infer import IndexTTS
16
  from tools.i18n.i18n import I18nAuto
17
 
18
- i18n = I18nAuto(language="zh_CN")
19
- MODE = 'local'
20
- snapshot_download("IndexTeam/IndexTTS-1.5",local_dir="checkpoints",)
21
- tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml")
 
 
 
 
22
 
23
- os.makedirs("outputs/tasks",exist_ok=True)
24
- os.makedirs("prompts",exist_ok=True)
 
25
 
26
- @spaces.GPU
27
- def infer(voice, text,output_path=None):
28
- if not tts:
29
- raise Exception("Model not loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  if not output_path:
31
  output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
32
- tts.infer(voice, text, output_path)
 
 
33
  return output_path
34
 
35
- def gen_single(prompt, text):
36
- output_path = infer(prompt, text)
37
- return gr.update(value=output_path,visible=True)
38
-
39
- def update_prompt_audio():
40
- update_button = gr.update(interactive=True)
41
- return update_button
42
-
43
-
44
- with gr.Blocks() as demo:
45
- mutex = threading.Lock()
46
- gr.HTML('''
47
- <h2><center>IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System</h2>
48
-
49
- <p align="center">
50
- <a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
51
- ''')
52
- with gr.Tab("音频生成"):
53
- with gr.Row():
54
- os.makedirs("prompts",exist_ok=True)
55
- prompt_audio = gr.Audio(label="请上传参考音频",key="prompt_audio",
56
- sources=["upload","microphone"],type="filepath")
57
- prompt_list = os.listdir("prompts")
58
- default = ''
59
- if prompt_list:
60
- default = prompt_list[0]
61
- input_text_single = gr.Textbox(label="请输入目标文本",key="input_text_single")
62
- gen_button = gr.Button("生成语音",key="gen_button",interactive=True)
63
- output_audio = gr.Audio(label="生成结果", visible=False,key="output_audio")
64
-
65
- prompt_audio.upload(update_prompt_audio,
66
- inputs=[],
67
- outputs=[gen_button])
68
-
69
- gen_button.click(gen_single,
70
- inputs=[prompt_audio, input_text_single],
71
- outputs=[output_audio])
72
-
73
-
74
- def main():
75
- tts.load_normalizer()
76
- demo.queue(20)
77
- demo.launch(server_name="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  if __name__ == "__main__":
80
- main()
81
-
 
1
+ import streamlit as st
2
  import os
 
 
3
  import time
4
  import sys
5
+ import torch
6
  from huggingface_hub import snapshot_download
7
 
8
  current_dir = os.path.dirname(os.path.abspath(__file__))
9
  sys.path.append(current_dir)
10
  sys.path.append(os.path.join(current_dir, "indextts"))
11
 
 
12
  from indextts.infer import IndexTTS
13
  from tools.i18n.i18n import I18nAuto
14
 
15
+ # Initialize internationalization
16
+ i18n = I18nAuto(language="en") # Changed to English
17
+
18
+ # GPU configuration
19
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ # App configuration
22
+ st.set_page_config(page_title="echoAI - IndexTTS", layout="wide")
23
 
24
+ # Create necessary directories
25
+ os.makedirs("outputs/tasks", exist_ok=True)
26
+ os.makedirs("prompts", exist_ok=True)
27
 
28
+ # Download checkpoints if not exists
29
+ if not os.path.exists("checkpoints"):
30
+ snapshot_download("IndexTeam/IndexTTS-1.5", local_dir="checkpoints")
31
+
32
+ # Load TTS model with GPU support
33
+ @st.cache_resource
34
+ def load_model():
35
+ tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml")
36
+ tts.load_normalizer()
37
+ if DEVICE == "cuda":
38
+ tts.model.to(DEVICE) # Move model to GPU if available
39
+ return tts
40
+
41
+ tts = load_model()
42
+
43
+ # Inference function with device awareness
44
+ def infer(voice_path, text, output_path=None):
45
  if not output_path:
46
  output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
47
+
48
+ # Ensure input is on correct device
49
+ tts.infer(voice_path, text, output_path)
50
  return output_path
51
 
52
+ # Streamlit UI
53
+ st.title("echoAI - IndexTTS")
54
+ st.markdown("""
55
+ <h4 style='text-align: center;'>
56
+ An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System
57
+ </h4>
58
+ <p style='text-align: center;'>
59
+ <a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
60
+ </p>
61
+ """, unsafe_allow_html=True)
62
+
63
+ # Device status indicator
64
+ st.sidebar.markdown(f"**Device:** {DEVICE.upper()}")
65
+
66
+ # Main interface
67
+ with st.container():
68
+ st.header("Audio Generation") # Translated
69
+
70
+ col1, col2 = st.columns(2)
71
+
72
+ with col1:
73
+ uploaded_audio = st.file_uploader(
74
+ "Upload reference audio", # Translated
75
+ type=["wav", "mp3", "ogg"],
76
+ accept_multiple_files=False
77
+ )
78
+
79
+ input_text = st.text_area(
80
+ "Input target text", # Translated
81
+ height=150,
82
+ placeholder="Enter text to synthesize..."
83
+ )
84
+
85
+ generate_btn = st.button("Generate Speech") # Translated
86
+
87
+ with col2:
88
+ if generate_btn and uploaded_audio and input_text:
89
+ with st.spinner("Generating audio..."):
90
+ # Save uploaded audio
91
+ audio_path = os.path.join("prompts", uploaded_audio.name)
92
+ with open(audio_path, "wb") as f:
93
+ f.write(uploaded_audio.getbuffer())
94
+
95
+ # Perform inference
96
+ try:
97
+ output_path = infer(audio_path, input_text)
98
+ st.audio(output_path, format="audio/wav")
99
+ st.success("Generation complete!")
100
+
101
+ # Download button
102
+ with open(output_path, "rb") as f:
103
+ st.download_button(
104
+ "Download Result", # Translated
105
+ f,
106
+ file_name=os.path.basename(output_path)
107
+ except Exception as e:
108
+ st.error(f"Error: {str(e)}")
109
+ elif generate_btn:
110
+ st.warning("Please upload an audio file and enter text first!") # Translated
111
+
112
+ # Sidebar with additional info
113
+ with st.sidebar:
114
+ st.header("About echoAI")
115
+ st.markdown("""
116
+ ### Key Features:
117
+ - Zero-shot voice cloning
118
+ - Industrial-grade TTS
119
+ - Efficient synthesis
120
+ - Controllable output
121
+ """)
122
+
123
+ st.markdown("---")
124
+ st.markdown("""
125
+ ### Usage Instructions:
126
+ 1. Upload a reference audio clip
127
+ 2. Enter target text
128
+ 3. Click 'Generate Speech'
129
+ """)
130
 
131
  if __name__ == "__main__":
132
+ # Cleanup old files if needed
133
+ pass