import random from typing import Iterator from data_handlers import get_melody_handler from .utils.g2p import preprocess_text class MelodyController: def __init__(self, melody_source_id: str, cache_dir: str): self.melody_source_id = melody_source_id self.song_id = None # load song database if needed parts = self.melody_source_id.split("-") self.mode = parts[0] self.align_type = parts[1] dataset_name = parts[-1] if dataset_name == "none": self.database = None else: handler_cls = get_melody_handler(dataset_name) self.database = handler_cls(self.align_type, cache_dir) def get_melody_constraints(self, max_num_phrases: int = 5) -> str: """Return a lyric-format prompt based on melody structure.""" if self.mode == "gen": return "" elif self.mode == "sample": assert self.database is not None, "Song database is not loaded." self.song_id = random.choice(self.database.get_song_ids()) self.reference_song = self.database.iter_song_phrases(self.song_id) phrase_length = self.database.get_phrase_length(self.song_id) if not phrase_length: return "" prompt = ( "\n请按照歌词格式回复,每句需遵循以下字数规则:" + "".join( [ f"\n第{i}句:{c}个字" for i, c in enumerate(phrase_length[:max_num_phrases], 1) ] ) + "\n如果没有足够的信息回答,请使用最少的句子,不要重复、不要扩展、不要加入无关内容。\n" ) return prompt else: raise ValueError(f"Unsupported melody mode: {self.mode}") def generate_score( self, lyrics: str, language: str ) -> list[tuple[float, float, str, int]]: """ lyrics: [lyric, ...] returns: [(start, end, lyric, pitch), ...] """ text_list = preprocess_text(lyrics, language) if self.mode == "gen" and self.align_type == "random": return self._generate_random_score(text_list) elif self.mode == "sample": if not self.reference_song: raise RuntimeError( "Must call get_melody_constraints() before generate_score() in sample mode." ) return self._align_text_to_score( text_list, self.reference_song, self.align_type ) else: raise ValueError(f"Unsupported melody_source_id: {self.melody_source_id}") def _generate_random_score(self, text_list: list[str]): st = 0 score = [] for lyric in text_list: pitch = random.randint(57, 69) duration = round(random.uniform(0.1, 0.5), 4) ed = st + duration score.append((st, ed, lyric, pitch)) st = ed return score def _align_text_to_score( self, text_list: list[str], song_phrase_iterator: Iterator[dict], align_type: str, ): score = [] text_idx = 0 while text_idx < len(text_list): reference = next(song_phrase_iterator) for st, ed, ref_lyric, pitch in zip( reference["note_start_times"], reference["note_end_times"], reference["note_lyrics"], reference["note_midi"], ): assert ref_lyric not in [ "", "", ], f"Proccessed {self.melody_source_id} score segments should not contain or ." # TODO: remove in PR, only for debug if pitch == 0: score.append((st, ed, ref_lyric, pitch)) elif ref_lyric in ["-", "——"] and align_type == "lyric": score.append((st, ed, "-", pitch)) else: score.append((st, ed, text_list[text_idx], pitch)) text_idx += 1 if text_idx >= len(text_list): break return score