Spaces:
Sleeping
Sleeping
| from copy import deepcopy | |
| import pytest | |
| import os | |
| import pickle | |
| from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \ | |
| cartpole_ppo_offpolicy_create_config # noqa | |
| from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\ | |
| cartpole_trex_offppo_create_config | |
| from dizoo.classic_control.cartpole.envs import CartPoleEnv | |
| from ding.entry import serial_pipeline, eval, collect_demo_data | |
| from ding.config import compile_config | |
| from ding.entry.application_entry import collect_episodic_demo_data, episode_to_transitions | |
| def setup_state_dict(): | |
| config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) | |
| try: | |
| policy = serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, 'Serial pipeline failure' | |
| state_dict = { | |
| 'eval': policy.eval_mode.state_dict(), | |
| 'collect': policy.collect_mode.state_dict(), | |
| } | |
| return state_dict | |
| class TestApplication: | |
| def test_eval(self, setup_state_dict): | |
| cfg_for_stop_value = compile_config( | |
| cartpole_ppo_offpolicy_config, auto=True, create_cfg=cartpole_ppo_offpolicy_create_config | |
| ) | |
| stop_value = cfg_for_stop_value.env.stop_value | |
| config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) | |
| episode_return = eval(config, seed=0, state_dict=setup_state_dict['eval']) | |
| assert episode_return >= stop_value | |
| config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) | |
| episode_return = eval( | |
| config, | |
| seed=0, | |
| env_setting=[CartPoleEnv, None, [{} for _ in range(5)]], | |
| state_dict=setup_state_dict['eval'] | |
| ) | |
| assert episode_return >= stop_value | |
| def test_collect_demo_data(self, setup_state_dict): | |
| config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) | |
| collect_count = 16 | |
| expert_data_path = './expert.data' | |
| collect_demo_data( | |
| config, | |
| seed=0, | |
| state_dict=setup_state_dict['collect'], | |
| collect_count=collect_count, | |
| expert_data_path=expert_data_path | |
| ) | |
| with open(expert_data_path, 'rb') as f: | |
| exp_data = pickle.load(f) | |
| assert isinstance(exp_data, list) | |
| assert isinstance(exp_data[0], dict) | |
| def test_collect_episodic_demo_data(self, setup_state_dict): | |
| config = deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config) | |
| config[0].exp_name = 'cartpole_trex_offppo_episodic' | |
| collect_count = 16 | |
| if not os.path.exists('./test_episode'): | |
| os.mkdir('./test_episode') | |
| expert_data_path = './test_episode/expert.data' | |
| collect_episodic_demo_data( | |
| config, | |
| seed=0, | |
| state_dict=setup_state_dict['collect'], | |
| expert_data_path=expert_data_path, | |
| collect_count=collect_count, | |
| ) | |
| with open(expert_data_path, 'rb') as f: | |
| exp_data = pickle.load(f) | |
| assert isinstance(exp_data, list) | |
| assert isinstance(exp_data[0][0], dict) | |
| def test_episode_to_transitions(self, setup_state_dict): | |
| self.test_collect_episodic_demo_data(setup_state_dict) | |
| expert_data_path = './test_episode/expert.data' | |
| episode_to_transitions(data_path=expert_data_path, expert_data_path=expert_data_path, nstep=3) | |
| with open(expert_data_path, 'rb') as f: | |
| exp_data = pickle.load(f) | |
| assert isinstance(exp_data, list) | |
| assert isinstance(exp_data[0], dict) | |
| os.popen('rm -rf ./test_episode/expert.data ckpt* log') | |
| os.popen('rm -rf ./test_episode') | |