learnmlf commited on
Commit
2c1dff6
·
1 Parent(s): 9c07de8

Optimize model loading for HF Space

Browse files

- Implement global model instance with one-time initialization
- Add progressive loading status updates with Chinese descriptions
- Remove manual model loading - auto-initialize on startup
- Add model status refresh functionality
- Improve user experience with clear progress indicators
- Cache model in memory to avoid reloading on each request

Files changed (1) hide show
  1. app.py +77 -19
app.py CHANGED
@@ -99,7 +99,9 @@ class AudioFoleyModel:
99
  self.feature_utils = None
100
 
101
  def load_model(self, variant='large_44k', model_path=None):
102
- """Load the hf_AC model"""
 
 
103
  try:
104
  if not HF_AC_AVAILABLE:
105
  return "❌ hf_AC modules not available. Please install the hf_AC package."
@@ -108,16 +110,20 @@ class AudioFoleyModel:
108
  available_variants = list(all_model_cfg.keys()) if all_model_cfg else []
109
  return f"❌ Unknown model variant: {variant}. Available: {available_variants}"
110
 
 
 
111
  log.info(f"Loading model variant: {variant}")
112
  self.model: ModelConfig = all_model_cfg[variant]
113
 
114
- # Download model components if needed
 
115
  try:
116
  self.model.download_if_needed()
117
  except Exception as e:
118
  log.warning(f"Could not download model components: {e}")
119
 
120
- # Try to download main model weights from HuggingFace
 
121
  if not hasattr(self.model, 'model_path') or not self.model.model_path or not Path(self.model.model_path).exists():
122
  try:
123
  from huggingface_hub import hf_hub_download
@@ -146,10 +152,12 @@ class AudioFoleyModel:
146
  self.model.model_path = Path(model_path)
147
  log.info(f"Using custom model path: {model_path}")
148
 
149
- # Load network
 
150
  self.net: MMAudio = get_my_mmaudio(self.model.model_name).to(self.device, self.dtype).eval()
151
 
152
- # Load weights
 
153
  if hasattr(self.model, 'model_path') and self.model.model_path and Path(self.model.model_path).exists():
154
  try:
155
  weights = torch.load(self.model.model_path, map_location=self.device, weights_only=True)
@@ -157,15 +165,19 @@ class AudioFoleyModel:
157
  log.info(f'✅ Loaded weights from {self.model.model_path}')
158
  except Exception as e:
159
  log.error(f"Failed to load weights: {e}")
160
- return f"❌ Failed to load model weights: {e}"
 
161
  else:
162
  log.warning('⚠️ No model weights found, using default initialization')
163
- return "⚠️ Model components loaded, but main weights not available. Some features may be limited."
 
164
 
165
- # Initialize flow matching
 
166
  self.fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=25)
167
 
168
- # Initialize feature utils
 
169
  try:
170
  self.feature_utils = FeaturesUtils(
171
  tod_vae_ckpt=self.model.vae_path,
@@ -178,13 +190,17 @@ class AudioFoleyModel:
178
  self.feature_utils = self.feature_utils.to(self.device, self.dtype).eval()
179
  except Exception as e:
180
  log.error(f"Failed to initialize feature utils: {e}")
181
- return f"❌ Failed to initialize feature utilities: {e}"
 
182
 
183
- return "✅ Model loaded successfully!"
 
 
184
 
185
  except Exception as e:
186
- error_msg = f"❌ Error loading model: {str(e)}\n{traceback.format_exc()}"
187
  log.error(error_msg)
 
188
  return error_msg
189
 
190
  def generate_audio(self, video_file, prompt: str, negative_prompt: str = "",
@@ -298,11 +314,35 @@ class AudioFoleyModel:
298
  log.error(error_msg)
299
  return None, error_msg
300
 
301
- # Initialize model
302
- audio_model = AudioFoleyModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  def generate_audio_interface(video_file, prompt, duration, cfg_strength):
305
  """Interface function for generating audio"""
 
 
 
 
 
 
306
  # Use fixed seed for consistency in HF Space
307
  seed = 42
308
  negative_prompt = "" # Simplified interface
@@ -312,6 +352,11 @@ def generate_audio_interface(video_file, prompt, duration, cfg_strength):
312
  )
313
  return audio_path, message
314
 
 
 
 
 
 
315
  # Create Gradio interface
316
  with gr.Blocks(title="hf_AC Audio Foley Generator", theme=gr.themes.Soft()) as demo:
317
  gr.Markdown("""
@@ -319,16 +364,23 @@ with gr.Blocks(title="hf_AC Audio Foley Generator", theme=gr.themes.Soft()) as d
319
 
320
  基于AI的视频音频生成工具。上传视频并提供文本描述,模型将生成匹配的音频内容。
321
 
322
- **注意**: 首次使用时模型需要下载,请耐心等待。
323
  """)
324
 
325
- # Model status display
326
  model_status = gr.Textbox(
327
  label="模型状态",
328
- value="正在初始化模型...",
329
  interactive=False
330
  )
331
 
 
 
 
 
 
 
 
332
  with gr.Row():
333
  with gr.Column():
334
  video_input = gr.Video(
@@ -404,12 +456,18 @@ with gr.Blocks(title="hf_AC Audio Foley Generator", theme=gr.themes.Soft()) as d
404
  - "木地板上轻柔的脚步声"
405
  """)
406
 
407
- # Auto-load model on startup
408
  demo.load(
409
- fn=lambda: audio_model.load_model(),
410
  outputs=[model_status]
411
  )
412
 
 
 
 
 
 
 
413
  if __name__ == "__main__":
414
  # HF Space will handle the server configuration
415
  demo.launch()
 
99
  self.feature_utils = None
100
 
101
  def load_model(self, variant='large_44k', model_path=None):
102
+ """Load the hf_AC model with progress updates"""
103
+ global model_loading_status
104
+
105
  try:
106
  if not HF_AC_AVAILABLE:
107
  return "❌ hf_AC modules not available. Please install the hf_AC package."
 
110
  available_variants = list(all_model_cfg.keys()) if all_model_cfg else []
111
  return f"❌ Unknown model variant: {variant}. Available: {available_variants}"
112
 
113
+ # Step 1: Initialize model config
114
+ model_loading_status = "🔧 初始化模型配置..."
115
  log.info(f"Loading model variant: {variant}")
116
  self.model: ModelConfig = all_model_cfg[variant]
117
 
118
+ # Step 2: Download model components
119
+ model_loading_status = "📥 下载模型组件..."
120
  try:
121
  self.model.download_if_needed()
122
  except Exception as e:
123
  log.warning(f"Could not download model components: {e}")
124
 
125
+ # Step 3: Download main model weights
126
+ model_loading_status = "📥 下载主模型权重..."
127
  if not hasattr(self.model, 'model_path') or not self.model.model_path or not Path(self.model.model_path).exists():
128
  try:
129
  from huggingface_hub import hf_hub_download
 
152
  self.model.model_path = Path(model_path)
153
  log.info(f"Using custom model path: {model_path}")
154
 
155
+ # Step 4: Load neural network
156
+ model_loading_status = "🧠 加载神经网络..."
157
  self.net: MMAudio = get_my_mmaudio(self.model.model_name).to(self.device, self.dtype).eval()
158
 
159
+ # Step 5: Load weights
160
+ model_loading_status = "⚖️ 加载模型权重..."
161
  if hasattr(self.model, 'model_path') and self.model.model_path and Path(self.model.model_path).exists():
162
  try:
163
  weights = torch.load(self.model.model_path, map_location=self.device, weights_only=True)
 
165
  log.info(f'✅ Loaded weights from {self.model.model_path}')
166
  except Exception as e:
167
  log.error(f"Failed to load weights: {e}")
168
+ model_loading_status = f"❌ Failed to load model weights: {e}"
169
+ return model_loading_status
170
  else:
171
  log.warning('⚠️ No model weights found, using default initialization')
172
+ model_loading_status = "⚠️ 模型组件已加载,但主权重不可用。某些功能可能受限。"
173
+ return model_loading_status
174
 
175
+ # Step 6: Initialize flow matching
176
+ model_loading_status = "🌊 初始化流匹配..."
177
  self.fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=25)
178
 
179
+ # Step 7: Initialize feature utils
180
+ model_loading_status = "🔧 初始化特征工具..."
181
  try:
182
  self.feature_utils = FeaturesUtils(
183
  tod_vae_ckpt=self.model.vae_path,
 
190
  self.feature_utils = self.feature_utils.to(self.device, self.dtype).eval()
191
  except Exception as e:
192
  log.error(f"Failed to initialize feature utils: {e}")
193
+ model_loading_status = f"❌ Failed to initialize feature utilities: {e}"
194
+ return model_loading_status
195
 
196
+ # Step 8: Complete
197
+ model_loading_status = "✅ 模型加载完成!可以开始生成音频。"
198
+ return model_loading_status
199
 
200
  except Exception as e:
201
+ error_msg = f"❌ 模型加载错误: {str(e)}"
202
  log.error(error_msg)
203
+ model_loading_status = error_msg
204
  return error_msg
205
 
206
  def generate_audio(self, video_file, prompt: str, negative_prompt: str = "",
 
314
  log.error(error_msg)
315
  return None, error_msg
316
 
317
+ # Global model instance - initialized once
318
+ audio_model = None
319
+ model_loading_status = "未初始化"
320
+
321
+ def initialize_model():
322
+ """Initialize model once at startup"""
323
+ global audio_model, model_loading_status
324
+
325
+ if audio_model is None:
326
+ try:
327
+ model_loading_status = "正在初始化模型..."
328
+ audio_model = AudioFoleyModel()
329
+ load_result = audio_model.load_model()
330
+ model_loading_status = load_result
331
+ return load_result
332
+ except Exception as e:
333
+ model_loading_status = f"❌ 模型初始化失败: {str(e)}"
334
+ return model_loading_status
335
+ else:
336
+ return "✅ 模型已加载"
337
 
338
  def generate_audio_interface(video_file, prompt, duration, cfg_strength):
339
  """Interface function for generating audio"""
340
+ global audio_model, model_loading_status
341
+
342
+ # Check if model is loaded
343
+ if audio_model is None or audio_model.net is None:
344
+ return None, "❌ 模型未加载,请等待初始化完成或刷新页面"
345
+
346
  # Use fixed seed for consistency in HF Space
347
  seed = 42
348
  negative_prompt = "" # Simplified interface
 
352
  )
353
  return audio_path, message
354
 
355
+ def get_model_status():
356
+ """Get current model loading status"""
357
+ global model_loading_status
358
+ return model_loading_status
359
+
360
  # Create Gradio interface
361
  with gr.Blocks(title="hf_AC Audio Foley Generator", theme=gr.themes.Soft()) as demo:
362
  gr.Markdown("""
 
364
 
365
  基于AI的视频音频生成工具。上传视频并提供文本描述,模型将生成匹配的音频内容。
366
 
367
+ **注意**: 模型会在启动时自动加载,首次使用需要下载约3GB的模型文件。
368
  """)
369
 
370
+ # Model status display - will be updated automatically
371
  model_status = gr.Textbox(
372
  label="模型状态",
373
+ value=model_loading_status,
374
  interactive=False
375
  )
376
 
377
+ # Add a refresh button for status
378
+ refresh_status_btn = gr.Button("🔄 刷新状态", size="sm")
379
+ refresh_status_btn.click(
380
+ fn=get_model_status,
381
+ outputs=model_status
382
+ )
383
+
384
  with gr.Row():
385
  with gr.Column():
386
  video_input = gr.Video(
 
456
  - "木地板上轻柔的脚步声"
457
  """)
458
 
459
+ # Auto-initialize model on startup
460
  demo.load(
461
+ fn=initialize_model,
462
  outputs=[model_status]
463
  )
464
 
465
+ # Initialize model when module is imported (for HF Space)
466
+ if HF_AC_AVAILABLE:
467
+ print("🚀 Starting model initialization...")
468
+ initialize_model()
469
+ print(f"📊 Model status: {model_loading_status}")
470
+
471
  if __name__ == "__main__":
472
  # HF Space will handle the server configuration
473
  demo.launch()