Spaces:
Paused
Paused
| import logging | |
| import os | |
| import sys | |
| from io import BytesIO | |
| from json import loads | |
| import av | |
| import pilk | |
| from flask import Flask | |
| from torch import load, FloatTensor | |
| from numpy import float32 | |
| import librosa | |
| from voice import Voice | |
| app = Flask(__name__) | |
| app.config.from_pyfile("config.py") | |
| class HParams(): | |
| def __init__(self, **kwargs): | |
| for k, v in kwargs.items(): | |
| if type(v) == dict: | |
| v = HParams(**v) | |
| self[k] = v | |
| def keys(self): | |
| return self.__dict__.keys() | |
| def items(self): | |
| return self.__dict__.items() | |
| def values(self): | |
| return self.__dict__.values() | |
| def __len__(self): | |
| return len(self.__dict__) | |
| def __getitem__(self, key): | |
| return getattr(self, key) | |
| def __setitem__(self, key, value): | |
| return setattr(self, key, value) | |
| def __contains__(self, key): | |
| return key in self.__dict__ | |
| def __repr__(self): | |
| return self.__dict__.__repr__() | |
| def load_checkpoint(checkpoint_path, model): | |
| checkpoint_dict = load(checkpoint_path, map_location='cpu') | |
| iteration = checkpoint_dict['iteration'] | |
| saved_state_dict = checkpoint_dict['model'] | |
| if hasattr(model, 'module'): | |
| state_dict = model.module.state_dict() | |
| else: | |
| state_dict = model.state_dict() | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| try: | |
| new_state_dict[k] = saved_state_dict[k] | |
| except: | |
| logging.info("%s is not in the checkpoint" % k) | |
| new_state_dict[k] = v | |
| if hasattr(model, 'module'): | |
| model.module.load_state_dict(new_state_dict) | |
| else: | |
| model.load_state_dict(new_state_dict) | |
| logging.info("Loaded checkpoint '{}' (iteration {})".format( | |
| checkpoint_path, iteration)) | |
| return | |
| def get_hparams_from_file(config_path): | |
| with open(config_path, "r") as f: | |
| data = f.read() | |
| config = loads(data) | |
| hparams = HParams(**config) | |
| return hparams | |
| def load_audio_to_torch(full_path, target_sampling_rate): | |
| audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True) | |
| return FloatTensor(audio.astype(float32)) | |
| def wav2ogg(input, output): | |
| with av.open(input, 'rb') as i: | |
| with av.open(output, 'wb', format='ogg') as o: | |
| out_stream = o.add_stream('libvorbis') | |
| for frame in i.decode(audio=0): | |
| for p in out_stream.encode(frame): | |
| o.mux(p) | |
| for p in out_stream.encode(None): | |
| o.mux(p) | |
| # def wav2silk(input, output): | |
| # with av.open(input) as in_wav: | |
| # in_stream = in_wav.streams.audio[0] | |
| # sample_rate = in_stream.codec_context.sample_rate | |
| # with BytesIO() as pcm: | |
| # with av.open(pcm, 'w', 's16le') as out_pcm: | |
| # out_stream = out_pcm.add_stream( | |
| # 'pcm_s16le', | |
| # rate=sample_rate, | |
| # layout='mono' | |
| # ) | |
| # for frame in in_wav.decode(in_stream): | |
| # frame.pts = None | |
| # for packet in out_stream.encode(frame): | |
| # out_pcm.mux(packet) | |
| # | |
| # pilk.encode(out_pcm, output, pcm_rate=sample_rate, tencent=True) | |
| def to_pcm(in_path: str) -> tuple[str, int]: | |
| out_path = os.path.splitext(in_path)[0] + '.pcm' | |
| with av.open(in_path) as in_container: | |
| in_stream = in_container.streams.audio[0] | |
| sample_rate = in_stream.codec_context.sample_rate | |
| with av.open(out_path, 'w', 's16le') as out_container: | |
| out_stream = out_container.add_stream( | |
| 'pcm_s16le', | |
| rate=sample_rate, | |
| layout='mono' | |
| ) | |
| try: | |
| for frame in in_container.decode(in_stream): | |
| frame.pts = None | |
| for packet in out_stream.encode(frame): | |
| out_container.mux(packet) | |
| except: | |
| pass | |
| return out_path, sample_rate | |
| def convert_to_silk(media_path: str) -> str: | |
| pcm_path, sample_rate = to_pcm(media_path) | |
| silk_path = os.path.splitext(pcm_path)[0] + '.silk' | |
| pilk.encode(pcm_path, silk_path, pcm_rate=sample_rate, tencent=True) | |
| os.remove(pcm_path) | |
| return silk_path | |
| def clean_folder(folder_path): | |
| for filename in os.listdir(folder_path): | |
| file_path = os.path.join(folder_path, filename) | |
| # 如果是文件,则删除文件 | |
| if os.path.isfile(file_path): | |
| os.remove(file_path) | |
| def merge_model(merging_model): | |
| voice_obj = [] | |
| voice_speakers = [] | |
| new_id = 0 | |
| out_path = os.path.dirname(os.path.realpath(sys.argv[0])) + "/out_silk" | |
| for obj_id, i in enumerate(merging_model): | |
| obj = Voice(i[0], i[1], out_path) | |
| for id, name in enumerate(obj.return_speakers()): | |
| voice_obj.append([int(id), obj, obj_id]) | |
| voice_speakers.append({new_id: name}) | |
| new_id += 1 | |
| return voice_obj, voice_speakers | |