#!/usr/bin/env python3 # # Copyright 2021-2022 Xiaomi Corporation # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Usage: This script loads checkpoints and averages them. python3 -m zipvoice.bin.generate_averaged_model \ --epoch 11 \ --avg 4 \ --model_name zipvoice \ --model-config conf/zipvoice_base.json \ --token-file data/tokens_emilia.txt \ --exp-dir exp/zipvoice It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`. You can later load it by `torch.load("epoch-11-avg-4.pt")`. """ import argparse import json from pathlib import Path import torch from zipvoice.models.zipvoice import ZipVoice from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo from zipvoice.models.zipvoice_distill import ZipVoiceDistill from zipvoice.tokenizer.tokenizer import SimpleTokenizer from zipvoice.utils.checkpoint import ( average_checkpoints_with_averaged_model, find_checkpoints, ) from zipvoice.utils.common import AttributeDict def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--epoch", type=int, default=11, help="""It specifies the checkpoint to use for decoding. Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) parser.add_argument( "--iter", type=int, default=0, help="""If positive, --epoch is ignored and it will use the checkpoint exp_dir/checkpoint-iter.pt. You can specify --avg to use more checkpoints for model averaging. """, ) parser.add_argument( "--avg", type=int, default=4, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch' or --iter", ) parser.add_argument( "--exp-dir", type=str, default="zipvoice/exp_zipvoice", help="The experiment dir", ) parser.add_argument( "--model_name", type=str, default="zipvoice", choices=[ "zipvoice", "zipvoice_distill", "zipvoice_dialog", "zipvoice_dialog_stereo", ], help="The model type to be averaged. ", ) parser.add_argument( "--model-config", type=str, default="conf/zipvoice_base.json", help="The model configuration file.", ) parser.add_argument( "--token-file", type=str, default="data/tokens_emilia.txt", help="The file that contains information that maps tokens to ids," "which is a text file with '{token}\t{token_id}' per line if type is" "char or phone, otherwise it is a bpe_model file.", ) return parser @torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() args.exp_dir = Path(args.exp_dir) params = AttributeDict() params.update(vars(args)) with open(params.model_config, "r") as f: model_config = json.load(f) tokenizer = SimpleTokenizer(token_file=params.token_file) if params.model_name in ["zipvoice", "zipvoice_distill"]: tokenizer_config = { "vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id, } elif params.model_name in ["zipvoice_dialog", "zipvoice_dialog_stereo"]: tokenizer_config = { "vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id, "spk_a_id": tokenizer.spk_a_id, "spk_b_id": tokenizer.spk_a_id, } params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" print("Script started") params.device = torch.device("cpu") print(f"Device: {params.device}") print("About to create model") if params.model_name == "zipvoice": model = ZipVoice( **model_config["model"], **tokenizer_config, ) elif params.model_name == "zipvoice_distill": model = ZipVoiceDistill( **model_config["model"], **tokenizer_config, ) elif params.model_name == "zipvoice_dialog": model = ZipVoiceDialog( **model_config["model"], **tokenizer_config, ) elif params.model_name == "zipvoice_dialog_stereo": model = ZipVoiceDialogStereo( **model_config["model"], **tokenizer_config, ) else: raise ValueError(f"Unknown model name: {params.model_name}") if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ : params.avg + 1 ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( f"Not enough checkpoints ({len(filenames)}) found for" f" --iter {params.iter}, --avg {params.avg}" ) filename_start = filenames[-1] filename_end = filenames[0] print( "Calculating the averaged model over iteration checkpoints" f" from {filename_start} (excluded) to {filename_end}" ) model.to(params.device) model.load_state_dict( average_checkpoints_with_averaged_model( filename_start=filename_start, filename_end=filename_end, device=params.device, ), strict=True, ) else: assert params.avg > 0, params.avg start = params.epoch - params.avg assert start >= 1, start filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" print( f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(params.device) model.load_state_dict( average_checkpoints_with_averaged_model( filename_start=filename_start, filename_end=filename_end, device=params.device, ), strict=True, ) if params.iter > 0: filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" else: filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" torch.save({"model": model.state_dict()}, filename) num_param = sum([p.numel() for p in model.parameters()]) print(f"Number of model parameters: {num_param}") print("Done!") if __name__ == "__main__": main()