from pathlib import Path import gym import multi_agent_ale_py import numpy as np from gym import spaces from gym.utils import EzPickle, seeding from pettingzoo import AECEnv from pettingzoo.utils import agent_selector, wrappers from pettingzoo.utils.conversions import parallel_to_aec_wrapper, parallel_wrapper_fn from pettingzoo.utils.env import ParallelEnv def base_env_wrapper_fn(raw_env_fn): def env_fn(**kwargs): env = raw_env_fn(**kwargs) env = wrappers.AssertOutOfBoundsWrapper(env) env = wrappers.OrderEnforcingWrapper(env) return env return env_fn def BaseAtariEnv(**kwargs): return parallel_to_aec_wrapper(ParallelAtariEnv(**kwargs)) class ParallelAtariEnv(ParallelEnv, EzPickle): def __init__( self, game, num_players, mode_num=None, seed=None, obs_type='rgb_image', full_action_space=True, env_name=None, max_cycles=100000, auto_rom_install_path=None): """Frameskip should be either a tuple (indicating a random range to choose from, with the top value exclude), or an int.""" EzPickle.__init__( self, game, num_players, mode_num, seed, obs_type, full_action_space, env_name, max_cycles, auto_rom_install_path, ) assert obs_type in ('ram', 'rgb_image', "grayscale_image"), "obs_type must either be 'ram' or 'rgb_image' or 'grayscale_image'" self.obs_type = obs_type self.full_action_space = full_action_space self.num_players = num_players self.max_cycles = max_cycles if env_name is None: env_name = "custom_" + game self.metadata = {'render.modes': ['human', 'rgb_array'], 'name': env_name, 'video.frames_per_second': 60} multi_agent_ale_py.ALEInterface.setLoggerMode("error") self.ale = multi_agent_ale_py.ALEInterface() self.ale.setFloat(b'repeat_action_probability', 0.) if auto_rom_install_path is None: start = Path(multi_agent_ale_py.__file__).parent else: start = Path(auto_rom_install_path).resolve() # start looking in local directory final = start / f"{game}.bin" if not final.exists(): # if that doesn't work, look in 'roms' final = start / "roms" / f"{game}.bin" if not final.exists(): # use old AutoROM install path as backup final = start / "ROM" / game / f"{game}.bin" if not final.exists(): raise OSError(f"rom {game} is not installed. Please install roms using AutoROM tool (https://github.com/Farama-Foundation/AutoROM) " "or specify and double-check the path to your Atari rom using the `rom_path` argument.") self.rom_path = str(final) self.ale.loadROM(self.rom_path) all_modes = self.ale.getAvailableModes(num_players) if mode_num is None: mode = all_modes[0] else: mode = mode_num assert mode in all_modes, f"mode_num parameter is wrong. Mode {mode_num} selected, only {list(all_modes)} modes are supported" self.mode = mode self.ale.setMode(self.mode) assert num_players == self.ale.numPlayersActive() if full_action_space: action_size = 18 action_mapping = np.arange(action_size) else: action_mapping = self.ale.getMinimalActionSet() action_size = len(action_mapping) self.action_mapping = action_mapping if obs_type == 'ram': observation_space = gym.spaces.Box(low=0, high=255, dtype=np.uint8, shape=(128,)) else: (screen_width, screen_height) = self.ale.getScreenDims() if obs_type == 'rgb_image': num_channels = 3 elif obs_type == 'grayscale_image': num_channels = 1 observation_space = spaces.Box(low=0, high=255, shape=(screen_height, screen_width, num_channels), dtype=np.uint8) player_names = ["first", "second", "third", "fourth"] self.agents = [f"{player_names[n]}_0" for n in range(num_players)] self.possible_agents = self.agents[:] self.action_spaces = {agent: gym.spaces.Discrete(action_size) for agent in self.possible_agents} self.observation_spaces = {agent: observation_space for agent in self.possible_agents} self._screen = None self.seed(seed) def seed(self, seed=None): if seed is None: seed = seeding.create_seed(seed, max_bytes=4) self.ale.setInt(b"random_seed", seed) self.ale.loadROM(self.rom_path) self.ale.setMode(self.mode) def reset(self): self.ale.reset_game() self.agents = self.possible_agents[:] self.dones = {agent: False for agent in self.possible_agents} self.frame = 0 obs = self._observe() return {agent: obs for agent in self.agents} def observation_space(self, agent): return self.observation_spaces[agent] def action_space(self, agent): return self.action_spaces[agent] def _observe(self): if self.obs_type == 'ram': bytes = self.ale.getRAM() return bytes elif self.obs_type == 'rgb_image': return self.ale.getScreenRGB() elif self.obs_type == 'grayscale_image': return self.ale.getScreenGrayscale() def step(self, action_dict): actions = np.zeros(self.max_num_agents, dtype=np.int32) for i, agent in enumerate(self.possible_agents): if agent in action_dict: actions[i] = action_dict[agent] actions = self.action_mapping[actions] rewards = self.ale.act(actions) self.frame += 1 if self.ale.game_over() or self.frame >= self.max_cycles: dones = {agent: True for agent in self.agents} else: lives = self.ale.allLives() # an inactive agent in ale gets a -1 life. dones = {agent: int(life) < 0 for agent, life in zip(self.possible_agents, lives) if agent in self.agents} obs = self._observe() observations = {agent: obs for agent in self.agents} rewards = {agent: rew for agent, rew in zip(self.possible_agents, rewards) if agent in self.agents} infos = {agent: {} for agent in self.possible_agents if agent in self.agents} self.agents = [agent for agent in self.agents if not dones[agent]] return observations, rewards, dones, infos def render(self, mode="human"): (screen_width, screen_height) = self.ale.getScreenDims() image = self.ale.getScreenRGB() if mode == "human": import os import pygame zoom_factor = 4 if self._screen is None: pygame.init() self._screen = pygame.display.set_mode((screen_width * zoom_factor, screen_height * zoom_factor)) myImage = pygame.image.fromstring(image.tobytes(), image.shape[:2][::-1], "RGB") myImage = pygame.transform.scale(myImage, (screen_width * zoom_factor, screen_height * zoom_factor)) self._screen.blit(myImage, (0, 0)) pygame.display.flip() elif mode == "rgb_array": return image else: raise ValueError("bad value for render mode") def close(self): if self._screen is not None: import pygame pygame.quit() self._screen = None def clone_state(self): """Clone emulator state w/o system state. Restoring this state will *not* give an identical environment. For complete cloning and restoring of the full state, see `{clone,restore}_full_state()`.""" state_ref = self.ale.cloneState() state = self.ale.encodeState(state_ref) self.ale.deleteState(state_ref) return state def restore_state(self, state): """Restore emulator state w/o system state.""" state_ref = self.ale.decodeState(state) self.ale.restoreState(state_ref) self.ale.deleteState(state_ref) def clone_full_state(self): """Clone emulator state w/ system state including pseudorandomness. Restoring this state will give an identical environment.""" state_ref = self.ale.cloneSystemState() state = self.ale.encodeState(state_ref) self.ale.deleteState(state_ref) return state def restore_full_state(self, state): """Restore emulator state w/ system state including pseudorandomness.""" state_ref = self.ale.decodeState(state) self.ale.restoreSystemState(state_ref) self.ale.deleteState(state_ref)