"""Gymnasium environment wrapping the Magic Civilization player-API harness. One Gym `step()` corresponds to one PlayerAction. When the policy's chosen action does not advance the turn (i.e. is anything except `end_turn`), we keep collecting actions inside the same Gym step's trajectory until the policy emits `end_turn` or the per-turn-action budget is exhausted. This mirrors how the built-in AI takes "a turn": many micro-actions then an `end_turn` boundary. The opponent is whatever AI the harness ships with for the non-Claude slots — that's our frozen baseline. As the policy trains, we measure its win rate against this baseline; the policy is considered to have "beat the built-in AI" when it crosses a configurable threshold (default 55%). """ from __future__ import annotations from typing import Any import gymnasium as gym import numpy as np from gymnasium import spaces from .encoders import ( ACTION_DIM, OBS_DIM, decode_action_index, encode_legal_actions, encode_observation, ) from .harness_client import HarnessClient, HarnessConfig, HarnessError # Reward shape: # +1.0 on win (score-fallback or domination) # -1.0 on loss (all our cities lost OR opponent wins) # 0.0 on draw / unresolved at turn limit # Plus an intermediate dense signal: small reward for each delta in # score_estimate so the policy doesn't have to learn from sparse # terminal-only rewards from scratch. Scaled small (1e-3) so terminal # dominates once the agent starts winning. SCORE_DELTA_SCALE = 1e-3 WIN_REWARD = 1.0 LOSS_REWARD = -1.0 DRAW_REWARD = 0.0 # Hard ceiling on env.step() calls per episode. A policy that learned # "ending the turn lowers my reward" would otherwise produce episodes # of unbounded length (observed: 1.3M harness round-trips in a single # eval episode). A total-episode budget catches that without biasing # intra-turn behavior — players in late game with hundreds of units # legitimately have hundreds of micro-actions per turn, so a per-turn # cap would interfere with normal play. 50k bounds eval wall-clock to # ~10 min at 50 fps while sitting an order of magnitude above any # plausibly legitimate game length (200 units * 200 turns * 5 acts/unit # = 200k upper bound, but real PPO eval games end far earlier). DEFAULT_MAX_STEPS_PER_EPISODE = 50_000 class MagicCivEnv(gym.Env[np.ndarray, np.int64]): """Single-player Gym wrapper: our policy controls slot 0, the harness's built-in AI controls slot 1..N-1.""" metadata = {"render_modes": []} def __init__( self, harness_config: HarnessConfig | None = None, max_turns: int = 200, max_micro_actions_per_turn: int = DEFAULT_MAX_MICRO_ACTIONS_PER_TURN, ) -> None: super().__init__() self._config = harness_config or HarnessConfig() self._max_turns = max_turns self._max_micro_actions_per_turn = max_micro_actions_per_turn self.observation_space = spaces.Box( low=-1e6, high=1e6, shape=(OBS_DIM,), dtype=np.float32 ) self.action_space = spaces.Discrete(ACTION_DIM) self._client: HarnessClient | None = None self._last_view: dict[str, Any] = {} self._last_score: float = 0.0 self._idx_to_action: dict[int, dict[str, Any]] = {} self._cur_mask: np.ndarray = np.zeros(ACTION_DIM, dtype=bool) self._terminated: bool = False self._cur_turn: int = 0 self._micro_actions_this_turn: int = 0 # ── Gymnasium API ──────────────────────────────────────────────── def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None, ) -> tuple[np.ndarray, dict[str, Any]]: if self._client is not None: self._client.shutdown() cfg = self._config if seed is not None: cfg = HarnessConfig( seed=seed, players=cfg.players, player_slot=cfg.player_slot, map_size=cfg.map_size, map_type=cfg.map_type, omniscient=cfg.omniscient, timeout_sec=cfg.timeout_sec, ) self._client = HarnessClient(cfg) self._terminated = False self._cur_turn = 0 self._micro_actions_this_turn = 0 view = self._client.view() self._sync_state(view) return encode_observation(view), {"action_mask": self._cur_mask.copy()} def step( self, action: np.int64 | int ) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: if self._client is None: raise RuntimeError("step() called before reset()") if self._terminated: raise RuntimeError("step() called on terminated env; call reset()") idx = int(action) if not self._cur_mask[idx]: # Mask should prevent this, but be defensive: substitute end_turn. idx = 0 player_action = decode_action_index(idx, self._idx_to_action) # Hard ceiling: if the policy refuses to end its turn after # MAX_MICRO_ACTIONS_PER_TURN, force end_turn. Without this an eval # policy that has learned "ending the turn lowers my reward" # produces an episode of unbounded length. forced_end = False if ( self._micro_actions_this_turn >= MAX_MICRO_ACTIONS_PER_TURN and player_action.get("type") != "end_turn" ): player_action = {"type": "end_turn"} forced_end = True reward = 0.0 try: if player_action.get("type") == "end_turn": self._client.end_turn() else: self._client.act(player_action) except HarnessError: # Treat any harness failure as a loss — bad action, dead # subprocess, etc. Terminate the episode. self._terminated = True return ( np.zeros(OBS_DIM, dtype=np.float32), LOSS_REWARD, True, False, {"action_mask": np.zeros(ACTION_DIM, dtype=bool), "reason": "harness_error"}, ) view = self._client.view() new_turn = int(view.get("turn", 0)) if new_turn != self._cur_turn: self._cur_turn = new_turn self._micro_actions_this_turn = 0 else: self._micro_actions_this_turn += 1 prev_score = self._last_score new_score = float(view.get("score", {}).get("score_estimate", 0.0)) reward += SCORE_DELTA_SCALE * (new_score - prev_score) terminated, terminal_reward, reason = self._check_termination(view) reward += terminal_reward self._sync_state(view) self._terminated = terminated truncated = (not terminated) and int(view.get("turn", 0)) >= self._max_turns if truncated: self._terminated = True info: dict[str, Any] = { "action_mask": self._cur_mask.copy(), "turn": int(view.get("turn", 0)), "score": new_score, "city_count": int(view.get("score", {}).get("city_count", 0)), } if reason: info["reason"] = reason if forced_end: info["forced_end_turn"] = True return encode_observation(view), reward, terminated, truncated, info def close(self) -> None: if self._client is not None: self._client.shutdown() self._client = None # ── Internals ──────────────────────────────────────────────────── def _sync_state(self, view: dict[str, Any]) -> None: self._last_view = view self._last_score = float(view.get("score", {}).get("score_estimate", 0.0)) mask, idx_to_action = encode_legal_actions(view) self._cur_mask = mask self._idx_to_action = idx_to_action def _check_termination( self, view: dict[str, Any] ) -> tuple[bool, float, str | None]: """Decide whether the episode ended this step. Termination conditions: - All our cities + founders gone → loss - Opponent in same state → win - Score-fallback victory or domination victory recorded in the view's `pending_events` / global `winner_index` (TODO once the player API exposes it consistently) """ score = view.get("score", {}) if int(score.get("city_count", 0)) == 0: # Verify no founder either — a founder can still found a city. units = view.get("units", []) me = int(view.get("player", 0)) has_founder = any( int(u.get("owner", -1)) == me and "founder" in str(u.get("type", "")) and float(u.get("hp", 0)) > 0 for u in units ) if not has_founder: return True, LOSS_REWARD, "eliminated" # TODO: detect win via view; for now the env relies on max_turns # truncation + win/loss via elimination only. return False, 0.0, None def action_masks(self) -> np.ndarray: """sb3-contrib MaskablePPO hook — returns the current mask.""" return self._cur_mask.copy()