File size: 8,585 Bytes
b040570
 
 
 
a5bbb73
b040570
 
220c7ea
145d786
56bf13d
b040570
56bf13d
b040570
 
 
 
2ff5b31
b040570
 
 
70759be
b040570
 
 
 
 
faead0f
 
 
 
 
 
 
 
56bf13d
b040570
 
 
2ff5b31
b040570
a1efa23
b040570
2ff5b31
b040570
56bf13d
220c7ea
56bf13d
220c7ea
56bf13d
220c7ea
 
 
 
faead0f
220c7ea
 
2ff5b31
b040570
56bf13d
b040570
 
 
2ff5b31
b040570
56bf13d
2ff5b31
b040570
2ff5b31
b040570
 
70759be
b040570
70759be
0b650eb
a1efa23
f0c7d9e
 
 
 
70759be
f0c7d9e
 
145d786
70759be
 
 
 
 
 
 
 
 
 
 
f0c7d9e
 
 
 
70759be
 
 
 
 
 
f0c7d9e
 
 
70759be
 
f0c7d9e
70759be
 
 
f0c7d9e
 
 
0b650eb
 
 
a1efa23
0b650eb
 
 
 
 
a1efa23
faead0f
2ff5b31
0b650eb
2ff5b31
 
 
 
b040570
56bf13d
 
 
b040570
1f37d61
 
b040570
56bf13d
 
faead0f
56bf13d
 
 
a1efa23
b040570
145d786
faead0f
f0c7d9e
56bf13d
 
a1efa23
70759be
faead0f
 
f0c7d9e
145d786
56bf13d
 
1f37d61
f0c7d9e
a1efa23
2ff5b31
1f37d61
2ff5b31
56bf13d
 
2ff5b31
56bf13d
b040570
 
 
 
 
 
faead0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import sys
import time
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
from pathlib import Path
import tempfile
from pydub import AudioSegment
import traceback

# افزودن پوشه src به مسیر سیستم
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))

from models.inference.moda_test import LiveVASAPipeline, emo_map, set_seed

# --- تنظیمات اولیه ---
set_seed(42)

DEFAULT_CFG_PATH = "configs/audio2motion/inference/inference.yaml"
DEFAULT_MOTION_MEAN_STD_PATH = "src.datasets/mean.pt"
DEFAULT_SILENT_AUDIO_PATH = "src/examples/silent-audio.wav"
OUTPUT_DIR = "gradio_output"
WEIGHTS_DIR = "pretrain_weights"
REPO_ID = "lixinyizju/moda"

PERSIAN_EMOTION_MAP = {
    "خنثی (Neutral)": "Neutral",
    "خوشحال (Happy)": "Happy",
    "عصبانی (Angry)": "Angry",
    "غمگین (Sad)": "Sad",
    "متعجب (Surprise)": "Surprise"
}

# --- دانلود وزن‌ها ---
def download_weights():
    motion_model_file = os.path.join(WEIGHTS_DIR, "moda", "net-200.pth")
    if not os.path.exists(motion_model_file):
        print("📥 در حال دانلود مدل‌ها...")
        try:
            snapshot_download(repo_id=REPO_ID, local_dir=WEIGHTS_DIR, local_dir_use_symlinks=False, resume_download=True)
        except Exception as e:
            print(f"Error downloading: {e}")

# --- تبدیل صدا ---
def ensure_wav_format(audio_path):
    if audio_path is None: return None
    audio_path = Path(audio_path)
    if audio_path.suffix.lower() == '.wav': return str(audio_path)
    try:
        audio = AudioSegment.from_file(audio_path)
        with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
            wav_path = tmp_file.name
            audio.export(wav_path, format='wav', parameters=["-ar", "16000", "-ac", "1"])
        return wav_path
    except Exception as e:
        raise gr.Error(f"فرمت فایل صوتی نامعتبر است: {e}")

# --- لود اولیه ---
os.makedirs(OUTPUT_DIR, exist_ok=True)
download_weights()

print("⏳ در حال لود مدل...")
try:
    pipeline = LiveVASAPipeline(cfg_path=DEFAULT_CFG_PATH, motion_mean_std_path=DEFAULT_MOTION_MEAN_STD_PATH)
    print("✅ مدل با موفقیت لود شد.")
except Exception as e:
    print(f"❌ خطا در لود مدل: {e}")
    pipeline = None

emo_name_to_id = {v: k for k, v in emo_map.items()}

# --- **جاوا اسکریپت جدید برای نمایش پیام خطای زیبا** ---
js_func = """
() => {
    const observer = new MutationObserver((mutations) => {
        mutations.forEach((mutation) => {
            if (mutation.addedNodes.length) {
                mutation.addedNodes.forEach((node) => {
                    // فقط روی المان‌های مربوط به خطای گرادیو کار کن
                    if (node.nodeType === 1 && (node.classList.contains('toast-body') || node.classList.contains('error'))) {
                        const originalText = node.innerText;
                        
                        // Regex برای پیدا کردن اعداد زمان
                        const regex = /(\d+)s requested vs. (\d+)s left/;
                        const match = originalText.match(regex);
                        
                        // اگر متن خطا مربوط به Quota بود و قبلا ترجمه نشده بود
                        if (match && !node.dataset.translated) {
                            const requested = match[1];
                            const left = match[2];
                            
                            // **ساخت کارت HTML زیبا**
                            const prettyHtml = `
                                <div style="display: flex; align-items: center; gap: 15px; font-family: 'Tahoma', sans-serif; direction: rtl; padding: 10px;">
                                    <div style="font-size: 2.5em; color: #dc3545;">⏳</div>
                                    <div>
                                        <h4 style="margin: 0; color: #5a6268; font-weight: bold;">ظرفیت سرور تکمیل است!</h4>
                                        <p style="margin: 5px 0 0 0; color: #6c757d; font-size: 0.9em;">
                                            سهمیه رایگان GPU شما برای پردازش یک ویدیوی <b>${requested} ثانیه‌ای</b> کافی نیست.
                                        </p>
                                        <div style="background-color: #f8d7da; border: 1px solid #f5c6cb; border-radius: 5px; padding: 5px 8px; margin-top: 10px; font-size: 0.85em;">
                                            اعتبار باقیمانده: <b>${left} ثانیه</b>
                                        </div>
                                    </div>
                                </div>
                            `;
                            
                            // جایگزینی محتوای قدیمی با کارت جدید
                            node.innerHTML = prettyHtml;
                            
                            // جلوگیری از ترجمه مجدد
                            node.dataset.translated = 'true'; 
                        }
                    }
                });
            }
        });
    });

    observer.observe(document.body, { childList: true, subtree: true });
}
"""

# --- تابع اصلی ---
@spaces.GPU(duration=120)
def generate_motion(source_image_path, driving_audio_path, persian_emotion_name, cfg_scale, progress=gr.Progress(track_tqdm=True)):
    if pipeline is None:
        raise gr.Error("❌ مدل هنوز بارگذاری نشده است. لطفاً صفحه را رفرش کنید.")
    if source_image_path is None:
        raise gr.Error("⚠️ لطفاً تصویر چهره را انتخاب کنید.")
    if driving_audio_path is None:
        raise gr.Error("⚠️ لطفاً فایل صوتی را انتخاب کنید.")
    try:
        english_emo_name = PERSIAN_EMOTION_MAP.get(persian_emotion_name, "Neutral")
        emotion_id = emo_name_to_id.get(english_emo_name, 8)
        wav_audio_path = ensure_wav_format(driving_audio_path)
        result_video_path = pipeline.driven_sample(
            image_path=source_image_path, audio_path=wav_audio_path, cfg_scale=float(cfg_scale),
            emo=emotion_id, save_dir=".", smooth=False, silent_audio_path=DEFAULT_SILENT_AUDIO_PATH,
        )
        final_path = Path(result_video_path)
        renamed_path = final_path.with_name(f"final_{int(time.time())}.mp4")
        final_path.rename(renamed_path)
        return str(renamed_path)
    except Exception as e:
        traceback.print_exc()
        raise gr.Error(f"❌ خطا در پردازش: {e}")

# --- رابط کاربری ---
css = """
.gradio-container {max-width: 900px !important; margin: auto !important; font-family: 'Tahoma', sans-serif;}
h1, h2, h3, p, span, div {direction: rtl; text-align: right;}
.center-text {text-align: center !important;}
.toast-wrap { direction: rtl !important; }
.toast-body { padding: 0 !important; } /* برای اینکه کارت ما کل فضا را بگیرد */
"""

with gr.Blocks(theme=gr.themes.Soft(), css=css, title="MoDA Farsi", js=js_func) as demo:
    gr.HTML("<div class='center-text'><h1>MoDA: ساخت ویدیو سخنگو</h1></div>")
    with gr.Row():
        with gr.Column():
            source_image = gr.Image(label="۱. تصویر چهره", type="filepath", value="src/examples/reference_images/7.jpg", height=250)
            driving_audio = gr.Audio(label="۲. فایل صوتی", type="filepath", value="src/examples/driving_audios/5.wav")
            with gr.Accordion("⚙️ تنظیمات پیشرفته", open=True):
                with gr.Row():
                    emotion_dropdown = gr.Dropdown(label="حالت چهره", choices=list(PERSIAN_EMOTION_MAP.keys()), value="خنثی (Neutral)")
                    cfg_slider = gr.Slider(label="دقت (CFG)", minimum=1.0, maximum=3.0, value=1.2, step=0.1)
            submit_btn = gr.Button("🎥 ساخت ویدیو", variant="primary", size="lg")
        with gr.Column():
            output_video = gr.Video(label="خروجی نهایی", height=400, autoplay=True)
    submit_btn.click(
        fn=generate_motion,
        inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider],
        outputs=output_video
    )

if __name__ == "__main__":
    demo.launch()