import torch from transformers import VitsModel, AutoTokenizer, VitsConfig import soundfile as sf import numpy as np import gradio as gr import os from thaicleantext import clean_thai_text def load_tts_model(pth_path, speed=1.0): """Load the TTS model from a .pth file""" try: loaded_dict = torch.load(pth_path, map_location=torch.device('cpu')) config = VitsConfig(**loaded_dict['config']) model = VitsModel(config) model.load_state_dict(loaded_dict['model_state']) model.eval() model.speaking_rate = speed tokenizer = AutoTokenizer.from_pretrained("VIZINTZOR/tts-tha-vits") return model, tokenizer, None except Exception as e: return None, None, f"Error loading model: {str(e)}" def generate_speech(model, tokenizer, text, speed, volume, output_file="output.wav"): """Generate speech from text and save to file""" try: model.speaking_rate = speed inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): waveform = model(**inputs).waveform waveform = waveform.squeeze().cpu().numpy() waveform = waveform / np.max(np.abs(waveform)) # Normalize to [-1, 1] waveform = waveform * volume # Apply volume adjustment sample_rate = model.config.sampling_rate sf.write(output_file, waveform, sample_rate) return output_file, None except Exception as e: return None, f"Error generating speech: {str(e)}" def get_available_models(model_dir="./models"): """Get list of .pth files in the models directory""" if not os.path.exists(model_dir): return [] return [os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith('.pth')] def tts_interface(text, model_path, speed, volume): """Gradio interface function""" model, tokenizer, error = load_tts_model(model_path, speed) if model is None or tokenizer is None: return None, error output_file = "output.wav" text = clean_thai_text(text) audio_file, error = generate_speech(model, tokenizer, text, speed, volume, output_file) if audio_file: return audio_file, "Audio generated successfully!" return None, error # Create Gradio interface with gr.Blocks(title="Text-to-Speech Generator", theme=gr.themes.Soft()) as demo: gr.Markdown("# Text-to-Speech Generator") gr.Markdown("Enter text, select a model, adjust speed and volume, and generate audio!") with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox( label="Input Text", placeholder="Enter your text here...", lines=5 ) model_dropdown = gr.Dropdown( label="Select Model", choices=get_available_models(), value=get_available_models()[0] if get_available_models() else None ) with gr.Column(scale=1): speed_slider = gr.Slider( minimum=0.5, maximum=2.0, value=1.0, step=0.05, label="Speaking Speed", info="1.0 is normal speed" ) volume_slider = gr.Slider( minimum=0.1, maximum=1.0, value=1, step=0.05, label="Volume", info="Adjust output volume" ) generate_btn = gr.Button("Generate Audio", variant="primary") with gr.Row(): audio_output = gr.Audio(label="Generated Audio") status_output = gr.Textbox(label="Status", interactive=False) # Connect the button to the function generate_btn.click( fn=tts_interface, inputs=[text_input, model_dropdown, speed_slider, volume_slider], outputs=[audio_output, status_output] ) # Launch the interface demo.launch()