Spaces:
Sleeping
Sleeping
| from ding.framework import Context, OnlineRLContext, OfflineRLContext | |
| import random | |
| import numpy as np | |
| import treetensor.torch as ttorch | |
| import torch | |
| batch_size = 64 | |
| n_sample = 8 | |
| action_dim = 1 | |
| obs_dim = 4 | |
| logit_dim = 2 | |
| n_episodes = 2 | |
| n_episode_length = 16 | |
| update_per_collect = 4 | |
| collector_env_num = 8 | |
| # the range here is meaningless and just for test | |
| def fake_train_data(): | |
| train_data = ttorch.as_tensor( | |
| { | |
| 'action': torch.randint(0, 2, size=(action_dim, )), | |
| 'collect_train_iter': torch.randint(0, 100, size=(1, )), | |
| 'done': torch.tensor(False), | |
| 'env_data_id': torch.tensor([2]), | |
| 'next_obs': torch.randn(obs_dim), | |
| 'obs': torch.randn(obs_dim), | |
| 'reward': torch.randint(0, 2, size=(1, )), | |
| } | |
| ) | |
| return train_data | |
| def fake_online_rl_context(): | |
| ctx = OnlineRLContext( | |
| env_step=random.randint(0, 100), | |
| env_episode=random.randint(0, 100), | |
| train_iter=random.randint(0, 100), | |
| train_data=[fake_train_data() for _ in range(batch_size)], | |
| train_output=[{ | |
| 'cur_lr': 0.001, | |
| 'total_loss': random.uniform(0, 2) | |
| } for _ in range(update_per_collect)], | |
| obs=torch.randn(collector_env_num, obs_dim), | |
| action=[np.random.randint(low=0, high=1, size=(action_dim), dtype=np.int64) for _ in range(collector_env_num)], | |
| inference_output={ | |
| env_id: { | |
| 'logit': torch.randn(logit_dim), | |
| 'action': torch.randint(0, 2, size=(action_dim, )) | |
| } | |
| for env_id in range(collector_env_num) | |
| }, | |
| collect_kwargs={'eps': random.uniform(0, 1)}, | |
| trajectories=[fake_train_data() for _ in range(n_sample)], | |
| episodes=[[fake_train_data() for _ in range(n_episode_length)] for _ in range(n_episodes)], | |
| trajectory_end_idx=[i for i in range(n_sample)], | |
| eval_value=random.uniform(-1.0, 1.0), | |
| last_eval_iter=random.randint(0, 100), | |
| ) | |
| return ctx | |
| def fake_offline_rl_context(): | |
| ctx = OfflineRLContext( | |
| train_epoch=random.randint(0, 100), | |
| train_iter=random.randint(0, 100), | |
| train_data=[fake_train_data() for _ in range(batch_size)], | |
| train_output=[{ | |
| 'cur_lr': 0.001, | |
| 'total_loss': random.uniform(0, 2) | |
| } for _ in range(update_per_collect)], | |
| eval_value=random.uniform(-1.0, 1.0), | |
| last_eval_iter=random.randint(0, 100), | |
| ) | |
| return ctx | |