MPCIRCLE commited on
Commit
f222c6c
·
verified ·
1 Parent(s): c026f7e

Update webui.py

Browse files

changed overall streamlit code

Files changed (1) hide show
  1. webui.py +177 -120
webui.py CHANGED
@@ -1,133 +1,190 @@
1
- import streamlit as st
2
  import os
3
  import time
4
- import sys
5
- import torch
6
  from huggingface_hub import snapshot_download
 
7
 
8
- current_dir = os.path.dirname(os.path.abspath(__file__))
9
- sys.path.append(current_dir)
10
- sys.path.append(os.path.join(current_dir, "indextts"))
11
-
12
  from indextts.infer import IndexTTS
13
- from tools.i18n.i18n import I18nAuto
14
-
15
- # Initialize internationalization
16
- i18n = I18nAuto(language="en") # Changed to English
17
-
18
- # GPU configuration
19
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
-
21
- # App configuration
22
- st.set_page_config(page_title="echoAI - IndexTTS", layout="wide")
23
-
24
- # Create necessary directories
25
- os.makedirs("outputs/tasks", exist_ok=True)
26
- os.makedirs("prompts", exist_ok=True)
27
-
28
- # Download checkpoints if not exists
29
- if not os.path.exists("checkpoints"):
30
- snapshot_download("IndexTeam/IndexTTS-1.5", local_dir="checkpoints")
31
 
32
- # Load TTS model with GPU support
33
- @st.cache_resource
34
- def load_model():
35
- tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  tts.load_normalizer()
37
- if DEVICE == "cuda":
38
- tts.model.to(DEVICE) # Move model to GPU if available
39
  return tts
40
 
41
- tts = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Inference function with device awareness
44
- def infer(voice_path, text, output_path=None):
45
- if not output_path:
46
- output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
47
-
48
- # Ensure input is on correct device
49
- tts.infer(voice_path, text, output_path)
50
  return output_path
51
 
 
52
  # Streamlit UI
53
- st.title("echoAI - IndexTTS")
54
- st.markdown("""
55
- <h4 style='text-align: center;'>
56
- An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System
57
- </h4>
58
- <p style='text-align: center;'>
59
- <a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
60
- </p>
61
- """, unsafe_allow_html=True)
62
-
63
- # Device status indicator
64
- st.sidebar.markdown(f"**Device:** {DEVICE.upper()}")
65
-
66
- # Main interface
67
- with st.container():
68
- st.header("Audio Generation") # Translated
69
-
70
- col1, col2 = st.columns(2)
71
-
72
- with col1:
73
- uploaded_audio = st.file_uploader(
74
- "Upload reference audio", # Translated
75
- type=["wav", "mp3", "ogg"],
76
- accept_multiple_files=False
77
- )
78
-
79
- input_text = st.text_area(
80
- "Input target text", # Translated
81
- height=150,
82
- placeholder="Enter text to synthesize..."
83
- )
84
-
85
- generate_btn = st.button("Generate Speech") # Translated
86
-
87
- with col2:
88
- if generate_btn and uploaded_audio and input_text:
89
- with st.spinner("Generating audio..."):
90
- # Save uploaded audio
91
- audio_path = os.path.join("prompts", uploaded_audio.name)
92
- with open(audio_path, "wb") as f:
93
- f.write(uploaded_audio.getbuffer())
94
-
95
- # Perform inference
96
- try:
97
- output_path = infer(audio_path, input_text)
98
- st.audio(output_path, format="audio/wav")
99
- st.success("Generation complete!")
100
-
101
- # Download button
102
- with open(output_path, "rb") as f:
103
- st.download_button(
104
- "Download Result", # Translated
105
- f,
106
- file_name=os.path.basename(output_path))
107
- except Exception as e:
108
- st.error(f"Error: {str(e)}")
109
- elif generate_btn:
110
- st.warning("Please upload an audio file and enter text first!") # Translated
111
-
112
- # Sidebar with additional info
113
- with st.sidebar:
114
- st.header("About echoAI")
115
- st.markdown("""
116
- ### Key Features:
117
- - Zero-shot voice cloning
118
- - Industrial-grade TTS
119
- - Efficient synthesis
120
- - Controllable output
121
- """)
122
-
123
- st.markdown("---")
124
- st.markdown("""
125
- ### Usage Instructions:
126
- 1. Upload a reference audio clip
127
- 2. Enter target text
128
- 3. Click 'Generate Speech'
129
- """)
130
-
131
- if __name__ == "__main__":
132
- # Cleanup old files if needed
133
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import time
3
+ import shutil # Added shutil for potentially cleaning old files if needed, though not used in this version
 
4
  from huggingface_hub import snapshot_download
5
+ import streamlit as st
6
 
7
+ # Imports from your package
8
+ # Ensure 'indextts' is correctly installed or available in your environment/requirements.txt
 
 
9
  from indextts.infer import IndexTTS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # ------------------------------------------------------------------------------
12
+ # Configuration
13
+ # ------------------------------------------------------------------------------
14
+
15
+ # Where to store model checkpoints and outputs
16
+ # These paths are relative to the root directory of your Spaces repository
17
+ CHECKPOINT_DIR = "checkpoints"
18
+ OUTPUT_DIR = "outputs"
19
+ PROMPTS_DIR = "prompts" # Directory to save uploaded reference audio
20
+
21
+ # Ensure necessary directories exist. Hugging Face Spaces provides a writable filesystem.
22
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
23
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
24
+ os.makedirs(PROMPTS_DIR, exist_ok=True)
25
+
26
+ MODEL_REPO = "IndexTeam/IndexTTS-1.5"
27
+ CFG_FILENAME = "config.yaml"
28
+
29
+ # ------------------------------------------------------------------------------
30
+ # Model loading (cached so it only runs once per resource identifier)
31
+ # ------------------------------------------------------------------------------
32
+
33
+ # @st.cache_resource is the recommended way in Streamlit to cache large objects
34
+ # like ML models that should be loaded only once.
35
+ # This is crucial for efficiency on platforms like Spaces, preventing re-loading
36
+ # the model on every user interaction/script re-run.
37
+ @st.cache_resource(show_spinner=False)
38
+ def load_tts_model():
39
+ """
40
+ Downloads the model snapshot and initializes the IndexTTS model.
41
+ Cached using st.cache_resource to load only once.
42
+ """
43
+ st.write("⏳ Loading model... This may take a moment.")
44
+ # Download the model snapshot if not already present
45
+ # local_dir_use_symlinks=False is often safer in containerized environments
46
+ snapshot_download(
47
+ repo_id=MODEL_REPO,
48
+ local_dir=CHECKPOINT_DIR,
49
+ local_dir_use_symlinks=False,
50
+ )
51
+ # Initialize the TTS object
52
+ # The underlying IndexTTS library should handle using the GPU if available
53
+ # and if dependencies (like CUDA-enabled PyTorch/TensorFlow) are installed.
54
+ tts = IndexTTS(
55
+ model_dir=CHECKPOINT_DIR,
56
+ cfg_path=os.path.join(CHECKPOINT_DIR, CFG_FILENAME)
57
+ )
58
+ # Load any normalizer or auxiliary data required by the model
59
  tts.load_normalizer()
60
+ st.write("✅ Model loaded!")
 
61
  return tts
62
 
63
+ # Load the TTS model using the cached function
64
+ # This line is executed on each script run, but the function body only runs
65
+ # the first time or if the function signature/dependencies change.
66
+ tts = load_tts_model()
67
+
68
+ # ------------------------------------------------------------------------------
69
+ # Inference function
70
+ # ------------------------------------------------------------------------------
71
+
72
+ def run_inference(reference_audio_path: str, text: str) -> str:
73
+ """
74
+ Run TTS inference using the uploaded reference audio and the target text.
75
+ Returns the path to the generated .wav file.
76
+ """
77
+ if not os.path.exists(reference_audio_path):
78
+ raise FileNotFoundError(f"Reference audio not found at {reference_audio_path}")
79
+
80
+ # Generate a unique output filename
81
+ timestamp = int(time.time())
82
+ output_filename = f"generated_{timestamp}.wav"
83
+ output_path = os.path.join(OUTPUT_DIR, output_filename)
84
+
85
+ # Perform the TTS inference
86
+ # The efficiency of this step depends on the IndexTTS library and hardware
87
+ tts.infer(reference_audio_path, text, output_path)
88
+
89
+ # Optional: Clean up old files in output/prompts directories if space is limited
90
+ # This can be added if you find directories filling up on Spaces.
91
+ # E.g., a function to remove files older than X hours/days.
92
+ # For a simple demo, may not be necessary initially.
93
 
 
 
 
 
 
 
 
94
  return output_path
95
 
96
+ # ------------------------------------------------------------------------------
97
  # Streamlit UI
98
+ # ------------------------------------------------------------------------------
99
+
100
+ st.set_page_config(page_title="IndexTTS Demo", layout="wide")
101
+
102
+ st.markdown(
103
+ """
104
+ <h1 style="text-align: center;">IndexTTS: Zero-Shot Controllable & Efficient TTS</h1>
105
+ <p style="text-align: center;">
106
+ <a href="https://arxiv.org/abs/2502.05512" target="_blank">
107
+ View the paper on arXiv (2502.05512)
108
+ </a>
109
+ </p>
110
+ """,
111
+ unsafe_allow_html=True
112
+ )
113
+
114
+ st.sidebar.header("Settings")
115
+ with st.sidebar.expander("🗂️ Output Directories"):
116
+ st.write(f"- Checkpoints: `{CHECKPOINT_DIR}`")
117
+ st.write(f"- Generated audio: `{OUTPUT_DIR}`")
118
+ st.write(f"- Uploaded prompts: `{PROMPTS_DIR}`")
119
+ st.info("These directories are located within your Space's persistent storage.")
120
+
121
+
122
+ st.header("1. Upload Reference Audio")
123
+ ref_audio_file = st.file_uploader(
124
+ label="Upload a reference audio (wav or mp3)",
125
+ type=["wav", "mp3"],
126
+ help="This audio will condition the voice characteristics.",
127
+ key="ref_audio_uploader" # Added a key for potential future state management
128
+ )
129
+
130
+ ref_path = None # Initialize ref_path
131
+
132
+ if ref_audio_file:
133
+ # Save the uploaded file to the prompts directory
134
+ # Streamlit's uploader provides file-like object
135
+ ref_filename = ref_audio_file.name
136
+ ref_path = os.path.join(PROMPTS_DIR, ref_filename)
137
+
138
+ # Use a more robust way to save the file
139
+ with open(ref_path, "wb") as f:
140
+ # Use getbuffer() for efficiency with large files
141
+ f.write(ref_audio_file.getbuffer())
142
+
143
+ st.success(f"Saved reference audio: `{ref_filename}`")
144
+ st.audio(ref_path, format="audio/wav") # Display the uploaded audio
145
+
146
+
147
+ st.header("2. Enter Text to Synthesize")
148
+ text_input = st.text_area(
149
+ label="Enter the text you want to convert to speech",
150
+ placeholder="Type your sentence here...",
151
+ key="text_input_area" # Added a key
152
+ )
153
+
154
+ # Button to trigger generation
155
+ generate_button = st.button("Generate Speech", key="generate_tts_button")
156
+
157
+ # ------------------------------------------------------------------------------
158
+ # Trigger Inference and Display Results
159
+ # ------------------------------------------------------------------------------
160
+
161
+ # This block runs only when the button is clicked AND inputs are valid
162
+ if generate_button:
163
+ if not ref_path or not os.path.exists(ref_path):
164
+ st.error("Please upload a reference audio first.")
165
+ elif not text_input or not text_input.strip():
166
+ st.error("Please enter some text to synthesize.")
167
+ else:
168
+ # Use st.spinner to indicate processing is happening
169
+ with st.spinner("🚀 Generating speech..."):
170
+ try:
171
+ # Call the inference function
172
+ output_wav_path = run_inference(ref_path, text_input)
173
+
174
+ # Check if output file was actually created
175
+ if os.path.exists(output_wav_path):
176
+ st.success("🎉 Done! Here’s your generated audio:")
177
+ # Display the generated audio
178
+ st.audio(output_wav_path, format="audio/wav")
179
+ else:
180
+ st.error("Generation failed: Output file was not created.")
181
+
182
+ except Exception as e:
183
+ st.error(f"An error occurred during inference: {e}")
184
+ # Optional: Log the full traceback for debugging on Spaces
185
+ # import traceback
186
+ # st.exception(e) # This shows traceback in the app
187
+
188
+ # Add a footer or more info
189
+ st.markdown("---")
190
+ st.markdown("Demo powered by [IndexTTS](https://arxiv.org/abs/2502.05512) and built with Streamlit.")