ginipick commited on
Commit
5c398a5
ยท
verified ยท
1 Parent(s): 47acf4f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +354 -0
app.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import logging
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ import gradio as gr
6
+ import torch
7
+ import torchaudio
8
+ import os
9
+ import requests
10
+ from transformers import pipeline
11
+ import tempfile
12
+ import numpy as np
13
+ from einops import rearrange
14
+ import cv2
15
+ from scipy.io import wavfile
16
+ import librosa
17
+ import json
18
+ from typing import Optional, Tuple, List
19
+ import atexit
20
+
21
+ # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •์œผ๋กœ torch.load ์ฒดํฌ ์šฐํšŒ (์ž„์‹œ ํ•ด๊ฒฐ์ฑ…)
22
+ os.environ["TRANSFORMERS_ALLOW_UNSAFE_DESERIALIZATION"] = "1"
23
+
24
+ try:
25
+ import mmaudio
26
+ except ImportError:
27
+ os.system("pip install -e .")
28
+ import mmaudio
29
+
30
+ from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
31
+ setup_eval_logging)
32
+ from mmaudio.model.flow_matching import FlowMatching
33
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
34
+ from mmaudio.model.sequence_config import SequenceConfig
35
+ from mmaudio.model.utils.features_utils import FeaturesUtils
36
+
37
+ # ๋กœ๊น… ์„ค์ •
38
+ logging.basicConfig(
39
+ level=logging.INFO,
40
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
41
+ )
42
+ log = logging.getLogger()
43
+
44
+ # CUDA ์„ค์ •
45
+ if torch.cuda.is_available():
46
+ device = torch.device("cuda")
47
+ torch.backends.cuda.matmul.allow_tf32 = True
48
+ torch.backends.cudnn.allow_tf32 = True
49
+ torch.backends.cudnn.benchmark = True
50
+ else:
51
+ device = torch.device("cpu")
52
+
53
+ dtype = torch.bfloat16
54
+
55
+ # ๋ชจ๋ธ ์„ค์ •
56
+ model: ModelConfig = all_model_cfg['large_44k_v2']
57
+ model.download_if_needed()
58
+ output_dir = Path('./output/gradio')
59
+
60
+ setup_eval_logging()
61
+
62
+ # ๋ฒˆ์—ญ๊ธฐ ์„ค์ • - safetensors ์‚ฌ์šฉ ์‹œ๋„
63
+ try:
64
+ # ๋จผ์ € safetensors ํ˜•์‹์ด ์žˆ๋Š”์ง€ ํ™•์ธ
65
+ translator = pipeline("translation",
66
+ model="Helsinki-NLP/opus-mt-ko-en",
67
+ device="cpu",
68
+ use_fast=True, # Fast tokenizer ์‚ฌ์šฉ
69
+ trust_remote_code=False)
70
+ except Exception as e:
71
+ log.warning(f"Failed to load translation model with safetensors: {e}")
72
+ # ๋Œ€์ฒด ๋ฐฉ๋ฒ•: ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ • ํ›„ ๋กœ๋“œ
73
+ try:
74
+ translator = pipeline("translation",
75
+ model="Helsinki-NLP/opus-mt-ko-en",
76
+ device="cpu")
77
+ except Exception as e2:
78
+ log.error(f"Failed to load translation model: {e2}")
79
+ translator = None
80
+
81
+ PIXABAY_API_KEY = "33492762-a28a596ec4f286f84cd328b17"
82
+
83
+ def cleanup_temp_files():
84
+ temp_dir = tempfile.gettempdir()
85
+ for file in os.listdir(temp_dir):
86
+ if file.endswith(('.mp4', '.flac')):
87
+ try:
88
+ os.remove(os.path.join(temp_dir, file))
89
+ except:
90
+ pass
91
+
92
+ atexit.register(cleanup_temp_files)
93
+
94
+ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
95
+ with torch.cuda.device(device):
96
+ seq_cfg = model.seq_cfg
97
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
98
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
99
+ log.info(f'Loaded weights from {model.model_path}')
100
+
101
+ feature_utils = FeaturesUtils(
102
+ tod_vae_ckpt=model.vae_path,
103
+ synchformer_ckpt=model.synchformer_ckpt,
104
+ enable_conditions=True,
105
+ mode=model.mode,
106
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
107
+ need_vae_encoder=False
108
+ ).to(device, dtype).eval()
109
+
110
+ return net, feature_utils, seq_cfg
111
+
112
+ net, feature_utils, seq_cfg = get_model()
113
+
114
+ # translate_prompt ํ•จ์ˆ˜ ์ˆ˜์ •
115
+ def translate_prompt(text):
116
+ try:
117
+ # ๋ฒˆ์—ญ๊ธฐ๊ฐ€ ์—†์œผ๋ฉด ์›๋ณธ ํ…์ŠคํŠธ ๋ฐ˜ํ™˜
118
+ if translator is None:
119
+ return text
120
+
121
+ if text and any(ord(char) >= 0x3131 and ord(char) <= 0xD7A3 for char in text):
122
+ # CPU์—์„œ ๋ฒˆ์—ญ ์‹คํ–‰
123
+ with torch.no_grad():
124
+ translation = translator(text)[0]['translation_text']
125
+ return translation
126
+ return text
127
+ except Exception as e:
128
+ logging.error(f"Translation error: {e}")
129
+ return text
130
+
131
+ # search_videos ํ•จ์ˆ˜ ์ˆ˜์ •
132
+ @torch.no_grad()
133
+ def search_videos(query):
134
+ try:
135
+ # CPU์—์„œ ๋ฒˆ์—ญ ์‹คํ–‰
136
+ query = translate_prompt(query)
137
+ return search_pixabay_videos(query, PIXABAY_API_KEY)
138
+ except Exception as e:
139
+ logging.error(f"Video search error: {e}")
140
+ return []
141
+
142
+ def search_pixabay_videos(query, api_key):
143
+ try:
144
+ base_url = "https://pixabay.com/api/videos/"
145
+ params = {
146
+ "key": api_key,
147
+ "q": query,
148
+ "per_page": 40
149
+ }
150
+
151
+ response = requests.get(base_url, params=params)
152
+ if response.status_code == 200:
153
+ data = response.json()
154
+ return [video['videos']['large']['url'] for video in data.get('hits', [])]
155
+ return []
156
+ except Exception as e:
157
+ logging.error(f"Pixabay API error: {e}")
158
+ return []
159
+
160
+ @spaces.GPU
161
+ @torch.inference_mode()
162
+ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
163
+ cfg_strength: float, duration: float):
164
+ prompt = translate_prompt(prompt)
165
+ negative_prompt = translate_prompt(negative_prompt)
166
+
167
+ rng = torch.Generator(device=device)
168
+ rng.manual_seed(seed)
169
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
170
+
171
+ clip_frames, sync_frames, duration = load_video(video, duration)
172
+ clip_frames = clip_frames.unsqueeze(0)
173
+ sync_frames = sync_frames.unsqueeze(0)
174
+ seq_cfg.duration = duration
175
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
176
+
177
+ audios = generate(clip_frames,
178
+ sync_frames, [prompt],
179
+ negative_text=[negative_prompt],
180
+ feature_utils=feature_utils,
181
+ net=net,
182
+ fm=fm,
183
+ rng=rng,
184
+ cfg_strength=cfg_strength)
185
+ audio = audios.float().cpu()[0]
186
+
187
+ video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
188
+ make_video(video,
189
+ video_save_path,
190
+ audio,
191
+ sampling_rate=seq_cfg.sampling_rate,
192
+ duration_sec=seq_cfg.duration)
193
+ return video_save_path
194
+
195
+ @spaces.GPU
196
+ @torch.inference_mode()
197
+ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
198
+ duration: float):
199
+ prompt = translate_prompt(prompt)
200
+ negative_prompt = translate_prompt(negative_prompt)
201
+
202
+ rng = torch.Generator(device=device)
203
+ rng.manual_seed(seed)
204
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
205
+
206
+ clip_frames = sync_frames = None
207
+ seq_cfg.duration = duration
208
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
209
+
210
+ audios = generate(clip_frames,
211
+ sync_frames, [prompt],
212
+ negative_text=[negative_prompt],
213
+ feature_utils=feature_utils,
214
+ net=net,
215
+ fm=fm,
216
+ rng=rng,
217
+ cfg_strength=cfg_strength)
218
+ audio = audios.float().cpu()[0]
219
+
220
+ audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name
221
+ torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
222
+ return audio_save_path
223
+
224
+ # CSS ์Šคํƒ€์ผ
225
+ custom_css = """
226
+ .gradio-container {
227
+ background: linear-gradient(45deg, #1a1a1a, #2a2a2a);
228
+ border-radius: 15px;
229
+ box-shadow: 0 8px 32px rgba(0,0,0,0.3);
230
+ color: #e0e0e0;
231
+ }
232
+
233
+ .input-container, .output-container {
234
+ background: rgba(40, 40, 40, 0.95);
235
+ backdrop-filter: blur(10px);
236
+ border-radius: 10px;
237
+ padding: 20px;
238
+ transform-style: preserve-3d;
239
+ transition: transform 0.3s ease;
240
+ border: 1px solid rgba(255, 255, 255, 0.1);
241
+ }
242
+
243
+ .input-container:hover {
244
+ transform: translateZ(20px);
245
+ box-shadow: 0 8px 32px rgba(0,0,0,0.5);
246
+ }
247
+
248
+ .gallery-item {
249
+ transition: transform 0.3s ease;
250
+ border-radius: 8px;
251
+ overflow: hidden;
252
+ background: #2a2a2a;
253
+ }
254
+
255
+ .gallery-item:hover {
256
+ transform: scale(1.05);
257
+ box-shadow: 0 4px 15px rgba(0,0,0,0.4);
258
+ }
259
+
260
+ .tabs {
261
+ background: rgba(30, 30, 30, 0.95);
262
+ border-radius: 10px;
263
+ padding: 10px;
264
+ border: 1px solid rgba(255, 255, 255, 0.05);
265
+ }
266
+
267
+ button {
268
+ background: linear-gradient(45deg, #2196F3, #1976D2);
269
+ border: none;
270
+ border-radius: 5px;
271
+ transition: all 0.3s ease;
272
+ color: white;
273
+ }
274
+
275
+ button:hover {
276
+ transform: translateY(-2px);
277
+ box-shadow: 0 4px 15px rgba(33,150,243,0.3);
278
+ }
279
+
280
+ textarea, input[type="text"], input[type="number"] {
281
+ background: rgba(30, 30, 30, 0.95) !important;
282
+ border: 1px solid rgba(255, 255, 255, 0.1) !important;
283
+ color: #e0e0e0 !important;
284
+ border-radius: 5px !important;
285
+ }
286
+
287
+ label {
288
+ color: #e0e0e0 !important;
289
+ }
290
+
291
+ .gallery {
292
+ background: rgba(30, 30, 30, 0.95);
293
+ padding: 15px;
294
+ border-radius: 10px;
295
+ border: 1px solid rgba(255, 255, 255, 0.05);
296
+ }
297
+ """
298
+
299
+ css = """
300
+ footer {
301
+ visibility: hidden;
302
+ }
303
+ """ + custom_css
304
+
305
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
306
+ text_to_audio_tab = gr.Interface(
307
+ fn=text_to_audio,
308
+ inputs=[
309
+ gr.Textbox(label="Prompt(ํ•œ๊ธ€์ง€์›)" if translator else "Prompt"),
310
+ gr.Textbox(label="Negative Prompt"),
311
+ gr.Number(label="Seed", value=0),
312
+ gr.Number(label="Steps", value=25),
313
+ gr.Number(label="Guidance Scale", value=4.5),
314
+ gr.Number(label="Duration (sec)", value=8),
315
+ ],
316
+ outputs=gr.Audio(label="Generated Audio"),
317
+ css=custom_css
318
+ )
319
+
320
+ video_to_audio_tab = gr.Interface(
321
+ fn=video_to_audio,
322
+ inputs=[
323
+ gr.Video(label="Input Video"),
324
+ gr.Textbox(label="Prompt(ํ•œ๊ธ€์ง€์›)" if translator else "Prompt"),
325
+ gr.Textbox(label="Negative Prompt", value="music"),
326
+ gr.Number(label="Seed", value=0),
327
+ gr.Number(label="Steps", value=25),
328
+ gr.Number(label="Guidance Scale", value=4.5),
329
+ gr.Number(label="Duration (sec)", value=8),
330
+ ],
331
+ outputs=gr.Video(label="Generated Result"),
332
+ css=custom_css
333
+ )
334
+
335
+ video_search_tab = gr.Interface(
336
+ fn=search_videos,
337
+ inputs=gr.Textbox(label="Search Query(ํ•œ๊ธ€์ง€์›)" if translator else "Search Query"),
338
+ outputs=gr.Gallery(label="Search Results", columns=4, rows=20),
339
+ css=custom_css,
340
+ api_name=False
341
+ )
342
+
343
+ # ๋ฉ”์ธ ์‹คํ–‰
344
+ if __name__ == "__main__":
345
+ # ๋ฒˆ์—ญ๊ธฐ ๋กœ๋“œ ์‹คํŒจ ์‹œ ๊ฒฝ๊ณ  ๋ฉ”์‹œ์ง€
346
+ if translator is None:
347
+ log.warning("Translation model failed to load. Korean translation will be disabled.")
348
+
349
+ gr.TabbedInterface(
350
+ [video_search_tab, video_to_audio_tab, text_to_audio_tab],
351
+ ["Video Search", "Video-to-Audio", "Text-to-Audio"],
352
+ theme="soft",
353
+ css=css
354
+ ).launch(allowed_paths=[output_dir])