from pydub import AudioSegment, silence import tempfile import hashlib import matplotlib.pylab as plt import librosa from transformers import pipeline import re import torch import numpy as np import os from scipy.io import wavfile from scipy.signal import resample_poly _ref_audio_cache = {} asr_pipe = None def resample_to_24khz(input_path: str, output_path: str): """ Resample WAV audio file to 24,000 Hz using scipy. Parameters: - input_path (str): Path to the input WAV file. - output_path (str): Path to save the output WAV file. """ # Load WAV file orig_sr, audio = wavfile.read(input_path) # Convert to mono if stereo if len(audio.shape) == 2: audio = audio.mean(axis=1) # Convert to float32 for processing if audio.dtype != np.float32: audio = audio.astype(np.float32) / np.iinfo(audio.dtype).max # Resample target_sr = 24000 resampled = resample_poly(audio, target_sr, orig_sr) # Convert back to int16 for saving resampled_int16 = (resampled * 32767).astype(np.int16) # Save output wavfile.write(output_path, target_sr, resampled_int16) def chunk_text(text, max_chars=135): # print(text) # Bước 1: Tách câu theo dấu ". " sentences = [s.strip() for s in text.split('. ') if s.strip()] # Ghép câu ngắn hơn 4 từ với câu liền kề i = 0 while i < len(sentences): if len(sentences[i].split()) < 4: if i == 0 and i + 1 < len(sentences): # Ghép với câu sau sentences[i + 1] = sentences[i] + ', ' + sentences[i + 1] del sentences[i] else: if i - 1 >= 0: # Ghép với câu trước sentences[i - 1] = sentences[i - 1] + ', ' + sentences[i] del sentences[i] i -= 1 else: i += 1 # print(sentences) # Bước 2: Tách phần quá dài trong câu theo dấu ", " final_sentences = [] for sentence in sentences: parts = [p.strip() for p in sentence.split(', ')] buffer = [] for part in parts: buffer.append(part) total_words = sum(len(p.split()) for p in buffer) if total_words > 20: # Tách câu ra long_part = ', '.join(buffer) final_sentences.append(long_part) buffer = [] if buffer: final_sentences.append(', '.join(buffer)) # print(final_sentences) if len(final_sentences[-1].split()) < 4 and len(final_sentences) >= 2: final_sentences[-2] = final_sentences[-2] + ", " + final_sentences[-1] final_sentences = final_sentences[0:-1] # print(final_sentences) return final_sentences def initialize_asr_pipeline(device="cuda", dtype=None): if dtype is None: dtype = ( torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 and not torch.cuda.get_device_name().endswith("[ZLUDA]") else torch.float32 ) global asr_pipe asr_pipe = pipeline( "automatic-speech-recognition", model="vinai/PhoWhisper-medium", torch_dtype=dtype, device=device, ) # transcribe def transcribe(ref_audio, language=None): global asr_pipe if asr_pipe is None: initialize_asr_pipeline(device="cuda") return asr_pipe( ref_audio, chunk_length_s=30, batch_size=128, generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"}, return_timestamps=False, )["text"].strip() def caculate_spec(audio): # Compute spectrogram (Short-Time Fourier Transform) stft = librosa.stft(audio, n_fft=512, hop_length=256, win_length=512) spectrogram = np.abs(stft) # Convert to dB spectrogram_db = librosa.amplitude_to_db(spectrogram, ref=np.max) return spectrogram_db def save_spectrogram(audio, path): spectrogram = caculate_spec(audio) plt.figure(figsize=(12, 4)) plt.imshow(spectrogram, origin="lower", aspect="auto") plt.colorbar() plt.savefig(path) plt.close() def remove_silence_edges(audio, silence_threshold=-42): # Remove silence from the start non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold) audio = audio[non_silent_start_idx:] # Remove silence from the end non_silent_end_duration = audio.duration_seconds for ms in reversed(audio): if ms.dBFS > silence_threshold: break non_silent_end_duration -= 0.001 trimmed_audio = audio[: int(non_silent_end_duration * 1000)] return trimmed_audio def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device="cuda"): show_info("Converting audio...") # ref_audio_orig_converted = ref_audio_orig.replace(".wav", "_24k.wav").replace(".mp3", "_24k.mp3").replace(".m4a", "_24k.m4a").replace(".flac", "_24k.flac") # resample_to_24khz(ref_audio_orig, ref_audio_orig_converted) # ref_audio_orig = ref_audio_orig_converted with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: aseg = AudioSegment.from_file(ref_audio_orig) if clip_short: # 1. try to find long silence for clipping non_silent_segs = silence.split_on_silence( aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10 ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000: show_info("Audio is over 15s, clipping short. (1)") break non_silent_wave += non_silent_seg # 2. try to find short silence for clipping if 1. failed if len(non_silent_wave) > 15000: non_silent_segs = silence.split_on_silence( aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10 ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000: show_info("Audio is over 15s, clipping short. (2)") break non_silent_wave += non_silent_seg aseg = non_silent_wave # 3. if no proper silence found for clipping if len(aseg) > 15000: aseg = aseg[:15000] show_info("Audio is over 15s, clipping short. (3)") aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50) aseg.export(f.name, format="wav") ref_audio = f.name # Compute a hash of the reference audio file with open(ref_audio, "rb") as audio_file: audio_data = audio_file.read() audio_hash = hashlib.md5(audio_data).hexdigest() if not ref_text.strip(): global _ref_audio_cache if audio_hash in _ref_audio_cache: # Use cached asr transcription show_info("Using cached reference text...") ref_text = _ref_audio_cache[audio_hash] else: show_info("No reference text provided, transcribing reference audio...") ref_text = transcribe(ref_audio) # Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak) _ref_audio_cache[audio_hash] = ref_text else: show_info("Using custom reference text...") # Ensure ref_text ends with a proper sentence-ending punctuation if not ref_text.endswith(". ") and not ref_text.endswith("。"): if ref_text.endswith("."): ref_text += " " else: ref_text += ". " print("\nref_text ", ref_text) return ref_audio, ref_text