magicciv/tooling/rl_self_play/record_expert.py

477 lines
17 KiB
Python

"""Record scripted-AI expert games as (obs, action, mask) triples for BC.
Stage 6.1.6 — behavioural-cloning warm-start. PPO-from-scratch cannot
find wins through exploration on the sparse long-horizon reward
(`duel-v3-s7` and the `duel-v3-s7-t200` curriculum probe both stayed
flat-negative). The fix is to *give* the policy winning play to imitate
before PPO refines it.
This recorder drives BOTH player slots of a duel externally and uses the
read-only `suggest` wire request to harvest, at every decision point,
what the scripted controller would play — paired with the encoder's
fixed-shape observation and legal-action mask. The triples feed
`bc_pretrain.py`.
Usage:
python -m tooling.rl_self_play.record_expert \\
--games 1000 --out-dir tooling/rl_self_play/expert/duel
Output: sharded `.npz` files under `--out-dir`, each carrying
obs float32 [N, OBS_DIM]
actions int64 [N] -- encoder action index
masks bool [N, ACTION_DIM]
slots int8 [N] -- which player slot acted
games int32 [N] -- game id
outcomes int8 [N] -- 1 win / 0 loss / -1 undecided, for slots[i]
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
from typing import Any
import numpy as np
THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parents[1]
if __package__ is None:
sys.path.insert(0, str(PROJECT_ROOT))
from tooling.rl_self_play.encoders import ( # type: ignore[import-not-found]
CITY_QUEUE_DIM,
CITY_QUEUE_ITEMS,
MAX_CITIES,
MAX_UNITS,
PER_UNIT_ACTIONS,
encode_legal_actions,
encode_observation,
)
from tooling.rl_self_play.harness_client import ( # type: ignore[import-not-found]
HarnessClient,
HarnessConfig,
HarnessError,
)
class _DropStats:
"""Tracks how much of the expert trajectory the fixed 322-index
action space captures, by category. A high drop rate means BC learns
from a fraction of expert play — a flag, not an acceptable cost."""
def __init__(self) -> None:
self.move_projected = 0 # move resolved onto a legal adjacent step
self.nonmove_matched = 0 # non-move exact-matched into the index space
self.move_unresolved = 0 # move for a unit with no legal move this step
self.nonmove_unmatched = 0 # non-move outside the 322-index space
@property
def suggested(self) -> int:
return (self.move_projected + self.nonmove_matched
+ self.move_unresolved + self.nonmove_unmatched)
@property
def recorded(self) -> int:
return self.move_projected + self.nonmove_matched
@property
def dropped(self) -> int:
return self.move_unresolved + self.nonmove_unmatched
@property
def drop_rate(self) -> float:
return self.dropped / self.suggested if self.suggested else 0.0
class _ShardWriter:
"""Buffers triples and flushes a compressed `.npz` every
`shard_games` games to bound memory over a 1k-game run."""
def __init__(self, out_dir: Path, shard_games: int) -> None:
self._out_dir = out_dir
self._shard_games = shard_games
self._shard_idx = 0
self._games_in_shard = 0
self._obs: list[np.ndarray] = []
self._actions: list[int] = []
self._masks: list[np.ndarray] = []
self._slots: list[int] = []
self._games: list[int] = []
self._outcomes: list[int] = []
out_dir.mkdir(parents=True, exist_ok=True)
def add(
self,
obs: np.ndarray,
action_idx: int,
mask: np.ndarray,
slot: int,
game_id: int,
) -> None:
self._obs.append(obs.astype(np.float32))
self._actions.append(int(action_idx))
self._masks.append(mask.astype(bool))
self._slots.append(int(slot))
self._games.append(int(game_id))
self._outcomes.append(-1) # backfilled by finish_game
def finish_game(self, game_id: int, winner: int | None) -> None:
"""Backfill the outcome label for every row of `game_id`."""
for i in range(len(self._games)):
if self._games[i] != game_id:
continue
if winner is None:
self._outcomes[i] = -1
else:
self._outcomes[i] = 1 if self._slots[i] == winner else 0
self._games_in_shard += 1
if self._games_in_shard >= self._shard_games:
self.flush()
def flush(self) -> None:
if not self._obs:
return
path = self._out_dir / f"expert_{self._shard_idx:04d}.npz"
np.savez_compressed(
path,
obs=np.stack(self._obs),
actions=np.asarray(self._actions, dtype=np.int64),
masks=np.stack(self._masks),
slots=np.asarray(self._slots, dtype=np.int8),
games=np.asarray(self._games, dtype=np.int32),
outcomes=np.asarray(self._outcomes, dtype=np.int8),
)
print(f" wrote {path.name} ({len(self._obs)} triples)", flush=True)
self._shard_idx += 1
self._games_in_shard = 0
self._obs.clear()
self._actions.clear()
self._masks.clear()
self._slots.clear()
self._games.clear()
self._outcomes.clear()
def _terminal_winner(events: list[dict[str, Any]]) -> tuple[bool, int | None]:
"""Scan wire events for a terminal signal. Returns
(game_ended, winner_slot_or_None)."""
eliminated: set[int] = set()
for ev in events:
kind = ev.get("type")
if kind == "game_over":
return True, int(ev.get("winner", -1))
if kind == "player_eliminated":
eliminated.add(int(ev.get("player", -1)))
if eliminated:
# Duel map: one elimination is decisive. Winner is the other slot.
# Resolved by the caller, which knows the slot set.
return True, None
return False, None
def _drain_events(client: HarnessClient, response: dict[str, Any]) -> list[dict[str, Any]]:
events: list[dict[str, Any]] = list(response.get("events", []))
events.extend(client.drain_notifications())
return events
# Sub-action offsets within a unit's PER_UNIT_ACTIONS-slot block, mirroring
# `encoders.encode_legal_actions`. Move/attack occupy contiguous ranges.
_UNIT_SUB: dict[str, int] = {
"skip": 0, "fortify": 1, "sentry": 2, "found_city": 3, "unfortify": 4,
}
_MOVE_SUB_RANGE = range(5, 11) # six directional move slots
_ATTACK_SUB_RANGE = range(11, 16) # five directional attack slots
def _encoder_units(view: dict[str, Any]) -> list[dict[str, Any]]:
"""My units in the exact order `encode_legal_actions` slots them."""
me = int(view.get("player", 0))
mine = [u for u in view.get("units", []) if int(u.get("owner", -1)) == me]
return mine[:MAX_UNITS]
def _unit_slot(view: dict[str, Any], unit_id: str) -> int | None:
for slot, u in enumerate(_encoder_units(view)):
if str(u.get("id")) == unit_id:
return slot
return None
def _city_slot(view: dict[str, Any], city_id: str) -> int | None:
for slot, c in enumerate(view.get("cities", [])[:MAX_CITIES]):
if str(c.get("id")) == city_id:
return slot
return None
def _project_directional(
action: dict[str, Any],
coord_key: str,
sub_range: range,
unit_base: int,
idx_to_action: dict[int, dict[str, Any]],
) -> int | None:
"""For a goal-oriented move/attack, pick the legal directional slot
whose destination is closest (squared Euclidean) to the suggested
target. The scripted controller emits multi-hex goal moves; the
player-API legal set is the six adjacent hexes — projecting onto the
legal step toward the goal is the right single-step BC label."""
goal = action.get(coord_key)
if not isinstance(goal, (list, tuple)) or len(goal) < 2:
return None
gx, gy = float(goal[0]), float(goal[1])
best_idx: int | None = None
best_d: float | None = None
for sub in sub_range:
offset = unit_base + sub
a = idx_to_action.get(offset)
if a is None:
continue
dest = a.get(coord_key)
if not isinstance(dest, (list, tuple)) or len(dest) < 2:
continue
d = (float(dest[0]) - gx) ** 2 + (float(dest[1]) - gy) ** 2
if best_d is None or d < best_d:
best_d, best_idx = d, offset
return best_idx
def _resolve(
action: dict[str, Any],
view: dict[str, Any],
idx_to_action: dict[int, dict[str, Any]],
drops: _DropStats,
) -> int | None:
"""Resolve a suggested expert action to its legal encoder index, or
None (counting the drop by category).
The `suggest` chain is flat — its actions carry explicit `unit_id` /
`city_id`. The encoder's index space is positional — identity is the
slot. So we map the explicit id to the encoder slot and compute the
index arithmetically (mirroring `encoders.encode_legal_actions`)
rather than matching JSON shapes, which differ between the two."""
t = action.get("type")
if t in ("end_turn", "noop"):
idx = 0 if t == "end_turn" else 1
if idx in idx_to_action:
drops.nonmove_matched += 1
return idx
drops.nonmove_unmatched += 1
return None
if t == "queue_production":
item = str(action.get("item", ""))
cslot = _city_slot(view, str(action.get("city_id")))
if item not in CITY_QUEUE_ITEMS or cslot is None:
drops.nonmove_unmatched += 1
return None
offset = (2 + MAX_UNITS * PER_UNIT_ACTIONS
+ cslot * CITY_QUEUE_DIM + CITY_QUEUE_ITEMS.index(item))
if offset in idx_to_action:
drops.nonmove_matched += 1
return offset
drops.nonmove_unmatched += 1
return None
# Unit-scoped: move / attack / skip / fortify / sentry / found_city /
# unfortify.
uslot = _unit_slot(view, str(action.get("unit_id", action.get("unit"))))
if uslot is None:
if t in ("move", "attack"):
drops.move_unresolved += 1
else:
drops.nonmove_unmatched += 1
return None
base = 2 + uslot * PER_UNIT_ACTIONS
if t in ("move", "attack"):
coord_key = "to" if t == "move" else "target"
sub_range = _MOVE_SUB_RANGE if t == "move" else _ATTACK_SUB_RANGE
idx = _project_directional(action, coord_key, sub_range, base, idx_to_action)
if idx is None:
drops.move_unresolved += 1
else:
drops.move_projected += 1
return idx
sub = _UNIT_SUB.get(t)
if sub is None:
drops.nonmove_unmatched += 1
return None
offset = base + sub
if offset in idx_to_action:
drops.nonmove_matched += 1
return offset
drops.nonmove_unmatched += 1
return None
def _record_slot_turn(
client: HarnessClient,
slot: int,
game_id: int,
writer: _ShardWriter,
drops: _DropStats,
) -> tuple[bool, int | None]:
"""Drive one slot's full turn: suggest the scripted chain, then for
each action re-`view`, resolve it to a legal encoder index, record
the `(obs, idx, mask)` triple and apply it. Only ever applies actions
drawn from the engine's own legal set, so `act` never rejects.
Returns (game_ended, winner)."""
chain = client.suggest(slot=slot)
for action in chain:
view = client.view(slot=slot)
obs = encode_observation(view)
mask, idx_to_action = encode_legal_actions(view)
idx = _resolve(action, view, idx_to_action, drops)
if idx is None:
continue # not representable in the 322-index space — skip
writer.add(obs, idx, mask, slot, game_id)
resolved = idx_to_action[idx]
try:
if resolved.get("type") == "end_turn":
response = client.end_turn(slot=slot)
else:
response = client.act(resolved, slot=slot)
except HarnessError as exc:
# `resolved` came straight from the engine's legal set, so a
# rejection here is a genuine protocol surprise, not stale
# expert intent — surface it and close the turn.
print(f" game {game_id} slot {slot}: legal action rejected "
f"({exc})", file=sys.stderr, flush=True)
break
ended, winner = _terminal_winner(_drain_events(client, response))
if ended:
return True, winner
# Scripted chains do not carry `end_turn` (Stage 6.1.6 diagnostic) —
# close the turn explicitly and record the boundary so BC learns it.
# The `duel-v3-s7` failure (30/30 episodes stuck at turn 12) was a
# policy that never learned to end its turn; the dataset must not
# share that blind spot.
view = client.view(slot=slot)
mask, idx_to_action = encode_legal_actions(view)
if mask[0]: # end_turn is index 0
writer.add(encode_observation(view), 0, mask, slot, game_id)
response = client.end_turn(slot=slot)
ended, winner = _terminal_winner(_drain_events(client, response))
return (True, winner) if ended else (False, None)
def _play_game(
game_id: int,
cfg: HarnessConfig,
max_turns: int,
writer: _ShardWriter,
drops: _DropStats,
) -> int | None:
"""Record one full scripted-vs-scripted duel. Returns the winner
slot, or None if the game hit the turn cap undecided."""
slots = cfg.effective_player_slots
client = HarnessClient(cfg)
winner: int | None = None
try:
for _turn in range(max_turns):
view = client.view(slot=slots[0])
if int(view.get("turn", 0)) >= max_turns:
break
game_ended = False
for slot in slots:
ended, w = _record_slot_turn(client, slot, game_id, writer, drops)
if ended:
game_ended = True
if w is not None:
winner = w
elif len(slots) == 2:
# Elimination with no explicit winner — the
# surviving slot took it.
winner = slots[1] if slot == slots[0] else slots[0]
break
if game_ended:
break
finally:
client.shutdown()
return winner
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Record scripted-AI expert games for BC")
p.add_argument("--games", type=int, default=1000,
help="Number of duels to record (default: 1000).")
p.add_argument("--out-dir", default=str(THIS_DIR / "expert" / "duel"),
help="Directory for the sharded .npz output.")
p.add_argument("--shard-games", type=int, default=100,
help="Games per .npz shard (default: 100).")
p.add_argument("--max-turns", type=int, default=600,
help="Per-game turn cap (default: 600).")
p.add_argument("--seed", type=int, default=20000,
help="Base seed; game N uses seed+N (default: 20000).")
p.add_argument("--map-size", default="duel",
help="MapGenerator size key (default: duel).")
return p
def main() -> int:
args = _build_argparser().parse_args()
out_dir = Path(args.out_dir)
writer = _ShardWriter(out_dir, args.shard_games)
drops = _DropStats()
wins = {0: 0, 1: 0}
undecided = 0
t0 = time.time()
print(f"recording {args.games} expert games → {out_dir}", flush=True)
for game_id in range(args.games):
cfg = HarnessConfig(
seed=args.seed + game_id,
players=2,
player_slots=(0, 1),
map_size=args.map_size,
map_type="continents",
victory_mode="domination",
)
winner = _play_game(game_id, cfg, args.max_turns, writer, drops)
writer.finish_game(game_id, winner)
if winner is None:
undecided += 1
else:
wins[winner] = wins.get(winner, 0) + 1
if (game_id + 1) % 10 == 0 or game_id == 0:
elapsed = time.time() - t0
rate = (game_id + 1) / elapsed
print(f" game {game_id + 1}/{args.games} "
f"({rate:.2f} games/s, drop_rate={drops.drop_rate:.1%})",
flush=True)
writer.flush()
summary = {
"games": args.games,
"triples": drops.recorded,
"suggested_actions": drops.suggested,
"dropped_actions": drops.dropped,
"drop_rate": round(drops.drop_rate, 4),
"move_projected": drops.move_projected,
"nonmove_matched": drops.nonmove_matched,
"move_unresolved": drops.move_unresolved,
"nonmove_unmatched": drops.nonmove_unmatched,
"wins_slot0": wins.get(0, 0),
"wins_slot1": wins.get(1, 0),
"undecided": undecided,
"elapsed_sec": round(time.time() - t0, 1),
}
print("RECORD_SUMMARY " + json.dumps(summary), flush=True)
if drops.drop_rate > 0.30:
print(f"WARNING: drop_rate {drops.drop_rate:.1%} exceeds 30% — the "
f"322-index action space is missing a large slice of expert "
f"play; BC will train on a biased subset.", file=sys.stderr,
flush=True)
return 0
if __name__ == "__main__":
sys.exit(main())