Spaces:
Sleeping
Sleeping
File size: 4,630 Bytes
b5e825c 91394e0 b5e825c 91394e0 b5e825c 91394e0 11e246d 91394e0 2195601 11e246d 2195601 91394e0 2195601 11e246d 2195601 91394e0 2195601 11e246d 2195601 91394e0 780954b f1b8d35 b5e825c 91394e0 780954b 91394e0 780954b 91394e0 b5e825c 91394e0 f1b8d35 91394e0 b5e825c 91394e0 b5e825c 91394e0 1a42cf5 91394e0 1a42cf5 91394e0 ea9e31f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from __future__ import annotations
import time
from pathlib import Path
import librosa
import soundfile as sf
import torch
from modules.asr import get_asr_model
from modules.llm import get_llm_model
from modules.svs import get_svs_model
from evaluation.svs_eval import load_evaluators, run_evaluation
from modules.melody import MelodyController
from modules.utils.text_normalize import clean_llm_output
class SingingDialoguePipeline:
def __init__(self, config: dict):
if "device" in config:
self.device = config["device"]
else:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.cache_dir = config["cache_dir"]
self.asr = get_asr_model(
config["asr_model"], device=self.device, cache_dir=self.cache_dir
)
self.llm = get_llm_model(
config["llm_model"], device=self.device, cache_dir=self.cache_dir
)
self.svs = get_svs_model(
config["svs_model"], device=self.device, cache_dir=self.cache_dir
)
self.melody_controller = MelodyController(
config["melody_source"], self.cache_dir
)
self.max_sentences = config.get("max_sentences", 5)
self.track_latency = config.get("track_latency", False)
self.evaluators = load_evaluators(config.get("evaluators", {}).get("svs", []))
def set_asr_model(self, asr_model: str):
if self.asr is not None:
del self.asr
import gc
gc.collect()
torch.cuda.empty_cache()
self.asr = get_asr_model(
asr_model, device=self.device, cache_dir=self.cache_dir
)
def set_llm_model(self, llm_model: str):
if self.llm is not None:
del self.llm
import gc
gc.collect()
torch.cuda.empty_cache()
self.llm = get_llm_model(
llm_model, device=self.device, cache_dir=self.cache_dir
)
def set_svs_model(self, svs_model: str):
if self.svs is not None:
del self.svs
import gc
gc.collect()
torch.cuda.empty_cache()
self.svs = get_svs_model(
svs_model, device=self.device, cache_dir=self.cache_dir
)
def set_melody_controller(self, melody_source: str):
self.melody_controller = MelodyController(melody_source, self.cache_dir)
def run(
self,
audio_path,
language,
system_prompt,
speaker,
output_audio_path: Path | str = None,
):
if self.track_latency:
asr_start_time = time.time()
audio_array, audio_sample_rate = librosa.load(audio_path, sr=16000)
asr_result = self.asr.transcribe(
audio_array, audio_sample_rate=audio_sample_rate, language=language
)
if self.track_latency:
asr_end_time = time.time()
asr_latency = asr_end_time - asr_start_time
melody_prompt = self.melody_controller.get_melody_constraints(max_num_phrases=self.max_sentences)
if self.track_latency:
llm_start_time = time.time()
output = self.llm.generate(asr_result, system_prompt + melody_prompt)
if self.track_latency:
llm_end_time = time.time()
llm_latency = llm_end_time - llm_start_time
llm_response = clean_llm_output(
output, language=language, max_sentences=self.max_sentences
)
score = self.melody_controller.generate_score(llm_response, language)
if self.track_latency:
svs_start_time = time.time()
singing_audio, sample_rate = self.svs.synthesize(
score, language=language, speaker=speaker
)
if self.track_latency:
svs_end_time = time.time()
svs_latency = svs_end_time - svs_start_time
results = {
"asr_text": asr_result,
"llm_text": llm_response,
"svs_audio": (sample_rate, singing_audio),
}
if output_audio_path:
Path(output_audio_path).parent.mkdir(parents=True, exist_ok=True)
sf.write(output_audio_path, singing_audio, sample_rate)
results["output_audio_path"] = output_audio_path
if self.track_latency:
results["metrics"] = {
"asr_latency": asr_latency,
"llm_latency": llm_latency,
"svs_latency": svs_latency,
}
return results
def evaluate(self, audio_path, **kwargs):
return run_evaluation(audio_path, self.evaluators, **kwargs)
|