import gradio as gr import os import torch import numpy as np import matplotlib.pyplot as plt from transformers import ( AutoTokenizer, AutoModelForCausalLM, pipeline, AutoProcessor, MusicgenForConditionalGeneration, ) from scipy.io.wavfile import write from pydub import AudioSegment from dotenv import load_dotenv import tempfile import spaces from TTS.api import TTS import psutil import GPUtil # ------------------------------- # Configuration # ------------------------------- load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN", os.getenv("HF_TOKEN_SECRET")) MODEL_CONFIG = { "llama_models": { "Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B-Instruct", "Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2", }, "tts_models": { "Standard English": "tts_models/en/ljspeech/tacotron2-DDC", "High Quality": "tts_models/en/ljspeech/vits" }, "musicgen_model": "facebook/musicgen-medium" } # ------------------------------- # Model Manager with Cache # ------------------------------- class ModelManager: def __init__(self): self.llama_pipelines = {} self.musicgen_model = None self.tts_models = {} self.processor = None # Add processor cache def get_llama_pipeline(self, model_id, token): if model_id not in self.llama_pipelines: tokenizer = AutoTokenizer.from_pretrained( model_id, token=token, legacy=False ) model = AutoModelForCausalLM.from_pretrained( model_id, token=token, torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True ) self.llama_pipelines[model_id] = pipeline( "text-generation", model=model, tokenizer=tokenizer, device_map="auto" ) return self.llama_pipelines[model_id] def get_musicgen_model(self): if not self.musicgen_model: self.musicgen_model = MusicgenForConditionalGeneration.from_pretrained( MODEL_CONFIG["musicgen_model"] ) self.processor = AutoProcessor.from_pretrained(MODEL_CONFIG["musicgen_model"]) self.musicgen_model.to("cuda" if torch.cuda.is_available() else "cpu") return self.musicgen_model, self.processor model_manager = ModelManager() # ------------------------------- # Core Functions with Enhanced Error Handling # ------------------------------- @spaces.GPU def generate_script(user_prompt, model_id, duration, progress=gr.Progress()): try: progress(0.1, "Initializing script generation...") text_pipeline = model_manager.get_llama_pipeline(model_id, HF_TOKEN) system_prompt = f"""Generate a {duration}-second radio promo with: 1. Voice Script: [Clear narration, 25-35 words] 2. Sound Design: [3-5 specific sound effects] 3. Music: [Genre, tempo, mood] Format strictly as: Voice Script: [content] Sound Design: [effects] Music: [description]""" progress(0.3, "Generating content...") response = text_pipeline( f"{system_prompt}\nConcept: {user_prompt}", max_new_tokens=300, temperature=0.7, do_sample=True, top_p=0.95 ) progress(0.8, "Parsing results...") return parse_generated_content(response[0]["generated_text"]) except Exception as e: return [f"Error: {str(e)}"] * 3 def parse_generated_content(text): sections = {"Voice Script": "", "Sound Design": "", "Music": ""} current_section = None for line in text.split('\n'): line = line.strip() for section in sections: if line.startswith(section + ":"): current_section = section line = line.replace(section + ":", "").strip() break if current_section and line: sections[current_section] += line + "\n" return [sections[section].strip() for section in sections] @spaces.GPU def generate_voice(script, tts_model, speed=1.0, progress=gr.Progress()): try: progress(0.2, "Initializing TTS...") if not script.strip(): return None, "No script provided" tts = model_manager.get_tts_model(tts_model) output_path = os.path.join(tempfile.gettempdir(), "voice.wav") progress(0.5, "Generating audio...") tts.tts_to_file(text=script, file_path=output_path, speed=speed) return output_path, None except Exception as e: return None, f"Voice Error: {str(e)}" @spaces.GPU def generate_music(prompt, duration_sec=30, progress=gr.Progress()): try: progress(0.1, "Initializing MusicGen...") model = model_manager.get_musicgen_model() processor = AutoProcessor.from_pretrained(MODEL_CONFIG["musicgen_model"]) progress(0.4, "Processing input...") inputs = processor(text=[prompt], padding=True, return_tensors="pt").to(model.device) progress(0.6, "Generating music...") audio_values = model.generate(**inputs, max_new_tokens=int(duration_sec * 50)) output_path = os.path.join(tempfile.gettempdir(), "music.wav") write(output_path, 32000, audio_values[0, 0].cpu().numpy()) return output_path, None except Exception as e: return None, f"Music Error: {str(e)}" def blend_audio(voice_path, music_path, ducking=True, progress=gr.Progress()): try: progress(0.2, "Loading audio files...") voice = AudioSegment.from_wav(voice_path) music = AudioSegment.from_wav(music_path) progress(0.4, "Aligning durations...") if len(music) < len(voice): music = music * (len(voice) // len(music) + 1) music = music[:len(voice)] progress(0.6, "Mixing audio...") if ducking: music = music - 10 # 10dB ducking mixed = music.overlay(voice) output_path = os.path.join(tempfile.gettempdir(), "final_mix.wav") mixed.export(output_path, format="wav") return output_path, None except Exception as e: return None, f"Mixing Error: {str(e)}" # ------------------------------- # UI Components # ------------------------------- def create_audio_visualization(audio_path): if not audio_path: return None audio = AudioSegment.from_file(audio_path) samples = np.array(audio.get_array_of_samples()) plt.figure(figsize=(10, 3)) plt.plot(samples) plt.axis('off') plt.tight_layout() temp_file = os.path.join(tempfile.gettempdir(), "waveform.png") plt.savefig(temp_file, bbox_inches='tight', pad_inches=0) plt.close() return temp_file def system_monitor(): gpus = GPUtil.getGPUs() return { "CPU": f"{psutil.cpu_percent()}%", "RAM": f"{psutil.virtual_memory().percent}%", "GPU": f"{gpus[0].load*100 if gpus else 0:.1f}%" if gpus else "N/A" } # ------------------------------- # Gradio Interface # ------------------------------- theme = gr.themes.Soft( primary_hue="blue", secondary_hue="teal", ).set( body_text_color_dark='#FFFFFF', background_fill_primary_dark='#1F1F1F' ) with gr.Blocks(theme=theme, title="AI Radio Studio Pro") as demo: gr.Markdown("# 🎙️ AI Radio Studio Pro") with gr.Row(): with gr.Column(scale=3): concept_input = gr.Textbox( label="Concept Description", placeholder="Describe your radio segment...", lines=3 ) with gr.Accordion("Advanced Settings", open=False): model_selector = gr.Dropdown( list(MODEL_CONFIG["llama_models"].values()), label="AI Model", value=next(iter(MODEL_CONFIG["llama_models"].values())) ) duration_selector = gr.Slider(15, 120, 30, step=15, label="Duration (seconds)") generate_btn = gr.Button("Generate Script", variant="primary") with gr.Column(scale=2): script_output = gr.Textbox(label="Voice Script", interactive=True) sound_output = gr.Textbox(label="Sound Design", interactive=True) music_output = gr.Textbox(label="Music Style", interactive=True) with gr.Tabs(): with gr.Tab("🎤 Voice Production"): with gr.Row(): tts_selector = gr.Dropdown( list(MODEL_CONFIG["tts_models"].values()), label="Voice Model", value=next(iter(MODEL_CONFIG["tts_models"].values())) ) speed_selector = gr.Slider(0.5, 2.0, 1.0, step=0.1, label="Speaking Rate") voice_btn = gr.Button("Generate Voiceover", variant="primary") with gr.Row(): voice_audio = gr.Audio(label="Voice Preview", interactive=False) voice_viz = gr.Image(label="Waveform", interactive=False) with gr.Tab("🎵 Music Production"): music_btn = gr.Button("Generate Music Track", variant="primary") with gr.Row(): music_audio = gr.Audio(label="Music Preview", interactive=False) music_viz = gr.Image(label="Waveform", interactive=False) with gr.Tab("🔉 Final Mix"): mix_btn = gr.Button("Create Final Mix", variant="primary") with gr.Row(): final_mix_audio = gr.Audio(label="Final Mix", interactive=False) final_mix_viz = gr.Image(label="Waveform", interactive=False) with gr.Row(): download_btn = gr.Button("Download Mix") play_btn = gr.Button("▶️ Play in Browser") with gr.Accordion("📊 System Monitor", open=False): monitor = gr.JSON(label="Resource Usage", value=lambda: system_monitor(), every=5) gr.Markdown("""

Created with ❤️ by Bils Imaging

""") # Event Handling generate_btn.click( generate_script, [concept_input, model_selector, duration_selector], [script_output, sound_output, music_output] ) voice_btn.click( generate_voice, [script_output, tts_selector, speed_selector], [voice_audio, voice_viz], preprocess=create_audio_visualization ) music_btn.click( generate_music, [music_output], [music_audio, music_viz], preprocess=create_audio_visualization ) mix_btn.click( blend_audio, [voice_audio, music_audio], [final_mix_audio, final_mix_viz], preprocess=create_audio_visualization ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)