File size: 2,130 Bytes
fe64bad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import os
import torch
from torch.utils.data import DataLoader
from src.utils.trainer import train
from src.utils.tensors import collate
import src.utils.fixseed # noqa
from src.utils.get_model_and_data import get_model_and_data
from src.parser.checkpoint import parser
from lion_pytorch import Lion
def add_epochs(model, datasets, parameters, optimizer, origepoch):
dataset = datasets["train"]
train_iterator = DataLoader(dataset, batch_size=parameters["batch_size"],
shuffle=True, num_workers=8, collate_fn=collate)
for epoch in range(1, parameters["num_epochs"]+1):
dict_loss = train(model, optimizer, train_iterator, model.device)
for key in dict_loss.keys():
dict_loss[key] /= len(train_iterator)
print(f"Epoch {epoch}, train losses: {dict_loss}")
if ((epoch % parameters["snapshot"]) == 0) or (epoch == parameters["num_epochs"]):
checkpoint_path = os.path.join(parameters["folder"],
'retraincheckpoint_orig_{:04d}_added_{:04d}.pth.tar'.format(origepoch, epoch))
print('Saving checkpoint {}'.format(checkpoint_path))
torch.save(model.state_dict(), checkpoint_path)
def main():
# parse options
parameters, folder, checkpointname, epoch = parser()
device = parameters["device"]
model, datasets = get_model_and_data(parameters)
datasets.pop("test")
print("Restore weights..")
checkpointpath = os.path.join(folder, checkpointname)
state_dict = torch.load(checkpointpath, map_location=device)
model.load_state_dict(state_dict)
# optimizer: AdamW or Lion
optimizer = torch.optim.AdamW(model.parameters(), lr=parameters["lr"])
# optimizer = Lion(model.parameters(), lr=parameters["lr"])
print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))
print("Training model..")
add_epochs(model, datasets, parameters, optimizer, epoch)
if __name__ == '__main__':
main()
|