465 lines
21 KiB
Python
465 lines
21 KiB
Python
"""Gymnasium environment wrapping the Magic Civilization player-API harness.
|
|
|
|
One Gym `step()` corresponds to one PlayerAction. When the policy's
|
|
chosen action does not advance the turn (i.e. is anything except
|
|
`end_turn`), we keep collecting actions inside the same Gym step's
|
|
trajectory until the policy emits `end_turn` or the per-turn-action
|
|
budget is exhausted. This mirrors how the built-in AI takes "a turn":
|
|
many micro-actions then an `end_turn` boundary.
|
|
|
|
The opponent is whatever AI the harness ships with for the non-Claude
|
|
slots — that's our frozen baseline. As the policy trains, we measure
|
|
its win rate against this baseline; the policy is considered to have
|
|
"beat the built-in AI" when it crosses a configurable threshold
|
|
(default 55%).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import sys
|
|
from dataclasses import replace
|
|
from typing import Any
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
from gymnasium import spaces
|
|
|
|
from .encoders import (
|
|
ACTION_DIM,
|
|
OBS_DIM,
|
|
decode_action_index,
|
|
encode_legal_actions,
|
|
encode_observation,
|
|
)
|
|
from .harness_client import HarnessClient, HarnessConfig, HarnessError
|
|
from .opponent import ModelOpponent
|
|
|
|
# Reward shape (Stage 6.1.5 redesign, 2026-05-18). The prior shape used
|
|
# TURN_ADVANCE_BONUS = 1e-2 which on a 200-turn cap accumulates +2.0 —
|
|
# larger than the +1.0 terminal win. The policy correctly learned to
|
|
# stall (60% turn-cap rate at the 6.1 eval gate). This catalog removes
|
|
# the bonus, sharpens terminals, adds event-attributed shaping from
|
|
# wire events the simulator already emits, and adds a decisive-win
|
|
# bonus + slow-game ramp to push the policy toward fast decisive play.
|
|
#
|
|
# See `~/.claude/plans/in-the-game-civilization-elegant-popcorn.md`
|
|
# Stage 6.1.5 for the budget analysis. Calibrated against Civ5 norms:
|
|
# Standard speed = 500 turns; median domination win ≈ 400 turns.
|
|
|
|
# Terminal (sparse, decisive)
|
|
WIN_BASE = 2.0
|
|
LOSS_REWARD = -2.0
|
|
DRAW_REWARD = 0.0
|
|
TURN_CAP_PENALTY = -0.5
|
|
STEP_CAP_PENALTY = -0.5
|
|
HARNESS_ERROR_REWARD = -2.0
|
|
|
|
# Decisive-win bonus — stacks on WIN_BASE; decays linearly to 0 at
|
|
# turn DECISIVE_DECAY_TURNS. Pulls the policy toward fast wins.
|
|
DECISIVE_BONUS_MAX = 2.0
|
|
DECISIVE_DECAY_TURNS = 500
|
|
|
|
# Event-driven shaping (sourced from wire events in
|
|
# `mc-player-api/src/wire.rs:135-301`, already collected by step()
|
|
# via response["events"] + self._client.drain_notifications()).
|
|
CAPITAL_CAPTURED_BY_ME = 1.0
|
|
CAPITAL_LOST_BY_ME = -1.0
|
|
CITY_CAPTURED_BY_ME = 0.30
|
|
CITY_LOST_BY_ME = -0.30
|
|
CITY_FOUNDED_BY_ME = 0.15 # capped via MAX_CITY_FOUNDED_REWARDS
|
|
MAX_CITY_FOUNDED_REWARDS = 6
|
|
WONDER_BUILT_BY_ME = 0.30
|
|
ENEMY_UNIT_KILLED_BY_ME = 0.05
|
|
OWN_UNIT_LOST_TO_ENEMY = -0.04 # asymmetric: +0.01 net on even trades
|
|
TECH_RESEARCHED_BY_ME = 0.05
|
|
CULTURE_RESEARCHED_BY_ME = 0.05
|
|
OPPONENT_ELIMINATED = 0.50
|
|
|
|
# Per-step (anti-stall + slow-game ramp). Symmetric score-delta keeps
|
|
# the dense intra-turn gradient. The slow-game ramp adds linearly-
|
|
# growing per-step pressure after SLOW_PENALTY_START turns, reaching
|
|
# SLOW_PENALTY_PEAK per step at turn SLOW_PENALTY_START + SLOW_PENALTY_SPAN.
|
|
SCORE_DELTA_SCALE = 1e-3
|
|
STEP_PENALTY_BASE = 5e-4
|
|
SLOW_PENALTY_PEAK = 1e-3
|
|
SLOW_PENALTY_START = 500
|
|
SLOW_PENALTY_SPAN = 500 # peak reached at turn 1000
|
|
|
|
|
|
def _step_penalty(turn: int) -> float:
|
|
"""Per-step penalty including the slow-game ramp.
|
|
|
|
Returns a positive number; subtract from reward."""
|
|
base = STEP_PENALTY_BASE
|
|
if turn <= SLOW_PENALTY_START:
|
|
return base
|
|
ramp = min(1.0, (turn - SLOW_PENALTY_START) / SLOW_PENALTY_SPAN)
|
|
return base + SLOW_PENALTY_PEAK * ramp
|
|
|
|
|
|
# Hard ceiling on env.step() calls per episode. A policy that learned
|
|
# "ending the turn lowers my reward" would otherwise produce episodes
|
|
# of unbounded length (observed: 1.3M harness round-trips in a single
|
|
# eval episode). A total-episode budget catches that without biasing
|
|
# intra-turn behavior — players in late game with hundreds of units
|
|
# legitimately have hundreds of micro-actions per turn, so a per-turn
|
|
# cap would interfere with normal play.
|
|
DEFAULT_MAX_STEPS_PER_EPISODE = 250_000
|
|
DEFAULT_MAX_TURNS = 1000
|
|
|
|
|
|
class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|
"""Single-learner Gym wrapper: our policy controls slot 0.
|
|
|
|
The opponent on slot 1..N-1 is one of:
|
|
* the harness's built-in MCTS (default — `opponent=None`), driven
|
|
internally by the simulator's `apply_end_turn` AI loop; or
|
|
* a frozen `ModelOpponent` (self-play curriculum), driven in-process
|
|
over the multi-slot wire. When a model opponent is supplied, both
|
|
slot 0 and the opponent slots are *externally* controlled, so the
|
|
harness's internal AI loop skips them (see dispatch.rs Stage 4)
|
|
and this env advances the opponent's turn after the learner's
|
|
`end_turn`.
|
|
"""
|
|
|
|
metadata = {"render_modes": []}
|
|
|
|
def __init__(
|
|
self,
|
|
harness_config: HarnessConfig | None = None,
|
|
max_turns: int = DEFAULT_MAX_TURNS,
|
|
max_steps_per_episode: int = DEFAULT_MAX_STEPS_PER_EPISODE,
|
|
opponent: ModelOpponent | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self._config = harness_config or HarnessConfig()
|
|
self._max_turns = max_turns
|
|
self._max_steps_per_episode = max_steps_per_episode
|
|
self._opponent = opponent
|
|
self._my_slot = self._config.player_slot
|
|
# When a model opponent drives the other slot(s), every wire call
|
|
# MUST name its slot (multi-slot adapter contract). With the
|
|
# default MCTS opponent we keep the legacy single-slot wire shape
|
|
# (slot omitted) so nothing about the shipping path changes.
|
|
self._multi_slot = opponent is not None
|
|
self._slot_kw = self._my_slot if self._multi_slot else None
|
|
self.observation_space = spaces.Box(
|
|
low=-1e6, high=1e6, shape=(OBS_DIM,), dtype=np.float32
|
|
)
|
|
self.action_space = spaces.Discrete(ACTION_DIM)
|
|
self._client: HarnessClient | None = None
|
|
self._last_view: dict[str, Any] = {}
|
|
self._last_score: float = 0.0
|
|
self._idx_to_action: dict[int, dict[str, Any]] = {}
|
|
self._cur_mask: np.ndarray = np.zeros(ACTION_DIM, dtype=bool)
|
|
self._terminated: bool = False
|
|
self._step_count: int = 0
|
|
# Maps PlayerId → city_id of that player's capital. Populated in
|
|
# reset() and updated when CityFounded for a player previously
|
|
# without a city. CityCaptured events look up old_owner here to
|
|
# decide whether to route to CAPITAL_* or CITY_* reward buckets.
|
|
self._capital_by_player: dict[int, str] = {}
|
|
# Throttle for CITY_FOUNDED_BY_ME — settler-spam protection.
|
|
self._city_founded_rewards_issued: int = 0
|
|
|
|
# ── Gymnasium API ────────────────────────────────────────────────
|
|
|
|
def reset(
|
|
self,
|
|
*,
|
|
seed: int | None = None,
|
|
options: dict[str, Any] | None = None,
|
|
) -> tuple[np.ndarray, dict[str, Any]]:
|
|
if self._client is not None:
|
|
self._client.shutdown()
|
|
cfg = self._config
|
|
# When self-playing against a model opponent, both the learner's
|
|
# slot and the opponent's slot(s) are externally driven — declare
|
|
# them so the simulator's AI loop skips them.
|
|
if self._multi_slot and self._opponent is not None:
|
|
cfg = replace(cfg, player_slots=(self._my_slot, *self._opponent.slots))
|
|
# `replace` preserves every other field (player_slots, victory_mode,
|
|
# player_controllers, …) — the old field-by-field rebuild silently
|
|
# dropped them, which would have un-declared the external slots.
|
|
if seed is not None:
|
|
cfg = replace(cfg, seed=seed)
|
|
self._client = HarnessClient(cfg)
|
|
self._terminated = False
|
|
self._step_count = 0
|
|
self._capital_by_player = {}
|
|
self._city_founded_rewards_issued = 0
|
|
view = self._client.view(slot=self._slot_kw)
|
|
# Seed capitals from any cities present at game start. In duel
|
|
# maps each player begins with a founder, so the capital map is
|
|
# populated on the first CityFounded event per player (handled
|
|
# in _apply_event_rewards). If the simulator ever pre-places
|
|
# cities, this scan picks them up.
|
|
for city in view.get("cities", []):
|
|
owner = int(city.get("owner", -1))
|
|
cid = str(city.get("id", ""))
|
|
if owner >= 0 and cid and owner not in self._capital_by_player:
|
|
self._capital_by_player[owner] = cid
|
|
self._sync_state(view)
|
|
return encode_observation(view), {"action_mask": self._cur_mask.copy()}
|
|
|
|
def step(
|
|
self, action: np.int64 | int
|
|
) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
|
|
if self._client is None:
|
|
raise RuntimeError("step() called before reset()")
|
|
if self._terminated:
|
|
raise RuntimeError("step() called on terminated env; call reset()")
|
|
|
|
idx = int(action)
|
|
if not self._cur_mask[idx]:
|
|
# Mask should prevent this, but be defensive: substitute end_turn.
|
|
idx = 0
|
|
player_action = decode_action_index(idx, self._idx_to_action)
|
|
self._step_count += 1
|
|
|
|
prev_turn = int(self._last_view.get("turn", 0))
|
|
reward = -_step_penalty(prev_turn)
|
|
opp_events: list[dict[str, Any]] = []
|
|
try:
|
|
if player_action.get("type") == "end_turn":
|
|
response = self._client.end_turn(slot=self._slot_kw)
|
|
# With a frozen model opponent, the simulator's AI loop
|
|
# skips the opponent slot (it is externally declared) — so
|
|
# we drive its full turn here. With the default MCTS
|
|
# opponent this is a no-op: the AI loop already ran inside
|
|
# the end_turn dispatch and its events are in `response`.
|
|
if self._opponent is not None:
|
|
opp_events = self._opponent.play_turn(self._client)
|
|
else:
|
|
response = self._client.act(player_action, slot=self._slot_kw)
|
|
except HarnessError:
|
|
# Treat any harness failure as a loss — bad action, dead
|
|
# subprocess, etc. Terminate the episode.
|
|
self._terminated = True
|
|
return (
|
|
np.zeros(OBS_DIM, dtype=np.float32),
|
|
HARNESS_ERROR_REWARD,
|
|
True,
|
|
False,
|
|
{"action_mask": np.zeros(ACTION_DIM, dtype=bool), "reason": "harness_error"},
|
|
)
|
|
|
|
view = self._client.view(slot=self._slot_kw)
|
|
# Collect synchronous events from the act response + the opponent's
|
|
# turn + any async notifications buffered while waiting for view's
|
|
# response. Terminal events (game_over / player_eliminated) may
|
|
# have fired during the opponent's turn between our act and view.
|
|
recent_events: list[dict[str, Any]] = list(response.get("events", []))
|
|
recent_events.extend(opp_events)
|
|
recent_events.extend(self._client.drain_notifications())
|
|
new_turn = int(view.get("turn", 0))
|
|
me = int(view.get("player", 0))
|
|
prev_score = self._last_score
|
|
new_score = float(view.get("score", {}).get("score_estimate", 0.0))
|
|
# Symmetric score-delta — gains and losses both count.
|
|
reward += SCORE_DELTA_SCALE * (new_score - prev_score)
|
|
# Event-driven shaping (Phase 1 catalog).
|
|
reward += self._apply_event_rewards(recent_events, me)
|
|
|
|
terminated, terminal_reward, reason = self._check_termination(view, recent_events)
|
|
if terminated and reason == "won":
|
|
# Decisive bonus: linearly decays to 0 at DECISIVE_DECAY_TURNS.
|
|
decay = max(0.0, 1.0 - new_turn / DECISIVE_DECAY_TURNS)
|
|
terminal_reward += DECISIVE_BONUS_MAX * decay
|
|
reward += terminal_reward
|
|
self._sync_state(view)
|
|
self._terminated = terminated
|
|
|
|
step_capped = (
|
|
not terminated
|
|
and self._step_count >= self._max_steps_per_episode
|
|
)
|
|
turn_capped = (
|
|
not terminated
|
|
and int(view.get("turn", 0)) >= self._max_turns
|
|
)
|
|
truncated = step_capped or turn_capped
|
|
if truncated:
|
|
self._terminated = True
|
|
# Stalling without resolving the game is worse than losing
|
|
# decisively — apply a cap penalty so the policy learns to
|
|
# commit. Without this, "drag to the cap" was the equilibrium.
|
|
if turn_capped:
|
|
reward += TURN_CAP_PENALTY
|
|
else:
|
|
reward += STEP_CAP_PENALTY
|
|
info: dict[str, Any] = {
|
|
"action_mask": self._cur_mask.copy(),
|
|
"turn": int(view.get("turn", 0)),
|
|
"score": new_score,
|
|
"city_count": int(view.get("score", {}).get("city_count", 0)),
|
|
}
|
|
if self._opponent is not None:
|
|
# Diagnostic: how many wire events the frozen opponent's turn
|
|
# produced this step. Zero across a whole episode means the
|
|
# opponent never actually acted (e.g. stale binary not skipping
|
|
# the external slot) — the smoke asserts this is >0.
|
|
info["opp_events"] = len(opp_events)
|
|
if reason:
|
|
info["reason"] = reason
|
|
elif step_capped:
|
|
info["reason"] = "step_cap"
|
|
print(
|
|
f"[MagicCivEnv] step_cap hit at step={self._step_count} "
|
|
f"turn={int(view.get('turn', 0))} — truncating episode",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
elif turn_capped:
|
|
info["reason"] = "turn_cap"
|
|
return encode_observation(view), reward, terminated, truncated, info
|
|
|
|
def close(self) -> None:
|
|
if self._client is not None:
|
|
self._client.shutdown()
|
|
self._client = None
|
|
|
|
# ── Internals ────────────────────────────────────────────────────
|
|
|
|
def _sync_state(self, view: dict[str, Any]) -> None:
|
|
self._last_view = view
|
|
self._last_score = float(view.get("score", {}).get("score_estimate", 0.0))
|
|
mask, idx_to_action = encode_legal_actions(view)
|
|
self._cur_mask = mask
|
|
self._idx_to_action = idx_to_action
|
|
|
|
def _check_termination(
|
|
self, view: dict[str, Any], recent_events: list[dict[str, Any]]
|
|
) -> tuple[bool, float, str | None]:
|
|
"""Decide whether the episode ended this step.
|
|
|
|
Termination signals, in priority order:
|
|
1. `game_over` event from the simulator: winner == us → win;
|
|
winner == opponent → loss.
|
|
2. `player_eliminated` event for us → loss.
|
|
3. `player_eliminated` event for our last opponent → win.
|
|
4. Defensive fallback: cities==0 and no founder → loss (in
|
|
case the simulator's elimination wiring lags one step).
|
|
"""
|
|
me = int(view.get("player", 0))
|
|
eliminated_players: set[int] = set()
|
|
for ev in recent_events:
|
|
kind = ev.get("type")
|
|
if kind == "game_over":
|
|
winner = int(ev.get("winner", -1))
|
|
if winner == me:
|
|
return True, WIN_BASE, "won"
|
|
return True, LOSS_REWARD, "eliminated"
|
|
if kind == "player_eliminated":
|
|
eliminated_players.add(int(ev.get("player", -1)))
|
|
if me in eliminated_players:
|
|
return True, LOSS_REWARD, "eliminated"
|
|
if eliminated_players:
|
|
# Any opponent elimination — duel maps have only one opponent
|
|
# so this is decisive. Multi-player maps would need to track
|
|
# the remaining-player set, but Game 1 is 1v1 by design.
|
|
return True, WIN_BASE, "won"
|
|
# Defensive fallback for the case where the simulator drops the
|
|
# game_over event (observed in early integration tests).
|
|
score = view.get("score", {})
|
|
if int(score.get("city_count", 0)) == 0:
|
|
units = view.get("units", [])
|
|
has_founder = any(
|
|
int(u.get("owner", -1)) == me
|
|
and "founder" in str(u.get("type", ""))
|
|
and float(u.get("hp", 0)) > 0
|
|
for u in units
|
|
)
|
|
if not has_founder:
|
|
return True, LOSS_REWARD, "eliminated"
|
|
return False, 0.0, None
|
|
|
|
def action_masks(self) -> np.ndarray:
|
|
"""sb3-contrib MaskablePPO hook — returns the current mask."""
|
|
return self._cur_mask.copy()
|
|
|
|
def _apply_event_rewards(
|
|
self, events: list[dict[str, Any]], me: int
|
|
) -> float:
|
|
"""Phase 1 event-driven reward catalog.
|
|
|
|
Sources from already-emitted wire events
|
|
(`mc-player-api/src/wire.rs:135-301`). Terminal events
|
|
(`game_over`, `player_eliminated`) are handled in
|
|
`_check_termination`, not here.
|
|
"""
|
|
total = 0.0
|
|
for ev in events:
|
|
kind = ev.get("type")
|
|
if kind == "city_founded":
|
|
owner = int(ev.get("owner", -1))
|
|
cid = str(ev.get("city_id", ""))
|
|
# Track capitals: first city per player is their capital.
|
|
if owner >= 0 and cid and owner not in self._capital_by_player:
|
|
self._capital_by_player[owner] = cid
|
|
if owner == me:
|
|
if self._city_founded_rewards_issued < MAX_CITY_FOUNDED_REWARDS:
|
|
total += CITY_FOUNDED_BY_ME
|
|
self._city_founded_rewards_issued += 1
|
|
elif kind == "city_captured":
|
|
old_owner = int(ev.get("old_owner", -1))
|
|
new_owner = int(ev.get("new_owner", -1))
|
|
cid = str(ev.get("city_id", ""))
|
|
is_capital = (
|
|
old_owner >= 0
|
|
and self._capital_by_player.get(old_owner) == cid
|
|
)
|
|
if new_owner == me:
|
|
total += CAPITAL_CAPTURED_BY_ME if is_capital else CITY_CAPTURED_BY_ME
|
|
elif old_owner == me:
|
|
total += CAPITAL_LOST_BY_ME if is_capital else CITY_LOST_BY_ME
|
|
# When a capital changes hands, the *capturer's* first
|
|
# city is still their own capital — don't reassign.
|
|
elif kind == "wonder_built":
|
|
if int(ev.get("player", -1)) == me:
|
|
total += WONDER_BUILT_BY_ME
|
|
elif kind == "combat_resolved":
|
|
# Attribution: the wire event carries unit ids, not owners.
|
|
# We synthesise from defender_killed/attacker_killed plus
|
|
# the unit_destroyed events that should accompany them.
|
|
# Skip here; let unit_destroyed do the bookkeeping to
|
|
# avoid double-counting.
|
|
pass
|
|
elif kind == "unit_destroyed":
|
|
# Need owner attribution. The PlayerView snapshot has the
|
|
# owner before destruction; we look up via the last view.
|
|
uid = str(ev.get("unit_id", ""))
|
|
owner = self._unit_owner_lookup(uid)
|
|
if owner == me:
|
|
total += OWN_UNIT_LOST_TO_ENEMY
|
|
elif owner >= 0:
|
|
# Enemy unit destroyed — we get kill credit *only* if
|
|
# we have a killer_unit_id we own, OR no killer info
|
|
# at all (treat as our kill — conservative since the
|
|
# asymmetric ±0.04/+0.05 is net-positive on even trades).
|
|
killer = ev.get("killer_unit_id")
|
|
if killer is None or self._unit_owner_lookup(str(killer)) == me:
|
|
total += ENEMY_UNIT_KILLED_BY_ME
|
|
elif kind == "tech_researched":
|
|
if int(ev.get("player", -1)) == me:
|
|
total += TECH_RESEARCHED_BY_ME
|
|
elif kind == "culture_researched":
|
|
if int(ev.get("player", -1)) == me:
|
|
total += CULTURE_RESEARCHED_BY_ME
|
|
elif kind == "player_eliminated":
|
|
p = int(ev.get("player", -1))
|
|
if p != me and p >= 0:
|
|
total += OPPONENT_ELIMINATED
|
|
return total
|
|
|
|
def _unit_owner_lookup(self, unit_id: str) -> int:
|
|
"""Resolve a unit_id → owner from the last synced PlayerView.
|
|
|
|
Returns -1 when the unit is no longer present (already destroyed
|
|
in the same step batch) — caller treats this as unknown owner.
|
|
"""
|
|
if not unit_id:
|
|
return -1
|
|
for u in self._last_view.get("units", []):
|
|
if str(u.get("id", "")) == unit_id:
|
|
return int(u.get("owner", -1))
|
|
return -1
|