import gradio as gr import torch import gc import numpy as np import random import os import tempfile import soundfile as sf import time os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG' from transformers import AutoProcessor, pipeline from elastic_models.transformers import MusicgenForConditionalGeneration MODEL_CONFIG = { 'cost_per_hour': 1.8, # $1.8 per hour on L40S 'cost_savings_1000h': { 'savings_dollars': 8.4, # $8.4 saved per 1000 hours 'savings_percent': 74.9, # 74.9% savings 'compressed_cost': 2.8, # $2.8 for compressed 'original_cost': 11.3, # $11.3 for original }, 'batch_mode': True, 'batch_size': 2 # Number of variants to generate (2, 4, 6, etc.) } original_time_cache = {"original_time": 22.57} # def set_seed(seed: int = 42): # random.seed(seed) # np.random.seed(seed) # torch.manual_seed(seed) # torch.cuda.manual_seed(seed) # torch.cuda.manual_seed_all(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False def cleanup_gpu(): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() gc.collect() def cleanup_temp_files(): import glob import time temp_dir = tempfile.gettempdir() cutoff_time = time.time() - 3600 # Clean old generated music files patterns = [ os.path.join(temp_dir, "tmp*.wav"), os.path.join(temp_dir, "generated_music_*.wav"), os.path.join(temp_dir, "musicgen_variant_*.wav") ] for pattern in patterns: for temp_file in glob.glob(pattern): try: if os.path.getctime(temp_file) < cutoff_time: os.remove(temp_file) print(f"[CLEANUP] Removed old temp file: {temp_file}") except OSError: pass _generator = None _processor = None _original_generator = None _original_processor = None def load_model(): global _generator, _processor if _generator is None: print("[MODEL] Starting model initialization...") cleanup_gpu() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[MODEL] Using device: {device}") print("[MODEL] Loading processor...") _processor = AutoProcessor.from_pretrained( "facebook/musicgen-large" ) print("[MODEL] Loading model...") model = MusicgenForConditionalGeneration.from_pretrained( "facebook/musicgen-large", torch_dtype=torch.float16, device=device, mode="S", __paged=True, ) model.eval() print("[MODEL] Creating pipeline...") _generator = pipeline( task="text-to-audio", model=model, tokenizer=_processor.tokenizer, device=device, ) print("[MODEL] Model initialization completed successfully") return _generator, _processor def load_original_model(): global _original_generator, _original_processor if _original_generator is None: print("[ORIGINAL MODEL] Starting original model initialization...") cleanup_gpu() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[ORIGINAL MODEL] Using device: {device}") print("[ORIGINAL MODEL] Loading processor...") _original_processor = AutoProcessor.from_pretrained( "facebook/musicgen-large" ) from transformers import MusicgenForConditionalGeneration as HFMusicgenForConditionalGeneration print("[ORIGINAL MODEL] Loading original model...") model = HFMusicgenForConditionalGeneration.from_pretrained( "facebook/musicgen-large", torch_dtype=torch.float16, ).to(device) model.eval() print("[ORIGINAL MODEL] Creating pipeline...") _original_generator = pipeline( task="text-to-audio", model=model, tokenizer=_original_processor.tokenizer, device=device, ) print("[ORIGINAL MODEL] Original model initialization completed successfully") return _original_generator, _original_processor def calculate_max_tokens(duration_seconds): token_rate = 50 max_new_tokens = int(duration_seconds * token_rate) print(f"[MODEL] Duration: {duration_seconds}s -> Tokens: {max_new_tokens} (rate: {token_rate})") return max_new_tokens def generate_music(text_prompt, duration=10, guidance_scale=3.0): try: generator, processor = load_model() print(f"[GENERATION] Starting generation...") print(f"[GENERATION] Prompt: '{text_prompt}'") print(f"[GENERATION] Duration: {duration}s") print(f"[GENERATION] Guidance scale: {guidance_scale}") cleanup_gpu() import time # set_seed(42) print(f"[GENERATION] Using seed: {42}") max_new_tokens = calculate_max_tokens(duration) generation_params = { 'do_sample': True, 'guidance_scale': guidance_scale, 'max_new_tokens': max_new_tokens, 'min_new_tokens': max_new_tokens, 'cache_implementation': 'paged', } prompts = [text_prompt] outputs = generator( prompts, batch_size=1, generate_kwargs=generation_params ) print(f"[GENERATION] Generation completed successfully") output = outputs[0] audio_data = output['audio'] sample_rate = output['sampling_rate'] print(f"[GENERATION] Audio shape: {audio_data.shape}") print(f"[GENERATION] Sample rate: {sample_rate}") print(f"[GENERATION] Audio dtype: {audio_data.dtype}") print(f"[GENERATION] Audio is numpy: {type(audio_data)}") if hasattr(audio_data, 'cpu'): audio_data = audio_data.cpu().numpy() print(f"[GENERATION] Audio shape after tensor conversion: {audio_data.shape}") if len(audio_data.shape) == 3: audio_data = audio_data[0] if len(audio_data.shape) == 2: if audio_data.shape[0] < audio_data.shape[1]: audio_data = audio_data.T if audio_data.shape[1] > 1: audio_data = audio_data[:, 0] else: audio_data = audio_data.flatten() audio_data = audio_data.flatten() print(f"[GENERATION] Audio shape after flattening: {audio_data.shape}") max_val = np.max(np.abs(audio_data)) if max_val > 0: audio_data = audio_data / max_val * 0.95 audio_data = (audio_data * 32767).astype(np.int16) print(f"[GENERATION] Final audio shape: {audio_data.shape}") print(f"[GENERATION] Audio range: [{np.min(audio_data)}, {np.max(audio_data)}]") print(f"[GENERATION] Audio dtype: {audio_data.dtype}") print(f"[GENERATION] Sample rate: {sample_rate}") timestamp = int(time.time() * 1000) temp_filename = f"generated_music_{timestamp}.wav" temp_path = os.path.join(tempfile.gettempdir(), temp_filename) sf.write(temp_path, audio_data, sample_rate) if os.path.exists(temp_path): file_size = os.path.getsize(temp_path) print(f"[GENERATION] Audio saved to: {temp_path}") print(f"[GENERATION] File size: {file_size} bytes") # Try returning numpy format instead print(f"[GENERATION] Returning numpy tuple: ({sample_rate}, audio_array)") return (sample_rate, audio_data) else: print(f"[ERROR] Failed to create audio file: {temp_path}") return None except Exception as e: print(f"[ERROR] Generation failed: {str(e)}") cleanup_gpu() return None def calculate_generation_cost(generation_time_seconds, mode='S'): hours = generation_time_seconds / 3600 cost_per_hour = MODEL_CONFIG['cost_per_hour'] return hours * cost_per_hour def calculate_cost_savings(compressed_time, original_time): compressed_cost = calculate_generation_cost(compressed_time, 'S') original_cost = calculate_generation_cost(original_time, 'original') savings = original_cost - compressed_cost savings_percent = (savings / original_cost * 100) if original_cost > 0 else 0 return { 'compressed_cost': compressed_cost, 'original_cost': original_cost, 'savings': savings, 'savings_percent': savings_percent } def get_fixed_savings_message(): config = MODEL_CONFIG['cost_savings_1000h'] return f"💰 **Cost Savings for generation batch size 4 on L40S (1000h)**: ${config['savings_dollars']:.1f}" \ f" ({config['savings_percent']:.1f}%) - Compressed: ${config['compressed_cost']:.1f} " \ f"vs Original: ${config['original_cost']:.1f}" def get_cache_key(prompt, duration, guidance_scale): return f"{hash(prompt)}_{duration}_{guidance_scale}" def generate_music_batch(text_prompt, duration=10, guidance_scale=3.0, model_mode="compressed"): try: generator, processor = load_model() model_name = "Compressed (S)" print(f"[GENERATION] Starting generation using {model_name} model...") print(f"[GENERATION] Prompt: '{text_prompt}'") print(f"[GENERATION] Duration: {duration}s") print(f"[GENERATION] Guidance scale: {guidance_scale}") print(f"[GENERATION] Batch mode: {MODEL_CONFIG['batch_mode']}") print(f"[GENERATION] Batch size: {MODEL_CONFIG['batch_size']}") cleanup_gpu() # set_seed(42) print(f"[GENERATION] Using seed: {42}") max_new_tokens = calculate_max_tokens(duration) generation_params = { 'do_sample': True, 'guidance_scale': guidance_scale, 'max_new_tokens': max_new_tokens, 'min_new_tokens': max_new_tokens, 'cache_implementation': 'paged', } batch_size = MODEL_CONFIG['batch_size'] if MODEL_CONFIG['batch_mode'] else 1 prompts = [text_prompt] * batch_size start_time = time.time() outputs = generator( prompts, batch_size=batch_size, generate_kwargs=generation_params ) generation_time = time.time() - start_time print(f"[GENERATION] Generation completed in {generation_time:.2f}s") audio_variants = [] sample_rate = outputs[0]['sampling_rate'] # Create unique timestamp for this generation batch batch_timestamp = int(time.time() * 1000) for i, output in enumerate(outputs): audio_data = output['audio'] print(f"[GENERATION] Processing variant {i + 1} audio shape: {audio_data.shape}") if hasattr(audio_data, 'cpu'): audio_data = audio_data.cpu().numpy() if len(audio_data.shape) == 3: audio_data = audio_data[0] if len(audio_data.shape) == 2: if audio_data.shape[0] < audio_data.shape[1]: audio_data = audio_data.T if audio_data.shape[1] > 1: audio_data = audio_data[:, 0] else: audio_data = audio_data.flatten() audio_data = audio_data.flatten() max_val = np.max(np.abs(audio_data)) if max_val > 0: audio_data = audio_data / max_val * 0.95 audio_data = (audio_data * 32767).astype(np.int16) # Save each variant to a unique temporary file temp_filename = f"musicgen_variant_{i + 1}_{batch_timestamp}.wav" temp_path = os.path.join(tempfile.gettempdir(), temp_filename) sf.write(temp_path, audio_data, sample_rate) print(f"[GENERATION] Variant {i + 1} saved to: {temp_path}") print(f"[GENERATION] Variant {i + 1} file size: {os.path.getsize(temp_path)} bytes") audio_variants.append(temp_path) print(f"[GENERATION] Variant {i + 1} final shape: {audio_data.shape}") while len(audio_variants) < 6: audio_variants.append(None) variants_text = "audio" generation_info = f"✅ Generated {variants_text} in {generation_time:.2f}s\n" return audio_variants[0], audio_variants[1], audio_variants[2], audio_variants[3], audio_variants[4], \ audio_variants[5], generation_info except Exception as e: print(f"[ERROR] Batch generation failed: {str(e)}") cleanup_gpu() error_msg = f"❌ Generation failed: {str(e)}" return None, None, None, None, None, None, error_msg with gr.Blocks(title="MusicGen Large - Music Generation") as demo: gr.Markdown("# 🎵 MusicGen Large Music Generator. 2.3x Accelerated by TheStage ANNA") gr.Markdown( f"Generate music from text descriptions using Facebook's MusicGen " f"Large model accelerated by TheStage for 2.3x faster performance.") with gr.Column(): text_input = gr.Textbox( label="Music Description", placeholder="Enter a description of the music you want to generate", lines=3, value="A groovy funk bassline with a tight drum beat" ) with gr.Row(): duration = gr.Slider( minimum=5, maximum=30, value=10, step=1, label="Duration (seconds)" ) guidance_scale = gr.Slider( minimum=1.0, maximum=10.0, value=3.0, step=0.5, label="Guidance Scale", info="Higher values follow prompt more closely" ) generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg") generation_info = gr.Markdown("Ready to generate music with elastic acceleration") audio_section_title = "### Generated Music" gr.Markdown(audio_section_title) actual_outputs = MODEL_CONFIG['batch_size'] if MODEL_CONFIG['batch_mode'] else 1 audio_outputs = [] with gr.Row(): audio_output1 = gr.Audio(label="Variant 1", type="filepath", visible=actual_outputs >= 1) audio_output2 = gr.Audio(label="Variant 2", type="filepath", visible=actual_outputs >= 2) audio_outputs.extend([audio_output1, audio_output2]) with gr.Row(): audio_output3 = gr.Audio(label="Variant 3", type="filepath", visible=actual_outputs >= 3) audio_output4 = gr.Audio(label="Variant 4", type="filepath", visible=actual_outputs >= 4) audio_outputs.extend([audio_output3, audio_output4]) with gr.Row(): audio_output5 = gr.Audio(label="Variant 5", type="filepath", visible=actual_outputs >= 5) audio_output6 = gr.Audio(label="Variant 6", type="filepath", visible=actual_outputs >= 6) audio_outputs.extend([audio_output5, audio_output6]) savings_banner = gr.Markdown(get_fixed_savings_message()) with gr.Accordion("💡 Tips & Information", open=False): gr.Markdown(f""" **Generation Tips:** - Be specific in your descriptions (e.g., "slow blues guitar with harmonica") - Higher guidance scale = follows prompt more closely - Lower guidance scale = more creative/varied results - Duration is limited to 30 seconds for faster generation **Performance:** - Accelerated by TheStage elastic compression - L40S GPU pricing: $1.8/hour """) def generate_simple(text_prompt, duration, guidance_scale): return generate_music_batch(text_prompt, duration, guidance_scale, "compressed") generate_btn.click( fn=generate_simple, inputs=[text_input, duration, guidance_scale], outputs=[audio_output1, audio_output2, audio_output3, audio_output4, audio_output5, audio_output6, generation_info], show_progress=True ) gr.Examples( examples=[ "A groovy funk bassline with a tight drum beat", "Relaxing acoustic guitar melody", "Electronic dance music with heavy bass", "Classical violin concerto", "Reggae with steel drums and bass", "Rock ballad with electric guitar solo", "Jazz piano improvisation with brushed drums", "Ambient synthwave with retro vibes", ], inputs=text_input, label="Example Prompts" ) gr.Markdown("---") gr.Markdown("""
Limitations:
• The model is not able to generate realistic vocals.
• The model has been trained with English descriptions and will not perform as well in other languages.
• The model does not perform equally well for all music styles and cultures.
• The model sometimes generates end of songs, collapsing to silence.
• It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results.
""") if __name__ == "__main__": cleanup_temp_files() demo.launch()