import os import sys import time import gradio as gr import spaces from huggingface_hub import snapshot_download from pathlib import Path import tempfile from pydub import AudioSegment import traceback # افزودن پوشه src به مسیر سیستم sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))) from models.inference.moda_test import LiveVASAPipeline, emo_map, set_seed # --- تنظیمات اولیه --- set_seed(42) DEFAULT_CFG_PATH = "configs/audio2motion/inference/inference.yaml" DEFAULT_MOTION_MEAN_STD_PATH = "src.datasets/mean.pt" DEFAULT_SILENT_AUDIO_PATH = "src/examples/silent-audio.wav" OUTPUT_DIR = "gradio_output" WEIGHTS_DIR = "pretrain_weights" REPO_ID = "lixinyizju/moda" PERSIAN_EMOTION_MAP = { "خنثی (Neutral)": "Neutral", "خوشحال (Happy)": "Happy", "عصبانی (Angry)": "Angry", "غمگین (Sad)": "Sad", "متعجب (Surprise)": "Surprise" } # --- دانلود وزن‌ها --- def download_weights(): motion_model_file = os.path.join(WEIGHTS_DIR, "moda", "net-200.pth") if not os.path.exists(motion_model_file): print("📥 در حال دانلود مدل‌ها...") try: snapshot_download(repo_id=REPO_ID, local_dir=WEIGHTS_DIR, local_dir_use_symlinks=False, resume_download=True) except Exception as e: print(f"Error downloading: {e}") # --- تبدیل صدا --- def ensure_wav_format(audio_path): if audio_path is None: return None audio_path = Path(audio_path) if audio_path.suffix.lower() == '.wav': return str(audio_path) try: audio = AudioSegment.from_file(audio_path) with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: wav_path = tmp_file.name audio.export(wav_path, format='wav', parameters=["-ar", "16000", "-ac", "1"]) return wav_path except Exception as e: raise gr.Error(f"فرمت فایل صوتی نامعتبر است: {e}") # --- لود اولیه --- os.makedirs(OUTPUT_DIR, exist_ok=True) download_weights() print("⏳ در حال لود مدل...") try: pipeline = LiveVASAPipeline(cfg_path=DEFAULT_CFG_PATH, motion_mean_std_path=DEFAULT_MOTION_MEAN_STD_PATH) print("✅ مدل با موفقیت لود شد.") except Exception as e: print(f"❌ خطا در لود مدل: {e}") pipeline = None emo_name_to_id = {v: k for k, v in emo_map.items()} # --- **جاوا اسکریپت جدید برای نمایش پیام خطای زیبا** --- js_func = """ () => { const observer = new MutationObserver((mutations) => { mutations.forEach((mutation) => { if (mutation.addedNodes.length) { mutation.addedNodes.forEach((node) => { // فقط روی المان‌های مربوط به خطای گرادیو کار کن if (node.nodeType === 1 && (node.classList.contains('toast-body') || node.classList.contains('error'))) { const originalText = node.innerText; // Regex برای پیدا کردن اعداد زمان const regex = /(\d+)s requested vs. (\d+)s left/; const match = originalText.match(regex); // اگر متن خطا مربوط به Quota بود و قبلا ترجمه نشده بود if (match && !node.dataset.translated) { const requested = match[1]; const left = match[2]; // **ساخت کارت HTML زیبا** const prettyHtml = `

ظرفیت سرور تکمیل است!

سهمیه رایگان GPU شما برای پردازش یک ویدیوی ${requested} ثانیه‌ای کافی نیست.

اعتبار باقیمانده: ${left} ثانیه
`; // جایگزینی محتوای قدیمی با کارت جدید node.innerHTML = prettyHtml; // جلوگیری از ترجمه مجدد node.dataset.translated = 'true'; } } }); } }); }); observer.observe(document.body, { childList: true, subtree: true }); } """ # --- تابع اصلی --- @spaces.GPU(duration=120) def generate_motion(source_image_path, driving_audio_path, persian_emotion_name, cfg_scale, progress=gr.Progress(track_tqdm=True)): if pipeline is None: raise gr.Error("❌ مدل هنوز بارگذاری نشده است. لطفاً صفحه را رفرش کنید.") if source_image_path is None: raise gr.Error("⚠️ لطفاً تصویر چهره را انتخاب کنید.") if driving_audio_path is None: raise gr.Error("⚠️ لطفاً فایل صوتی را انتخاب کنید.") try: english_emo_name = PERSIAN_EMOTION_MAP.get(persian_emotion_name, "Neutral") emotion_id = emo_name_to_id.get(english_emo_name, 8) wav_audio_path = ensure_wav_format(driving_audio_path) result_video_path = pipeline.driven_sample( image_path=source_image_path, audio_path=wav_audio_path, cfg_scale=float(cfg_scale), emo=emotion_id, save_dir=".", smooth=False, silent_audio_path=DEFAULT_SILENT_AUDIO_PATH, ) final_path = Path(result_video_path) renamed_path = final_path.with_name(f"final_{int(time.time())}.mp4") final_path.rename(renamed_path) return str(renamed_path) except Exception as e: traceback.print_exc() raise gr.Error(f"❌ خطا در پردازش: {e}") # --- رابط کاربری --- css = """ .gradio-container {max-width: 900px !important; margin: auto !important; font-family: 'Tahoma', sans-serif;} h1, h2, h3, p, span, div {direction: rtl; text-align: right;} .center-text {text-align: center !important;} .toast-wrap { direction: rtl !important; } .toast-body { padding: 0 !important; } /* برای اینکه کارت ما کل فضا را بگیرد */ """ with gr.Blocks(theme=gr.themes.Soft(), css=css, title="MoDA Farsi", js=js_func) as demo: gr.HTML("

MoDA: ساخت ویدیو سخنگو

") with gr.Row(): with gr.Column(): source_image = gr.Image(label="۱. تصویر چهره", type="filepath", value="src/examples/reference_images/7.jpg", height=250) driving_audio = gr.Audio(label="۲. فایل صوتی", type="filepath", value="src/examples/driving_audios/5.wav") with gr.Accordion("⚙️ تنظیمات پیشرفته", open=True): with gr.Row(): emotion_dropdown = gr.Dropdown(label="حالت چهره", choices=list(PERSIAN_EMOTION_MAP.keys()), value="خنثی (Neutral)") cfg_slider = gr.Slider(label="دقت (CFG)", minimum=1.0, maximum=3.0, value=1.2, step=0.1) submit_btn = gr.Button("🎥 ساخت ویدیو", variant="primary", size="lg") with gr.Column(): output_video = gr.Video(label="خروجی نهایی", height=400, autoplay=True) submit_btn.click( fn=generate_motion, inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider], outputs=output_video ) if __name__ == "__main__": demo.launch()