import gradio as gr import torch from transformers import AutoProcessor, CsmForConditionalGeneration import librosa import numpy as np import soundfile as sf import tempfile import os import re import spaces # Global variables to persist across requests device = "cuda" if torch.cuda.is_available() else "cpu" model = None processor = None current_model_id = None # Model configuration MODELS = { "CSM 1B Dhivehi 5-Speaker": "alakxender/csm-1b-dhivehi-5-spk-gd", "CSM 1B Dhivehi 2-Speakers": "alakxender/csm-1b-dhivehi-2-speakers", } # Model-specific speaker prompts MODEL_SPEAKER_PROMPTS = { "alakxender/csm-1b-dhivehi-2-speakers": { "female_01": { "text": None, "audio": None, "speaker_id": "0", "name": "Female Speaker 01" }, "male_01": { "text": None, "audio": None, "speaker_id": "1", "name": "Male Speaker 01" } }, "alakxender/csm-1b-dhivehi-5-spk-gd": { "female_01": { "text": None, "audio": None, "speaker_id": "0", "name": "Female Speaker 01" }, "male_01": { "text": None, "audio": None, "speaker_id": "1", "name": "Male Speaker 01" }, "female_02": { "text": None, "audio": None, "speaker_id": "2", "name": "Female Speaker 02" }, "male_02": { "text": None, "audio": None, "speaker_id": "3", "name": "Male Speaker 02" }, "female_03": { "text": None, "audio": None, "speaker_id": "4", "name": "Female Speaker 03" } } } @spaces.GPU def load_model(model_name): global model, processor, current_model_id print(f"load_model called with '{model_name}', current_model_id before: {current_model_id}") if model_name not in MODELS: print(f"Error: Model '{model_name}' not found") return False model_id = MODELS[model_name] print(f"Model ID: {model_id}") if current_model_id == model_id and model is not None and processor is not None: print(f"Model '{model_name}' already loaded and valid") return True elif current_model_id == model_id: print(f"Model '{model_name}' was previously loaded but model/processor is None, reloading...") # Reset the current_model_id to force reload current_model_id = None try: print(f"Loading model '{model_name}' on {device}...") if model is not None: del model if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"Loading processor from {model_id}...") processor = AutoProcessor.from_pretrained(model_id) print(f"Processor loaded successfully: {processor is not None}") if hasattr(processor.tokenizer, "init_kwargs"): processor.tokenizer.init_kwargs.pop("pad_to_multiple_of", None) print(f"Loading model from {model_id}...") model = CsmForConditionalGeneration.from_pretrained( model_id, device_map=device, torch_dtype=torch.float32 ) print(f"Model loaded successfully: {model is not None}") # Check if model has generate method if hasattr(model, 'generate'): print(f"Model has generate method: {type(model.generate)}") else: print(f"ERROR: Model does not have generate method!") return False current_model_id = model_id print(f"Model '{model_name}' loaded successfully! current_model_id after: {current_model_id}") return True except Exception as e: print(f"Error loading model '{model_name}': {e}") import traceback traceback.print_exc() return False def get_current_speaker_prompts(): global current_model_id print(f"Getting speaker prompts for model_id: {current_model_id}") print(f"Available keys in MODEL_SPEAKER_PROMPTS: {list(MODEL_SPEAKER_PROMPTS.keys())}") if current_model_id in MODEL_SPEAKER_PROMPTS: print(f"Found speaker prompts for {current_model_id}") return MODEL_SPEAKER_PROMPTS[current_model_id] print(f"No speaker prompts found for {current_model_id}, using defaults") return { "female_01": {"text": "", "audio": None, "speaker_id": "0", "name": "Female Speaker 01"}, "male_01": {"text": "", "audio": None, "speaker_id": "1", "name": "Male Speaker 01"} } def get_model_info(): global current_model_id print(f"Getting model info, current_model_id: {current_model_id}") if current_model_id: model_name = next((name for name, id in MODELS.items() if id == current_model_id), "Unknown") result = f"Current Model: {model_name}" print(f"Model info: {result}") return result print("No model loaded") return "No model loaded" def get_speaker_choices(): global current_model_id print(f"get_speaker_choices called, current_model_id: {current_model_id}") prompts = get_current_speaker_prompts() choices = list(prompts.keys()) print(f"Speaker choices: {choices}") return choices def load_audio_file(filepath, target_sr=24000): if filepath is None or not os.path.exists(filepath): return None try: audio, sr = librosa.load(filepath, sr=target_sr) return audio.astype(np.float32) except Exception: return None @spaces.GPU def generate_simple_audio(text, speaker_id, model_name=None): global model, processor if not text.strip(): return None, "Please enter some text." # Debug: Check model and processor status print(f"generate_simple_audio: model is None = {model is None}") print(f"generate_simple_audio: processor is None = {processor is None}") print(f"generate_simple_audio: model_name = {model_name}") if model_name and model_name in MODELS: success = load_model(model_name) print(f"load_model result: {success}") elif model is None or processor is None: success = load_model(list(MODELS.keys())[0]) print(f"load_model result: {success}") # Check again after loading print(f"After loading: model is None = {model is None}") print(f"After loading: processor is None = {processor is None}") if model is None: return None, "Error: Model failed to load. Please try again." if processor is None: return None, "Error: Processor failed to load. Please try again." try: formatted_text = f"[{speaker_id}]{text}" print(f"Formatted text: {formatted_text}") inputs = processor(formatted_text, add_special_tokens=True).to(device) print(f"Inputs processed successfully") print(f"Model type: {type(model)}") print(f"Model generate method: {type(model.generate)}") print(f"Inputs type: {type(inputs)}") print(f"Inputs keys: {inputs.keys() if hasattr(inputs, 'keys') else 'Not a dict'}") with torch.no_grad(): try: audio = model.generate(**inputs, output_audio=True) print(f"Audio generated successfully") except Exception as gen_error: print(f"Error in model.generate: {gen_error}") print(f"Error type: {type(gen_error)}") import traceback traceback.print_exc() raise gen_error temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") processor.save_audio(audio, temp_file.name) return temp_file.name, "Audio generated successfully!" except Exception as e: print(f"Exception in generate_simple_audio: {e}") return None, f"Error: {str(e)}" @spaces.GPU def generate_context_audio(text, speaker_id, context_text, context_audio_file, model_name=None): global model, processor if not text.strip(): return None, "Please enter text to generate." if not context_text.strip() or context_audio_file is None: return None, "Please provide both context text and audio." if model_name and model_name in MODELS: load_model(model_name) elif model is None or processor is None: load_model(list(MODELS.keys())[0]) try: context_audio = load_audio_file(context_audio_file) if context_audio is None: return None, "Failed to load context audio." conversation = [ { "role": str(speaker_id), "content": [ {"type": "text", "text": context_text}, {"type": "audio", "path": context_audio} ] }, { "role": str(speaker_id), "content": [{"type": "text", "text": text}] } ] inputs = processor.apply_chat_template( conversation, tokenize=True, return_dict=True ).to(device) with torch.no_grad(): audio = model.generate(**inputs, output_audio=True) temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") processor.save_audio(audio, temp_file.name) return temp_file.name, "Audio generated with context!" except Exception as e: return None, f"Error: {str(e)}" def trim_silence(audio_array, top_db=20): try: trimmed, _ = librosa.effects.trim(audio_array, top_db=top_db) return trimmed except: return audio_array def split_sentences(text): sentences = re.split(r'[.؟!،]', text) return [s.strip() for s in sentences if s.strip()] def extract_audio_output(audio_output): global processor try: if hasattr(audio_output, 'audio_values'): return audio_output.audio_values.cpu().squeeze().numpy() elif isinstance(audio_output, torch.Tensor): return audio_output.cpu().squeeze().numpy() else: temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) processor.save_audio(audio_output, temp_file.name) audio, _ = librosa.load(temp_file.name, sr=24000) os.unlink(temp_file.name) return audio except Exception: return np.array([]) @spaces.GPU def generate_conversation(speaker_a, speaker_b, speaker_a_audio, speaker_a_text, speaker_b_audio, speaker_b_text, dialogue_text, split_sentences_flag, use_style, model_name=None, progress=gr.Progress()): global model, processor if model_name and model_name in MODELS: load_model(model_name) elif model is None or processor is None: load_model(list(MODELS.keys())[0]) try: lines = [line.strip() for line in dialogue_text.strip().split("\n") if line.strip()] if not lines: return None, "Please enter dialogue text." current_prompts = get_current_speaker_prompts() speaker_a_id = current_prompts.get(speaker_a, {}).get("speaker_id", "0") speaker_b_id = current_prompts.get(speaker_b, {}).get("speaker_id", "1") speakers = [speaker_a_id, speaker_b_id] progress(0, desc="Preparing...") speaker_contexts = {} if use_style: if speaker_a in current_prompts and speaker_a_audio and speaker_a_text: audio_data = load_audio_file(speaker_a_audio) if audio_data is not None: speaker_contexts[speaker_a_id] = { "audio": trim_silence(audio_data), "text": speaker_a_text.strip() } if speaker_b in current_prompts and speaker_b_audio and speaker_b_text: audio_data = load_audio_file(speaker_b_audio) if audio_data is not None: speaker_contexts[speaker_b_id] = { "audio": trim_silence(audio_data), "text": speaker_b_text.strip() } text_units = [] if split_sentences_flag: for i, line in enumerate(lines): speaker_role = speakers[i % 2] sentences = split_sentences(line) for sentence in sentences: text_units.append((sentence, speaker_role)) else: for i, line in enumerate(lines): speaker_role = speakers[i % 2] text_units.append((line, speaker_role)) if not text_units: return None, "No text to generate." audio_segments = [] progress(0.1, desc="Generating audio...") for i, (text, role) in enumerate(text_units): progress(0.1 + (i / len(text_units)) * 0.8, desc=f"Generating {i+1}/{len(text_units)}: Speaker {role}") conversation = [] if use_style and role in speaker_contexts: is_first = True for prev_i in range(i): if text_units[prev_i][1] == role: is_first = False break if is_first: context = speaker_contexts[role] conversation.append({ "role": role, "content": [ {"type": "text", "text": context["text"]}, {"type": "audio", "path": context["audio"]} ] }) conversation.append({ "role": role, "content": [{"type": "text", "text": text}] }) inputs = processor.apply_chat_template( conversation, tokenize=True, return_dict=True ).to(device) with torch.no_grad(): audio_output = model.generate(**inputs, output_audio=True) audio_segment = extract_audio_output(audio_output) if len(audio_segment) > 0: audio_segment = trim_silence(audio_segment) audio_segments.append(audio_segment) if not audio_segments: return None, "No audio generated." progress(0.9, desc="Finalizing...") if len(audio_segments) == 1: final_audio = audio_segments[0] else: final_audio = np.concatenate(audio_segments) if np.max(np.abs(final_audio)) > 0: final_audio = final_audio / np.max(np.abs(final_audio)) * 0.95 output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") sf.write(output_file.name, final_audio, 24000) duration = len(final_audio) / 24000 progress(1.0, desc="Complete!") return output_file.name, f"Generated {len(text_units)} segments, {duration:.1f}s total" except Exception as e: return None, f"Error: {str(e)}" def change_model_and_update_ui(model_name): success = load_model(model_name) if not success: return ( gr.update(value="Error loading model"), gr.update(), gr.update(), gr.update(), gr.update() ) choices = get_speaker_choices() current_prompts = get_current_speaker_prompts() new_speaker_choices = [(f"{prompt_data['speaker_id']}: {prompt_data['name']}", prompt_data['speaker_id']) for prompt_key, prompt_data in current_prompts.items()] display_choices = [choice[0] for choice in new_speaker_choices] return ( gr.update(value=get_model_info()), gr.update(choices=choices, value=choices[0]), gr.update(choices=choices, value=choices[1] if len(choices) > 1 else choices[0]), gr.update(choices=display_choices, value=display_choices[0] if display_choices else "0: Speaker"), gr.update(choices=display_choices, value=display_choices[0] if display_choices else "0: Speaker") ) def get_csm1b_tab(): global current_model_id print(f"get_csm1b_tab called, current_model_id at start: {current_model_id}") # Load the first model by default before building the UI try: first_model = list(MODELS.keys())[0] print(f"Loading default model in csm1b_dv: {first_model}") success = load_model(first_model) if success: print(f"Model loaded successfully in csm1b_dv, current_model_id: {current_model_id}") # Force set the current_model_id to ensure it persists current_model_id = MODELS[first_model] print(f"Force set current_model_id to: {current_model_id}") # Test speaker loading print("\n=== TESTING SPEAKER LOADING ===") test_speakers = get_current_speaker_prompts() print(f"Available speakers: {list(test_speakers.keys())}") for speaker_key, speaker_data in test_speakers.items(): print(f" {speaker_key}: {speaker_data['name']} (ID: {speaker_data['speaker_id']})") print("=== END SPEAKER TEST ===\n") else: print("Failed to load model in csm1b_dv") except Exception as e: print(f"Failed to load default model in csm1b_dv: {e}") print(f"current_model_id before UI creation: {current_model_id}") with gr.Tab("🎙️ CSM-1B"): gr.Markdown("# 🎙️ CSM-1B Text-to-Speech Synthesis") gr.Markdown("**CSM (Conversational Speech Model)** is a speech generation model from [Sesame](sesame.com) that generates **RVQ audio codes** from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes. This demo uses a **fine-tuned version** of the model for **Dhivehi speech synthesis**.") print(f"Creating model dropdown and info, current_model_id: {current_model_id}") with gr.Row(): model_dropdown = gr.Dropdown( choices=list(MODELS.keys()), value=list(MODELS.keys())[0], label="🤖 Select Model" ) model_info = gr.Textbox( value=get_model_info(), label="Model Status", interactive=False ) with gr.Tabs(): with gr.TabItem("🎯 Simple Generation"): gr.Markdown("### Generate speech from text without context") with gr.Row(): with gr.Column(): simple_text = gr.Textbox( label="Text to Generate (Dhivehi)", placeholder="މަންމަ ކިހާއިރެއް ދަރިފުޅުގެ އިންތިޒާރުގަ އިންނަތާ", value="މަންމަ ކިހާއިރެއް ދަރިފުޅުގެ އިންތިޒާރުގަ އިންނަތާ", lines=3, elem_classes=["dhivehi-text"] ) current_prompts = get_current_speaker_prompts() speaker_choices = [(f"{prompt_data['speaker_id']}: {prompt_data['name']}", prompt_data['speaker_id']) for prompt_key, prompt_data in current_prompts.items()] simple_speaker = gr.Radio( choices=[choice[0] for choice in speaker_choices], label="Speaker", value=speaker_choices[0][0] if speaker_choices else "0: Speaker" ) simple_btn = gr.Button("🎵 Generate", variant="primary") with gr.Column(): simple_audio = gr.Audio(label="Generated Audio") simple_status = gr.Textbox(label="Status", interactive=False) def simple_generate_with_mapping(text, speaker_display, selected_model): speaker_id = speaker_display.split(":")[0] return generate_simple_audio(text, speaker_id, selected_model) simple_btn.click( simple_generate_with_mapping, inputs=[simple_text, simple_speaker, model_dropdown], outputs=[simple_audio, simple_status] ) with gr.TabItem("🎭 Context Generation"): gr.Markdown("### Generate speech with voice prompt") with gr.Row(): with gr.Column(): context_text = gr.Textbox( label="Speaker prompt", placeholder="މަންމަ ކިހާއިރެއް ދަރިފުޅުގެ އިންތިޒާރުގަ އިންނަތާ", value="", lines=2, elem_classes=["dhivehi-text"] ) context_audio = gr.Audio( label="Speaker Prompt", type="filepath" ) target_text = gr.Textbox( label="Text to Generate", placeholder="މަންމަ ކިހާއިރެއް ދަރިފުޅުގެ އިންތިޒާރުގަ އިންނަތާ", value="މަންމަ ކިހާއިރެއް ދަރިފުޅުގެ އިންތިޒާރުގަ އިންނަތާ", lines=3, elem_classes=["dhivehi-text"] ) context_speaker = gr.Radio( choices=[choice[0] for choice in speaker_choices], label="Speaker", value=speaker_choices[0][0] if speaker_choices else "0: Speaker" ) context_btn = gr.Button("🎵 Generate with Context", variant="primary") with gr.Column(): context_audio_out = gr.Audio(label="Generated Audio") context_status = gr.Textbox(label="Status", interactive=False) def context_generate_with_mapping(text, speaker_display, context_text_val, context_audio_val, selected_model): speaker_id = speaker_display.split(":")[0] return generate_context_audio(text, speaker_id, context_text_val, context_audio_val, selected_model) context_btn.click( context_generate_with_mapping, inputs=[target_text, context_speaker, context_text, context_audio, model_dropdown], outputs=[context_audio_out, context_status] ) with gr.TabItem("💬 Conversation"): gr.Markdown("### Generate dual-speaker conversations") with gr.Row(): speaker_a = gr.Dropdown( choices=get_speaker_choices(), label="Speaker A", value=get_speaker_choices()[0] if get_speaker_choices() else "female_01" ) speaker_b = gr.Dropdown( choices=get_speaker_choices(), label="Speaker B", value=get_speaker_choices()[1] if len(get_speaker_choices()) > 1 else get_speaker_choices()[0] if get_speaker_choices() else "male_01" ) with gr.Accordion("🎵 Audio Style References", open=False): with gr.Row(): with gr.Column(): gr.Markdown("**Speaker A Prompt**") speaker_a_audio = gr.Audio( type="filepath", label="Speaker A Audio Style" ) speaker_a_text = gr.Textbox( label="Speaker A Prompt Text", lines=2, placeholder="މަންމަ ކިހާއިރެއް ދަރިފުޅުގެ އިންތިޒާރުގަ އިންނަތާ", elem_classes=["dhivehi-text"] ) with gr.Column(): gr.Markdown("**Speaker B Prompt**") speaker_b_audio = gr.Audio( type="filepath", label="Speaker B Speaker Prompt" ) speaker_b_text = gr.Textbox( label="Speaker B Speaker Prompt", lines=2, placeholder="މަންމަ ކިހާއިރެއް ދަރިފުޅުގެ އިންތިޒާރުގަ އިންނަތާ", elem_classes=["dhivehi-text"] ) with gr.Accordion("⚙️ Options", open=False): use_style = gr.Checkbox( label="Use audio style references", value=False ) split_sentences_checkbox = gr.Checkbox( label="Split sentences", value=True ) dialogue_text = gr.Textbox( lines=6, placeholder="ދަރިފުޅު މިއަދު ހާދަ ލަސްތިވީ.. މަންމަ ކިހާއިރެއް ދަރިފުޅުގެ އިންތިޒާރުގަ އިންނަތާ...... ކޮބާ ތިޔަ ރިޕޯޓް ފޮތް؟\nދަބަހުގަ އެބައޮތް!.", value="ދަރިފުޅު މިއަދު ހާދަ ލަސްތިވީ.. މަންމަ ކިހާއިރެއް ދަރިފުޅުގެ އިންތިޒާރުގަ އިންނަތާ...... ކޮބާ ތިޔަ ރިޕޯޓް ފޮތް؟\nދަބަހުގަ އެބައޮތް!.", label="Dialogue Lines (one per line)", elem_classes=["dhivehi-text"] ) conv_btn = gr.Button("🎵 Generate Conversation", variant="primary") conv_audio = gr.Audio(label="Generated Conversation") conv_status = gr.Textbox(label="Status", interactive=False) def conversation_generate_with_model(speaker_a_val, speaker_b_val, speaker_a_audio_val, speaker_a_text_val, speaker_b_audio_val, speaker_b_text_val, dialogue_text_val, split_sentences_flag, use_style_flag, selected_model): return generate_conversation(speaker_a_val, speaker_b_val, speaker_a_audio_val, speaker_a_text_val, speaker_b_audio_val, speaker_b_text_val, dialogue_text_val, split_sentences_flag, use_style_flag, selected_model) conv_btn.click( conversation_generate_with_model, inputs=[speaker_a, speaker_b, speaker_a_audio, speaker_a_text, speaker_b_audio, speaker_b_text, dialogue_text, split_sentences_checkbox, use_style, model_dropdown], outputs=[conv_audio, conv_status] ) model_dropdown.change( change_model_and_update_ui, inputs=[model_dropdown], outputs=[model_info, speaker_a, speaker_b, simple_speaker, context_speaker] ) gr.Markdown(""" --- **Tips:** - Simple: Basic text-to-speech - Context: Use reference audio for voice consistency - Conversation: Multi-speaker dialogues with style control **Issues:** - Context: Context breaks sometimes. Adding multiple context audio seems to make it work, or adding previous generation to the context helps. - Audio: Sometimes the generated audio is not in sync with the text. - Long sentences: Generated long sentences seems sped up. - Repeating words: Generated text sometimes repeats words. """) # No explicit return needed for context manager pattern