Update webui.py
Browse fileschanged overall streamlit code
webui.py
CHANGED
@@ -1,133 +1,190 @@
|
|
1 |
-
import streamlit as st
|
2 |
import os
|
3 |
import time
|
4 |
-
import
|
5 |
-
import torch
|
6 |
from huggingface_hub import snapshot_download
|
|
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
sys.path.append(os.path.join(current_dir, "indextts"))
|
11 |
-
|
12 |
from indextts.infer import IndexTTS
|
13 |
-
from tools.i18n.i18n import I18nAuto
|
14 |
-
|
15 |
-
# Initialize internationalization
|
16 |
-
i18n = I18nAuto(language="en") # Changed to English
|
17 |
-
|
18 |
-
# GPU configuration
|
19 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
-
|
21 |
-
# App configuration
|
22 |
-
st.set_page_config(page_title="echoAI - IndexTTS", layout="wide")
|
23 |
-
|
24 |
-
# Create necessary directories
|
25 |
-
os.makedirs("outputs/tasks", exist_ok=True)
|
26 |
-
os.makedirs("prompts", exist_ok=True)
|
27 |
-
|
28 |
-
# Download checkpoints if not exists
|
29 |
-
if not os.path.exists("checkpoints"):
|
30 |
-
snapshot_download("IndexTeam/IndexTTS-1.5", local_dir="checkpoints")
|
31 |
|
32 |
-
#
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
tts.load_normalizer()
|
37 |
-
|
38 |
-
tts.model.to(DEVICE) # Move model to GPU if available
|
39 |
return tts
|
40 |
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
# Inference function with device awareness
|
44 |
-
def infer(voice_path, text, output_path=None):
|
45 |
-
if not output_path:
|
46 |
-
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
|
47 |
-
|
48 |
-
# Ensure input is on correct device
|
49 |
-
tts.infer(voice_path, text, output_path)
|
50 |
return output_path
|
51 |
|
|
|
52 |
# Streamlit UI
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
<
|
60 |
-
|
61 |
-
""
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
#
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import time
|
3 |
+
import shutil # Added shutil for potentially cleaning old files if needed, though not used in this version
|
|
|
4 |
from huggingface_hub import snapshot_download
|
5 |
+
import streamlit as st
|
6 |
|
7 |
+
# Imports from your package
|
8 |
+
# Ensure 'indextts' is correctly installed or available in your environment/requirements.txt
|
|
|
|
|
9 |
from indextts.infer import IndexTTS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
# ------------------------------------------------------------------------------
|
12 |
+
# Configuration
|
13 |
+
# ------------------------------------------------------------------------------
|
14 |
+
|
15 |
+
# Where to store model checkpoints and outputs
|
16 |
+
# These paths are relative to the root directory of your Spaces repository
|
17 |
+
CHECKPOINT_DIR = "checkpoints"
|
18 |
+
OUTPUT_DIR = "outputs"
|
19 |
+
PROMPTS_DIR = "prompts" # Directory to save uploaded reference audio
|
20 |
+
|
21 |
+
# Ensure necessary directories exist. Hugging Face Spaces provides a writable filesystem.
|
22 |
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
23 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
24 |
+
os.makedirs(PROMPTS_DIR, exist_ok=True)
|
25 |
+
|
26 |
+
MODEL_REPO = "IndexTeam/IndexTTS-1.5"
|
27 |
+
CFG_FILENAME = "config.yaml"
|
28 |
+
|
29 |
+
# ------------------------------------------------------------------------------
|
30 |
+
# Model loading (cached so it only runs once per resource identifier)
|
31 |
+
# ------------------------------------------------------------------------------
|
32 |
+
|
33 |
+
# @st.cache_resource is the recommended way in Streamlit to cache large objects
|
34 |
+
# like ML models that should be loaded only once.
|
35 |
+
# This is crucial for efficiency on platforms like Spaces, preventing re-loading
|
36 |
+
# the model on every user interaction/script re-run.
|
37 |
+
@st.cache_resource(show_spinner=False)
|
38 |
+
def load_tts_model():
|
39 |
+
"""
|
40 |
+
Downloads the model snapshot and initializes the IndexTTS model.
|
41 |
+
Cached using st.cache_resource to load only once.
|
42 |
+
"""
|
43 |
+
st.write("⏳ Loading model... This may take a moment.")
|
44 |
+
# Download the model snapshot if not already present
|
45 |
+
# local_dir_use_symlinks=False is often safer in containerized environments
|
46 |
+
snapshot_download(
|
47 |
+
repo_id=MODEL_REPO,
|
48 |
+
local_dir=CHECKPOINT_DIR,
|
49 |
+
local_dir_use_symlinks=False,
|
50 |
+
)
|
51 |
+
# Initialize the TTS object
|
52 |
+
# The underlying IndexTTS library should handle using the GPU if available
|
53 |
+
# and if dependencies (like CUDA-enabled PyTorch/TensorFlow) are installed.
|
54 |
+
tts = IndexTTS(
|
55 |
+
model_dir=CHECKPOINT_DIR,
|
56 |
+
cfg_path=os.path.join(CHECKPOINT_DIR, CFG_FILENAME)
|
57 |
+
)
|
58 |
+
# Load any normalizer or auxiliary data required by the model
|
59 |
tts.load_normalizer()
|
60 |
+
st.write("✅ Model loaded!")
|
|
|
61 |
return tts
|
62 |
|
63 |
+
# Load the TTS model using the cached function
|
64 |
+
# This line is executed on each script run, but the function body only runs
|
65 |
+
# the first time or if the function signature/dependencies change.
|
66 |
+
tts = load_tts_model()
|
67 |
+
|
68 |
+
# ------------------------------------------------------------------------------
|
69 |
+
# Inference function
|
70 |
+
# ------------------------------------------------------------------------------
|
71 |
+
|
72 |
+
def run_inference(reference_audio_path: str, text: str) -> str:
|
73 |
+
"""
|
74 |
+
Run TTS inference using the uploaded reference audio and the target text.
|
75 |
+
Returns the path to the generated .wav file.
|
76 |
+
"""
|
77 |
+
if not os.path.exists(reference_audio_path):
|
78 |
+
raise FileNotFoundError(f"Reference audio not found at {reference_audio_path}")
|
79 |
+
|
80 |
+
# Generate a unique output filename
|
81 |
+
timestamp = int(time.time())
|
82 |
+
output_filename = f"generated_{timestamp}.wav"
|
83 |
+
output_path = os.path.join(OUTPUT_DIR, output_filename)
|
84 |
+
|
85 |
+
# Perform the TTS inference
|
86 |
+
# The efficiency of this step depends on the IndexTTS library and hardware
|
87 |
+
tts.infer(reference_audio_path, text, output_path)
|
88 |
+
|
89 |
+
# Optional: Clean up old files in output/prompts directories if space is limited
|
90 |
+
# This can be added if you find directories filling up on Spaces.
|
91 |
+
# E.g., a function to remove files older than X hours/days.
|
92 |
+
# For a simple demo, may not be necessary initially.
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
return output_path
|
95 |
|
96 |
+
# ------------------------------------------------------------------------------
|
97 |
# Streamlit UI
|
98 |
+
# ------------------------------------------------------------------------------
|
99 |
+
|
100 |
+
st.set_page_config(page_title="IndexTTS Demo", layout="wide")
|
101 |
+
|
102 |
+
st.markdown(
|
103 |
+
"""
|
104 |
+
<h1 style="text-align: center;">IndexTTS: Zero-Shot Controllable & Efficient TTS</h1>
|
105 |
+
<p style="text-align: center;">
|
106 |
+
<a href="https://arxiv.org/abs/2502.05512" target="_blank">
|
107 |
+
View the paper on arXiv (2502.05512)
|
108 |
+
</a>
|
109 |
+
</p>
|
110 |
+
""",
|
111 |
+
unsafe_allow_html=True
|
112 |
+
)
|
113 |
+
|
114 |
+
st.sidebar.header("Settings")
|
115 |
+
with st.sidebar.expander("🗂️ Output Directories"):
|
116 |
+
st.write(f"- Checkpoints: `{CHECKPOINT_DIR}`")
|
117 |
+
st.write(f"- Generated audio: `{OUTPUT_DIR}`")
|
118 |
+
st.write(f"- Uploaded prompts: `{PROMPTS_DIR}`")
|
119 |
+
st.info("These directories are located within your Space's persistent storage.")
|
120 |
+
|
121 |
+
|
122 |
+
st.header("1. Upload Reference Audio")
|
123 |
+
ref_audio_file = st.file_uploader(
|
124 |
+
label="Upload a reference audio (wav or mp3)",
|
125 |
+
type=["wav", "mp3"],
|
126 |
+
help="This audio will condition the voice characteristics.",
|
127 |
+
key="ref_audio_uploader" # Added a key for potential future state management
|
128 |
+
)
|
129 |
+
|
130 |
+
ref_path = None # Initialize ref_path
|
131 |
+
|
132 |
+
if ref_audio_file:
|
133 |
+
# Save the uploaded file to the prompts directory
|
134 |
+
# Streamlit's uploader provides file-like object
|
135 |
+
ref_filename = ref_audio_file.name
|
136 |
+
ref_path = os.path.join(PROMPTS_DIR, ref_filename)
|
137 |
+
|
138 |
+
# Use a more robust way to save the file
|
139 |
+
with open(ref_path, "wb") as f:
|
140 |
+
# Use getbuffer() for efficiency with large files
|
141 |
+
f.write(ref_audio_file.getbuffer())
|
142 |
+
|
143 |
+
st.success(f"Saved reference audio: `{ref_filename}`")
|
144 |
+
st.audio(ref_path, format="audio/wav") # Display the uploaded audio
|
145 |
+
|
146 |
+
|
147 |
+
st.header("2. Enter Text to Synthesize")
|
148 |
+
text_input = st.text_area(
|
149 |
+
label="Enter the text you want to convert to speech",
|
150 |
+
placeholder="Type your sentence here...",
|
151 |
+
key="text_input_area" # Added a key
|
152 |
+
)
|
153 |
+
|
154 |
+
# Button to trigger generation
|
155 |
+
generate_button = st.button("Generate Speech", key="generate_tts_button")
|
156 |
+
|
157 |
+
# ------------------------------------------------------------------------------
|
158 |
+
# Trigger Inference and Display Results
|
159 |
+
# ------------------------------------------------------------------------------
|
160 |
+
|
161 |
+
# This block runs only when the button is clicked AND inputs are valid
|
162 |
+
if generate_button:
|
163 |
+
if not ref_path or not os.path.exists(ref_path):
|
164 |
+
st.error("Please upload a reference audio first.")
|
165 |
+
elif not text_input or not text_input.strip():
|
166 |
+
st.error("Please enter some text to synthesize.")
|
167 |
+
else:
|
168 |
+
# Use st.spinner to indicate processing is happening
|
169 |
+
with st.spinner("🚀 Generating speech..."):
|
170 |
+
try:
|
171 |
+
# Call the inference function
|
172 |
+
output_wav_path = run_inference(ref_path, text_input)
|
173 |
+
|
174 |
+
# Check if output file was actually created
|
175 |
+
if os.path.exists(output_wav_path):
|
176 |
+
st.success("🎉 Done! Here’s your generated audio:")
|
177 |
+
# Display the generated audio
|
178 |
+
st.audio(output_wav_path, format="audio/wav")
|
179 |
+
else:
|
180 |
+
st.error("Generation failed: Output file was not created.")
|
181 |
+
|
182 |
+
except Exception as e:
|
183 |
+
st.error(f"An error occurred during inference: {e}")
|
184 |
+
# Optional: Log the full traceback for debugging on Spaces
|
185 |
+
# import traceback
|
186 |
+
# st.exception(e) # This shows traceback in the app
|
187 |
+
|
188 |
+
# Add a footer or more info
|
189 |
+
st.markdown("---")
|
190 |
+
st.markdown("Demo powered by [IndexTTS](https://arxiv.org/abs/2502.05512) and built with Streamlit.")
|