File size: 2,761 Bytes
6d95ea1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import pathlib
import os
import torch

def save(ckpt_dir, module, optimizer, scheduler, global_step, keep_latest=2, model_name='model'):
    pathlib.Path(ckpt_dir).mkdir(exist_ok=True, parents=True)
    prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*pth' % model_name))
    prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
    if len(prev_ckpts) > keep_latest-1:
        for f in prev_ckpts[keep_latest-1:]:
            f.unlink()
    save_path = '%s/%s-%09d.pth' % (ckpt_dir, model_name, global_step)
    save_dict = {
        "model": module.state_dict(),
        "optimizer": optimizer.state_dict(),
        "global_step": global_step,
    }
    if scheduler is not None:
        save_dict['scheduler'] = scheduler.state_dict()
    print(f"saving {save_path}")
    torch.save(save_dict, save_path)
    return False

def load(fabric, ckpt_dir, model, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None, strict=True, verbose=True, weights_only=False):
    if verbose:
        print('reading ckpt from %s' % ckpt_dir)
    if not os.path.exists(ckpt_dir):
        print('...there is no full checkpoint in %s' % ckpt_dir)
        print('-- note this function no longer appends "saved_checkpoints/" before the ckpt_dir --')
        assert(False)
    else:
        prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*pth' % model_name))
        prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
        if len(prev_ckpts):
            path = prev_ckpts[0]
            # e.g., './checkpoints/2Ai4_5e-4_base18_1539/model-000050000.pth'
            step = int(str(path).split('-')[-1].split('.')[0])
            if verbose:
                print('...found checkpoint %s; (parsed step %d from path)' % (path, step))
            if fabric is not None:
                checkpoint = fabric.load(path)
            else:
                checkpoint = torch.load(path, weights_only=weights_only)
            if optimizer is not None:
                optimizer.load_state_dict(checkpoint['optimizer'])
            if scheduler is not None:
                scheduler.load_state_dict(checkpoint['scheduler'])
            assert ignore_load is None # not ready yet
            if 'model' in checkpoint:
                state_dict = checkpoint['model']
            else:
                state_dict = checkpoint
            model.load_state_dict(state_dict, strict=strict)
        else:
            print('...there is no full checkpoint here!')
    return step