File size: 12,411 Bytes
eaf8c2f
 
 
 
 
 
 
 
 
b8eecd0
 
eaf8c2f
b8eecd0
 
 
 
 
 
 
eaf8c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6af7869
 
 
 
 
 
 
eaf8c2f
6af7869
 
 
 
 
 
eaf8c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
#!/usr/bin/env python3
"""
simulate.py

Core simulation logic for a single game simulation.
Handles environment creation, policy initialization, and the simulation loop.
"""

import logging
import sys
from pathlib import Path
from typing import Dict, Any

# Ensure the src directory is in the Python path
current_dir = Path(__file__).parent
src_dir = current_dir / ".." / "src"
sys.path.insert(0, str(src_dir.resolve()))

# pylint: disable=wrong-import-position
from game_reasoning_arena.arena.utils.seeding import set_seed
from game_reasoning_arena.arena.games.registry import registry  # Games registry
from game_reasoning_arena.backends import initialize_llm_registry
from game_reasoning_arena.arena.agents.policy_manager import (
    initialize_policies, policy_mapping_fn
)
from game_reasoning_arena.arena.utils.loggers import SQLiteLogger
from torch.utils.tensorboard import SummaryWriter


logger = logging.getLogger(__name__)


def log_llm_action(agent_id: int,
                   agent_model: str,
                   observation: Dict[str, Any],
                   chosen_action: int,
                   reasoning: str,
                   flag: bool = False
                   ) -> None:
    """Logs the LLM agent's decision."""
    logger.info("Board state: \n%s", observation['state_string'])
    logger.info("Legal actions: %s", observation['legal_actions'])
    logger.info(
        "Agent %s (%s) chose action: %s with reasoning: %s",
        agent_id, agent_model, chosen_action, reasoning
    )
    if flag:
        logger.error("Terminated due to illegal move: %s.", chosen_action)


def compute_actions(env, player_to_agent, observations):
    """
    Computes actions for all agents in the current state.

    Args:
        env: The environment (OpenSpiel env).
        player_to_agent: Dictionary mapping player IDs to agent instances.
        observations: Dictionary of observations for each player.

    Returns:
        Dictionary mapping player IDs to their chosen actions.
        Also stores reasoning in agent objects for later retrieval.
    """

    def extract_action_and_store_reasoning(player_id, agent_response):
        agent = player_to_agent[player_id]
        if isinstance(agent_response, dict) and "action" in agent_response:
            # Store reasoning in the agent object for later retrieval
            if "reasoning" in agent_response:
                agent.last_reasoning = agent_response["reasoning"]
            else:
                agent.last_reasoning = "None"
            return agent_response.get("action", -1)
        else:
            # Fallback for unexpected response formats
            agent.last_reasoning = "None"
            return -1

    if env.state.is_simultaneous_node():
        # Simultaneous-move game: All players act at once
        actions = {}
        for player in player_to_agent:
            agent_response = player_to_agent[player](observations[player])
            actions[player] = extract_action_and_store_reasoning(
                player, agent_response)
        return actions
    else:
        # Turn-based game: Only the current player acts
        current_player = env.state.current_player()
        agent_response = player_to_agent[current_player](
            observations[current_player])
        return {current_player: extract_action_and_store_reasoning(
            current_player, agent_response)}


def simulate_game(game_name: str, config: Dict[str, Any], seed: int) -> str:
    """
    Runs a game simulation, logs agent actions and final rewards to
    TensorBoard.

    Args:
        game_name: The name of the game.
        config: Simulation configuration.
        seed: Random seed for reproducibility.

    Returns:
        str: Confirmation that the simulation is complete.
    """

    # Set global seed for reproducibility across all random number generators
    set_seed(seed)

    # Initialize LLM registry
    initialize_llm_registry()

    # Initialize loggers for all agents
    logger.info("Initializing environment for %s.", game_name)

    # Assign players to their policy classes
    policies_dict = initialize_policies(config, game_name, seed)

    # Initialize loggers and writers for all agents
    agent_loggers_dict = {}
    for agent_id, policy_name in enumerate(policies_dict.keys()):
        # Get agent config and pass it to the logger
        player_key = f"player_{agent_id}"
        default_config = {"type": "unknown", "model": "None"}
        agent_config = config["agents"].get(player_key, default_config)

        # Sanitize model name for filename use
        model_name = agent_config.get("model", "None")
        sanitized_model_name = model_name.replace("-", "_").replace("/", "_")
        agent_loggers_dict[policy_name] = SQLiteLogger(
            agent_type=agent_config["type"],
            model_name=sanitized_model_name
        )
    writer = SummaryWriter(log_dir=f"runs/{game_name}")  # Tensorboard writer

    # Create player_to_agent mapping for RLLib-style action computation
    player_to_agent = {}
    for i, policy_name in enumerate(policies_dict.keys()):
        player_to_agent[i] = policies_dict[policy_name]

    # Loads the pyspiel game and the env simulator
    env = registry.make_env(game_name, config)

    for episode in range(config["num_episodes"]):
        episode_seed = seed + episode
        observation_dict, _ = env.reset(seed=episode_seed)
        terminated = truncated = False
        rewards_dict = {}  # Initialize rewards_dict

        logger.info(
            "Episode %d started with seed %d.", episode + 1, episode_seed
            )
        turn = 0

        while not (terminated or truncated):
            # Use RLLib-style action computation
            try:
                action_dict = compute_actions(
                    env, player_to_agent, observation_dict
                )
            except Exception as e:
                logger.error("Error computing actions: %s", e)
                truncated = True
                break

            # Process each action for logging and validation
            for agent_id, chosen_action in action_dict.items():
                policy_key = policy_mapping_fn(agent_id)
                agent_logger = agent_loggers_dict[policy_key]
                observation = observation_dict[agent_id]

                # Get agent config for logging - ensure we get the right
                # agent's config
                agent_type = None
                agent_model = "None"
                player_key = f"player_{agent_id}"
                if player_key in config["agents"]:
                    agent_config = config["agents"][player_key]
                    agent_type = agent_config["type"]
                    # Only set model for LLM agents
                    if agent_type == "llm":
                        agent_model = agent_config.get("model", "None")
                    else:
                        agent_model = "None"

                # Check if the chosen action is legal
                if (chosen_action is None or
                        chosen_action not in observation["legal_actions"]):
                    logger.error(
                        f"ILLEGAL MOVE DETECTED - Agent {agent_id}: "
                        f"chosen_action={chosen_action} (type: {type(chosen_action)}), "
                        f"legal_actions={observation['legal_actions']}"
                    )
                    if agent_type == "llm":
                        log_llm_action(
                            agent_id, agent_model, observation,
                            chosen_action, "Illegal action", flag=True
                        )
                    agent_logger.log_illegal_move(
                        game_name=game_name, episode=episode + 1, turn=turn,
                        agent_id=agent_id, illegal_action=chosen_action,
                        reason="Illegal action",
                        board_state=observation["state_string"]
                    )
                    truncated = True
                    break

                # Get reasoning if available (for LLM agents)
                reasoning = "None"
                if (agent_type == "llm" and
                        hasattr(player_to_agent[agent_id], 'last_reasoning')):
                    reasoning = getattr(
                        player_to_agent[agent_id], 'last_reasoning', "None"
                    )

                # Logging
                opponents_list = []
                for a_id in config["agents"]:
                    if a_id != f"player_{agent_id}":
                        opp_agent_type = config['agents'][a_id]['type']
                        model = config['agents'][a_id].get('model', 'None')
                        model_clean = model.replace('-', '_')
                        opponents_list.append(f"{opp_agent_type}_{model_clean}")
                opponents = ", ".join(opponents_list)

                agent_logger.log_move(
                    game_name=game_name,
                    episode=episode + 1,
                    turn=turn,
                    action=chosen_action,
                    reasoning=reasoning,
                    opponent=opponents,
                    generation_time=0.0,  # TODO: Add timing back
                    agent_type=agent_type,
                    agent_model=agent_model,
                    seed=episode_seed,
                    board_state=observation["state_string"]
                )

                if agent_type == "llm":
                    log_llm_action(
                        agent_id, agent_model, observation,
                        chosen_action, reasoning
                    )

            # Step forward in the environment
            if not truncated:
                (observation_dict, rewards_dict,
                 terminated, truncated, _) = env.step(action_dict)
                turn += 1

        # Logging
        game_status = "truncated" if truncated else "terminated"
        logger.info(
            "Game status: %s with rewards dict: %s", game_status, rewards_dict
        )

        for agent_id, reward in rewards_dict.items():
            policy_key = policy_mapping_fn(agent_id)
            agent_logger = agent_loggers_dict[policy_key]

            # Calculate opponents for this agent
            opponents_list = []
            for a_id in config["agents"]:
                if a_id != f"player_{agent_id}":
                    opp_agent_type = config['agents'][a_id]['type']
                    opp_model = config['agents'][a_id].get('model', 'None')
                    opp_model_clean = opp_model.replace('-', '_')
                    opponent_str = f"{opp_agent_type}_{opp_model_clean}"
                    opponents_list.append(opponent_str)
            opponents = ", ".join(opponents_list)

            # Log reward to the rewards table
            agent_logger.log_rewards(
                game_name=game_name,
                episode=episode + 1,
                reward=reward
            )

            agent_logger.log_game_result(
                game_name=game_name,
                episode=episode + 1,
                status=game_status,
                reward=reward,
                opponent=opponents
            )
            # Tensorboard logging
            agent_type = "unknown"
            agent_model = "None"

            # Find the agent config by index - handle both string and int keys
            for key, value in config["agents"].items():
                if (key.startswith("player_") and
                        int(key.split("_")[1]) == agent_id):
                    agent_type = value["type"]
                    agent_model = value.get("model", "None")
                    break
                elif str(key) == str(agent_id):
                    agent_type = value["type"]
                    agent_model = value.get("model", "None")
                    break

            tensorboard_key = f"{agent_type}_{agent_model.replace('-', '_')}"
            writer.add_scalar(
                f"Rewards/{tensorboard_key}", reward, episode + 1
            )

        logger.info(
            "Simulation for game %s, Episode %d completed.",
            game_name, episode + 1
        )
    writer.close()
    return "Simulation Completed"

# start tensorboard from the terminal:
# tensorboard --logdir=runs

# In the browser:
# http://localhost:6006/