import os import requests import torch import scipy.io.wavfile as wav import streamlit as st from io import BytesIO from transformers import ( AutoTokenizer, AutoModelForCausalLM, pipeline, AutoProcessor, MusicgenForConditionalGeneration ) from streamlit_lottie import st_lottie # --------------------------------------------------------------------- # 1) PAGE CONFIGURATION # --------------------------------------------------------------------- st.set_page_config( page_title="AI Radio Imaging with Llama 3", page_icon="🎧", layout="wide" ) # --------------------------------------------------------------------- # 2) CUSTOM CSS / UI DESIGN # --------------------------------------------------------------------- CUSTOM_CSS = """ """ st.markdown(CUSTOM_CSS, unsafe_allow_html=True) # --------------------------------------------------------------------- # 3) LOAD LOTTIE ANIMATION # --------------------------------------------------------------------- @st.cache_data def load_lottie_url(url: str): r = requests.get(url) if r.status_code != 200: return None return r.json() LOTTIE_URL = "https://assets3.lottiefiles.com/temp/lf20_Q6h5zV.json" lottie_animation = load_lottie_url(LOTTIE_URL) # --------------------------------------------------------------------- # 4) LOAD LLAMA 3 (GATED MODEL) # --------------------------------------------------------------------- @st.cache_resource def load_llama_pipeline(model_id: str, device: str, token: str): tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token) model = AutoModelForCausalLM.from_pretrained( model_id, use_auth_token=token, torch_dtype=torch.float16 if device == "auto" else torch.float32, device_map=device, low_cpu_mem_usage=True ) text_gen_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, device_map=device ) return text_gen_pipeline # --------------------------------------------------------------------- # 5) GENERATE RADIO SCRIPT # --------------------------------------------------------------------- def generate_radio_script(user_input: str, pipeline_llama) -> str: system_prompt = ( "You are a top-tier radio imaging producer using Llama 3. " "Take the user's concept and craft a short, creative promo script." ) combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:" result = pipeline_llama( combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9 ) output_text = result[0]["generated_text"] if "Refined script:" in output_text: output_text = output_text.split("Refined script:", 1)[-1].strip() output_text += "\n\n(Generated by Llama 3 - Radio Imaging)" return output_text # --------------------------------------------------------------------- # 6) LOAD MUSICGEN # --------------------------------------------------------------------- @st.cache_resource def load_musicgen_model(): mg_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") mg_processor = AutoProcessor.from_pretrained("facebook/musicgen-small") return mg_model, mg_processor # --------------------------------------------------------------------- # 7) HEADER # --------------------------------------------------------------------- st.title("🎧 AI Radio Imaging with Llama 3") st.subheader("Create engaging radio promos with Llama 3 + MusicGen") st.markdown("""Create **radio imaging promos** and **jingles** easily. Ensure you have access to **meta-llama/Meta-Llama-3-70B** on Hugging Face and provide your token below.""") if lottie_animation: st_lottie(lottie_animation, height=180, loop=True, key="radio_lottie") st.markdown("---") # --------------------------------------------------------------------- # 8) USER INPUT # --------------------------------------------------------------------- st.subheader("🎤 Step 1: Describe Your Promo Idea") prompt = st.text_area( "Example: 'A 15-second hype jingle for a morning talk show, fun and energetic.'", height=120 ) col_model, col_device = st.columns(2) with col_model: llama_model_id = st.text_input( "Llama 3 Model ID", value="meta-llama/Meta-Llama-3-70B", help="Enter the exact model ID from Hugging Face." ) with col_device: device_option = st.selectbox( "Device", ["auto", "cpu"], help="Choose GPU (auto) or CPU." ) hf_token = os.getenv("HF_TOKEN") if not hf_token: st.error("No HF_TOKEN found. Please set it in your environment.") st.stop() if st.button("✍ Generate Promo Script"): if not prompt.strip(): st.error("Please provide a concept first.") else: with st.spinner("Generating script..."): try: llama_pipeline = load_llama_pipeline(llama_model_id, device_option, hf_token) final_script = generate_radio_script(prompt, llama_pipeline) st.success("Promo script generated!") st.text_area("Generated Script", value=final_script, height=200) except Exception as e: st.error(f"Llama generation error: {e}") st.markdown("---") # --------------------------------------------------------------------- # 9) GENERATE AUDIO WITH MUSICGEN # --------------------------------------------------------------------- st.subheader("🎵 Step 2: Generate Audio") audio_length = st.slider("Track Length (tokens)", 128, 1024, 512, 64) if st.button("🎧 Create Audio"): if "final_script" not in st.session_state: st.error("Please generate a script first.") else: with st.spinner("Generating audio..."): try: mg_model, mg_processor = load_musicgen_model() inputs = mg_processor( text=[st.session_state["final_script"]], padding=True, return_tensors="pt" ) audio_values = mg_model.generate(**inputs, max_new_tokens=audio_length) sr = mg_model.config.audio_encoder.sampling_rate output_file = "radio_jingle.wav" audio_data = audio_values[0, 0].cpu().numpy() normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16") wav.write(output_file, rate=sr, data=normalized_audio) st.success("Audio generated! Play it below:") st.audio(output_file) except Exception as e: st.error(f"MusicGen error: {e}") # --------------------------------------------------------------------- # 10) FOOTER # --------------------------------------------------------------------- st.markdown("---") st.markdown( """ """, unsafe_allow_html=True )