File size: 4,895 Bytes
91394e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780954b
91394e0
 
780954b
 
 
91394e0
 
 
 
780954b
91394e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a23964
91394e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5749621
 
 
 
91394e0
 
 
 
 
7a23964
91394e0
 
 
f1b8d35
91394e0
 
 
d5992a5
 
91394e0
 
 
 
 
 
f1b8d35
91394e0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from typing import Callable

import numpy as np

from modules.utils.g2p import (
    kana_to_phonemes_openjtalk,
    pinyin_to_phonemes_ace,
    pinyin_to_phonemes_opencpop,
)

from .base import AbstractSVSModel
from .registry import register_svs_model


@register_svs_model("espnet/")
class ESPNetSVS(AbstractSVSModel):
    def __init__(self, model_id: str, device="auto", cache_dir="cache", **kwargs):
        from espnet2.bin.svs_inference import SingingGenerate
        from espnet_model_zoo.downloader import ModelDownloader
        if device == "auto":
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device
        downloaded = ModelDownloader(cache_dir).download_and_unpack(model_id)
        self.model = SingingGenerate(
            train_config=downloaded["train_config"],
            model_file=downloaded["model_file"],
            device=self.device,
        )
        self.model_id = model_id
        self.output_sample_rate = self.model.fs
        self.phoneme_mappers = self._build_phoneme_mappers()

    def _build_phoneme_mappers(self) -> dict[str, Callable[[str], list[str]]]:
        if self.model_id == "espnet/aceopencpop_svs_visinger2_40singer_pretrain":
            phoneme_mappers = {
                "mandarin": pinyin_to_phonemes_opencpop,
            }
        elif self.model_id == "espnet/mixdata_svs_visinger2_spkemb_lang_pretrained":

            def mandarin_mapper(pinyin: str) -> list[str]:
                phns = pinyin_to_phonemes_ace(pinyin)
                return [phn + "@zh" for phn in phns]

            def japanese_mapper(kana: str) -> list[str]:
                phones = kana_to_phonemes_openjtalk(kana)
                return [phn + "@jp" for phn in phones]

            phoneme_mappers = {
                "mandarin": mandarin_mapper,
                "japanese": japanese_mapper,
            }
        else:
            phoneme_mappers = {}
        return phoneme_mappers

    def _preprocess(self, score: list[tuple[float, float, str, int] | tuple[float, float, str, float]], language: str):
        if language not in self.phoneme_mappers:
            raise ValueError(f"Unsupported language: {language} for {self.model_id}")
        phoneme_mapper = self.phoneme_mappers[language]

        # text to phoneme
        notes = []
        phns = []
        pre_phn = None
        for st, ed, text, pitch in score:
            assert text not in [
                "<AP>",
                "<SP>",
            ], f"Proccessed score segments should not contain <AP> or <SP>. {score}"  # TODO: remove in PR, only for debug
            if text == "AP" or text == "SP":
                lyric_units = [text]
                phn_units = [text]
            elif text == "-" or text == "——":
                lyric_units = [text]
                if pre_phn is None:
                    raise ValueError(
                        f"Text `{text}` cannot be recognized by {self.model_id}. Lyrics cannot start with a lyric continuation symbol `-` or `——`"
                    )
                phn_units = [pre_phn]
            else:
                try:
                    lyric_units = phoneme_mapper(text)
                except ValueError as e:
                    raise ValueError(
                        f"Text `{text}` cannot be recognized by {self.model_id}"
                    ) from e
                phn_units = lyric_units
            notes.append((st, ed, "".join(lyric_units), pitch, "_".join(phn_units)))
            phns.extend(phn_units)
            pre_phn = phn_units[-1]

        batch = {
            "score": (
                120,  # does not affect svs result, as note durations are in time unit
                notes,
            ),
            "text": " ".join(phns),
        }
        return batch

    def synthesize(
        self, score: list[tuple[float, float, str, float] | tuple[float, float, str, int]], language: str, speaker: str, **kwargs
    ):
        batch = self._preprocess(score, language)
        if self.model_id == "espnet/aceopencpop_svs_visinger2_40singer_pretrain":
            sid = np.array([int(speaker)])
            output_dict = self.model(batch, sids=sid)
        elif self.model_id == "espnet/mixdata_svs_visinger2_spkemb_lang_pretrained":
            langs = {
                "mandarin": 2,
                "japanese": 1,
            }
            if language not in langs:
                raise ValueError(
                    f"Unsupported language: {language} for {self.model_id}"
                )
            lid = np.array([langs[language]])
            spk_embed = np.load(speaker)
            output_dict = self.model(batch, lids=lid, spembs=spk_embed)
        else:
            raise NotImplementedError(f"Model {self.model_id} not supported")
        wav_info = output_dict["wav"].cpu().numpy()
        return wav_info, self.output_sample_rate