File size: 2,681 Bytes
50dd0bc
 
20bd7b4
50dd0bc
 
 
120744c
50dd0bc
 
 
 
 
 
 
2ce9d86
20bd7b4
 
 
2ce9d86
 
50dd0bc
 
 
20bd7b4
50dd0bc
 
 
 
 
 
 
 
 
 
 
 
 
120744c
50dd0bc
2ce9d86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50dd0bc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from argparse import ArgumentParser
from logging import getLogger
from pathlib import Path

import yaml

from characters import get_character
from pipeline import SingingDialoguePipeline

logger = getLogger(__name__)


def get_parser():
    parser = ArgumentParser()
    parser.add_argument("--query_audios", nargs="+", type=Path, required=True)
    parser.add_argument(
        "--config_path", type=Path, default="config/cli/yaoyin_default.yaml"
    )
    parser.add_argument("--output_audio_folder", type=Path, required=True)
    parser.add_argument("--eval_results_csv", type=Path, required=True)
    return parser


def load_config(config_path: Path):
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    return config


def main():
    parser = get_parser()
    args = parser.parse_args()
    config = load_config(args.config_path)
    pipeline = SingingDialoguePipeline(config)
    speaker = config["speaker"]
    language = config["language"]
    character_name = config["prompt_template_character"]
    character = get_character(character_name)
    prompt_template = character.prompt
    args.output_audio_folder.mkdir(parents=True, exist_ok=True)
    args.eval_results_csv.parent.mkdir(parents=True, exist_ok=True)
    with open(args.eval_results_csv, "a") as f:
        f.write(
            f"query_audio,asr_model,llm_model,svs_model,melody_source,language,speaker,output_audio,asr_text,llm_text,metrics\n"
        )
    try:
        for query_audio in args.query_audios:
            output_audio = args.output_audio_folder / f"{query_audio.stem}_response.wav"
            results = pipeline.run(
                query_audio,
                language,
                prompt_template,
                speaker,
                output_audio_path=output_audio,
            )
            metrics = pipeline.evaluate(output_audio, **results)
            metrics.update(results.get("metrics", {}))
            metrics_str = ",".join([f"{metrics[k]}" for k in sorted(metrics.keys())])
            logger.info(
                f"Input: {query_audio}, Output: {output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
            )
            with open(args.eval_results_csv, "a") as f:
                f.write(
                    f"{query_audio},{config['asr_model']},{config['llm_model']},{config['svs_model']},{config['melody_source']},{config['language']},{config['speaker']},{output_audio},{results['asr_text']},{results['llm_text']},{metrics_str}\n"
                )
    except Exception as e:
        logger.error(f"Error in main: {e}")
        breakpoint()
        raise e


if __name__ == "__main__":
    main()