Optimize model loading for HF Space
Browse files- Implement global model instance with one-time initialization
- Add progressive loading status updates with Chinese descriptions
- Remove manual model loading - auto-initialize on startup
- Add model status refresh functionality
- Improve user experience with clear progress indicators
- Cache model in memory to avoid reloading on each request
app.py
CHANGED
|
@@ -99,7 +99,9 @@ class AudioFoleyModel:
|
|
| 99 |
self.feature_utils = None
|
| 100 |
|
| 101 |
def load_model(self, variant='large_44k', model_path=None):
|
| 102 |
-
"""Load the hf_AC model"""
|
|
|
|
|
|
|
| 103 |
try:
|
| 104 |
if not HF_AC_AVAILABLE:
|
| 105 |
return "❌ hf_AC modules not available. Please install the hf_AC package."
|
|
@@ -108,16 +110,20 @@ class AudioFoleyModel:
|
|
| 108 |
available_variants = list(all_model_cfg.keys()) if all_model_cfg else []
|
| 109 |
return f"❌ Unknown model variant: {variant}. Available: {available_variants}"
|
| 110 |
|
|
|
|
|
|
|
| 111 |
log.info(f"Loading model variant: {variant}")
|
| 112 |
self.model: ModelConfig = all_model_cfg[variant]
|
| 113 |
|
| 114 |
-
# Download model components
|
|
|
|
| 115 |
try:
|
| 116 |
self.model.download_if_needed()
|
| 117 |
except Exception as e:
|
| 118 |
log.warning(f"Could not download model components: {e}")
|
| 119 |
|
| 120 |
-
#
|
|
|
|
| 121 |
if not hasattr(self.model, 'model_path') or not self.model.model_path or not Path(self.model.model_path).exists():
|
| 122 |
try:
|
| 123 |
from huggingface_hub import hf_hub_download
|
|
@@ -146,10 +152,12 @@ class AudioFoleyModel:
|
|
| 146 |
self.model.model_path = Path(model_path)
|
| 147 |
log.info(f"Using custom model path: {model_path}")
|
| 148 |
|
| 149 |
-
# Load network
|
|
|
|
| 150 |
self.net: MMAudio = get_my_mmaudio(self.model.model_name).to(self.device, self.dtype).eval()
|
| 151 |
|
| 152 |
-
# Load weights
|
|
|
|
| 153 |
if hasattr(self.model, 'model_path') and self.model.model_path and Path(self.model.model_path).exists():
|
| 154 |
try:
|
| 155 |
weights = torch.load(self.model.model_path, map_location=self.device, weights_only=True)
|
|
@@ -157,15 +165,19 @@ class AudioFoleyModel:
|
|
| 157 |
log.info(f'✅ Loaded weights from {self.model.model_path}')
|
| 158 |
except Exception as e:
|
| 159 |
log.error(f"Failed to load weights: {e}")
|
| 160 |
-
|
|
|
|
| 161 |
else:
|
| 162 |
log.warning('⚠️ No model weights found, using default initialization')
|
| 163 |
-
|
|
|
|
| 164 |
|
| 165 |
-
# Initialize flow matching
|
|
|
|
| 166 |
self.fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=25)
|
| 167 |
|
| 168 |
-
# Initialize feature utils
|
|
|
|
| 169 |
try:
|
| 170 |
self.feature_utils = FeaturesUtils(
|
| 171 |
tod_vae_ckpt=self.model.vae_path,
|
|
@@ -178,13 +190,17 @@ class AudioFoleyModel:
|
|
| 178 |
self.feature_utils = self.feature_utils.to(self.device, self.dtype).eval()
|
| 179 |
except Exception as e:
|
| 180 |
log.error(f"Failed to initialize feature utils: {e}")
|
| 181 |
-
|
|
|
|
| 182 |
|
| 183 |
-
|
|
|
|
|
|
|
| 184 |
|
| 185 |
except Exception as e:
|
| 186 |
-
error_msg = f"❌
|
| 187 |
log.error(error_msg)
|
|
|
|
| 188 |
return error_msg
|
| 189 |
|
| 190 |
def generate_audio(self, video_file, prompt: str, negative_prompt: str = "",
|
|
@@ -298,11 +314,35 @@ class AudioFoleyModel:
|
|
| 298 |
log.error(error_msg)
|
| 299 |
return None, error_msg
|
| 300 |
|
| 301 |
-
#
|
| 302 |
-
audio_model =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
def generate_audio_interface(video_file, prompt, duration, cfg_strength):
|
| 305 |
"""Interface function for generating audio"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
# Use fixed seed for consistency in HF Space
|
| 307 |
seed = 42
|
| 308 |
negative_prompt = "" # Simplified interface
|
|
@@ -312,6 +352,11 @@ def generate_audio_interface(video_file, prompt, duration, cfg_strength):
|
|
| 312 |
)
|
| 313 |
return audio_path, message
|
| 314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
# Create Gradio interface
|
| 316 |
with gr.Blocks(title="hf_AC Audio Foley Generator", theme=gr.themes.Soft()) as demo:
|
| 317 |
gr.Markdown("""
|
|
@@ -319,16 +364,23 @@ with gr.Blocks(title="hf_AC Audio Foley Generator", theme=gr.themes.Soft()) as d
|
|
| 319 |
|
| 320 |
基于AI的视频音频生成工具。上传视频并提供文本描述,模型将生成匹配的音频内容。
|
| 321 |
|
| 322 |
-
**注意**:
|
| 323 |
""")
|
| 324 |
|
| 325 |
-
# Model status display
|
| 326 |
model_status = gr.Textbox(
|
| 327 |
label="模型状态",
|
| 328 |
-
value=
|
| 329 |
interactive=False
|
| 330 |
)
|
| 331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
with gr.Row():
|
| 333 |
with gr.Column():
|
| 334 |
video_input = gr.Video(
|
|
@@ -404,12 +456,18 @@ with gr.Blocks(title="hf_AC Audio Foley Generator", theme=gr.themes.Soft()) as d
|
|
| 404 |
- "木地板上轻柔的脚步声"
|
| 405 |
""")
|
| 406 |
|
| 407 |
-
# Auto-
|
| 408 |
demo.load(
|
| 409 |
-
fn=
|
| 410 |
outputs=[model_status]
|
| 411 |
)
|
| 412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
if __name__ == "__main__":
|
| 414 |
# HF Space will handle the server configuration
|
| 415 |
demo.launch()
|
|
|
|
| 99 |
self.feature_utils = None
|
| 100 |
|
| 101 |
def load_model(self, variant='large_44k', model_path=None):
|
| 102 |
+
"""Load the hf_AC model with progress updates"""
|
| 103 |
+
global model_loading_status
|
| 104 |
+
|
| 105 |
try:
|
| 106 |
if not HF_AC_AVAILABLE:
|
| 107 |
return "❌ hf_AC modules not available. Please install the hf_AC package."
|
|
|
|
| 110 |
available_variants = list(all_model_cfg.keys()) if all_model_cfg else []
|
| 111 |
return f"❌ Unknown model variant: {variant}. Available: {available_variants}"
|
| 112 |
|
| 113 |
+
# Step 1: Initialize model config
|
| 114 |
+
model_loading_status = "🔧 初始化模型配置..."
|
| 115 |
log.info(f"Loading model variant: {variant}")
|
| 116 |
self.model: ModelConfig = all_model_cfg[variant]
|
| 117 |
|
| 118 |
+
# Step 2: Download model components
|
| 119 |
+
model_loading_status = "📥 下载模型组件..."
|
| 120 |
try:
|
| 121 |
self.model.download_if_needed()
|
| 122 |
except Exception as e:
|
| 123 |
log.warning(f"Could not download model components: {e}")
|
| 124 |
|
| 125 |
+
# Step 3: Download main model weights
|
| 126 |
+
model_loading_status = "📥 下载主模型权重..."
|
| 127 |
if not hasattr(self.model, 'model_path') or not self.model.model_path or not Path(self.model.model_path).exists():
|
| 128 |
try:
|
| 129 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 152 |
self.model.model_path = Path(model_path)
|
| 153 |
log.info(f"Using custom model path: {model_path}")
|
| 154 |
|
| 155 |
+
# Step 4: Load neural network
|
| 156 |
+
model_loading_status = "🧠 加载神经网络..."
|
| 157 |
self.net: MMAudio = get_my_mmaudio(self.model.model_name).to(self.device, self.dtype).eval()
|
| 158 |
|
| 159 |
+
# Step 5: Load weights
|
| 160 |
+
model_loading_status = "⚖️ 加载模型权重..."
|
| 161 |
if hasattr(self.model, 'model_path') and self.model.model_path and Path(self.model.model_path).exists():
|
| 162 |
try:
|
| 163 |
weights = torch.load(self.model.model_path, map_location=self.device, weights_only=True)
|
|
|
|
| 165 |
log.info(f'✅ Loaded weights from {self.model.model_path}')
|
| 166 |
except Exception as e:
|
| 167 |
log.error(f"Failed to load weights: {e}")
|
| 168 |
+
model_loading_status = f"❌ Failed to load model weights: {e}"
|
| 169 |
+
return model_loading_status
|
| 170 |
else:
|
| 171 |
log.warning('⚠️ No model weights found, using default initialization')
|
| 172 |
+
model_loading_status = "⚠️ 模型组件已加载,但主权重不可用。某些功能可能受限。"
|
| 173 |
+
return model_loading_status
|
| 174 |
|
| 175 |
+
# Step 6: Initialize flow matching
|
| 176 |
+
model_loading_status = "🌊 初始化流匹配..."
|
| 177 |
self.fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=25)
|
| 178 |
|
| 179 |
+
# Step 7: Initialize feature utils
|
| 180 |
+
model_loading_status = "🔧 初始化特征工具..."
|
| 181 |
try:
|
| 182 |
self.feature_utils = FeaturesUtils(
|
| 183 |
tod_vae_ckpt=self.model.vae_path,
|
|
|
|
| 190 |
self.feature_utils = self.feature_utils.to(self.device, self.dtype).eval()
|
| 191 |
except Exception as e:
|
| 192 |
log.error(f"Failed to initialize feature utils: {e}")
|
| 193 |
+
model_loading_status = f"❌ Failed to initialize feature utilities: {e}"
|
| 194 |
+
return model_loading_status
|
| 195 |
|
| 196 |
+
# Step 8: Complete
|
| 197 |
+
model_loading_status = "✅ 模型加载完成!可以开始生成音频。"
|
| 198 |
+
return model_loading_status
|
| 199 |
|
| 200 |
except Exception as e:
|
| 201 |
+
error_msg = f"❌ 模型加载错误: {str(e)}"
|
| 202 |
log.error(error_msg)
|
| 203 |
+
model_loading_status = error_msg
|
| 204 |
return error_msg
|
| 205 |
|
| 206 |
def generate_audio(self, video_file, prompt: str, negative_prompt: str = "",
|
|
|
|
| 314 |
log.error(error_msg)
|
| 315 |
return None, error_msg
|
| 316 |
|
| 317 |
+
# Global model instance - initialized once
|
| 318 |
+
audio_model = None
|
| 319 |
+
model_loading_status = "未初始化"
|
| 320 |
+
|
| 321 |
+
def initialize_model():
|
| 322 |
+
"""Initialize model once at startup"""
|
| 323 |
+
global audio_model, model_loading_status
|
| 324 |
+
|
| 325 |
+
if audio_model is None:
|
| 326 |
+
try:
|
| 327 |
+
model_loading_status = "正在初始化模型..."
|
| 328 |
+
audio_model = AudioFoleyModel()
|
| 329 |
+
load_result = audio_model.load_model()
|
| 330 |
+
model_loading_status = load_result
|
| 331 |
+
return load_result
|
| 332 |
+
except Exception as e:
|
| 333 |
+
model_loading_status = f"❌ 模型初始化失败: {str(e)}"
|
| 334 |
+
return model_loading_status
|
| 335 |
+
else:
|
| 336 |
+
return "✅ 模型已加载"
|
| 337 |
|
| 338 |
def generate_audio_interface(video_file, prompt, duration, cfg_strength):
|
| 339 |
"""Interface function for generating audio"""
|
| 340 |
+
global audio_model, model_loading_status
|
| 341 |
+
|
| 342 |
+
# Check if model is loaded
|
| 343 |
+
if audio_model is None or audio_model.net is None:
|
| 344 |
+
return None, "❌ 模型未加载,请等待初始化完成或刷新页面"
|
| 345 |
+
|
| 346 |
# Use fixed seed for consistency in HF Space
|
| 347 |
seed = 42
|
| 348 |
negative_prompt = "" # Simplified interface
|
|
|
|
| 352 |
)
|
| 353 |
return audio_path, message
|
| 354 |
|
| 355 |
+
def get_model_status():
|
| 356 |
+
"""Get current model loading status"""
|
| 357 |
+
global model_loading_status
|
| 358 |
+
return model_loading_status
|
| 359 |
+
|
| 360 |
# Create Gradio interface
|
| 361 |
with gr.Blocks(title="hf_AC Audio Foley Generator", theme=gr.themes.Soft()) as demo:
|
| 362 |
gr.Markdown("""
|
|
|
|
| 364 |
|
| 365 |
基于AI的视频音频生成工具。上传视频并提供文本描述,模型将生成匹配的音频内容。
|
| 366 |
|
| 367 |
+
**注意**: 模型会在启动时自动加载,首次使用需要下载约3GB的模型文件。
|
| 368 |
""")
|
| 369 |
|
| 370 |
+
# Model status display - will be updated automatically
|
| 371 |
model_status = gr.Textbox(
|
| 372 |
label="模型状态",
|
| 373 |
+
value=model_loading_status,
|
| 374 |
interactive=False
|
| 375 |
)
|
| 376 |
|
| 377 |
+
# Add a refresh button for status
|
| 378 |
+
refresh_status_btn = gr.Button("🔄 刷新状态", size="sm")
|
| 379 |
+
refresh_status_btn.click(
|
| 380 |
+
fn=get_model_status,
|
| 381 |
+
outputs=model_status
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
with gr.Row():
|
| 385 |
with gr.Column():
|
| 386 |
video_input = gr.Video(
|
|
|
|
| 456 |
- "木地板上轻柔的脚步声"
|
| 457 |
""")
|
| 458 |
|
| 459 |
+
# Auto-initialize model on startup
|
| 460 |
demo.load(
|
| 461 |
+
fn=initialize_model,
|
| 462 |
outputs=[model_status]
|
| 463 |
)
|
| 464 |
|
| 465 |
+
# Initialize model when module is imported (for HF Space)
|
| 466 |
+
if HF_AC_AVAILABLE:
|
| 467 |
+
print("🚀 Starting model initialization...")
|
| 468 |
+
initialize_model()
|
| 469 |
+
print(f"📊 Model status: {model_loading_status}")
|
| 470 |
+
|
| 471 |
if __name__ == "__main__":
|
| 472 |
# HF Space will handle the server configuration
|
| 473 |
demo.launch()
|