Spaces:
Running
on
L40S
Running
on
L40S
import gradio as gr | |
import torch | |
import gc | |
import numpy as np | |
import random | |
import os | |
import tempfile | |
import soundfile as sf | |
import time | |
os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG' | |
from transformers import AutoProcessor, pipeline | |
from elastic_models.transformers import MusicgenForConditionalGeneration | |
MODEL_CONFIG = { | |
'cost_per_hour': 1.8, # $1.8 per hour on L40S | |
'cost_savings_1000h': { | |
'savings_dollars': 8.4, # $8.4 saved per 1000 hours | |
'savings_percent': 74.9, # 74.9% savings | |
'compressed_cost': 2.8, # $2.8 for compressed | |
'original_cost': 11.3, # $11.3 for original | |
}, | |
'batch_mode': True, | |
'batch_size': 2 # Number of variants to generate (2, 4, 6, etc.) | |
} | |
original_time_cache = {"original_time": 22.57} | |
# def set_seed(seed: int = 42): | |
# random.seed(seed) | |
# np.random.seed(seed) | |
# torch.manual_seed(seed) | |
# torch.cuda.manual_seed(seed) | |
# torch.cuda.manual_seed_all(seed) | |
# torch.backends.cudnn.deterministic = True | |
# torch.backends.cudnn.benchmark = False | |
def cleanup_gpu(): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
gc.collect() | |
def cleanup_temp_files(): | |
import glob | |
import time | |
temp_dir = tempfile.gettempdir() | |
cutoff_time = time.time() - 3600 | |
# Clean old generated music files | |
patterns = [ | |
os.path.join(temp_dir, "tmp*.wav"), | |
os.path.join(temp_dir, "generated_music_*.wav"), | |
os.path.join(temp_dir, "musicgen_variant_*.wav") | |
] | |
for pattern in patterns: | |
for temp_file in glob.glob(pattern): | |
try: | |
if os.path.getctime(temp_file) < cutoff_time: | |
os.remove(temp_file) | |
print(f"[CLEANUP] Removed old temp file: {temp_file}") | |
except OSError: | |
pass | |
_generator = None | |
_processor = None | |
_original_generator = None | |
_original_processor = None | |
def load_model(): | |
global _generator, _processor | |
if _generator is None: | |
print("[MODEL] Starting model initialization...") | |
cleanup_gpu() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"[MODEL] Using device: {device}") | |
print("[MODEL] Loading processor...") | |
_processor = AutoProcessor.from_pretrained( | |
"facebook/musicgen-large" | |
) | |
print("[MODEL] Loading model...") | |
model = MusicgenForConditionalGeneration.from_pretrained( | |
"facebook/musicgen-large", | |
torch_dtype=torch.float16, | |
device=device, | |
mode="S", | |
__paged=True, | |
) | |
model.eval() | |
print("[MODEL] Creating pipeline...") | |
_generator = pipeline( | |
task="text-to-audio", | |
model=model, | |
tokenizer=_processor.tokenizer, | |
device=device, | |
) | |
print("[MODEL] Model initialization completed successfully") | |
return _generator, _processor | |
def load_original_model(): | |
global _original_generator, _original_processor | |
if _original_generator is None: | |
print("[ORIGINAL MODEL] Starting original model initialization...") | |
cleanup_gpu() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"[ORIGINAL MODEL] Using device: {device}") | |
print("[ORIGINAL MODEL] Loading processor...") | |
_original_processor = AutoProcessor.from_pretrained( | |
"facebook/musicgen-large" | |
) | |
from transformers import MusicgenForConditionalGeneration as HFMusicgenForConditionalGeneration | |
print("[ORIGINAL MODEL] Loading original model...") | |
model = HFMusicgenForConditionalGeneration.from_pretrained( | |
"facebook/musicgen-large", | |
torch_dtype=torch.float16, | |
).to(device) | |
model.eval() | |
print("[ORIGINAL MODEL] Creating pipeline...") | |
_original_generator = pipeline( | |
task="text-to-audio", | |
model=model, | |
tokenizer=_original_processor.tokenizer, | |
device=device, | |
) | |
print("[ORIGINAL MODEL] Original model initialization completed successfully") | |
return _original_generator, _original_processor | |
def calculate_max_tokens(duration_seconds): | |
token_rate = 50 | |
max_new_tokens = int(duration_seconds * token_rate) | |
print(f"[MODEL] Duration: {duration_seconds}s -> Tokens: {max_new_tokens} (rate: {token_rate})") | |
return max_new_tokens | |
def generate_music(text_prompt, duration=10, guidance_scale=3.0): | |
try: | |
generator, processor = load_model() | |
print(f"[GENERATION] Starting generation...") | |
print(f"[GENERATION] Prompt: '{text_prompt}'") | |
print(f"[GENERATION] Duration: {duration}s") | |
print(f"[GENERATION] Guidance scale: {guidance_scale}") | |
cleanup_gpu() | |
import time | |
# set_seed(42) | |
print(f"[GENERATION] Using seed: {42}") | |
max_new_tokens = calculate_max_tokens(duration) | |
generation_params = { | |
'do_sample': True, | |
'guidance_scale': guidance_scale, | |
'max_new_tokens': max_new_tokens, | |
'min_new_tokens': max_new_tokens, | |
'cache_implementation': 'paged', | |
} | |
prompts = [text_prompt] | |
outputs = generator( | |
prompts, | |
batch_size=1, | |
generate_kwargs=generation_params | |
) | |
print(f"[GENERATION] Generation completed successfully") | |
output = outputs[0] | |
audio_data = output['audio'] | |
sample_rate = output['sampling_rate'] | |
print(f"[GENERATION] Audio shape: {audio_data.shape}") | |
print(f"[GENERATION] Sample rate: {sample_rate}") | |
print(f"[GENERATION] Audio dtype: {audio_data.dtype}") | |
print(f"[GENERATION] Audio is numpy: {type(audio_data)}") | |
if hasattr(audio_data, 'cpu'): | |
audio_data = audio_data.cpu().numpy() | |
print(f"[GENERATION] Audio shape after tensor conversion: {audio_data.shape}") | |
if len(audio_data.shape) == 3: | |
audio_data = audio_data[0] | |
if len(audio_data.shape) == 2: | |
if audio_data.shape[0] < audio_data.shape[1]: | |
audio_data = audio_data.T | |
if audio_data.shape[1] > 1: | |
audio_data = audio_data[:, 0] | |
else: | |
audio_data = audio_data.flatten() | |
audio_data = audio_data.flatten() | |
print(f"[GENERATION] Audio shape after flattening: {audio_data.shape}") | |
max_val = np.max(np.abs(audio_data)) | |
if max_val > 0: | |
audio_data = audio_data / max_val * 0.95 | |
audio_data = (audio_data * 32767).astype(np.int16) | |
print(f"[GENERATION] Final audio shape: {audio_data.shape}") | |
print(f"[GENERATION] Audio range: [{np.min(audio_data)}, {np.max(audio_data)}]") | |
print(f"[GENERATION] Audio dtype: {audio_data.dtype}") | |
print(f"[GENERATION] Sample rate: {sample_rate}") | |
timestamp = int(time.time() * 1000) | |
temp_filename = f"generated_music_{timestamp}.wav" | |
temp_path = os.path.join(tempfile.gettempdir(), temp_filename) | |
sf.write(temp_path, audio_data, sample_rate) | |
if os.path.exists(temp_path): | |
file_size = os.path.getsize(temp_path) | |
print(f"[GENERATION] Audio saved to: {temp_path}") | |
print(f"[GENERATION] File size: {file_size} bytes") | |
# Try returning numpy format instead | |
print(f"[GENERATION] Returning numpy tuple: ({sample_rate}, audio_array)") | |
return (sample_rate, audio_data) | |
else: | |
print(f"[ERROR] Failed to create audio file: {temp_path}") | |
return None | |
except Exception as e: | |
print(f"[ERROR] Generation failed: {str(e)}") | |
cleanup_gpu() | |
return None | |
def calculate_generation_cost(generation_time_seconds, mode='S'): | |
hours = generation_time_seconds / 3600 | |
cost_per_hour = MODEL_CONFIG['cost_per_hour'] | |
return hours * cost_per_hour | |
def calculate_cost_savings(compressed_time, original_time): | |
compressed_cost = calculate_generation_cost(compressed_time, 'S') | |
original_cost = calculate_generation_cost(original_time, 'original') | |
savings = original_cost - compressed_cost | |
savings_percent = (savings / original_cost * 100) if original_cost > 0 else 0 | |
return { | |
'compressed_cost': compressed_cost, | |
'original_cost': original_cost, | |
'savings': savings, | |
'savings_percent': savings_percent | |
} | |
def get_fixed_savings_message(): | |
config = MODEL_CONFIG['cost_savings_1000h'] | |
return f"π° **Cost Savings for generation batch size 4 on L40S (1000h)**: ${config['savings_dollars']:.1f}" \ | |
f" ({config['savings_percent']:.1f}%) - Compressed: ${config['compressed_cost']:.1f} " \ | |
f"vs Original: ${config['original_cost']:.1f}" | |
def get_cache_key(prompt, duration, guidance_scale): | |
return f"{hash(prompt)}_{duration}_{guidance_scale}" | |
def generate_music_batch(text_prompt, duration=10, guidance_scale=3.0, model_mode="compressed"): | |
try: | |
generator, processor = load_model() | |
model_name = "Compressed (S)" | |
print(f"[GENERATION] Starting generation using {model_name} model...") | |
print(f"[GENERATION] Prompt: '{text_prompt}'") | |
print(f"[GENERATION] Duration: {duration}s") | |
print(f"[GENERATION] Guidance scale: {guidance_scale}") | |
print(f"[GENERATION] Batch mode: {MODEL_CONFIG['batch_mode']}") | |
print(f"[GENERATION] Batch size: {MODEL_CONFIG['batch_size']}") | |
cleanup_gpu() | |
# set_seed(42) | |
print(f"[GENERATION] Using seed: {42}") | |
max_new_tokens = calculate_max_tokens(duration) | |
generation_params = { | |
'do_sample': True, | |
'guidance_scale': guidance_scale, | |
'max_new_tokens': max_new_tokens, | |
'min_new_tokens': max_new_tokens, | |
'cache_implementation': 'paged', | |
} | |
batch_size = MODEL_CONFIG['batch_size'] if MODEL_CONFIG['batch_mode'] else 1 | |
prompts = [text_prompt] * batch_size | |
start_time = time.time() | |
outputs = generator( | |
prompts, | |
batch_size=batch_size, | |
generate_kwargs=generation_params | |
) | |
generation_time = time.time() - start_time | |
print(f"[GENERATION] Generation completed in {generation_time:.2f}s") | |
audio_variants = [] | |
sample_rate = outputs[0]['sampling_rate'] | |
# Create unique timestamp for this generation batch | |
batch_timestamp = int(time.time() * 1000) | |
for i, output in enumerate(outputs): | |
audio_data = output['audio'] | |
print(f"[GENERATION] Processing variant {i + 1} audio shape: {audio_data.shape}") | |
if hasattr(audio_data, 'cpu'): | |
audio_data = audio_data.cpu().numpy() | |
if len(audio_data.shape) == 3: | |
audio_data = audio_data[0] | |
if len(audio_data.shape) == 2: | |
if audio_data.shape[0] < audio_data.shape[1]: | |
audio_data = audio_data.T | |
if audio_data.shape[1] > 1: | |
audio_data = audio_data[:, 0] | |
else: | |
audio_data = audio_data.flatten() | |
audio_data = audio_data.flatten() | |
max_val = np.max(np.abs(audio_data)) | |
if max_val > 0: | |
audio_data = audio_data / max_val * 0.95 | |
audio_data = (audio_data * 32767).astype(np.int16) | |
# Save each variant to a unique temporary file | |
temp_filename = f"musicgen_variant_{i + 1}_{batch_timestamp}.wav" | |
temp_path = os.path.join(tempfile.gettempdir(), temp_filename) | |
sf.write(temp_path, audio_data, sample_rate) | |
print(f"[GENERATION] Variant {i + 1} saved to: {temp_path}") | |
print(f"[GENERATION] Variant {i + 1} file size: {os.path.getsize(temp_path)} bytes") | |
audio_variants.append(temp_path) | |
print(f"[GENERATION] Variant {i + 1} final shape: {audio_data.shape}") | |
while len(audio_variants) < 6: | |
audio_variants.append(None) | |
variants_text = "audio" | |
generation_info = f"β Generated {variants_text} in {generation_time:.2f}s\n" | |
return audio_variants[0], audio_variants[1], audio_variants[2], audio_variants[3], audio_variants[4], \ | |
audio_variants[5], generation_info | |
except Exception as e: | |
print(f"[ERROR] Batch generation failed: {str(e)}") | |
cleanup_gpu() | |
error_msg = f"β Generation failed: {str(e)}" | |
return None, None, None, None, None, None, error_msg | |
with gr.Blocks(title="MusicGen Large - Music Generation") as demo: | |
gr.Markdown("# π΅ MusicGen Large Music Generator. 2.3x Accelerated by TheStage ANNA") | |
gr.Markdown( | |
f"Generate music from text descriptions using Facebook's MusicGen " | |
f"Large model accelerated by TheStage for 2.3x faster performance.") | |
with gr.Column(): | |
text_input = gr.Textbox( | |
label="Music Description", | |
placeholder="Enter a description of the music you want to generate", | |
lines=3, | |
value="A groovy funk bassline with a tight drum beat" | |
) | |
with gr.Row(): | |
duration = gr.Slider( | |
minimum=5, | |
maximum=30, | |
value=10, | |
step=1, | |
label="Duration (seconds)" | |
) | |
guidance_scale = gr.Slider( | |
minimum=1.0, | |
maximum=10.0, | |
value=3.0, | |
step=0.5, | |
label="Guidance Scale", | |
info="Higher values follow prompt more closely" | |
) | |
generate_btn = gr.Button("π΅ Generate Music", variant="primary", size="lg") | |
generation_info = gr.Markdown("Ready to generate music with elastic acceleration") | |
audio_section_title = "### Generated Music" | |
gr.Markdown(audio_section_title) | |
actual_outputs = MODEL_CONFIG['batch_size'] if MODEL_CONFIG['batch_mode'] else 1 | |
audio_outputs = [] | |
with gr.Row(): | |
audio_output1 = gr.Audio(label="Variant 1", type="filepath", visible=actual_outputs >= 1) | |
audio_output2 = gr.Audio(label="Variant 2", type="filepath", visible=actual_outputs >= 2) | |
audio_outputs.extend([audio_output1, audio_output2]) | |
with gr.Row(): | |
audio_output3 = gr.Audio(label="Variant 3", type="filepath", visible=actual_outputs >= 3) | |
audio_output4 = gr.Audio(label="Variant 4", type="filepath", visible=actual_outputs >= 4) | |
audio_outputs.extend([audio_output3, audio_output4]) | |
with gr.Row(): | |
audio_output5 = gr.Audio(label="Variant 5", type="filepath", visible=actual_outputs >= 5) | |
audio_output6 = gr.Audio(label="Variant 6", type="filepath", visible=actual_outputs >= 6) | |
audio_outputs.extend([audio_output5, audio_output6]) | |
savings_banner = gr.Markdown(get_fixed_savings_message()) | |
with gr.Accordion("π‘ Tips & Information", open=False): | |
gr.Markdown(f""" | |
**Generation Tips:** | |
- Be specific in your descriptions (e.g., "slow blues guitar with harmonica") | |
- Higher guidance scale = follows prompt more closely | |
- Lower guidance scale = more creative/varied results | |
- Duration is limited to 30 seconds for faster generation | |
**Performance:** | |
- Accelerated by TheStage elastic compression | |
- L40S GPU pricing: $1.8/hour | |
""") | |
def generate_simple(text_prompt, duration, guidance_scale): | |
return generate_music_batch(text_prompt, duration, guidance_scale, "compressed") | |
generate_btn.click( | |
fn=generate_simple, | |
inputs=[text_input, duration, guidance_scale], | |
outputs=[audio_output1, audio_output2, audio_output3, audio_output4, audio_output5, audio_output6, | |
generation_info], | |
show_progress=True | |
) | |
gr.Examples( | |
examples=[ | |
"A groovy funk bassline with a tight drum beat", | |
"Relaxing acoustic guitar melody", | |
"Electronic dance music with heavy bass", | |
"Classical violin concerto", | |
"Reggae with steel drums and bass", | |
"Rock ballad with electric guitar solo", | |
"Jazz piano improvisation with brushed drums", | |
"Ambient synthwave with retro vibes", | |
], | |
inputs=text_input, | |
label="Example Prompts" | |
) | |
gr.Markdown("---") | |
gr.Markdown(""" | |
<div style="text-align: center; color: #666; font-size: 12px; margin-top: 2rem;"> | |
<strong>Limitations:</strong><br> | |
β’ The model is not able to generate realistic vocals.<br> | |
β’ The model has been trained with English descriptions and will not perform as well in other languages.<br> | |
β’ The model does not perform equally well for all music styles and cultures.<br> | |
β’ The model sometimes generates end of songs, collapsing to silence.<br> | |
β’ It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results. | |
</div> | |
""") | |
if __name__ == "__main__": | |
cleanup_temp_files() | |
demo.launch() | |