File size: 5,939 Bytes
10b3362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# coding: utf-8
# Copyright (c) 2025 inclusionAI.

from typing import Dict, Any, Tuple, SupportsFloat, List, Union

from aworld.config import ConfigDict, ToolConfig
from examples.tools.tool_action import GymAction
from aworld.core.common import Observation, ActionModel, ActionResult
from aworld.core.tool.base import Tool, ToolFactory
from aworld.utils.import_package import import_packages
from aworld.tools.utils import build_observation


class ActionType(object):
    DISCRETE = 'discrete'
    CONTINUOUS = 'continuous'


@ToolFactory.register(name="openai_gym",
                      desc="gym classic control game",
                      supported_action=GymAction,
                      conf_file_name=f'openai_gym_tool.yaml')
class OpenAIGym(Tool):
    def __init__(self, conf: Union[Dict[str, Any], ConfigDict, ToolConfig], **kwargs) -> None:
        """Gym environment constructor.

        Args:
            env_id: gym environment full name
            wrappers: gym environment wrapper list
        """
        import_packages(['pygame', 'gymnasium'])
        super(OpenAIGym, self).__init__(conf, **kwargs)
        self.env_id = self.conf.get("env_id")
        self._render = self.conf.get('render', True)
        if self._render:
            kwargs['render_mode'] = self.conf.get('render_mode', 'human')
        kwargs.pop('name', None)
        self.env = self._gym_env_wrappers(self.env_id, self.conf.get("wrappers", []), **kwargs)
        self.action_space = self.env.action_space

    def do_step(self, actions: List[ActionModel], **kwargs) -> Tuple[
        Observation, SupportsFloat, bool, bool, Dict[str, Any]]:
        if self._render:
            self.render()
        action = actions[0].params['result']
        action = OpenAIGym.transform_action(action=action)
        state, reward, terminal, truncate, info = self.env.step(action)
        info.update(kwargs)
        self._finished = terminal

        action_results = []
        for _ in actions:
            action_results.append(ActionResult(content=OpenAIGym.transform_state(state=state), success=True))
        return (build_observation(observer=self.name(),
                                  action_result=action_results,
                                  ability=GymAction.PLAY.value.name,
                                  content=OpenAIGym.transform_state(state=state),
                                  env_id=self.env_id,
                                  done=terminal,
                                  **kwargs),
                reward,
                terminal,
                truncate,
                info)

    def render(self):
        return self.env.render()

    def close(self):
        if self.env:
            self.env.close()
        self.env = None

    def reset(self, *, seed: int | None = None, options: Dict[str, str] | None = None) -> Tuple[Any, Dict[str, Any]]:
        state = self.env.reset()
        return build_observation(observer=self.name(),
                                 ability=GymAction.PLAY.value.name,
                                 content=OpenAIGym.transform_state(state=state),
                                 env_id=self.env_id,
                                 done=False), {}

    def _action_dim(self):
        from gymnasium import spaces

        if isinstance(self.env.action_space, spaces.Discrete):
            self.action_type = ActionType.DISCRETE
            return self.env.action_space.n
        elif isinstance(self.env.action_space, spaces.Box):
            self.action_type = ActionType.CONTINUOUS
            return self.env.action_space.shape[0]
        else:
            raise Exception('unsupported env.action_space: {}'.format(self.env.action_space))

    def _state_dim(self):
        if len(self.env.observation_space.shape) == 1:
            return self.env.observation_space.shape[0]
        else:
            raise Exception('unsupported observation_space.shape: {}'.format(self.env.observation_space))

    def _gym_env_wrappers(self, env_id, wrappers: list = [], **kwargs):
        import gymnasium

        env = gymnasium.make(env_id, **kwargs)

        if wrappers:
            for wrapper in wrappers:
                env = wrapper(env)

        return env

    @staticmethod
    def transform_state(state: Any):
        if isinstance(state, tuple):
            states = dict()
            for n, state in enumerate(state):
                state = OpenAIGym.transform_state(state=state)
                if isinstance(state, dict):
                    for name, state in state.items():
                        states['gym{}-{}'.format(n, name)] = state
                else:
                    states['gym{}'.format(n)] = state
            return states
        elif isinstance(state, dict):
            states = dict()
            for state_name, state in state.items():
                state = OpenAIGym.transform_state(state=state)
                if isinstance(state, dict):
                    for name, state in state.items():
                        states['{}-{}'.format(state_name, name)] = state
                else:
                    states['{}'.format(state_name)] = state
            return states
        else:
            return state

    @staticmethod
    def transform_action(action: Any):
        if not isinstance(action, dict):
            return action
        else:
            actions = dict()
            for name, action in action.items():
                if '-' in name:
                    name, inner_name = name.split('-', 1)
                    if name not in actions:
                        actions[name] = dict()
                    actions[name][inner_name] = action
                else:
                    actions[name] = action
            for name, action in actions.items():
                if isinstance(action, dict):
                    actions[name] = OpenAIGym.transform_action(action=action)
            return actions