|
import glob |
|
import os |
|
import shutil |
|
|
|
from tests import run_cli |
|
|
|
|
|
def test_continue_train(): |
|
output_path = "output/" |
|
|
|
command_train = "python tests/utils/train_mnist.py" |
|
run_cli(command_train) |
|
|
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) |
|
number_of_checkpoints = len(glob.glob(os.path.join(continue_path, "*.pth"))) |
|
|
|
command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path}" |
|
run_cli(command_continue) |
|
|
|
assert number_of_checkpoints < len(glob.glob(os.path.join(continue_path, "*.pth"))) |
|
shutil.rmtree(continue_path) |
|
|