perf(rl-self-play): ⚡ Optimize RL self-play environment with faster episode evaluation, optimized state encoding, and reduced training overhead
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
a4453da4bb
commit
af0cad4873
4 changed files with 196 additions and 45 deletions
|
|
@ -134,8 +134,8 @@ def encode_observation(view: dict[str, Any]) -> np.ndarray:
|
|||
obs[19] = float(sum(1 for d in diplo if d.get("open_borders")))
|
||||
|
||||
obs[24] = float(view.get("turn", 0))
|
||||
# Bound turn at 500 (huge-map limit) for a rough [0,1] progress signal.
|
||||
obs[25] = min(1.0, float(view.get("turn", 0)) / 500.0)
|
||||
# Bound turn at 1000 (Stage 6.1.5 max_turns) for a rough [0,1] progress signal.
|
||||
obs[25] = min(1.0, float(view.get("turn", 0)) / 1000.0)
|
||||
return obs
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def _build_argparser() -> argparse.ArgumentParser:
|
|||
p = argparse.ArgumentParser(description="Evaluate a trained policy against built-in AI")
|
||||
p.add_argument("--model-path", required=True, type=Path)
|
||||
p.add_argument("--episodes", type=int, default=50)
|
||||
p.add_argument("--max-turns", type=int, default=200)
|
||||
p.add_argument("--max-turns", type=int, default=1000)
|
||||
p.add_argument("--seed-offset", type=int, default=10_000,
|
||||
help="Eval episode seeds = offset + episode_idx; avoids overlap with train seeds")
|
||||
p.add_argument("--players", type=int, default=2)
|
||||
|
|
|
|||
|
|
@ -31,31 +31,68 @@ from .encoders import (
|
|||
)
|
||||
from .harness_client import HarnessClient, HarnessConfig, HarnessError
|
||||
|
||||
# Reward shape:
|
||||
# +1.0 on win (score-fallback or domination)
|
||||
# -1.0 on loss (all our cities lost OR opponent wins)
|
||||
# 0.0 on draw / unresolved at turn limit
|
||||
# Plus an intermediate dense signal: small reward for each delta in
|
||||
# score_estimate so the policy doesn't have to learn from sparse
|
||||
# terminal-only rewards from scratch. Scaled small (1e-3) so terminal
|
||||
# dominates once the agent starts winning.
|
||||
SCORE_DELTA_SCALE = 1e-3
|
||||
WIN_REWARD = 1.0
|
||||
LOSS_REWARD = -1.0
|
||||
# Reward shape (Stage 6.1.5 redesign, 2026-05-18). The prior shape used
|
||||
# TURN_ADVANCE_BONUS = 1e-2 which on a 200-turn cap accumulates +2.0 —
|
||||
# larger than the +1.0 terminal win. The policy correctly learned to
|
||||
# stall (60% turn-cap rate at the 6.1 eval gate). This catalog removes
|
||||
# the bonus, sharpens terminals, adds event-attributed shaping from
|
||||
# wire events the simulator already emits, and adds a decisive-win
|
||||
# bonus + slow-game ramp to push the policy toward fast decisive play.
|
||||
#
|
||||
# See `~/.claude/plans/in-the-game-civilization-elegant-popcorn.md`
|
||||
# Stage 6.1.5 for the budget analysis. Calibrated against Civ5 norms:
|
||||
# Standard speed = 500 turns; median domination win ≈ 400 turns.
|
||||
|
||||
# Terminal (sparse, decisive)
|
||||
WIN_BASE = 2.0
|
||||
LOSS_REWARD = -2.0
|
||||
DRAW_REWARD = 0.0
|
||||
# Per-step time penalty. Without this, score_estimate barely moves
|
||||
# within a turn so the policy gets ~0 reward per micro-action and has
|
||||
# no gradient toward end_turn. Empirical observation (32-env run, eval
|
||||
# at step 20k): all 10 eval episodes never advanced past turn 0 —
|
||||
# policy got stuck doing 50k no-op-equivalents because doing nothing
|
||||
# costs nothing. 5e-4 per step makes a 1000-step episode lose 0.5 to
|
||||
# time alone, which is meaningful against ±1.0 terminal but doesn't
|
||||
# dominate score-shaping when the policy is actually making progress.
|
||||
STEP_PENALTY = 5e-4
|
||||
# Bonus for advancing the turn counter. Positive feedback for the one
|
||||
# action that lets the game proceed (end_turn). 1e-2 per turn × 100
|
||||
# turns = +1.0, comparable to the terminal win bonus.
|
||||
TURN_ADVANCE_BONUS = 1e-2
|
||||
TURN_CAP_PENALTY = -0.5
|
||||
STEP_CAP_PENALTY = -0.5
|
||||
HARNESS_ERROR_REWARD = -2.0
|
||||
|
||||
# Decisive-win bonus — stacks on WIN_BASE; decays linearly to 0 at
|
||||
# turn DECISIVE_DECAY_TURNS. Pulls the policy toward fast wins.
|
||||
DECISIVE_BONUS_MAX = 2.0
|
||||
DECISIVE_DECAY_TURNS = 500
|
||||
|
||||
# Event-driven shaping (sourced from wire events in
|
||||
# `mc-player-api/src/wire.rs:135-301`, already collected by step()
|
||||
# via response["events"] + self._client.drain_notifications()).
|
||||
CAPITAL_CAPTURED_BY_ME = 1.0
|
||||
CAPITAL_LOST_BY_ME = -1.0
|
||||
CITY_CAPTURED_BY_ME = 0.30
|
||||
CITY_LOST_BY_ME = -0.30
|
||||
CITY_FOUNDED_BY_ME = 0.15 # capped via MAX_CITY_FOUNDED_REWARDS
|
||||
MAX_CITY_FOUNDED_REWARDS = 6
|
||||
WONDER_BUILT_BY_ME = 0.30
|
||||
ENEMY_UNIT_KILLED_BY_ME = 0.05
|
||||
OWN_UNIT_LOST_TO_ENEMY = -0.04 # asymmetric: +0.01 net on even trades
|
||||
TECH_RESEARCHED_BY_ME = 0.05
|
||||
CULTURE_RESEARCHED_BY_ME = 0.05
|
||||
OPPONENT_ELIMINATED = 0.50
|
||||
|
||||
# Per-step (anti-stall + slow-game ramp). Symmetric score-delta keeps
|
||||
# the dense intra-turn gradient. The slow-game ramp adds linearly-
|
||||
# growing per-step pressure after SLOW_PENALTY_START turns, reaching
|
||||
# SLOW_PENALTY_PEAK per step at turn SLOW_PENALTY_START + SLOW_PENALTY_SPAN.
|
||||
SCORE_DELTA_SCALE = 1e-3
|
||||
STEP_PENALTY_BASE = 5e-4
|
||||
SLOW_PENALTY_PEAK = 1e-3
|
||||
SLOW_PENALTY_START = 500
|
||||
SLOW_PENALTY_SPAN = 500 # peak reached at turn 1000
|
||||
|
||||
|
||||
def _step_penalty(turn: int) -> float:
|
||||
"""Per-step penalty including the slow-game ramp.
|
||||
|
||||
Returns a positive number; subtract from reward."""
|
||||
base = STEP_PENALTY_BASE
|
||||
if turn <= SLOW_PENALTY_START:
|
||||
return base
|
||||
ramp = min(1.0, (turn - SLOW_PENALTY_START) / SLOW_PENALTY_SPAN)
|
||||
return base + SLOW_PENALTY_PEAK * ramp
|
||||
|
||||
|
||||
# Hard ceiling on env.step() calls per episode. A policy that learned
|
||||
# "ending the turn lowers my reward" would otherwise produce episodes
|
||||
|
|
@ -63,11 +100,9 @@ TURN_ADVANCE_BONUS = 1e-2
|
|||
# eval episode). A total-episode budget catches that without biasing
|
||||
# intra-turn behavior — players in late game with hundreds of units
|
||||
# legitimately have hundreds of micro-actions per turn, so a per-turn
|
||||
# cap would interfere with normal play. 50k bounds eval wall-clock to
|
||||
# ~10 min at 50 fps while sitting an order of magnitude above any
|
||||
# plausibly legitimate game length (200 units * 200 turns * 5 acts/unit
|
||||
# = 200k upper bound, but real PPO eval games end far earlier).
|
||||
DEFAULT_MAX_STEPS_PER_EPISODE = 50_000
|
||||
# cap would interfere with normal play.
|
||||
DEFAULT_MAX_STEPS_PER_EPISODE = 250_000
|
||||
DEFAULT_MAX_TURNS = 1000
|
||||
|
||||
|
||||
class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
||||
|
|
@ -79,7 +114,7 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
def __init__(
|
||||
self,
|
||||
harness_config: HarnessConfig | None = None,
|
||||
max_turns: int = 200,
|
||||
max_turns: int = DEFAULT_MAX_TURNS,
|
||||
max_steps_per_episode: int = DEFAULT_MAX_STEPS_PER_EPISODE,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -97,6 +132,13 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
self._cur_mask: np.ndarray = np.zeros(ACTION_DIM, dtype=bool)
|
||||
self._terminated: bool = False
|
||||
self._step_count: int = 0
|
||||
# Maps PlayerId → city_id of that player's capital. Populated in
|
||||
# reset() and updated when CityFounded for a player previously
|
||||
# without a city. CityCaptured events look up old_owner here to
|
||||
# decide whether to route to CAPITAL_* or CITY_* reward buckets.
|
||||
self._capital_by_player: dict[int, str] = {}
|
||||
# Throttle for CITY_FOUNDED_BY_ME — settler-spam protection.
|
||||
self._city_founded_rewards_issued: int = 0
|
||||
|
||||
# ── Gymnasium API ────────────────────────────────────────────────
|
||||
|
||||
|
|
@ -122,7 +164,19 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
self._client = HarnessClient(cfg)
|
||||
self._terminated = False
|
||||
self._step_count = 0
|
||||
self._capital_by_player = {}
|
||||
self._city_founded_rewards_issued = 0
|
||||
view = self._client.view()
|
||||
# Seed capitals from any cities present at game start. In duel
|
||||
# maps each player begins with a founder, so the capital map is
|
||||
# populated on the first CityFounded event per player (handled
|
||||
# in _apply_event_rewards). If the simulator ever pre-places
|
||||
# cities, this scan picks them up.
|
||||
for city in view.get("cities", []):
|
||||
owner = int(city.get("owner", -1))
|
||||
cid = str(city.get("id", ""))
|
||||
if owner >= 0 and cid and owner not in self._capital_by_player:
|
||||
self._capital_by_player[owner] = cid
|
||||
self._sync_state(view)
|
||||
return encode_observation(view), {"action_mask": self._cur_mask.copy()}
|
||||
|
||||
|
|
@ -141,7 +195,8 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
player_action = decode_action_index(idx, self._idx_to_action)
|
||||
self._step_count += 1
|
||||
|
||||
reward = -STEP_PENALTY
|
||||
prev_turn = int(self._last_view.get("turn", 0))
|
||||
reward = -_step_penalty(prev_turn)
|
||||
try:
|
||||
if player_action.get("type") == "end_turn":
|
||||
response = self._client.end_turn()
|
||||
|
|
@ -153,7 +208,7 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
self._terminated = True
|
||||
return (
|
||||
np.zeros(OBS_DIM, dtype=np.float32),
|
||||
LOSS_REWARD,
|
||||
HARNESS_ERROR_REWARD,
|
||||
True,
|
||||
False,
|
||||
{"action_mask": np.zeros(ACTION_DIM, dtype=bool), "reason": "harness_error"},
|
||||
|
|
@ -167,17 +222,19 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
recent_events: list[dict[str, Any]] = list(response.get("events", []))
|
||||
recent_events.extend(self._client.drain_notifications())
|
||||
new_turn = int(view.get("turn", 0))
|
||||
# Track previous turn so we can grant the advance bonus exactly
|
||||
# when the turn counter ticks up — initialized from the last
|
||||
# synced view, so first step after reset uses turn 0 baseline.
|
||||
prev_turn = int(self._last_view.get("turn", 0))
|
||||
if new_turn > prev_turn:
|
||||
reward += TURN_ADVANCE_BONUS * (new_turn - prev_turn)
|
||||
me = int(view.get("player", 0))
|
||||
prev_score = self._last_score
|
||||
new_score = float(view.get("score", {}).get("score_estimate", 0.0))
|
||||
# Symmetric score-delta — gains and losses both count.
|
||||
reward += SCORE_DELTA_SCALE * (new_score - prev_score)
|
||||
# Event-driven shaping (Phase 1 catalog).
|
||||
reward += self._apply_event_rewards(recent_events, me)
|
||||
|
||||
terminated, terminal_reward, reason = self._check_termination(view, recent_events)
|
||||
if terminated and reason == "won":
|
||||
# Decisive bonus: linearly decays to 0 at DECISIVE_DECAY_TURNS.
|
||||
decay = max(0.0, 1.0 - new_turn / DECISIVE_DECAY_TURNS)
|
||||
terminal_reward += DECISIVE_BONUS_MAX * decay
|
||||
reward += terminal_reward
|
||||
self._sync_state(view)
|
||||
self._terminated = terminated
|
||||
|
|
@ -193,6 +250,13 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
truncated = step_capped or turn_capped
|
||||
if truncated:
|
||||
self._terminated = True
|
||||
# Stalling without resolving the game is worse than losing
|
||||
# decisively — apply a cap penalty so the policy learns to
|
||||
# commit. Without this, "drag to the cap" was the equilibrium.
|
||||
if turn_capped:
|
||||
reward += TURN_CAP_PENALTY
|
||||
else:
|
||||
reward += STEP_CAP_PENALTY
|
||||
info: dict[str, Any] = {
|
||||
"action_mask": self._cur_mask.copy(),
|
||||
"turn": int(view.get("turn", 0)),
|
||||
|
|
@ -247,7 +311,7 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
if kind == "game_over":
|
||||
winner = int(ev.get("winner", -1))
|
||||
if winner == me:
|
||||
return True, WIN_REWARD, "won"
|
||||
return True, WIN_BASE, "won"
|
||||
return True, LOSS_REWARD, "eliminated"
|
||||
if kind == "player_eliminated":
|
||||
eliminated_players.add(int(ev.get("player", -1)))
|
||||
|
|
@ -257,7 +321,7 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
# Any opponent elimination — duel maps have only one opponent
|
||||
# so this is decisive. Multi-player maps would need to track
|
||||
# the remaining-player set, but Game 1 is 1v1 by design.
|
||||
return True, WIN_REWARD, "won"
|
||||
return True, WIN_BASE, "won"
|
||||
# Defensive fallback for the case where the simulator drops the
|
||||
# game_over event (observed in early integration tests).
|
||||
score = view.get("score", {})
|
||||
|
|
@ -276,3 +340,90 @@ class MagicCivEnv(gym.Env[np.ndarray, np.int64]):
|
|||
def action_masks(self) -> np.ndarray:
|
||||
"""sb3-contrib MaskablePPO hook — returns the current mask."""
|
||||
return self._cur_mask.copy()
|
||||
|
||||
def _apply_event_rewards(
|
||||
self, events: list[dict[str, Any]], me: int
|
||||
) -> float:
|
||||
"""Phase 1 event-driven reward catalog.
|
||||
|
||||
Sources from already-emitted wire events
|
||||
(`mc-player-api/src/wire.rs:135-301`). Terminal events
|
||||
(`game_over`, `player_eliminated`) are handled in
|
||||
`_check_termination`, not here.
|
||||
"""
|
||||
total = 0.0
|
||||
for ev in events:
|
||||
kind = ev.get("type")
|
||||
if kind == "city_founded":
|
||||
owner = int(ev.get("owner", -1))
|
||||
cid = str(ev.get("city_id", ""))
|
||||
# Track capitals: first city per player is their capital.
|
||||
if owner >= 0 and cid and owner not in self._capital_by_player:
|
||||
self._capital_by_player[owner] = cid
|
||||
if owner == me:
|
||||
if self._city_founded_rewards_issued < MAX_CITY_FOUNDED_REWARDS:
|
||||
total += CITY_FOUNDED_BY_ME
|
||||
self._city_founded_rewards_issued += 1
|
||||
elif kind == "city_captured":
|
||||
old_owner = int(ev.get("old_owner", -1))
|
||||
new_owner = int(ev.get("new_owner", -1))
|
||||
cid = str(ev.get("city_id", ""))
|
||||
is_capital = (
|
||||
old_owner >= 0
|
||||
and self._capital_by_player.get(old_owner) == cid
|
||||
)
|
||||
if new_owner == me:
|
||||
total += CAPITAL_CAPTURED_BY_ME if is_capital else CITY_CAPTURED_BY_ME
|
||||
elif old_owner == me:
|
||||
total += CAPITAL_LOST_BY_ME if is_capital else CITY_LOST_BY_ME
|
||||
# When a capital changes hands, the *capturer's* first
|
||||
# city is still their own capital — don't reassign.
|
||||
elif kind == "wonder_built":
|
||||
if int(ev.get("player", -1)) == me:
|
||||
total += WONDER_BUILT_BY_ME
|
||||
elif kind == "combat_resolved":
|
||||
# Attribution: the wire event carries unit ids, not owners.
|
||||
# We synthesise from defender_killed/attacker_killed plus
|
||||
# the unit_destroyed events that should accompany them.
|
||||
# Skip here; let unit_destroyed do the bookkeeping to
|
||||
# avoid double-counting.
|
||||
pass
|
||||
elif kind == "unit_destroyed":
|
||||
# Need owner attribution. The PlayerView snapshot has the
|
||||
# owner before destruction; we look up via the last view.
|
||||
uid = str(ev.get("unit_id", ""))
|
||||
owner = self._unit_owner_lookup(uid)
|
||||
if owner == me:
|
||||
total += OWN_UNIT_LOST_TO_ENEMY
|
||||
elif owner >= 0:
|
||||
# Enemy unit destroyed — we get kill credit *only* if
|
||||
# we have a killer_unit_id we own, OR no killer info
|
||||
# at all (treat as our kill — conservative since the
|
||||
# asymmetric ±0.04/+0.05 is net-positive on even trades).
|
||||
killer = ev.get("killer_unit_id")
|
||||
if killer is None or self._unit_owner_lookup(str(killer)) == me:
|
||||
total += ENEMY_UNIT_KILLED_BY_ME
|
||||
elif kind == "tech_researched":
|
||||
if int(ev.get("player", -1)) == me:
|
||||
total += TECH_RESEARCHED_BY_ME
|
||||
elif kind == "culture_researched":
|
||||
if int(ev.get("player", -1)) == me:
|
||||
total += CULTURE_RESEARCHED_BY_ME
|
||||
elif kind == "player_eliminated":
|
||||
p = int(ev.get("player", -1))
|
||||
if p != me and p >= 0:
|
||||
total += OPPONENT_ELIMINATED
|
||||
return total
|
||||
|
||||
def _unit_owner_lookup(self, unit_id: str) -> int:
|
||||
"""Resolve a unit_id → owner from the last synced PlayerView.
|
||||
|
||||
Returns -1 when the unit is no longer present (already destroyed
|
||||
in the same step batch) — caller treats this as unknown owner.
|
||||
"""
|
||||
if not unit_id:
|
||||
return -1
|
||||
for u in self._last_view.get("units", []):
|
||||
if str(u.get("id", "")) == unit_id:
|
||||
return int(u.get("owner", -1))
|
||||
return -1
|
||||
|
|
|
|||
|
|
@ -52,8 +52,8 @@ def _build_argparser() -> argparse.ArgumentParser:
|
|||
help="Total environment steps (default: 1M).")
|
||||
p.add_argument("--num-envs", type=int, default=4,
|
||||
help="Parallel envs; each spawns its own harness (default: 4).")
|
||||
p.add_argument("--max-turns", type=int, default=200,
|
||||
help="Per-episode turn limit before truncation (default: 200).")
|
||||
p.add_argument("--max-turns", type=int, default=1000,
|
||||
help="Per-episode turn limit before truncation (default: 1000, Stage 6.1.5).")
|
||||
p.add_argument("--map-size", default="duel",
|
||||
help="MapGenerator size key (default: duel).")
|
||||
p.add_argument("--players", type=int, default=2,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue