Spaces:
Running
Running
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# MIT License | |
# | |
# Copyright (c) 2020 Jungil Kong | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# The following functions/classes were based on code from https://github.com/jik876/hifi-gan: | |
# init_weights, get_padding, AttrDict | |
import ctypes | |
import glob | |
import os | |
import re | |
import shutil | |
import warnings | |
from collections import defaultdict, OrderedDict | |
from pathlib import Path | |
from typing import Optional | |
import librosa | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from scipy.io.wavfile import read | |
def mask_from_lens(lens, max_len: Optional[int] = None): | |
if max_len is None: | |
max_len = lens.max() | |
ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype) | |
mask = torch.lt(ids, lens.unsqueeze(1)) | |
return mask | |
def load_wav(full_path, torch_tensor=False): | |
import soundfile # flac | |
data, sampling_rate = soundfile.read(full_path, dtype='int16') | |
if torch_tensor: | |
return torch.FloatTensor(data.astype(np.float32)), sampling_rate | |
else: | |
return data, sampling_rate | |
def load_wav_to_torch(full_path, force_sampling_rate=None): | |
if force_sampling_rate is not None: | |
data, sampling_rate = librosa.load(full_path, sr=force_sampling_rate) | |
else: | |
sampling_rate, data = read(full_path) | |
return torch.FloatTensor(data.astype(np.float32)), sampling_rate | |
def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"): | |
def split_line(root, line): | |
parts = line.strip().split(split) | |
if has_speakers: | |
#ANT: is this ok? | |
paths, non_paths = parts[:2], parts[2:] | |
#paths, non_paths = parts[:-2], parts[-2:] | |
else: | |
paths, non_paths = parts[:-1], parts[-1:] | |
return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths) | |
fpaths_and_text = [] | |
for fname in fnames: | |
with open(fname, encoding='utf-8') as f: | |
fpaths_and_text += [split_line(dataset_path, line) for line in f] | |
return fpaths_and_text | |
def to_gpu(x): | |
x = x.contiguous() | |
return x.cuda(non_blocking=True) if torch.cuda.is_available() else x | |
def l2_promote(): | |
_libcudart = ctypes.CDLL('libcudart.so') | |
# Set device limit on the current device | |
# cudaLimitMaxL2FetchGranularity = 0x05 | |
pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int)) | |
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) | |
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05)) | |
assert pValue.contents.value == 128 | |
def prepare_tmp(path): | |
if path is None: | |
return | |
p = Path(path) | |
if p.is_dir(): | |
warnings.warn(f'{p} exists. Removing...') | |
shutil.rmtree(p, ignore_errors=True) | |
p.mkdir(parents=False, exist_ok=False) | |
def print_once(*msg): | |
if not dist.is_initialized() or dist.get_rank() == 0: | |
print(*msg) | |
def init_weights(m, mean=0.0, std=0.01): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
m.weight.data.normal_(mean, std) | |
def get_padding(kernel_size, dilation=1): | |
return int((kernel_size*dilation - dilation)/2) | |
class AttrDict(dict): | |
def __init__(self, *args, **kwargs): | |
super(AttrDict, self).__init__(*args, **kwargs) | |
self.__dict__ = self | |
class DefaultAttrDict(defaultdict): | |
def __init__(self, *args, **kwargs): | |
super(DefaultAttrDict, self).__init__(*args, **kwargs) | |
self.__dict__ = self | |
def __getattr__(self, item): | |
return self[item] | |
class BenchmarkStats: | |
""" Tracks statistics used for benchmarking. """ | |
def __init__(self): | |
self.num_frames = [] | |
self.losses = [] | |
self.mel_losses = [] | |
self.took = [] | |
def update(self, num_frames, losses, mel_losses, took): | |
self.num_frames.append(num_frames) | |
self.losses.append(losses) | |
self.mel_losses.append(mel_losses) | |
self.took.append(took) | |
def get(self, n_epochs): | |
frames_s = sum(self.num_frames[-n_epochs:]) / sum(self.took[-n_epochs:]) | |
return {'frames/s': frames_s, | |
'loss': np.mean(self.losses[-n_epochs:]), | |
'mel_loss': np.mean(self.mel_losses[-n_epochs:]), | |
'took': np.mean(self.took[-n_epochs:]), | |
'benchmark_epochs_num': n_epochs} | |
def __len__(self): | |
return len(self.losses) | |
class Checkpointer: | |
def __init__(self, save_dir, keep_milestones=[]): | |
self.save_dir = save_dir | |
self.keep_milestones = keep_milestones | |
find = lambda name: [ | |
(int(re.search("_(\d+).pt", fn).group(1)), fn) | |
for fn in glob.glob(f"{save_dir}/{name}_checkpoint_*.pt")] | |
tracked = sorted(find("FastPitch"), key=lambda t: t[0]) | |
self.tracked = OrderedDict(tracked) | |
def last_checkpoint(self, output): | |
def corrupted(fpath): | |
try: | |
torch.load(fpath, map_location="cpu") | |
return False | |
except: | |
warnings.warn(f"Cannot load {fpath}") | |
return True | |
saved = sorted( | |
glob.glob(f"{output}/FastPitch_checkpoint_*.pt"), | |
key=lambda f: int(re.search("_(\d+).pt", f).group(1))) | |
if len(saved) >= 1 and not corrupted(saved[-1]): | |
return saved[-1] | |
elif len(saved) >= 2: | |
return saved[-2] | |
else: | |
return None | |
def maybe_load(self, model, optimizer, scaler, train_state, args, | |
ema_model=None): | |
assert args.checkpoint_path is None or args.resume is False, ( | |
"Specify a single checkpoint source") | |
fpath = None | |
if args.checkpoint_path is not None: | |
fpath = args.checkpoint_path | |
self.tracked = OrderedDict() # Do not track/delete prev ckpts | |
elif args.resume: | |
fpath = self.last_checkpoint(args.output) | |
if fpath is None: | |
return | |
print_once(f"Loading model and optimizer state from {fpath}") | |
ckpt = torch.load(fpath, map_location="cpu") | |
train_state["epoch"] = ckpt["epoch"] + 1 | |
train_state["total_iter"] = ckpt["iteration"] | |
no_pref = lambda sd: {re.sub("^module.", "", k): v for k, v in sd.items()} | |
unwrap = lambda m: getattr(m, "module", m) | |
unwrap(model).load_state_dict(no_pref(ckpt["state_dict"])) | |
if ema_model is not None: | |
unwrap(ema_model).load_state_dict(no_pref(ckpt["ema_state_dict"])) | |
optimizer.load_state_dict(ckpt["optimizer"]) | |
if "scaler" in ckpt: | |
scaler.load_state_dict(ckpt["scaler"]) | |
else: | |
warnings.warn("AMP scaler state missing from the checkpoint.") | |
def maybe_save(self, args, model, ema_model, optimizer, scaler, epoch, | |
total_iter, config): | |
intermediate = (args.epochs_per_checkpoint > 0 | |
and epoch % args.epochs_per_checkpoint == 0) | |
final = epoch == args.epochs | |
if not intermediate and not final and epoch not in self.keep_milestones: | |
return | |
rank = 0 | |
if dist.is_initialized(): | |
dist.barrier() | |
rank = dist.get_rank() | |
if rank != 0: | |
return | |
unwrap = lambda m: getattr(m, "module", m) | |
ckpt = {"epoch": epoch, | |
"iteration": total_iter, | |
"config": config, | |
"train_setup": args.__dict__, | |
"state_dict": unwrap(model).state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"scaler": scaler.state_dict()} | |
if ema_model is not None: | |
ckpt["ema_state_dict"] = unwrap(ema_model).state_dict() | |
fpath = Path(args.output, f"FastPitch_checkpoint_{epoch}.pt") | |
print(f"Saving model and optimizer state at epoch {epoch} to {fpath}") | |
torch.save(ckpt, fpath) | |
# Remove old checkpoints; keep milestones and the last two | |
self.tracked[epoch] = fpath | |
for epoch in set(list(self.tracked)[:-2]) - set(self.keep_milestones): | |
try: | |
os.remove(self.tracked[epoch]) | |
except: | |
pass | |
del self.tracked[epoch] | |