Spaces:
Running
Running
File size: 12,411 Bytes
eaf8c2f b8eecd0 eaf8c2f b8eecd0 eaf8c2f 6af7869 eaf8c2f 6af7869 eaf8c2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 |
#!/usr/bin/env python3
"""
simulate.py
Core simulation logic for a single game simulation.
Handles environment creation, policy initialization, and the simulation loop.
"""
import logging
import sys
from pathlib import Path
from typing import Dict, Any
# Ensure the src directory is in the Python path
current_dir = Path(__file__).parent
src_dir = current_dir / ".." / "src"
sys.path.insert(0, str(src_dir.resolve()))
# pylint: disable=wrong-import-position
from game_reasoning_arena.arena.utils.seeding import set_seed
from game_reasoning_arena.arena.games.registry import registry # Games registry
from game_reasoning_arena.backends import initialize_llm_registry
from game_reasoning_arena.arena.agents.policy_manager import (
initialize_policies, policy_mapping_fn
)
from game_reasoning_arena.arena.utils.loggers import SQLiteLogger
from torch.utils.tensorboard import SummaryWriter
logger = logging.getLogger(__name__)
def log_llm_action(agent_id: int,
agent_model: str,
observation: Dict[str, Any],
chosen_action: int,
reasoning: str,
flag: bool = False
) -> None:
"""Logs the LLM agent's decision."""
logger.info("Board state: \n%s", observation['state_string'])
logger.info("Legal actions: %s", observation['legal_actions'])
logger.info(
"Agent %s (%s) chose action: %s with reasoning: %s",
agent_id, agent_model, chosen_action, reasoning
)
if flag:
logger.error("Terminated due to illegal move: %s.", chosen_action)
def compute_actions(env, player_to_agent, observations):
"""
Computes actions for all agents in the current state.
Args:
env: The environment (OpenSpiel env).
player_to_agent: Dictionary mapping player IDs to agent instances.
observations: Dictionary of observations for each player.
Returns:
Dictionary mapping player IDs to their chosen actions.
Also stores reasoning in agent objects for later retrieval.
"""
def extract_action_and_store_reasoning(player_id, agent_response):
agent = player_to_agent[player_id]
if isinstance(agent_response, dict) and "action" in agent_response:
# Store reasoning in the agent object for later retrieval
if "reasoning" in agent_response:
agent.last_reasoning = agent_response["reasoning"]
else:
agent.last_reasoning = "None"
return agent_response.get("action", -1)
else:
# Fallback for unexpected response formats
agent.last_reasoning = "None"
return -1
if env.state.is_simultaneous_node():
# Simultaneous-move game: All players act at once
actions = {}
for player in player_to_agent:
agent_response = player_to_agent[player](observations[player])
actions[player] = extract_action_and_store_reasoning(
player, agent_response)
return actions
else:
# Turn-based game: Only the current player acts
current_player = env.state.current_player()
agent_response = player_to_agent[current_player](
observations[current_player])
return {current_player: extract_action_and_store_reasoning(
current_player, agent_response)}
def simulate_game(game_name: str, config: Dict[str, Any], seed: int) -> str:
"""
Runs a game simulation, logs agent actions and final rewards to
TensorBoard.
Args:
game_name: The name of the game.
config: Simulation configuration.
seed: Random seed for reproducibility.
Returns:
str: Confirmation that the simulation is complete.
"""
# Set global seed for reproducibility across all random number generators
set_seed(seed)
# Initialize LLM registry
initialize_llm_registry()
# Initialize loggers for all agents
logger.info("Initializing environment for %s.", game_name)
# Assign players to their policy classes
policies_dict = initialize_policies(config, game_name, seed)
# Initialize loggers and writers for all agents
agent_loggers_dict = {}
for agent_id, policy_name in enumerate(policies_dict.keys()):
# Get agent config and pass it to the logger
player_key = f"player_{agent_id}"
default_config = {"type": "unknown", "model": "None"}
agent_config = config["agents"].get(player_key, default_config)
# Sanitize model name for filename use
model_name = agent_config.get("model", "None")
sanitized_model_name = model_name.replace("-", "_").replace("/", "_")
agent_loggers_dict[policy_name] = SQLiteLogger(
agent_type=agent_config["type"],
model_name=sanitized_model_name
)
writer = SummaryWriter(log_dir=f"runs/{game_name}") # Tensorboard writer
# Create player_to_agent mapping for RLLib-style action computation
player_to_agent = {}
for i, policy_name in enumerate(policies_dict.keys()):
player_to_agent[i] = policies_dict[policy_name]
# Loads the pyspiel game and the env simulator
env = registry.make_env(game_name, config)
for episode in range(config["num_episodes"]):
episode_seed = seed + episode
observation_dict, _ = env.reset(seed=episode_seed)
terminated = truncated = False
rewards_dict = {} # Initialize rewards_dict
logger.info(
"Episode %d started with seed %d.", episode + 1, episode_seed
)
turn = 0
while not (terminated or truncated):
# Use RLLib-style action computation
try:
action_dict = compute_actions(
env, player_to_agent, observation_dict
)
except Exception as e:
logger.error("Error computing actions: %s", e)
truncated = True
break
# Process each action for logging and validation
for agent_id, chosen_action in action_dict.items():
policy_key = policy_mapping_fn(agent_id)
agent_logger = agent_loggers_dict[policy_key]
observation = observation_dict[agent_id]
# Get agent config for logging - ensure we get the right
# agent's config
agent_type = None
agent_model = "None"
player_key = f"player_{agent_id}"
if player_key in config["agents"]:
agent_config = config["agents"][player_key]
agent_type = agent_config["type"]
# Only set model for LLM agents
if agent_type == "llm":
agent_model = agent_config.get("model", "None")
else:
agent_model = "None"
# Check if the chosen action is legal
if (chosen_action is None or
chosen_action not in observation["legal_actions"]):
logger.error(
f"ILLEGAL MOVE DETECTED - Agent {agent_id}: "
f"chosen_action={chosen_action} (type: {type(chosen_action)}), "
f"legal_actions={observation['legal_actions']}"
)
if agent_type == "llm":
log_llm_action(
agent_id, agent_model, observation,
chosen_action, "Illegal action", flag=True
)
agent_logger.log_illegal_move(
game_name=game_name, episode=episode + 1, turn=turn,
agent_id=agent_id, illegal_action=chosen_action,
reason="Illegal action",
board_state=observation["state_string"]
)
truncated = True
break
# Get reasoning if available (for LLM agents)
reasoning = "None"
if (agent_type == "llm" and
hasattr(player_to_agent[agent_id], 'last_reasoning')):
reasoning = getattr(
player_to_agent[agent_id], 'last_reasoning', "None"
)
# Logging
opponents_list = []
for a_id in config["agents"]:
if a_id != f"player_{agent_id}":
opp_agent_type = config['agents'][a_id]['type']
model = config['agents'][a_id].get('model', 'None')
model_clean = model.replace('-', '_')
opponents_list.append(f"{opp_agent_type}_{model_clean}")
opponents = ", ".join(opponents_list)
agent_logger.log_move(
game_name=game_name,
episode=episode + 1,
turn=turn,
action=chosen_action,
reasoning=reasoning,
opponent=opponents,
generation_time=0.0, # TODO: Add timing back
agent_type=agent_type,
agent_model=agent_model,
seed=episode_seed,
board_state=observation["state_string"]
)
if agent_type == "llm":
log_llm_action(
agent_id, agent_model, observation,
chosen_action, reasoning
)
# Step forward in the environment
if not truncated:
(observation_dict, rewards_dict,
terminated, truncated, _) = env.step(action_dict)
turn += 1
# Logging
game_status = "truncated" if truncated else "terminated"
logger.info(
"Game status: %s with rewards dict: %s", game_status, rewards_dict
)
for agent_id, reward in rewards_dict.items():
policy_key = policy_mapping_fn(agent_id)
agent_logger = agent_loggers_dict[policy_key]
# Calculate opponents for this agent
opponents_list = []
for a_id in config["agents"]:
if a_id != f"player_{agent_id}":
opp_agent_type = config['agents'][a_id]['type']
opp_model = config['agents'][a_id].get('model', 'None')
opp_model_clean = opp_model.replace('-', '_')
opponent_str = f"{opp_agent_type}_{opp_model_clean}"
opponents_list.append(opponent_str)
opponents = ", ".join(opponents_list)
# Log reward to the rewards table
agent_logger.log_rewards(
game_name=game_name,
episode=episode + 1,
reward=reward
)
agent_logger.log_game_result(
game_name=game_name,
episode=episode + 1,
status=game_status,
reward=reward,
opponent=opponents
)
# Tensorboard logging
agent_type = "unknown"
agent_model = "None"
# Find the agent config by index - handle both string and int keys
for key, value in config["agents"].items():
if (key.startswith("player_") and
int(key.split("_")[1]) == agent_id):
agent_type = value["type"]
agent_model = value.get("model", "None")
break
elif str(key) == str(agent_id):
agent_type = value["type"]
agent_model = value.get("model", "None")
break
tensorboard_key = f"{agent_type}_{agent_model.replace('-', '_')}"
writer.add_scalar(
f"Rewards/{tensorboard_key}", reward, episode + 1
)
logger.info(
"Simulation for game %s, Episode %d completed.",
game_name, episode + 1
)
writer.close()
return "Simulation Completed"
# start tensorboard from the terminal:
# tensorboard --logdir=runs
# In the browser:
# http://localhost:6006/
|