Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| import uuid | |
| from ding.torch_utils.checkpoint_helper import auto_checkpoint, build_checkpoint_helper, CountVar | |
| from ding.utils import read_file, save_file | |
| class DstModel(nn.Module): | |
| def __init__(self): | |
| super(DstModel, self).__init__() | |
| self.fc1 = nn.Linear(3, 3) | |
| self.fc2 = nn.Linear(3, 8) | |
| self.fc_dst = nn.Linear(3, 6) | |
| class SrcModel(nn.Module): | |
| def __init__(self): | |
| super(SrcModel, self).__init__() | |
| self.fc1 = nn.Linear(3, 3) | |
| self.fc2 = nn.Linear(3, 8) | |
| self.fc_src = nn.Linear(3, 7) | |
| class HasStateDict(object): | |
| def __init__(self, name): | |
| self._name = name | |
| self._state_dict = name + str(uuid.uuid4()) | |
| def state_dict(self): | |
| old = self._state_dict | |
| self._state_dict = self._name + str(uuid.uuid4()) | |
| return old | |
| def load_state_dict(self, state_dict): | |
| self._state_dict = state_dict | |
| class TestCkptHelper: | |
| def test_load_model(self): | |
| path = 'model.pt' | |
| os.popen('rm -rf ' + path) | |
| time.sleep(1) | |
| dst_model = DstModel() | |
| src_model = SrcModel() | |
| ckpt_state_dict = {'model': src_model.state_dict()} | |
| torch.save(ckpt_state_dict, path) | |
| ckpt_helper = build_checkpoint_helper({}) | |
| with pytest.raises(RuntimeError): | |
| ckpt_helper.load(path, dst_model, strict=True) | |
| ckpt_helper.load(path, dst_model, strict=False) | |
| assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() < 1e-6 | |
| assert torch.abs(dst_model.fc1.bias - src_model.fc1.bias).max() < 1e-6 | |
| dst_model = DstModel() | |
| src_model = SrcModel() | |
| assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() > 1e-6 | |
| src_optimizer = HasStateDict('src_optimizer') | |
| dst_optimizer = HasStateDict('dst_optimizer') | |
| src_last_epoch = CountVar(11) | |
| dst_last_epoch = CountVar(5) | |
| src_last_iter = CountVar(110) | |
| dst_last_iter = CountVar(50) | |
| src_dataset = HasStateDict('src_dataset') | |
| dst_dataset = HasStateDict('dst_dataset') | |
| src_collector_info = HasStateDict('src_collect_info') | |
| dst_collector_info = HasStateDict('dst_collect_info') | |
| ckpt_helper.save( | |
| path, | |
| src_model, | |
| optimizer=src_optimizer, | |
| dataset=src_dataset, | |
| collector_info=src_collector_info, | |
| last_iter=src_last_iter, | |
| last_epoch=src_last_epoch, | |
| prefix_op='remove', | |
| prefix="f" | |
| ) | |
| ckpt_helper.load( | |
| path, | |
| dst_model, | |
| dataset=dst_dataset, | |
| optimizer=dst_optimizer, | |
| last_iter=dst_last_iter, | |
| last_epoch=dst_last_epoch, | |
| collector_info=dst_collector_info, | |
| strict=False, | |
| state_dict_mask=['fc1'], | |
| prefix_op='add', | |
| prefix="f" | |
| ) | |
| assert dst_dataset.state_dict().startswith('src') | |
| assert dst_optimizer.state_dict().startswith('src') | |
| assert dst_collector_info.state_dict().startswith('src') | |
| assert dst_last_iter.val == 110 | |
| for k, v in dst_model.named_parameters(): | |
| assert k.startswith('fc') | |
| print('==dst', dst_model.fc2.weight) | |
| print('==src', src_model.fc2.weight) | |
| assert torch.abs(dst_model.fc2.weight - src_model.fc2.weight).max() < 1e-6 | |
| assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() > 1e-6 | |
| checkpoint = read_file(path) | |
| checkpoint.pop('dataset') | |
| checkpoint.pop('optimizer') | |
| checkpoint.pop('last_iter') | |
| save_file(path, checkpoint) | |
| ckpt_helper.load( | |
| path, | |
| dst_model, | |
| dataset=dst_dataset, | |
| optimizer=dst_optimizer, | |
| last_iter=dst_last_iter, | |
| last_epoch=dst_last_epoch, | |
| collector_info=dst_collector_info, | |
| strict=True, | |
| state_dict_mask=['fc1'], | |
| prefix_op='add', | |
| prefix="f" | |
| ) | |
| with pytest.raises(NotImplementedError): | |
| ckpt_helper.load( | |
| path, | |
| dst_model, | |
| strict=False, | |
| lr_schduler='lr_scheduler', | |
| last_iter=dst_last_iter, | |
| ) | |
| with pytest.raises(KeyError): | |
| ckpt_helper.save(path, src_model, prefix_op='key_error', prefix="f") | |
| ckpt_helper.load(path, dst_model, strict=False, prefix_op='key_error', prefix="f") | |
| os.popen('rm -rf ' + path + '*') | |
| def test_count_var(): | |
| var = CountVar(0) | |
| var.add(5) | |
| assert var.val == 5 | |
| var.update(3) | |
| assert var.val == 3 | |
| def test_auto_checkpoint(): | |
| class AutoCkptCls: | |
| def __init__(self): | |
| pass | |
| def start(self): | |
| for i in range(10): | |
| if i < 5: | |
| time.sleep(0.2) | |
| else: | |
| raise Exception("There is an exception") | |
| break | |
| def save_checkpoint(self, ckpt_path): | |
| print('Checkpoint is saved successfully in {}!'.format(ckpt_path)) | |
| auto_ckpt = AutoCkptCls() | |
| auto_ckpt.start() | |
| if __name__ == '__main__': | |
| test = TestCkptHelper() | |
| test.test_load_model() | |