talking-head / app.py
Opera8's picture
Update app.py
70759be verified
import os
import sys
import time
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
from pathlib import Path
import tempfile
from pydub import AudioSegment
import traceback
# افزودن پوشه src به مسیر سیستم
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
from models.inference.moda_test import LiveVASAPipeline, emo_map, set_seed
# --- تنظیمات اولیه ---
set_seed(42)
DEFAULT_CFG_PATH = "configs/audio2motion/inference/inference.yaml"
DEFAULT_MOTION_MEAN_STD_PATH = "src.datasets/mean.pt"
DEFAULT_SILENT_AUDIO_PATH = "src/examples/silent-audio.wav"
OUTPUT_DIR = "gradio_output"
WEIGHTS_DIR = "pretrain_weights"
REPO_ID = "lixinyizju/moda"
PERSIAN_EMOTION_MAP = {
"خنثی (Neutral)": "Neutral",
"خوشحال (Happy)": "Happy",
"عصبانی (Angry)": "Angry",
"غمگین (Sad)": "Sad",
"متعجب (Surprise)": "Surprise"
}
# --- دانلود وزن‌ها ---
def download_weights():
motion_model_file = os.path.join(WEIGHTS_DIR, "moda", "net-200.pth")
if not os.path.exists(motion_model_file):
print("📥 در حال دانلود مدل‌ها...")
try:
snapshot_download(repo_id=REPO_ID, local_dir=WEIGHTS_DIR, local_dir_use_symlinks=False, resume_download=True)
except Exception as e:
print(f"Error downloading: {e}")
# --- تبدیل صدا ---
def ensure_wav_format(audio_path):
if audio_path is None: return None
audio_path = Path(audio_path)
if audio_path.suffix.lower() == '.wav': return str(audio_path)
try:
audio = AudioSegment.from_file(audio_path)
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
wav_path = tmp_file.name
audio.export(wav_path, format='wav', parameters=["-ar", "16000", "-ac", "1"])
return wav_path
except Exception as e:
raise gr.Error(f"فرمت فایل صوتی نامعتبر است: {e}")
# --- لود اولیه ---
os.makedirs(OUTPUT_DIR, exist_ok=True)
download_weights()
print("⏳ در حال لود مدل...")
try:
pipeline = LiveVASAPipeline(cfg_path=DEFAULT_CFG_PATH, motion_mean_std_path=DEFAULT_MOTION_MEAN_STD_PATH)
print("✅ مدل با موفقیت لود شد.")
except Exception as e:
print(f"❌ خطا در لود مدل: {e}")
pipeline = None
emo_name_to_id = {v: k for k, v in emo_map.items()}
# --- **جاوا اسکریپت جدید برای نمایش پیام خطای زیبا** ---
js_func = """
() => {
const observer = new MutationObserver((mutations) => {
mutations.forEach((mutation) => {
if (mutation.addedNodes.length) {
mutation.addedNodes.forEach((node) => {
// فقط روی المان‌های مربوط به خطای گرادیو کار کن
if (node.nodeType === 1 && (node.classList.contains('toast-body') || node.classList.contains('error'))) {
const originalText = node.innerText;
// Regex برای پیدا کردن اعداد زمان
const regex = /(\d+)s requested vs. (\d+)s left/;
const match = originalText.match(regex);
// اگر متن خطا مربوط به Quota بود و قبلا ترجمه نشده بود
if (match && !node.dataset.translated) {
const requested = match[1];
const left = match[2];
// **ساخت کارت HTML زیبا**
const prettyHtml = `
<div style="display: flex; align-items: center; gap: 15px; font-family: 'Tahoma', sans-serif; direction: rtl; padding: 10px;">
<div style="font-size: 2.5em; color: #dc3545;">⏳</div>
<div>
<h4 style="margin: 0; color: #5a6268; font-weight: bold;">ظرفیت سرور تکمیل است!</h4>
<p style="margin: 5px 0 0 0; color: #6c757d; font-size: 0.9em;">
سهمیه رایگان GPU شما برای پردازش یک ویدیوی <b>${requested} ثانیه‌ای</b> کافی نیست.
</p>
<div style="background-color: #f8d7da; border: 1px solid #f5c6cb; border-radius: 5px; padding: 5px 8px; margin-top: 10px; font-size: 0.85em;">
اعتبار باقیمانده: <b>${left} ثانیه</b>
</div>
</div>
</div>
`;
// جایگزینی محتوای قدیمی با کارت جدید
node.innerHTML = prettyHtml;
// جلوگیری از ترجمه مجدد
node.dataset.translated = 'true';
}
}
});
}
});
});
observer.observe(document.body, { childList: true, subtree: true });
}
"""
# --- تابع اصلی ---
@spaces.GPU(duration=120)
def generate_motion(source_image_path, driving_audio_path, persian_emotion_name, cfg_scale, progress=gr.Progress(track_tqdm=True)):
if pipeline is None:
raise gr.Error("❌ مدل هنوز بارگذاری نشده است. لطفاً صفحه را رفرش کنید.")
if source_image_path is None:
raise gr.Error("⚠️ لطفاً تصویر چهره را انتخاب کنید.")
if driving_audio_path is None:
raise gr.Error("⚠️ لطفاً فایل صوتی را انتخاب کنید.")
try:
english_emo_name = PERSIAN_EMOTION_MAP.get(persian_emotion_name, "Neutral")
emotion_id = emo_name_to_id.get(english_emo_name, 8)
wav_audio_path = ensure_wav_format(driving_audio_path)
result_video_path = pipeline.driven_sample(
image_path=source_image_path, audio_path=wav_audio_path, cfg_scale=float(cfg_scale),
emo=emotion_id, save_dir=".", smooth=False, silent_audio_path=DEFAULT_SILENT_AUDIO_PATH,
)
final_path = Path(result_video_path)
renamed_path = final_path.with_name(f"final_{int(time.time())}.mp4")
final_path.rename(renamed_path)
return str(renamed_path)
except Exception as e:
traceback.print_exc()
raise gr.Error(f"❌ خطا در پردازش: {e}")
# --- رابط کاربری ---
css = """
.gradio-container {max-width: 900px !important; margin: auto !important; font-family: 'Tahoma', sans-serif;}
h1, h2, h3, p, span, div {direction: rtl; text-align: right;}
.center-text {text-align: center !important;}
.toast-wrap { direction: rtl !important; }
.toast-body { padding: 0 !important; } /* برای اینکه کارت ما کل فضا را بگیرد */
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css, title="MoDA Farsi", js=js_func) as demo:
gr.HTML("<div class='center-text'><h1>MoDA: ساخت ویدیو سخنگو</h1></div>")
with gr.Row():
with gr.Column():
source_image = gr.Image(label="۱. تصویر چهره", type="filepath", value="src/examples/reference_images/7.jpg", height=250)
driving_audio = gr.Audio(label="۲. فایل صوتی", type="filepath", value="src/examples/driving_audios/5.wav")
with gr.Accordion("⚙️ تنظیمات پیشرفته", open=True):
with gr.Row():
emotion_dropdown = gr.Dropdown(label="حالت چهره", choices=list(PERSIAN_EMOTION_MAP.keys()), value="خنثی (Neutral)")
cfg_slider = gr.Slider(label="دقت (CFG)", minimum=1.0, maximum=3.0, value=1.2, step=0.1)
submit_btn = gr.Button("🎥 ساخت ویدیو", variant="primary", size="lg")
with gr.Column():
output_video = gr.Video(label="خروجی نهایی", height=400, autoplay=True)
submit_btn.click(
fn=generate_motion,
inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider],
outputs=output_video
)
if __name__ == "__main__":
demo.launch()