kevinwang676's picture
Upload folder using huggingface_hub
1b6bcbc verified
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
import os
import pdb
if("_CUDA_VISIBLE_DEVICES"in os.environ):
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
import argparse
import logging
from pathlib import Path
import torch,platform
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger#WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
torch.set_float32_matmul_precision('high')
from AR.utils import get_newest_ckpt
from collections import OrderedDict
class my_model_ckpt(ModelCheckpoint):
def __init__(self,config,if_save_latest,if_save_every_weights,half_weights_save_dir,exp_name,**kwargs):
super().__init__(**kwargs)
self.if_save_latest=if_save_latest
self.if_save_every_weights=if_save_every_weights
self.half_weights_save_dir=half_weights_save_dir
self.exp_name=exp_name
self.config=config
def on_train_epoch_end(self, trainer, pl_module):
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
if(self.if_save_latest==True):####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
to_clean=list(os.listdir(self.dirpath))
self._save_topk_checkpoint(trainer, monitor_candidates)
if (self.if_save_latest == True):
for name in to_clean:
try:
os.remove("%s/%s"%(self.dirpath,name))
except:pass
if(self.if_save_every_weights==True):
to_save_od=OrderedDict()
to_save_od["weight"]=OrderedDict()
dictt=trainer.strategy._lightning_module.state_dict()
for key in dictt:to_save_od["weight"][key]=dictt[key].half()
to_save_od["config"]=self.config
to_save_od["info"]="GPT-e%s"%(trainer.current_epoch+1)
torch.save(to_save_od,"%s/%s-e%s.ckpt"%(self.half_weights_save_dir,self.exp_name,trainer.current_epoch+1))
self._save_last_checkpoint(trainer, monitor_candidates)
def main(args):
config = load_yaml_config(args.config_file)
output_dir = Path(config["output_dir"])
output_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir = output_dir / 'ckpt'
ckpt_dir.mkdir(parents=True, exist_ok=True)
seed_everything(config["train"]["seed"], workers=True)
# ckpt_callback: ModelCheckpoint = ModelCheckpoint(
ckpt_callback: ModelCheckpoint = my_model_ckpt(
config=config,
if_save_latest=config["train"]["if_save_latest"], if_save_every_weights=config["train"]["if_save_every_weights"], half_weights_save_dir=config["train"]["half_weights_save_dir"], exp_name=config["train"]["exp_name"],
save_top_k=-1,
monitor='top_3_acc',
mode='max',
save_on_train_epoch_end=True,
every_n_epochs=config["train"]["save_every_n_epoch"],
dirpath=ckpt_dir,
)
# logger = WandbLogger(
logger = TensorBoardLogger(
# project="ar_s1bert25hz-ft",
name=output_dir.stem,
save_dir=output_dir,
# resume the loss curve
# resume=True,
# id='abc'
)
trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"],
accelerator='gpu',
# val_check_interval=9999999999999999999999,###不要验证
# check_val_every_n_epoch=None,
limit_val_batches=0,
devices=-1,
benchmark=False,
fast_dev_run=False,
strategy=DDPStrategy(process_group_backend="nccl"if platform.system()!="Windows"else "gloo"),
precision=config["train"]["precision"],
logger=logger,num_sanity_val_steps=0,
callbacks=[ckpt_callback])
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
config, output_dir)
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
config,
train_semantic_path=config["train_semantic_path"],
train_phoneme_path=config["train_phoneme_path"],
# dev_semantic_path=args.dev_semantic_path,
# dev_phoneme_path=args.dev_phoneme_path
)
try:
# 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
ckpt_path = ckpt_dir / newest_ckpt_name
except Exception:
ckpt_path = None
print("ckpt_path:", ckpt_path)
trainer.fit(model, data_module, ckpt_path=ckpt_path)
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-c',
'--config_file',
type=str,
default='configs/s1longer.yaml',
help='path of config file')
# args for dataset
# parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
# parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')
# parser.add_argument('--dev_semantic_path', type=str, default='dump_mix/semantic_dev.tsv')
# parser.add_argument('--dev_phoneme_path', type=str, default='dump_mix/phoneme_dev.npy')
# parser.add_argument('--output_dir',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/logs_s1',help='directory to save the results')
# parser.add_argument('--output_dir',type=str,default='/liujing04/gpt_logs/s1/xuangou_ft',help='directory to save the results')
args = parser.parse_args()
logging.info(str(args))
main(args)