2026-05-17 03:54:40 -07:00
|
|
|
"""Gymnasium environment wrapping the Magic Civilization player-API harness.
|
|
|
|
|
|
|
|
|
|
One Gym `step()` corresponds to one PlayerAction. When the policy's
|
|
|
|
|
chosen action does not advance the turn (i.e. is anything except
|
|
|
|
|
`end_turn`), we keep collecting actions inside the same Gym step's
|
|
|
|
|
trajectory until the policy emits `end_turn` or the per-turn-action
|
|
|
|
|
budget is exhausted. This mirrors how the built-in AI takes "a turn":
|
|
|
|
|
many micro-actions then an `end_turn` boundary.
|
|
|
|
|
|
|
|
|
|
The opponent is whatever AI the harness ships with for the non-Claude
|
|
|
|
|
slots — that's our frozen baseline. As the policy trains, we measure
|
|
|
|
|
its win rate against this baseline; the policy is considered to have
|
|
|
|
|
"beat the built-in AI" when it crosses a configurable threshold
|
|
|
|
|
(default 55%).
|
|
|
|
|
"""
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2026-05-17 05:34:29 -07:00
|
|
|
import sys
|
2026-05-17 03:54:40 -07:00
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
import gymnasium as gym
|
|
|
|
|
import numpy as np
|
|
|
|
|
from gymnasium import spaces
|
|
|
|
|
|
|
|
|
|
from .encoders import (
|
|
|
|
|
ACTION_DIM,
|
|
|
|
|
OBS_DIM,
|
|
|
|
|
decode_action_index,
|
|
|
|
|
encode_legal_actions,
|
|
|
|
|
encode_observation,
|
|
|
|
|
)
|
|
|
|
|
from .harness_client import HarnessClient, HarnessConfig, HarnessError
|
|
|
|
|
|
|
|
|
|
# Reward shape:
|
|
|
|
|
# +1.0 on win (score-fallback or domination)
|
|
|
|
|
# -1.0 on loss (all our cities lost OR opponent wins)
|
|
|
|
|
# 0.0 on draw / unresolved at turn limit
|
|
|
|
|
# Plus an intermediate dense signal: small reward for each delta in
|
|
|
|
|
# score_estimate so the policy doesn't have to learn from sparse
|
|
|
|
|
# terminal-only rewards from scratch. Scaled small (1e-3) so terminal
|
|
|
|
|
# dominates once the agent starts winning.
|
|
|
|
|
SCORE_DELTA_SCALE = 1e-3
|
|
|
|
|
WIN_REWARD = 1.0
|
|
|
|
|
LOSS_REWARD = -1.0
|
|
|
|
|
DRAW_REWARD = 0.0
|
|
|
|
|
|
2026-05-17 05:28:24 -07:00
|
|
|
# Hard ceiling on env.step() calls per episode. A policy that learned
|
|
|
|
|
# "ending the turn lowers my reward" would otherwise produce episodes
|
|
|
|
|
# of unbounded length (observed: 1.3M harness round-trips in a single
|
|
|
|
|
# eval episode). A total-episode budget catches that without biasing
|
|
|
|
|
# intra-turn behavior — players in late game with hundreds of units
|
|
|
|
|
# legitimately have hundreds of micro-actions per turn, so a per-turn
|
|
|
|
|
# cap would interfere with normal play. 50k bounds eval wall-clock to
|
|
|
|
|
# ~10 min at 50 fps while sitting an order of magnitude above any
|
|
|
|
|
# plausibly legitimate game length (200 units * 200 turns * 5 acts/unit
|
|
|
|
|
# = 200k upper bound, but real PPO eval games end far earlier).
|
|
|
|
|
DEFAULT_MAX_STEPS_PER_EPISODE = 50_000
|
2026-05-17 03:54:40 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|
|
|
|
"""Single-player Gym wrapper: our policy controls slot 0, the
|
|
|
|
|
harness's built-in AI controls slot 1..N-1."""
|
|
|
|
|
|
|
|
|
|
metadata = {"render_modes": []}
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
harness_config: HarnessConfig | None = None,
|
|
|
|
|
max_turns: int = 200,
|
2026-05-17 05:34:29 -07:00
|
|
|
max_steps_per_episode: int = DEFAULT_MAX_STEPS_PER_EPISODE,
|
2026-05-17 03:54:40 -07:00
|
|
|
) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
self._config = harness_config or HarnessConfig()
|
|
|
|
|
self._max_turns = max_turns
|
2026-05-17 05:34:29 -07:00
|
|
|
self._max_steps_per_episode = max_steps_per_episode
|
2026-05-17 03:54:40 -07:00
|
|
|
self.observation_space = spaces.Box(
|
|
|
|
|
low=-1e6, high=1e6, shape=(OBS_DIM,), dtype=np.float32
|
|
|
|
|
)
|
|
|
|
|
self.action_space = spaces.Discrete(ACTION_DIM)
|
|
|
|
|
self._client: HarnessClient | None = None
|
|
|
|
|
self._last_view: dict[str, Any] = {}
|
|
|
|
|
self._last_score: float = 0.0
|
|
|
|
|
self._idx_to_action: dict[int, dict[str, Any]] = {}
|
|
|
|
|
self._cur_mask: np.ndarray = np.zeros(ACTION_DIM, dtype=bool)
|
|
|
|
|
self._terminated: bool = False
|
2026-05-17 05:34:29 -07:00
|
|
|
self._step_count: int = 0
|
2026-05-17 03:54:40 -07:00
|
|
|
|
|
|
|
|
# ── Gymnasium API ────────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
def reset(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
seed: int | None = None,
|
|
|
|
|
options: dict[str, Any] | None = None,
|
|
|
|
|
) -> tuple[np.ndarray, dict[str, Any]]:
|
|
|
|
|
if self._client is not None:
|
|
|
|
|
self._client.shutdown()
|
|
|
|
|
cfg = self._config
|
|
|
|
|
if seed is not None:
|
|
|
|
|
cfg = HarnessConfig(
|
|
|
|
|
seed=seed,
|
|
|
|
|
players=cfg.players,
|
|
|
|
|
player_slot=cfg.player_slot,
|
|
|
|
|
map_size=cfg.map_size,
|
|
|
|
|
map_type=cfg.map_type,
|
|
|
|
|
omniscient=cfg.omniscient,
|
|
|
|
|
timeout_sec=cfg.timeout_sec,
|
|
|
|
|
)
|
|
|
|
|
self._client = HarnessClient(cfg)
|
|
|
|
|
self._terminated = False
|
2026-05-17 05:34:29 -07:00
|
|
|
self._step_count = 0
|
2026-05-17 03:54:40 -07:00
|
|
|
view = self._client.view()
|
|
|
|
|
self._sync_state(view)
|
|
|
|
|
return encode_observation(view), {"action_mask": self._cur_mask.copy()}
|
|
|
|
|
|
|
|
|
|
def step(
|
|
|
|
|
self, action: np.int64 | int
|
|
|
|
|
) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
|
|
|
|
|
if self._client is None:
|
|
|
|
|
raise RuntimeError("step() called before reset()")
|
|
|
|
|
if self._terminated:
|
|
|
|
|
raise RuntimeError("step() called on terminated env; call reset()")
|
|
|
|
|
|
|
|
|
|
idx = int(action)
|
|
|
|
|
if not self._cur_mask[idx]:
|
|
|
|
|
# Mask should prevent this, but be defensive: substitute end_turn.
|
|
|
|
|
idx = 0
|
|
|
|
|
player_action = decode_action_index(idx, self._idx_to_action)
|
2026-05-17 05:34:29 -07:00
|
|
|
self._step_count += 1
|
2026-05-17 05:16:18 -07:00
|
|
|
|
2026-05-17 03:54:40 -07:00
|
|
|
reward = 0.0
|
|
|
|
|
try:
|
|
|
|
|
if player_action.get("type") == "end_turn":
|
|
|
|
|
self._client.end_turn()
|
|
|
|
|
else:
|
|
|
|
|
self._client.act(player_action)
|
|
|
|
|
except HarnessError:
|
|
|
|
|
# Treat any harness failure as a loss — bad action, dead
|
|
|
|
|
# subprocess, etc. Terminate the episode.
|
|
|
|
|
self._terminated = True
|
|
|
|
|
return (
|
|
|
|
|
np.zeros(OBS_DIM, dtype=np.float32),
|
|
|
|
|
LOSS_REWARD,
|
|
|
|
|
True,
|
|
|
|
|
False,
|
|
|
|
|
{"action_mask": np.zeros(ACTION_DIM, dtype=bool), "reason": "harness_error"},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
view = self._client.view()
|
|
|
|
|
prev_score = self._last_score
|
|
|
|
|
new_score = float(view.get("score", {}).get("score_estimate", 0.0))
|
|
|
|
|
reward += SCORE_DELTA_SCALE * (new_score - prev_score)
|
|
|
|
|
|
|
|
|
|
terminated, terminal_reward, reason = self._check_termination(view)
|
|
|
|
|
reward += terminal_reward
|
|
|
|
|
self._sync_state(view)
|
|
|
|
|
self._terminated = terminated
|
|
|
|
|
|
2026-05-17 05:34:29 -07:00
|
|
|
step_capped = (
|
|
|
|
|
not terminated
|
|
|
|
|
and self._step_count >= self._max_steps_per_episode
|
|
|
|
|
)
|
|
|
|
|
turn_capped = (
|
|
|
|
|
not terminated
|
|
|
|
|
and int(view.get("turn", 0)) >= self._max_turns
|
|
|
|
|
)
|
|
|
|
|
truncated = step_capped or turn_capped
|
2026-05-17 03:54:40 -07:00
|
|
|
if truncated:
|
|
|
|
|
self._terminated = True
|
|
|
|
|
info: dict[str, Any] = {
|
|
|
|
|
"action_mask": self._cur_mask.copy(),
|
|
|
|
|
"turn": int(view.get("turn", 0)),
|
|
|
|
|
"score": new_score,
|
|
|
|
|
"city_count": int(view.get("score", {}).get("city_count", 0)),
|
|
|
|
|
}
|
|
|
|
|
if reason:
|
|
|
|
|
info["reason"] = reason
|
2026-05-17 05:34:29 -07:00
|
|
|
elif step_capped:
|
|
|
|
|
info["reason"] = "step_cap"
|
|
|
|
|
print(
|
|
|
|
|
f"[MagicCivEnv] step_cap hit at step={self._step_count} "
|
|
|
|
|
f"turn={int(view.get('turn', 0))} — truncating episode",
|
|
|
|
|
file=sys.stderr,
|
|
|
|
|
flush=True,
|
|
|
|
|
)
|
|
|
|
|
elif turn_capped:
|
|
|
|
|
info["reason"] = "turn_cap"
|
2026-05-17 03:54:40 -07:00
|
|
|
return encode_observation(view), reward, terminated, truncated, info
|
|
|
|
|
|
|
|
|
|
def close(self) -> None:
|
|
|
|
|
if self._client is not None:
|
|
|
|
|
self._client.shutdown()
|
|
|
|
|
self._client = None
|
|
|
|
|
|
|
|
|
|
# ── Internals ────────────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
def _sync_state(self, view: dict[str, Any]) -> None:
|
|
|
|
|
self._last_view = view
|
|
|
|
|
self._last_score = float(view.get("score", {}).get("score_estimate", 0.0))
|
|
|
|
|
mask, idx_to_action = encode_legal_actions(view)
|
|
|
|
|
self._cur_mask = mask
|
|
|
|
|
self._idx_to_action = idx_to_action
|
|
|
|
|
|
|
|
|
|
def _check_termination(
|
|
|
|
|
self, view: dict[str, Any]
|
|
|
|
|
) -> tuple[bool, float, str | None]:
|
|
|
|
|
"""Decide whether the episode ended this step.
|
|
|
|
|
|
|
|
|
|
Termination conditions:
|
|
|
|
|
- All our cities + founders gone → loss
|
|
|
|
|
- Opponent in same state → win
|
|
|
|
|
- Score-fallback victory or domination victory recorded in the
|
|
|
|
|
view's `pending_events` / global `winner_index` (TODO once
|
|
|
|
|
the player API exposes it consistently)
|
|
|
|
|
"""
|
|
|
|
|
score = view.get("score", {})
|
|
|
|
|
if int(score.get("city_count", 0)) == 0:
|
|
|
|
|
# Verify no founder either — a founder can still found a city.
|
|
|
|
|
units = view.get("units", [])
|
|
|
|
|
me = int(view.get("player", 0))
|
|
|
|
|
has_founder = any(
|
|
|
|
|
int(u.get("owner", -1)) == me
|
|
|
|
|
and "founder" in str(u.get("type", ""))
|
|
|
|
|
and float(u.get("hp", 0)) > 0
|
|
|
|
|
for u in units
|
|
|
|
|
)
|
|
|
|
|
if not has_founder:
|
|
|
|
|
return True, LOSS_REWARD, "eliminated"
|
|
|
|
|
# TODO: detect win via view; for now the env relies on max_turns
|
|
|
|
|
# truncation + win/loss via elimination only.
|
|
|
|
|
return False, 0.0, None
|
|
|
|
|
|
|
|
|
|
def action_masks(self) -> np.ndarray:
|
|
|
|
|
"""sb3-contrib MaskablePPO hook — returns the current mask."""
|
|
|
|
|
return self._cur_mask.copy()
|