magicciv/tooling/rl_self_play/magic_civ_env.py
Natalie e1f3a66a67
Some checks failed
ci / regression gate (push) Failing after 54s
tune(rl): drop SCORE_DELTA_SCALE 1e-3 -> 1e-4 for the unified raw score
score_estimate is now the unbounded unified score (~10-20x the old clamped [0,1000] magnitude);
scale the per-turn score-delta reward down to keep it in range with the other reward terms.
Empirical retune tracked for when the self-play stable resumes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-30 20:40:48 -04:00

572 lines
26 KiB
Python

"""Gymnasium environment wrapping the Magic Civilization player-API harness.
One Gym `step()` corresponds to one PlayerAction. When the policy's
chosen action does not advance the turn (i.e. is anything except
`end_turn`), we keep collecting actions inside the same Gym step's
trajectory until the policy emits `end_turn` or the per-turn-action
budget is exhausted. This mirrors how the built-in AI takes "a turn":
many micro-actions then an `end_turn` boundary.
The opponent is whatever AI the harness ships with for the non-Claude
slots — that's our frozen baseline. As the policy trains, we measure
its win rate against this baseline; the policy is considered to have
"beat the built-in AI" when it crosses a configurable threshold
(default 55%).
"""
from __future__ import annotations
import json
import os
import random
import sys
import time
from dataclasses import replace
from typing import Any
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from .encoders import (
ACTION_DIM,
OBS_DIM,
decode_action_index,
encode_legal_actions,
encode_observation,
)
from .harness_client import HarnessClient, HarnessConfig, HarnessError
from .opponent import ModelOpponent
# Reward shape (Stage 6.1.5 redesign, 2026-05-18). The prior shape used
# TURN_ADVANCE_BONUS = 1e-2 which on a 200-turn cap accumulates +2.0 —
# larger than the +1.0 terminal win. The policy correctly learned to
# stall (60% turn-cap rate at the 6.1 eval gate). This catalog removes
# the bonus, sharpens terminals, adds event-attributed shaping from
# wire events the simulator already emits, and adds a decisive-win
# bonus + slow-game ramp to push the policy toward fast decisive play.
#
# See `~/.claude/plans/in-the-game-civilization-elegant-popcorn.md`
# Stage 6.1.5 for the budget analysis. Calibrated against Civ5 norms:
# Standard speed = 500 turns; median domination win ≈ 400 turns.
# Terminal (sparse, decisive)
WIN_BASE = 2.0
LOSS_REWARD = -2.0
DRAW_REWARD = 0.0
TURN_CAP_PENALTY = -0.5
STEP_CAP_PENALTY = -0.5
HARNESS_ERROR_REWARD = -2.0
# Decisive-win bonus — stacks on WIN_BASE; decays linearly to 0 at
# turn DECISIVE_DECAY_TURNS. Pulls the policy toward fast wins.
DECISIVE_BONUS_MAX = 2.0
DECISIVE_DECAY_TURNS = 500
# Event-driven shaping (sourced from wire events in
# `mc-player-api/src/wire.rs:135-301`, already collected by step()
# via response["events"] + self._client.drain_notifications()).
CAPITAL_CAPTURED_BY_ME = 1.0
CAPITAL_LOST_BY_ME = -1.0
CITY_CAPTURED_BY_ME = 0.30
CITY_LOST_BY_ME = -0.30
CITY_FOUNDED_BY_ME = 0.15 # capped via MAX_CITY_FOUNDED_REWARDS
MAX_CITY_FOUNDED_REWARDS = 6
WONDER_BUILT_BY_ME = 0.30
ENEMY_UNIT_KILLED_BY_ME = 0.05
OWN_UNIT_LOST_TO_ENEMY = -0.04 # asymmetric: +0.01 net on even trades
TECH_RESEARCHED_BY_ME = 0.05
CULTURE_RESEARCHED_BY_ME = 0.05
OPPONENT_ELIMINATED = 0.50
# Per-step (anti-stall + slow-game ramp). Symmetric score-delta keeps
# the dense intra-turn gradient. The slow-game ramp adds linearly-
# growing per-step pressure after SLOW_PENALTY_START turns, reaching
# SLOW_PENALTY_PEAK per step at turn SLOW_PENALTY_START + SLOW_PENALTY_SPAN.
#
# NOTE: score_estimate is now the UNIFIED raw score (mc-score ScoreController,
# unbounded) — ~10-20x larger magnitude than the old clamped [0,1000] scale, so
# SCORE_DELTA_SCALE was dropped from 1e-3 to 1e-4 to keep the per-turn score
# reward in the same range as the other terms. Retune empirically once the
# self-play stable resumes training on the unified objective.
SCORE_DELTA_SCALE = 1e-4
STEP_PENALTY_BASE = 5e-4
SLOW_PENALTY_PEAK = 1e-3
SLOW_PENALTY_START = 500
SLOW_PENALTY_SPAN = 500 # peak reached at turn 1000
def _step_penalty(turn: int) -> float:
"""Per-step penalty including the slow-game ramp.
Returns a positive number; subtract from reward."""
base = STEP_PENALTY_BASE
if turn <= SLOW_PENALTY_START:
return base
ramp = min(1.0, (turn - SLOW_PENALTY_START) / SLOW_PENALTY_SPAN)
return base + SLOW_PENALTY_PEAK * ramp
# Hard ceiling on env.step() calls per episode. A policy that learned
# "ending the turn lowers my reward" would otherwise produce episodes
# of unbounded length (observed: 1.3M harness round-trips in a single
# eval episode). A total-episode budget catches that without biasing
# intra-turn behavior — players in late game with hundreds of units
# legitimately have hundreds of micro-actions per turn, so a per-turn
# cap would interfere with normal play.
DEFAULT_MAX_STEPS_PER_EPISODE = 250_000
DEFAULT_MAX_TURNS = 1000
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
_OVERLAYS_REL = "public/games/age-of-dwarves/data/ai/reward_overlays.json"
def _load_reward_overlays() -> dict[str, dict[str, float]]:
"""Per-clan reward-shaping overlays (clan -> {group -> multiplier}). Missing
file = no overlays (every clan trains on the neutral catalog)."""
path = os.environ.get("MC_REWARD_OVERLAYS") or os.path.join(_REPO_ROOT, _OVERLAYS_REL)
try:
with open(path, encoding="utf-8") as fh:
return json.load(fh).get("overlays", {})
except (OSError, json.JSONDecodeError):
return {}
class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
"""Single-learner Gym wrapper: our policy controls slot 0.
The opponent on slot 1..N-1 is one of:
* the harness's built-in MCTS (default — `opponent=None`), driven
internally by the simulator's `apply_end_turn` AI loop; or
* a frozen `ModelOpponent` (self-play curriculum), driven in-process
over the multi-slot wire. When a model opponent is supplied, both
slot 0 and the opponent slots are *externally* controlled, so the
harness's internal AI loop skips them (see dispatch.rs Stage 4)
and this env advances the opponent's turn after the learner's
`end_turn`.
"""
metadata = {"render_modes": []}
def __init__(
self,
harness_config: HarnessConfig | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
max_steps_per_episode: int = DEFAULT_MAX_STEPS_PER_EPISODE,
opponent: ModelOpponent | None = None,
clan_list: tuple[str, ...] = (),
) -> None:
super().__init__()
self._config = harness_config or HarnessConfig()
self._max_turns = max_turns
self._max_steps_per_episode = max_steps_per_episode
self._opponent = opponent
self._my_slot = self._config.player_slot
# When a model opponent drives the other slot(s), every wire call
# MUST name its slot (multi-slot adapter contract). With the
# default MCTS opponent we keep the legacy single-slot wire shape
# (slot omitted) so nothing about the shipping path changes.
self._multi_slot = opponent is not None
self._slot_kw = self._my_slot if self._multi_slot else None
self.observation_space = spaces.Box(
low=-1e6, high=1e6, shape=(OBS_DIM,), dtype=np.float32
)
self.action_space = spaces.Discrete(ACTION_DIM)
self._client: HarnessClient | None = None
self._last_view: dict[str, Any] = {}
self._last_score: float = 0.0
self._idx_to_action: dict[int, dict[str, Any]] = {}
self._cur_mask: np.ndarray = np.zeros(ACTION_DIM, dtype=bool)
self._terminated: bool = False
self._step_count: int = 0
# Maps PlayerId → city_id of that player's capital. Populated in
# reset() and updated when CityFounded for a player previously
# without a city. CityCaptured events look up old_owner here to
# decide whether to route to CAPITAL_* or CITY_* reward buckets.
self._capital_by_player: dict[int, str] = {}
# Throttle for CITY_FOUNDED_BY_ME — settler-spam protection.
self._city_founded_rewards_issued: int = 0
# Players still in the game (not yet eliminated). Initialised in
# reset() to every slot 0..players-1; `player_eliminated` events
# prune it. A learner win in a multi-player (3-5p) game requires the
# learner alive AND every opponent eliminated — not just *any*
# opponent elimination (the old duel-only 1v1 shortcut). The
# authoritative `game_over` event still takes priority when present.
self._live_players: set[int] = set()
# Clan-conditioned RL. Each episode the env samples a clan from
# `clan_list`, stamps it on the learner slot (CP_LEARNER_CLAN → the obs
# clan one-hot) and scales the SHAPING rewards by that clan's overlay
# (group -> multiplier). Terminal win/loss/decisive stay universal so
# every clan equally wants to win. Empty list = generalist (no clan,
# neutral catalog). Seeded RNG → reproducible clan sequence per run.
self._clan_list: tuple[str, ...] = tuple(clan_list)
self._overlays: dict[str, dict[str, float]] = _load_reward_overlays()
self._clan_rng = random.Random(self._config.seed)
self._cur_clan: str = ""
self._cur_overlay: dict[str, float] = {}
def _ov(self, group: str) -> float:
"""Reward-shaping multiplier for the current episode's clan (1.0 if
generalist / unknown group)."""
return self._cur_overlay.get(group, 1.0)
# ── Gymnasium API ────────────────────────────────────────────────
def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[np.ndarray, dict[str, Any]]:
if self._client is not None:
self._client.shutdown()
cfg = self._config
# When self-playing against a model opponent, both the learner's
# slot and the opponent's slot(s) are externally driven — declare
# them so the simulator's AI loop skips them.
if self._multi_slot and self._opponent is not None:
cfg = replace(cfg, player_slots=(self._my_slot, *self._opponent.slots))
# `replace` preserves every other field (player_slots, victory_mode,
# player_controllers, …) — the old field-by-field rebuild silently
# dropped them, which would have un-declared the external slots.
if seed is not None:
cfg = replace(cfg, seed=seed)
# Clan-conditioned RL: sample this episode's clan, stamp it on the
# learner slot (CP_LEARNER_CLAN), and select its reward overlay.
if self._clan_list:
self._cur_clan = self._clan_rng.choice(self._clan_list)
cfg = replace(cfg, learner_clan=self._cur_clan)
self._cur_overlay = self._overlays.get(self._cur_clan, {})
else:
self._cur_clan = ""
self._cur_overlay = {}
self._terminated = False
self._step_count = 0
self._capital_by_player = {}
self._city_founded_rewards_issued = 0
# Every configured slot starts alive. `cfg.players` is the total slot
# count (learner + opponents); eliminations prune this set.
self._live_players = set(range(int(cfg.players)))
# Bounded retry on the harness spawn + first view. Under heavy
# concurrent load (16+ Godot workers in heavy-tests.slice with
# CPUWeight=20, plus other jobs on the box), a freshly-spawned Godot
# can lose the boot race and EOF on the first wire request — which,
# un-retried, aborts a multi-hour training run from a single transient
# worker death (observed: gen0 died at the first eval, 9 min in). We
# fully reap the dead client and back off before respawning so a
# competing worker finishing actually frees resources between tries.
# A SYSTEMATIC failure (bad build, missing data) still surfaces: it
# exhausts the retries and re-raises, and MC_HARNESS_STDERR_DIR
# captures the Godot-side reason.
view = self._spawn_with_retry(cfg)
# Seed capitals from any cities present at game start. In duel
# maps each player begins with a founder, so the capital map is
# populated on the first CityFounded event per player (handled
# in _apply_event_rewards). If the simulator ever pre-places
# cities, this scan picks them up.
for city in view.get("cities", []):
owner = int(city.get("owner", -1))
cid = str(city.get("id", ""))
if owner >= 0 and cid and owner not in self._capital_by_player:
self._capital_by_player[owner] = cid
self._sync_state(view)
return encode_observation(view), {"action_mask": self._cur_mask.copy()}
def _spawn_with_retry(
self, cfg: HarnessConfig, attempts: int = 3
) -> dict[str, Any]:
"""Spawn the harness and fetch the first view, retrying a transient
boot-race EOF. Fully reaps a dead client before respawning, with a
backoff so a competing worker can free resources between tries.
Re-raises the last HarnessError after exhausting `attempts`."""
last_err: HarnessError | None = None
for attempt in range(attempts):
try:
self._client = HarnessClient(cfg)
return self._client.view(slot=self._slot_kw)
except HarnessError as e:
last_err = e
# Reap the half-dead client so we don't leak a scope and make
# contention worse, then back off (1s, 2s, …) before respawn.
if self._client is not None:
try:
self._client.shutdown()
except Exception:
pass
self._client = None
if attempt < attempts - 1:
print(
f"[MagicCivEnv] harness spawn attempt {attempt + 1}/"
f"{attempts} failed ({e}); reaped + retrying",
file=sys.stderr, flush=True,
)
time.sleep(1.0 * (attempt + 1))
assert last_err is not None
raise last_err
def step(
self, action: np.int64 | int
) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
if self._client is None:
raise RuntimeError("step() called before reset()")
if self._terminated:
raise RuntimeError("step() called on terminated env; call reset()")
idx = int(action)
if not self._cur_mask[idx]:
# Mask should prevent this, but be defensive: substitute end_turn.
idx = 0
player_action = decode_action_index(idx, self._idx_to_action)
self._step_count += 1
prev_turn = int(self._last_view.get("turn", 0))
reward = -_step_penalty(prev_turn)
opp_events: list[dict[str, Any]] = []
try:
if player_action.get("type") == "end_turn":
response = self._client.end_turn(slot=self._slot_kw)
# With a frozen model opponent, the simulator's AI loop
# skips the opponent slot (it is externally declared) — so
# we drive its full turn here. With the default MCTS
# opponent this is a no-op: the AI loop already ran inside
# the end_turn dispatch and its events are in `response`.
if self._opponent is not None:
opp_events = self._opponent.play_turn(self._client)
else:
response = self._client.act(player_action, slot=self._slot_kw)
except HarnessError:
# Treat any harness failure as a loss — bad action, dead
# subprocess, etc. Terminate the episode.
self._terminated = True
return (
np.zeros(OBS_DIM, dtype=np.float32),
HARNESS_ERROR_REWARD,
True,
False,
{"action_mask": np.zeros(ACTION_DIM, dtype=bool), "reason": "harness_error"},
)
view = self._client.view(slot=self._slot_kw)
# Collect synchronous events from the act response + the opponent's
# turn + any async notifications buffered while waiting for view's
# response. Terminal events (game_over / player_eliminated) may
# have fired during the opponent's turn between our act and view.
recent_events: list[dict[str, Any]] = list(response.get("events", []))
recent_events.extend(opp_events)
recent_events.extend(self._client.drain_notifications())
new_turn = int(view.get("turn", 0))
me = int(view.get("player", 0))
prev_score = self._last_score
new_score = float(view.get("score", {}).get("score_estimate", 0.0))
# Symmetric score-delta — gains and losses both count.
reward += SCORE_DELTA_SCALE * (new_score - prev_score) * self._ov("economy")
# Event-driven shaping (Phase 1 catalog).
reward += self._apply_event_rewards(recent_events, me)
terminated, terminal_reward, reason = self._check_termination(view, recent_events)
if terminated and reason == "won":
# Decisive bonus: linearly decays to 0 at DECISIVE_DECAY_TURNS.
decay = max(0.0, 1.0 - new_turn / DECISIVE_DECAY_TURNS)
terminal_reward += DECISIVE_BONUS_MAX * decay
reward += terminal_reward
self._sync_state(view)
self._terminated = terminated
step_capped = (
not terminated
and self._step_count >= self._max_steps_per_episode
)
turn_capped = (
not terminated
and int(view.get("turn", 0)) >= self._max_turns
)
truncated = step_capped or turn_capped
if truncated:
self._terminated = True
# Stalling without resolving the game is worse than losing
# decisively — apply a cap penalty so the policy learns to
# commit. Without this, "drag to the cap" was the equilibrium.
if turn_capped:
reward += TURN_CAP_PENALTY
else:
reward += STEP_CAP_PENALTY
info: dict[str, Any] = {
"action_mask": self._cur_mask.copy(),
"turn": int(view.get("turn", 0)),
"score": new_score,
"city_count": int(view.get("score", {}).get("city_count", 0)),
}
if self._opponent is not None:
# Diagnostic: how many wire events the frozen opponent's turn
# produced this step. Zero across a whole episode means the
# opponent never actually acted (e.g. stale binary not skipping
# the external slot) — the smoke asserts this is >0.
info["opp_events"] = len(opp_events)
if reason:
info["reason"] = reason
elif step_capped:
info["reason"] = "step_cap"
print(
f"[MagicCivEnv] step_cap hit at step={self._step_count} "
f"turn={int(view.get('turn', 0))} — truncating episode",
file=sys.stderr,
flush=True,
)
elif turn_capped:
info["reason"] = "turn_cap"
return encode_observation(view), reward, terminated, truncated, info
def close(self) -> None:
if self._client is not None:
self._client.shutdown()
self._client = None
# ── Internals ────────────────────────────────────────────────────
def _sync_state(self, view: dict[str, Any]) -> None:
self._last_view = view
self._last_score = float(view.get("score", {}).get("score_estimate", 0.0))
mask, idx_to_action = encode_legal_actions(view)
self._cur_mask = mask
self._idx_to_action = idx_to_action
def _check_termination(
self, view: dict[str, Any], recent_events: list[dict[str, Any]]
) -> tuple[bool, float, str | None]:
"""Decide whether the episode ended this step.
Termination signals, in priority order:
1. `game_over` event from the simulator: winner == us → win;
winner == opponent → loss.
2. `player_eliminated` event for us → loss.
3. Every opponent eliminated (the learner is the sole survivor) →
win. Tracked via `self._live_players` so this is correct for
3-5-player games, not just the 1v1 duel where *any* opponent
elimination was decisive.
4. Defensive fallback: cities==0 and no founder → loss (in
case the simulator's elimination wiring lags one step).
"""
me = int(view.get("player", 0))
for ev in recent_events:
kind = ev.get("type")
if kind == "game_over":
winner = int(ev.get("winner", -1))
if winner == me:
return True, WIN_BASE, "won"
return True, LOSS_REWARD, "eliminated"
if kind == "player_eliminated":
self._live_players.discard(int(ev.get("player", -1)))
if me not in self._live_players:
return True, LOSS_REWARD, "eliminated"
# Sole-survivor win: the learner is alive and is the ONLY player
# still live. In a duel this fires on the single opponent's
# elimination (identical to the old behaviour); in a 3-5p game it
# holds the win until the last opponent falls.
if self._live_players == {me}:
return True, WIN_BASE, "won"
# Defensive fallback for the case where the simulator drops the
# game_over event (observed in early integration tests).
score = view.get("score", {})
if int(score.get("city_count", 0)) == 0:
units = view.get("units", [])
has_founder = any(
int(u.get("owner", -1)) == me
and "founder" in str(u.get("type", ""))
and float(u.get("hp", 0)) > 0
for u in units
)
if not has_founder:
return True, LOSS_REWARD, "eliminated"
return False, 0.0, None
def action_masks(self) -> np.ndarray:
"""sb3-contrib MaskablePPO hook — returns the current mask."""
return self._cur_mask.copy()
def _apply_event_rewards(
self, events: list[dict[str, Any]], me: int
) -> float:
"""Phase 1 event-driven reward catalog.
Sources from already-emitted wire events
(`mc-player-api/src/wire.rs:135-301`). Terminal events
(`game_over`, `player_eliminated`) are handled in
`_check_termination`, not here.
"""
total = 0.0
for ev in events:
kind = ev.get("type")
if kind == "city_founded":
owner = int(ev.get("owner", -1))
cid = str(ev.get("city_id", ""))
# Track capitals: first city per player is their capital.
if owner >= 0 and cid and owner not in self._capital_by_player:
self._capital_by_player[owner] = cid
if owner == me:
if self._city_founded_rewards_issued < MAX_CITY_FOUNDED_REWARDS:
total += CITY_FOUNDED_BY_ME * self._ov("expansion")
self._city_founded_rewards_issued += 1
elif kind == "city_captured":
old_owner = int(ev.get("old_owner", -1))
new_owner = int(ev.get("new_owner", -1))
cid = str(ev.get("city_id", ""))
is_capital = (
old_owner >= 0
and self._capital_by_player.get(old_owner) == cid
)
if new_owner == me:
total += (
CAPITAL_CAPTURED_BY_ME if is_capital else CITY_CAPTURED_BY_ME
) * self._ov("combat")
elif old_owner == me:
total += CAPITAL_LOST_BY_ME if is_capital else CITY_LOST_BY_ME
# When a capital changes hands, the *capturer's* first
# city is still their own capital — don't reassign.
elif kind == "wonder_built":
if int(ev.get("player", -1)) == me:
total += WONDER_BUILT_BY_ME * self._ov("production")
elif kind == "combat_resolved":
# Attribution: the wire event carries unit ids, not owners.
# We synthesise from defender_killed/attacker_killed plus
# the unit_destroyed events that should accompany them.
# Skip here; let unit_destroyed do the bookkeeping to
# avoid double-counting.
pass
elif kind == "unit_destroyed":
# Need owner attribution. The PlayerView snapshot has the
# owner before destruction; we look up via the last view.
uid = str(ev.get("unit_id", ""))
owner = self._unit_owner_lookup(uid)
if owner == me:
total += OWN_UNIT_LOST_TO_ENEMY
elif owner >= 0:
# Enemy unit destroyed — we get kill credit *only* if
# we have a killer_unit_id we own, OR no killer info
# at all (treat as our kill — conservative since the
# asymmetric ±0.04/+0.05 is net-positive on even trades).
killer = ev.get("killer_unit_id")
if killer is None or self._unit_owner_lookup(str(killer)) == me:
total += ENEMY_UNIT_KILLED_BY_ME * self._ov("combat")
elif kind == "tech_researched":
if int(ev.get("player", -1)) == me:
total += TECH_RESEARCHED_BY_ME * self._ov("tech")
elif kind == "culture_researched":
if int(ev.get("player", -1)) == me:
total += CULTURE_RESEARCHED_BY_ME * self._ov("tech")
elif kind == "player_eliminated":
p = int(ev.get("player", -1))
if p != me and p >= 0:
total += OPPONENT_ELIMINATED * self._ov("combat")
return total
def _unit_owner_lookup(self, unit_id: str) -> int:
"""Resolve a unit_id → owner from the last synced PlayerView.
Returns -1 when the unit is no longer present (already destroyed
in the same step batch) — caller treats this as unknown owner.
"""
if not unit_id:
return -1
for u in self._last_view.get("units", []):
if str(u.get("id", "")) == unit_id:
return int(u.get("owner", -1))
return -1