File size: 621 Bytes
287c28c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import glob
import os
import shutil
from tests import run_cli
def test_continue_train():
output_path = "output/"
command_train = "python tests/utils/train_mnist.py"
run_cli(command_train)
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
number_of_checkpoints = len(glob.glob(os.path.join(continue_path, "*.pth")))
command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path}"
run_cli(command_continue)
assert number_of_checkpoints < len(glob.glob(os.path.join(continue_path, "*.pth")))
shutil.rmtree(continue_path)
|