Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		Den Pavloff
		
	commited on
		
		
					Commit 
							
							·
						
						164603c
	
1
								Parent(s):
							
							91eb188
								
first
Browse files- app.py +212 -0
- requirements.txt +5 -0
- util.py +222 -0
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,212 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import subprocess
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # Fix OMP_NUM_THREADS issue before any imports
         | 
| 6 | 
            +
            os.environ["OMP_NUM_THREADS"] = "4"
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # Install dependencies programmatically to avoid conflicts
         | 
| 9 | 
            +
            def setup_dependencies():
         | 
| 10 | 
            +
                try:
         | 
| 11 | 
            +
                    # Check if already installed
         | 
| 12 | 
            +
                    if os.path.exists('/tmp/deps_installed'):
         | 
| 13 | 
            +
                        return
         | 
| 14 | 
            +
                        
         | 
| 15 | 
            +
                    print("Installing transformers dev version...")
         | 
| 16 | 
            +
                    subprocess.check_call([
         | 
| 17 | 
            +
                        sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir",
         | 
| 18 | 
            +
                        "git+https://github.com/huggingface/transformers.git"
         | 
| 19 | 
            +
                    ])
         | 
| 20 | 
            +
                    
         | 
| 21 | 
            +
                    # Mark as installed
         | 
| 22 | 
            +
                    with open('/tmp/deps_installed', 'w') as f:
         | 
| 23 | 
            +
                        f.write('done')
         | 
| 24 | 
            +
                        
         | 
| 25 | 
            +
                except Exception as e:
         | 
| 26 | 
            +
                    print(f"Dependencies setup error: {e}")
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            # Run setup
         | 
| 29 | 
            +
            setup_dependencies()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            import spaces
         | 
| 32 | 
            +
            import gradio as gr
         | 
| 33 | 
            +
            from util import Config, NemoAudioPlayer, KaniModel
         | 
| 34 | 
            +
            import numpy as np
         | 
| 35 | 
            +
            import torch
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            # Get HuggingFace token
         | 
| 38 | 
            +
            token_ = os.getenv('HF_TOKEN')
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            # Model configurations
         | 
| 41 | 
            +
            models_configs = {
         | 
| 42 | 
            +
                'Base_pretrained_model': Config(),
         | 
| 43 | 
            +
                'Female_voice': Config(
         | 
| 44 | 
            +
                    model_name='nineninesix/lfm-nano-codec-expresso-ex02-v.0.2',
         | 
| 45 | 
            +
                    temperature=0.2
         | 
| 46 | 
            +
                ),
         | 
| 47 | 
            +
                'Male_voice': Config(
         | 
| 48 | 
            +
                    model_name='nineninesix/lfm-nano-codec-expresso-ex01-v.0.1',
         | 
| 49 | 
            +
                    temperature=0.2
         | 
| 50 | 
            +
                )
         | 
| 51 | 
            +
            }
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            # Global variables for models (loaded once)
         | 
| 54 | 
            +
            player = None
         | 
| 55 | 
            +
            models = {}
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            def initialize_models():
         | 
| 58 | 
            +
                """Initialize models globally to avoid reloading"""
         | 
| 59 | 
            +
                global player, models
         | 
| 60 | 
            +
                
         | 
| 61 | 
            +
                if player is None:
         | 
| 62 | 
            +
                    print("Initializing NeMo Audio Player...")
         | 
| 63 | 
            +
                    player = NemoAudioPlayer(Config())
         | 
| 64 | 
            +
                    print("NeMo Audio Player initialized!")
         | 
| 65 | 
            +
                
         | 
| 66 | 
            +
                if not models:
         | 
| 67 | 
            +
                    print("Loading TTS models...")
         | 
| 68 | 
            +
                    for model_name, config in models_configs.items():
         | 
| 69 | 
            +
                        print(f"Loading {model_name}...")
         | 
| 70 | 
            +
                        models[model_name] = KaniModel(config, player, token_)
         | 
| 71 | 
            +
                        print(f"{model_name} loaded!")
         | 
| 72 | 
            +
                    print("All models loaded!")
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            @spaces.GPU
         | 
| 75 | 
            +
            def generate_speech_gpu(text, model_choice):
         | 
| 76 | 
            +
                """
         | 
| 77 | 
            +
                Generate speech from text using the selected model on GPU
         | 
| 78 | 
            +
                """
         | 
| 79 | 
            +
                # Initialize models if not already done
         | 
| 80 | 
            +
                initialize_models()
         | 
| 81 | 
            +
                
         | 
| 82 | 
            +
                if not text.strip():
         | 
| 83 | 
            +
                    return None, "Please enter text for speech generation."
         | 
| 84 | 
            +
                
         | 
| 85 | 
            +
                if not model_choice:
         | 
| 86 | 
            +
                    return None, "Please select a model."
         | 
| 87 | 
            +
                
         | 
| 88 | 
            +
                try:
         | 
| 89 | 
            +
                    # Check GPU availability
         | 
| 90 | 
            +
                    device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 91 | 
            +
                    print(f"Using device: {device}")
         | 
| 92 | 
            +
                    
         | 
| 93 | 
            +
                    # Get selected model
         | 
| 94 | 
            +
                    selected_model = models[model_choice]
         | 
| 95 | 
            +
                    
         | 
| 96 | 
            +
                    # Generate audio
         | 
| 97 | 
            +
                    print(f"Generating speech with {model_choice}...")
         | 
| 98 | 
            +
                    audio, _ = selected_model.run_model(text)
         | 
| 99 | 
            +
                    
         | 
| 100 | 
            +
                    # Convert to Gradio format (sample_rate, audio_data)
         | 
| 101 | 
            +
                    sample_rate = 22050  # Standard sample rate for NeMo
         | 
| 102 | 
            +
                    print("Speech generation completed!")
         | 
| 103 | 
            +
                    
         | 
| 104 | 
            +
                    return (sample_rate, audio), f"✅ Audio generated successfully using {model_choice} on {device}"
         | 
| 105 | 
            +
                    
         | 
| 106 | 
            +
                except Exception as e:
         | 
| 107 | 
            +
                    print(f"Error during generation: {str(e)}")
         | 
| 108 | 
            +
                    return None, f"❌ Error during generation: {str(e)}"
         | 
| 109 | 
            +
             | 
| 110 | 
            +
            def validate_input(text, model_choice):
         | 
| 111 | 
            +
                """Quick validation without GPU"""
         | 
| 112 | 
            +
                if not text.strip():
         | 
| 113 | 
            +
                    return "⚠️ Please enter text for speech generation."
         | 
| 114 | 
            +
                if not model_choice:
         | 
| 115 | 
            +
                    return "⚠️ Please select a model."
         | 
| 116 | 
            +
                return f"✅ Ready to generate with {model_choice}"
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            # Create Gradio interface
         | 
| 119 | 
            +
            with gr.Blocks(title="KaniTTS - Text to Speech", theme=gr.themes.Soft()) as demo:
         | 
| 120 | 
            +
                gr.Markdown("# 🎤 KaniTTS - Text to Speech with Zero GPU")
         | 
| 121 | 
            +
                gr.Markdown("Select a model and enter text to generate high-quality speech")
         | 
| 122 | 
            +
                
         | 
| 123 | 
            +
                with gr.Row():
         | 
| 124 | 
            +
                    with gr.Column(scale=1):
         | 
| 125 | 
            +
                        model_dropdown = gr.Dropdown(
         | 
| 126 | 
            +
                            choices=list(models_configs.keys()),
         | 
| 127 | 
            +
                            value=list(models_configs.keys())[0],
         | 
| 128 | 
            +
                            label="Select Model",
         | 
| 129 | 
            +
                            info="Base - default model, Female - female voice, Male - male voice"
         | 
| 130 | 
            +
                        )
         | 
| 131 | 
            +
                        
         | 
| 132 | 
            +
                        text_input = gr.Textbox(
         | 
| 133 | 
            +
                            label="Enter Text",
         | 
| 134 | 
            +
                            placeholder="Enter text for speech generation...",
         | 
| 135 | 
            +
                            lines=3,
         | 
| 136 | 
            +
                            max_lines=10
         | 
| 137 | 
            +
                        )
         | 
| 138 | 
            +
                        
         | 
| 139 | 
            +
                        generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
         | 
| 140 | 
            +
                        
         | 
| 141 | 
            +
                        # Quick validation button (CPU only)
         | 
| 142 | 
            +
                        validate_btn = gr.Button("🔍 Validate Input", variant="secondary")
         | 
| 143 | 
            +
                        
         | 
| 144 | 
            +
                    with gr.Column(scale=1):
         | 
| 145 | 
            +
                        audio_output = gr.Audio(
         | 
| 146 | 
            +
                            label="Generated Speech",
         | 
| 147 | 
            +
                            type="numpy"
         | 
| 148 | 
            +
                        )
         | 
| 149 | 
            +
                        
         | 
| 150 | 
            +
                        status_text = gr.Textbox(
         | 
| 151 | 
            +
                            label="Status",
         | 
| 152 | 
            +
                            interactive=False,
         | 
| 153 | 
            +
                            value="Ready to generate speech"
         | 
| 154 | 
            +
                        )
         | 
| 155 | 
            +
                
         | 
| 156 | 
            +
                # GPU generation event
         | 
| 157 | 
            +
                generate_btn.click(
         | 
| 158 | 
            +
                    fn=generate_speech_gpu,
         | 
| 159 | 
            +
                    inputs=[text_input, model_dropdown],
         | 
| 160 | 
            +
                    outputs=[audio_output, status_text]
         | 
| 161 | 
            +
                )
         | 
| 162 | 
            +
                
         | 
| 163 | 
            +
                # CPU validation event
         | 
| 164 | 
            +
                validate_btn.click(
         | 
| 165 | 
            +
                    fn=validate_input,
         | 
| 166 | 
            +
                    inputs=[text_input, model_dropdown],
         | 
| 167 | 
            +
                    outputs=status_text
         | 
| 168 | 
            +
                )
         | 
| 169 | 
            +
                
         | 
| 170 | 
            +
                # Update status on input change
         | 
| 171 | 
            +
                text_input.change(
         | 
| 172 | 
            +
                    fn=validate_input,
         | 
| 173 | 
            +
                    inputs=[text_input, model_dropdown],
         | 
| 174 | 
            +
                    outputs=status_text
         | 
| 175 | 
            +
                )
         | 
| 176 | 
            +
                
         | 
| 177 | 
            +
                # Text examples
         | 
| 178 | 
            +
                gr.Markdown("### 📝 Text Examples:")
         | 
| 179 | 
            +
                examples = [
         | 
| 180 | 
            +
                    "Hello! How are you today?",
         | 
| 181 | 
            +
                    "Welcome to the world of artificial intelligence.",
         | 
| 182 | 
            +
                    "This is a demonstration of neural text-to-speech synthesis.",
         | 
| 183 | 
            +
                    "Zero GPU makes high-quality speech generation accessible to everyone!"
         | 
| 184 | 
            +
                ]
         | 
| 185 | 
            +
                
         | 
| 186 | 
            +
                gr.Examples(
         | 
| 187 | 
            +
                    examples=examples,
         | 
| 188 | 
            +
                    inputs=text_input,
         | 
| 189 | 
            +
                    label="Click on an example to use it"
         | 
| 190 | 
            +
                )
         | 
| 191 | 
            +
                
         | 
| 192 | 
            +
                # Information section
         | 
| 193 | 
            +
                with gr.Accordion("ℹ️ Model Information", open=False):
         | 
| 194 | 
            +
                    gr.Markdown("""
         | 
| 195 | 
            +
                    **Available Models:**
         | 
| 196 | 
            +
                    - **Base Model**: Default pre-trained model for general use
         | 
| 197 | 
            +
                    - **Female Voice**: Optimized for female voice characteristics
         | 
| 198 | 
            +
                    - **Male Voice**: Optimized for male voice characteristics
         | 
| 199 | 
            +
                    
         | 
| 200 | 
            +
                    **Features:**
         | 
| 201 | 
            +
                    - Powered by NVIDIA NeMo Toolkit
         | 
| 202 | 
            +
                    - High-quality 22kHz audio output
         | 
| 203 | 
            +
                    - Zero GPU acceleration for fast inference
         | 
| 204 | 
            +
                    - Support for long text sequences
         | 
| 205 | 
            +
                    """)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
            if __name__ == "__main__":
         | 
| 208 | 
            +
                demo.launch(
         | 
| 209 | 
            +
                    server_name="0.0.0.0",
         | 
| 210 | 
            +
                    server_port=7860,
         | 
| 211 | 
            +
                    show_error=True
         | 
| 212 | 
            +
                )
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            torch==2.8.0
         | 
| 2 | 
            +
            librosa==0.11.0
         | 
| 3 | 
            +
            nemo_toolkit[all]==2.4.0
         | 
| 4 | 
            +
            numpy==1.26.4
         | 
| 5 | 
            +
            gradio>=4.0.0
         | 
    	
        util.py
    ADDED
    
    | @@ -0,0 +1,222 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from nemo.collections.tts.models import AudioCodecModel
         | 
| 3 | 
            +
            from dataclasses import dataclass
         | 
| 4 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            @dataclass
         | 
| 9 | 
            +
            class Config:
         | 
| 10 | 
            +
                model_name: str = "nineninesix/lfm-nano-codec-tts-exp-4-large-61468-st"
         | 
| 11 | 
            +
                audiocodec_name: str = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps"
         | 
| 12 | 
            +
                device_map: str = "auto"
         | 
| 13 | 
            +
                tokeniser_length: int = 64400
         | 
| 14 | 
            +
                start_of_text: int = 1
         | 
| 15 | 
            +
                end_of_text: int = 2
         | 
| 16 | 
            +
                max_new_tokens: int = 2000
         | 
| 17 | 
            +
                temperature: float = .6
         | 
| 18 | 
            +
                top_p: float = .95
         | 
| 19 | 
            +
                repetition_penalty: float = 1.1
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class NemoAudioPlayer:
         | 
| 23 | 
            +
                def __init__(self, config, text_tokenizer_name: str = None) -> None:
         | 
| 24 | 
            +
                    self.conf = config
         | 
| 25 | 
            +
                    print(f"Loading NeMo codec model: {self.conf.audiocodec_name}")
         | 
| 26 | 
            +
                    
         | 
| 27 | 
            +
                    # Load NeMo codec model
         | 
| 28 | 
            +
                    self.nemo_codec_model = AudioCodecModel.from_pretrained(
         | 
| 29 | 
            +
                        self.conf.audiocodec_name
         | 
| 30 | 
            +
                    ).eval()
         | 
| 31 | 
            +
                    
         | 
| 32 | 
            +
                    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
         | 
| 33 | 
            +
                    print(f"Moving NeMo codec to device: {self.device}")
         | 
| 34 | 
            +
                    self.nemo_codec_model.to(self.device)
         | 
| 35 | 
            +
                    
         | 
| 36 | 
            +
                    self.text_tokenizer_name = text_tokenizer_name
         | 
| 37 | 
            +
                    if self.text_tokenizer_name:
         | 
| 38 | 
            +
                        self.tokenizer = AutoTokenizer.from_pretrained(self.text_tokenizer_name)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    # Token configuration
         | 
| 41 | 
            +
                    self.tokeniser_length = self.conf.tokeniser_length
         | 
| 42 | 
            +
                    self.start_of_text = self.conf.start_of_text
         | 
| 43 | 
            +
                    self.end_of_text = self.conf.end_of_text
         | 
| 44 | 
            +
                    self.start_of_speech = self.tokeniser_length + 1
         | 
| 45 | 
            +
                    self.end_of_speech = self.tokeniser_length + 2
         | 
| 46 | 
            +
                    self.start_of_human = self.tokeniser_length + 3
         | 
| 47 | 
            +
                    self.end_of_human = self.tokeniser_length + 4
         | 
| 48 | 
            +
                    self.start_of_ai = self.tokeniser_length + 5
         | 
| 49 | 
            +
                    self.end_of_ai = self.tokeniser_length + 6
         | 
| 50 | 
            +
                    self.pad_token = self.tokeniser_length + 7
         | 
| 51 | 
            +
                    self.audio_tokens_start = self.tokeniser_length + 10
         | 
| 52 | 
            +
                    self.codebook_size = 4032
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def output_validation(self, out_ids):
         | 
| 55 | 
            +
                    """Validate that output contains required speech tokens"""
         | 
| 56 | 
            +
                    start_of_speech_flag = self.start_of_speech in out_ids
         | 
| 57 | 
            +
                    end_of_speech_flag = self.end_of_speech in out_ids
         | 
| 58 | 
            +
                    
         | 
| 59 | 
            +
                    if not (start_of_speech_flag and end_of_speech_flag):
         | 
| 60 | 
            +
                        raise ValueError('Special speech tokens not found in output!')
         | 
| 61 | 
            +
                    
         | 
| 62 | 
            +
                    print("Output validation passed - speech tokens found")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def get_nano_codes(self, out_ids):
         | 
| 65 | 
            +
                    """Extract nano codec tokens from model output"""
         | 
| 66 | 
            +
                    try:
         | 
| 67 | 
            +
                        start_a_idx = (out_ids == self.start_of_speech).nonzero(as_tuple=True)[0].item()
         | 
| 68 | 
            +
                        end_a_idx = (out_ids == self.end_of_speech).nonzero(as_tuple=True)[0].item()
         | 
| 69 | 
            +
                    except IndexError:
         | 
| 70 | 
            +
                        raise ValueError('Speech start/end tokens not found!')
         | 
| 71 | 
            +
                        
         | 
| 72 | 
            +
                    if start_a_idx >= end_a_idx:
         | 
| 73 | 
            +
                        raise ValueError('Invalid audio codes sequence!')
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    audio_codes = out_ids[start_a_idx + 1: end_a_idx]
         | 
| 76 | 
            +
                    
         | 
| 77 | 
            +
                    if len(audio_codes) % 4:
         | 
| 78 | 
            +
                        raise ValueError('Audio codes length must be multiple of 4!')
         | 
| 79 | 
            +
                        
         | 
| 80 | 
            +
                    audio_codes = audio_codes.reshape(-1, 4)
         | 
| 81 | 
            +
                    
         | 
| 82 | 
            +
                    # Decode audio codes
         | 
| 83 | 
            +
                    audio_codes = audio_codes - torch.tensor([self.codebook_size * i for i in range(4)])
         | 
| 84 | 
            +
                    audio_codes = audio_codes - self.audio_tokens_start
         | 
| 85 | 
            +
                    
         | 
| 86 | 
            +
                    if (audio_codes < 0).sum().item() > 0:
         | 
| 87 | 
            +
                        raise ValueError('Invalid audio tokens detected!')
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    audio_codes = audio_codes.T.unsqueeze(0)
         | 
| 90 | 
            +
                    len_ = torch.tensor([audio_codes.shape[-1]])
         | 
| 91 | 
            +
                    
         | 
| 92 | 
            +
                    print(f"Extracted audio codes shape: {audio_codes.shape}")
         | 
| 93 | 
            +
                    return audio_codes, len_
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def get_text(self, out_ids):
         | 
| 96 | 
            +
                    """Extract text from model output"""
         | 
| 97 | 
            +
                    try:
         | 
| 98 | 
            +
                        start_t_idx = (out_ids == self.start_of_text).nonzero(as_tuple=True)[0].item()
         | 
| 99 | 
            +
                        end_t_idx = (out_ids == self.end_of_text).nonzero(as_tuple=True)[0].item()
         | 
| 100 | 
            +
                    except IndexError:
         | 
| 101 | 
            +
                        raise ValueError('Text start/end tokens not found!')
         | 
| 102 | 
            +
                        
         | 
| 103 | 
            +
                    txt_tokens = out_ids[start_t_idx: end_t_idx + 1]
         | 
| 104 | 
            +
                    text = self.tokenizer.decode(txt_tokens, skip_special_tokens=True)
         | 
| 105 | 
            +
                    return text
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def get_waveform(self, out_ids):
         | 
| 108 | 
            +
                    """Convert model output to audio waveform"""
         | 
| 109 | 
            +
                    out_ids = out_ids.flatten()
         | 
| 110 | 
            +
                    print("Starting waveform generation...")
         | 
| 111 | 
            +
                    
         | 
| 112 | 
            +
                    # Validate output
         | 
| 113 | 
            +
                    self.output_validation(out_ids)
         | 
| 114 | 
            +
                    
         | 
| 115 | 
            +
                    # Extract audio codes
         | 
| 116 | 
            +
                    audio_codes, len_ = self.get_nano_codes(out_ids)
         | 
| 117 | 
            +
                    audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device)
         | 
| 118 | 
            +
                    
         | 
| 119 | 
            +
                    print("Decoding audio with NeMo codec...")
         | 
| 120 | 
            +
                    with torch.inference_mode():
         | 
| 121 | 
            +
                        reconstructed_audio, _ = self.nemo_codec_model.decode(
         | 
| 122 | 
            +
                            tokens=audio_codes, 
         | 
| 123 | 
            +
                            tokens_len=len_
         | 
| 124 | 
            +
                        )
         | 
| 125 | 
            +
                        output_audio = reconstructed_audio.cpu().detach().numpy().squeeze()
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    print(f"Generated audio shape: {output_audio.shape}")
         | 
| 128 | 
            +
                    
         | 
| 129 | 
            +
                    if self.text_tokenizer_name:
         | 
| 130 | 
            +
                        text = self.get_text(out_ids)
         | 
| 131 | 
            +
                        return output_audio, text
         | 
| 132 | 
            +
                    else:
         | 
| 133 | 
            +
                        return output_audio, None
         | 
| 134 | 
            +
             | 
| 135 | 
            +
             | 
| 136 | 
            +
            class KaniModel:
         | 
| 137 | 
            +
                def __init__(self, config, player: NemoAudioPlayer, token: str) -> None:
         | 
| 138 | 
            +
                    self.conf = config
         | 
| 139 | 
            +
                    self.player = player
         | 
| 140 | 
            +
                    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
         | 
| 141 | 
            +
                    
         | 
| 142 | 
            +
                    print(f"Loading model: {self.conf.model_name}")
         | 
| 143 | 
            +
                    print(f"Target device: {self.device}")
         | 
| 144 | 
            +
                    
         | 
| 145 | 
            +
                    # Load model with proper configuration
         | 
| 146 | 
            +
                    self.model = AutoModelForCausalLM.from_pretrained(
         | 
| 147 | 
            +
                        self.conf.model_name,
         | 
| 148 | 
            +
                        torch_dtype=torch.bfloat16,
         | 
| 149 | 
            +
                        device_map=self.conf.device_map,
         | 
| 150 | 
            +
                        token=token,
         | 
| 151 | 
            +
                        trust_remote_code=True  # May be needed for some models
         | 
| 152 | 
            +
                    )
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(
         | 
| 155 | 
            +
                        self.conf.model_name, 
         | 
| 156 | 
            +
                        token=token,
         | 
| 157 | 
            +
                        trust_remote_code=True
         | 
| 158 | 
            +
                    )
         | 
| 159 | 
            +
                    
         | 
| 160 | 
            +
                    print(f"Model loaded successfully on device: {next(self.model.parameters()).device}")
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def get_input_ids(self, text_prompt: str) -> tuple[torch.tensor]:
         | 
| 163 | 
            +
                    """Prepare input tokens for the model"""
         | 
| 164 | 
            +
                    START_OF_HUMAN = self.player.start_of_human
         | 
| 165 | 
            +
                    END_OF_TEXT = self.player.end_of_text
         | 
| 166 | 
            +
                    END_OF_HUMAN = self.player.end_of_human
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    # Tokenize input text
         | 
| 169 | 
            +
                    input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids
         | 
| 170 | 
            +
                    
         | 
| 171 | 
            +
                    # Add special tokens
         | 
| 172 | 
            +
                    start_token = torch.tensor([[START_OF_HUMAN]], dtype=torch.int64)
         | 
| 173 | 
            +
                    end_tokens = torch.tensor([[END_OF_TEXT, END_OF_HUMAN]], dtype=torch.int64)
         | 
| 174 | 
            +
                    
         | 
| 175 | 
            +
                    # Concatenate tokens
         | 
| 176 | 
            +
                    modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
         | 
| 177 | 
            +
                    attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64)
         | 
| 178 | 
            +
                    
         | 
| 179 | 
            +
                    print(f"Input sequence length: {modified_input_ids.shape[1]}")
         | 
| 180 | 
            +
                    return modified_input_ids, attention_mask
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                def model_request(self, input_ids: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
         | 
| 183 | 
            +
                    """Generate tokens using the model"""
         | 
| 184 | 
            +
                    input_ids = input_ids.to(self.device)
         | 
| 185 | 
            +
                    attention_mask = attention_mask.to(self.device)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    print("Starting model generation...")
         | 
| 188 | 
            +
                    print(f"Generation parameters: max_tokens={self.conf.max_new_tokens}, "
         | 
| 189 | 
            +
                          f"temp={self.conf.temperature}, top_p={self.conf.top_p}")
         | 
| 190 | 
            +
                    
         | 
| 191 | 
            +
                    with torch.no_grad():
         | 
| 192 | 
            +
                        generated_ids = self.model.generate(
         | 
| 193 | 
            +
                            input_ids=input_ids,
         | 
| 194 | 
            +
                            attention_mask=attention_mask,
         | 
| 195 | 
            +
                            max_new_tokens=self.conf.max_new_tokens,
         | 
| 196 | 
            +
                            do_sample=True,
         | 
| 197 | 
            +
                            temperature=self.conf.temperature,
         | 
| 198 | 
            +
                            top_p=self.conf.top_p,
         | 
| 199 | 
            +
                            repetition_penalty=self.conf.repetition_penalty,
         | 
| 200 | 
            +
                            num_return_sequences=1,
         | 
| 201 | 
            +
                            eos_token_id=self.player.end_of_speech,
         | 
| 202 | 
            +
                            pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
         | 
| 203 | 
            +
                        )
         | 
| 204 | 
            +
                    
         | 
| 205 | 
            +
                    print(f"Generated sequence length: {generated_ids.shape[1]}")
         | 
| 206 | 
            +
                    return generated_ids.to('cpu')
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                def run_model(self, text: str):
         | 
| 209 | 
            +
                    """Complete pipeline: text -> tokens -> generation -> audio"""
         | 
| 210 | 
            +
                    print(f"Processing text: '{text[:50]}{'...' if len(text) > 50 else ''}'")
         | 
| 211 | 
            +
                    
         | 
| 212 | 
            +
                    # Prepare input
         | 
| 213 | 
            +
                    input_ids, attention_mask = self.get_input_ids(text)
         | 
| 214 | 
            +
                    
         | 
| 215 | 
            +
                    # Generate tokens
         | 
| 216 | 
            +
                    model_output = self.model_request(input_ids, attention_mask)
         | 
| 217 | 
            +
                    
         | 
| 218 | 
            +
                    # Convert to audio
         | 
| 219 | 
            +
                    audio, _ = self.player.get_waveform(model_output)
         | 
| 220 | 
            +
                    
         | 
| 221 | 
            +
                    print("Text-to-speech generation completed successfully!")
         | 
| 222 | 
            +
                    return audio, text
         | 
