Spaces:
Running
Running
from __future__ import annotations | |
import time | |
from pathlib import Path | |
import librosa | |
import soundfile as sf | |
import torch | |
from modules.asr import get_asr_model | |
from modules.llm import get_llm_model | |
from modules.svs import get_svs_model | |
from evaluation.svs_eval import load_evaluators, run_evaluation | |
from modules.melody import MelodyController | |
from modules.utils.text_normalize import clean_llm_output | |
class SingingDialoguePipeline: | |
def __init__(self, config: dict): | |
if "device" in config: | |
self.device = config["device"] | |
else: | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.cache_dir = config["cache_dir"] | |
self.asr = get_asr_model( | |
config["asr_model"], device=self.device, cache_dir=self.cache_dir | |
) | |
self.llm = get_llm_model( | |
config["llm_model"], device=self.device, cache_dir=self.cache_dir | |
) | |
self.svs = get_svs_model( | |
config["svs_model"], device=self.device, cache_dir=self.cache_dir | |
) | |
self.melody_controller = MelodyController( | |
config["melody_source"], self.cache_dir | |
) | |
self.max_sentences = config.get("max_sentences", 5) | |
self.track_latency = config.get("track_latency", False) | |
self.evaluators = load_evaluators(config.get("evaluators", {}).get("svs", [])) | |
def set_asr_model(self, asr_model: str): | |
if self.asr is not None: | |
del self.asr | |
import gc | |
gc.collect() | |
torch.cuda.empty_cache() | |
self.asr = get_asr_model( | |
asr_model, device=self.device, cache_dir=self.cache_dir | |
) | |
def set_llm_model(self, llm_model: str): | |
if self.llm is not None: | |
del self.llm | |
import gc | |
gc.collect() | |
torch.cuda.empty_cache() | |
self.llm = get_llm_model( | |
llm_model, device=self.device, cache_dir=self.cache_dir | |
) | |
def set_svs_model(self, svs_model: str): | |
if self.svs is not None: | |
del self.svs | |
import gc | |
gc.collect() | |
torch.cuda.empty_cache() | |
self.svs = get_svs_model( | |
svs_model, device=self.device, cache_dir=self.cache_dir | |
) | |
def set_melody_controller(self, melody_source: str): | |
self.melody_controller = MelodyController(melody_source, self.cache_dir) | |
def run( | |
self, | |
audio_path, | |
language, | |
system_prompt, | |
speaker, | |
output_audio_path: Path | str = None, | |
): | |
if self.track_latency: | |
asr_start_time = time.time() | |
audio_array, audio_sample_rate = librosa.load(audio_path, sr=16000) | |
asr_result = self.asr.transcribe( | |
audio_array, audio_sample_rate=audio_sample_rate, language=language | |
) | |
if self.track_latency: | |
asr_end_time = time.time() | |
asr_latency = asr_end_time - asr_start_time | |
melody_prompt = self.melody_controller.get_melody_constraints(max_num_phrases=self.max_sentences) | |
if self.track_latency: | |
llm_start_time = time.time() | |
output = self.llm.generate(asr_result, system_prompt + melody_prompt) | |
if self.track_latency: | |
llm_end_time = time.time() | |
llm_latency = llm_end_time - llm_start_time | |
llm_response = clean_llm_output( | |
output, language=language, max_sentences=self.max_sentences | |
) | |
score = self.melody_controller.generate_score(llm_response, language) | |
if self.track_latency: | |
svs_start_time = time.time() | |
singing_audio, sample_rate = self.svs.synthesize( | |
score, language=language, speaker=speaker | |
) | |
if self.track_latency: | |
svs_end_time = time.time() | |
svs_latency = svs_end_time - svs_start_time | |
results = { | |
"asr_text": asr_result, | |
"llm_text": llm_response, | |
"svs_audio": (sample_rate, singing_audio), | |
} | |
if output_audio_path: | |
Path(output_audio_path).parent.mkdir(parents=True, exist_ok=True) | |
sf.write(output_audio_path, singing_audio, sample_rate) | |
results["output_audio_path"] = output_audio_path | |
if self.track_latency: | |
results["metrics"] = { | |
"asr_latency": asr_latency, | |
"llm_latency": llm_latency, | |
"svs_latency": svs_latency, | |
} | |
return results | |
def evaluate(self, audio_path, **kwargs): | |
return run_evaluation(audio_path, self.evaluators, **kwargs) | |