477 lines
17 KiB
Python
477 lines
17 KiB
Python
"""Record scripted-AI expert games as (obs, action, mask) triples for BC.
|
|
|
|
Stage 6.1.6 — behavioural-cloning warm-start. PPO-from-scratch cannot
|
|
find wins through exploration on the sparse long-horizon reward
|
|
(`duel-v3-s7` and the `duel-v3-s7-t200` curriculum probe both stayed
|
|
flat-negative). The fix is to *give* the policy winning play to imitate
|
|
before PPO refines it.
|
|
|
|
This recorder drives BOTH player slots of a duel externally and uses the
|
|
read-only `suggest` wire request to harvest, at every decision point,
|
|
what the scripted controller would play — paired with the encoder's
|
|
fixed-shape observation and legal-action mask. The triples feed
|
|
`bc_pretrain.py`.
|
|
|
|
Usage:
|
|
python -m tooling.rl_self_play.record_expert \\
|
|
--games 1000 --out-dir tooling/rl_self_play/expert/duel
|
|
|
|
Output: sharded `.npz` files under `--out-dir`, each carrying
|
|
obs float32 [N, OBS_DIM]
|
|
actions int64 [N] -- encoder action index
|
|
masks bool [N, ACTION_DIM]
|
|
slots int8 [N] -- which player slot acted
|
|
games int32 [N] -- game id
|
|
outcomes int8 [N] -- 1 win / 0 loss / -1 undecided, for slots[i]
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
THIS_DIR = Path(__file__).resolve().parent
|
|
PROJECT_ROOT = THIS_DIR.parents[1]
|
|
if __package__ is None:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from tooling.rl_self_play.encoders import ( # type: ignore[import-not-found]
|
|
CITY_QUEUE_DIM,
|
|
CITY_QUEUE_ITEMS,
|
|
MAX_CITIES,
|
|
MAX_UNITS,
|
|
PER_UNIT_ACTIONS,
|
|
encode_legal_actions,
|
|
encode_observation,
|
|
)
|
|
from tooling.rl_self_play.harness_client import ( # type: ignore[import-not-found]
|
|
HarnessClient,
|
|
HarnessConfig,
|
|
HarnessError,
|
|
)
|
|
|
|
|
|
class _DropStats:
|
|
"""Tracks how much of the expert trajectory the fixed 322-index
|
|
action space captures, by category. A high drop rate means BC learns
|
|
from a fraction of expert play — a flag, not an acceptable cost."""
|
|
|
|
def __init__(self) -> None:
|
|
self.move_projected = 0 # move resolved onto a legal adjacent step
|
|
self.nonmove_matched = 0 # non-move exact-matched into the index space
|
|
self.move_unresolved = 0 # move for a unit with no legal move this step
|
|
self.nonmove_unmatched = 0 # non-move outside the 322-index space
|
|
|
|
@property
|
|
def suggested(self) -> int:
|
|
return (self.move_projected + self.nonmove_matched
|
|
+ self.move_unresolved + self.nonmove_unmatched)
|
|
|
|
@property
|
|
def recorded(self) -> int:
|
|
return self.move_projected + self.nonmove_matched
|
|
|
|
@property
|
|
def dropped(self) -> int:
|
|
return self.move_unresolved + self.nonmove_unmatched
|
|
|
|
@property
|
|
def drop_rate(self) -> float:
|
|
return self.dropped / self.suggested if self.suggested else 0.0
|
|
|
|
|
|
class _ShardWriter:
|
|
"""Buffers triples and flushes a compressed `.npz` every
|
|
`shard_games` games to bound memory over a 1k-game run."""
|
|
|
|
def __init__(self, out_dir: Path, shard_games: int) -> None:
|
|
self._out_dir = out_dir
|
|
self._shard_games = shard_games
|
|
self._shard_idx = 0
|
|
self._games_in_shard = 0
|
|
self._obs: list[np.ndarray] = []
|
|
self._actions: list[int] = []
|
|
self._masks: list[np.ndarray] = []
|
|
self._slots: list[int] = []
|
|
self._games: list[int] = []
|
|
self._outcomes: list[int] = []
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
def add(
|
|
self,
|
|
obs: np.ndarray,
|
|
action_idx: int,
|
|
mask: np.ndarray,
|
|
slot: int,
|
|
game_id: int,
|
|
) -> None:
|
|
self._obs.append(obs.astype(np.float32))
|
|
self._actions.append(int(action_idx))
|
|
self._masks.append(mask.astype(bool))
|
|
self._slots.append(int(slot))
|
|
self._games.append(int(game_id))
|
|
self._outcomes.append(-1) # backfilled by finish_game
|
|
|
|
def finish_game(self, game_id: int, winner: int | None) -> None:
|
|
"""Backfill the outcome label for every row of `game_id`."""
|
|
for i in range(len(self._games)):
|
|
if self._games[i] != game_id:
|
|
continue
|
|
if winner is None:
|
|
self._outcomes[i] = -1
|
|
else:
|
|
self._outcomes[i] = 1 if self._slots[i] == winner else 0
|
|
self._games_in_shard += 1
|
|
if self._games_in_shard >= self._shard_games:
|
|
self.flush()
|
|
|
|
def flush(self) -> None:
|
|
if not self._obs:
|
|
return
|
|
path = self._out_dir / f"expert_{self._shard_idx:04d}.npz"
|
|
np.savez_compressed(
|
|
path,
|
|
obs=np.stack(self._obs),
|
|
actions=np.asarray(self._actions, dtype=np.int64),
|
|
masks=np.stack(self._masks),
|
|
slots=np.asarray(self._slots, dtype=np.int8),
|
|
games=np.asarray(self._games, dtype=np.int32),
|
|
outcomes=np.asarray(self._outcomes, dtype=np.int8),
|
|
)
|
|
print(f" wrote {path.name} ({len(self._obs)} triples)", flush=True)
|
|
self._shard_idx += 1
|
|
self._games_in_shard = 0
|
|
self._obs.clear()
|
|
self._actions.clear()
|
|
self._masks.clear()
|
|
self._slots.clear()
|
|
self._games.clear()
|
|
self._outcomes.clear()
|
|
|
|
|
|
def _terminal_winner(events: list[dict[str, Any]]) -> tuple[bool, int | None]:
|
|
"""Scan wire events for a terminal signal. Returns
|
|
(game_ended, winner_slot_or_None)."""
|
|
eliminated: set[int] = set()
|
|
for ev in events:
|
|
kind = ev.get("type")
|
|
if kind == "game_over":
|
|
return True, int(ev.get("winner", -1))
|
|
if kind == "player_eliminated":
|
|
eliminated.add(int(ev.get("player", -1)))
|
|
if eliminated:
|
|
# Duel map: one elimination is decisive. Winner is the other slot.
|
|
# Resolved by the caller, which knows the slot set.
|
|
return True, None
|
|
return False, None
|
|
|
|
|
|
def _drain_events(client: HarnessClient, response: dict[str, Any]) -> list[dict[str, Any]]:
|
|
events: list[dict[str, Any]] = list(response.get("events", []))
|
|
events.extend(client.drain_notifications())
|
|
return events
|
|
|
|
|
|
# Sub-action offsets within a unit's PER_UNIT_ACTIONS-slot block, mirroring
|
|
# `encoders.encode_legal_actions`. Move/attack occupy contiguous ranges.
|
|
_UNIT_SUB: dict[str, int] = {
|
|
"skip": 0, "fortify": 1, "sentry": 2, "found_city": 3, "unfortify": 4,
|
|
}
|
|
_MOVE_SUB_RANGE = range(5, 11) # six directional move slots
|
|
_ATTACK_SUB_RANGE = range(11, 16) # five directional attack slots
|
|
|
|
|
|
def _encoder_units(view: dict[str, Any]) -> list[dict[str, Any]]:
|
|
"""My units in the exact order `encode_legal_actions` slots them."""
|
|
me = int(view.get("player", 0))
|
|
mine = [u for u in view.get("units", []) if int(u.get("owner", -1)) == me]
|
|
return mine[:MAX_UNITS]
|
|
|
|
|
|
def _unit_slot(view: dict[str, Any], unit_id: str) -> int | None:
|
|
for slot, u in enumerate(_encoder_units(view)):
|
|
if str(u.get("id")) == unit_id:
|
|
return slot
|
|
return None
|
|
|
|
|
|
def _city_slot(view: dict[str, Any], city_id: str) -> int | None:
|
|
for slot, c in enumerate(view.get("cities", [])[:MAX_CITIES]):
|
|
if str(c.get("id")) == city_id:
|
|
return slot
|
|
return None
|
|
|
|
|
|
def _project_directional(
|
|
action: dict[str, Any],
|
|
coord_key: str,
|
|
sub_range: range,
|
|
unit_base: int,
|
|
idx_to_action: dict[int, dict[str, Any]],
|
|
) -> int | None:
|
|
"""For a goal-oriented move/attack, pick the legal directional slot
|
|
whose destination is closest (squared Euclidean) to the suggested
|
|
target. The scripted controller emits multi-hex goal moves; the
|
|
player-API legal set is the six adjacent hexes — projecting onto the
|
|
legal step toward the goal is the right single-step BC label."""
|
|
goal = action.get(coord_key)
|
|
if not isinstance(goal, (list, tuple)) or len(goal) < 2:
|
|
return None
|
|
gx, gy = float(goal[0]), float(goal[1])
|
|
best_idx: int | None = None
|
|
best_d: float | None = None
|
|
for sub in sub_range:
|
|
offset = unit_base + sub
|
|
a = idx_to_action.get(offset)
|
|
if a is None:
|
|
continue
|
|
dest = a.get(coord_key)
|
|
if not isinstance(dest, (list, tuple)) or len(dest) < 2:
|
|
continue
|
|
d = (float(dest[0]) - gx) ** 2 + (float(dest[1]) - gy) ** 2
|
|
if best_d is None or d < best_d:
|
|
best_d, best_idx = d, offset
|
|
return best_idx
|
|
|
|
|
|
def _resolve(
|
|
action: dict[str, Any],
|
|
view: dict[str, Any],
|
|
idx_to_action: dict[int, dict[str, Any]],
|
|
drops: _DropStats,
|
|
) -> int | None:
|
|
"""Resolve a suggested expert action to its legal encoder index, or
|
|
None (counting the drop by category).
|
|
|
|
The `suggest` chain is flat — its actions carry explicit `unit_id` /
|
|
`city_id`. The encoder's index space is positional — identity is the
|
|
slot. So we map the explicit id to the encoder slot and compute the
|
|
index arithmetically (mirroring `encoders.encode_legal_actions`)
|
|
rather than matching JSON shapes, which differ between the two."""
|
|
t = action.get("type")
|
|
|
|
if t in ("end_turn", "noop"):
|
|
idx = 0 if t == "end_turn" else 1
|
|
if idx in idx_to_action:
|
|
drops.nonmove_matched += 1
|
|
return idx
|
|
drops.nonmove_unmatched += 1
|
|
return None
|
|
|
|
if t == "queue_production":
|
|
item = str(action.get("item", ""))
|
|
cslot = _city_slot(view, str(action.get("city_id")))
|
|
if item not in CITY_QUEUE_ITEMS or cslot is None:
|
|
drops.nonmove_unmatched += 1
|
|
return None
|
|
offset = (2 + MAX_UNITS * PER_UNIT_ACTIONS
|
|
+ cslot * CITY_QUEUE_DIM + CITY_QUEUE_ITEMS.index(item))
|
|
if offset in idx_to_action:
|
|
drops.nonmove_matched += 1
|
|
return offset
|
|
drops.nonmove_unmatched += 1
|
|
return None
|
|
|
|
# Unit-scoped: move / attack / skip / fortify / sentry / found_city /
|
|
# unfortify.
|
|
uslot = _unit_slot(view, str(action.get("unit_id", action.get("unit"))))
|
|
if uslot is None:
|
|
if t in ("move", "attack"):
|
|
drops.move_unresolved += 1
|
|
else:
|
|
drops.nonmove_unmatched += 1
|
|
return None
|
|
base = 2 + uslot * PER_UNIT_ACTIONS
|
|
|
|
if t in ("move", "attack"):
|
|
coord_key = "to" if t == "move" else "target"
|
|
sub_range = _MOVE_SUB_RANGE if t == "move" else _ATTACK_SUB_RANGE
|
|
idx = _project_directional(action, coord_key, sub_range, base, idx_to_action)
|
|
if idx is None:
|
|
drops.move_unresolved += 1
|
|
else:
|
|
drops.move_projected += 1
|
|
return idx
|
|
|
|
sub = _UNIT_SUB.get(t)
|
|
if sub is None:
|
|
drops.nonmove_unmatched += 1
|
|
return None
|
|
offset = base + sub
|
|
if offset in idx_to_action:
|
|
drops.nonmove_matched += 1
|
|
return offset
|
|
drops.nonmove_unmatched += 1
|
|
return None
|
|
|
|
|
|
def _record_slot_turn(
|
|
client: HarnessClient,
|
|
slot: int,
|
|
game_id: int,
|
|
writer: _ShardWriter,
|
|
drops: _DropStats,
|
|
) -> tuple[bool, int | None]:
|
|
"""Drive one slot's full turn: suggest the scripted chain, then for
|
|
each action re-`view`, resolve it to a legal encoder index, record
|
|
the `(obs, idx, mask)` triple and apply it. Only ever applies actions
|
|
drawn from the engine's own legal set, so `act` never rejects.
|
|
Returns (game_ended, winner)."""
|
|
chain = client.suggest(slot=slot)
|
|
|
|
for action in chain:
|
|
view = client.view(slot=slot)
|
|
obs = encode_observation(view)
|
|
mask, idx_to_action = encode_legal_actions(view)
|
|
idx = _resolve(action, view, idx_to_action, drops)
|
|
if idx is None:
|
|
continue # not representable in the 322-index space — skip
|
|
writer.add(obs, idx, mask, slot, game_id)
|
|
resolved = idx_to_action[idx]
|
|
try:
|
|
if resolved.get("type") == "end_turn":
|
|
response = client.end_turn(slot=slot)
|
|
else:
|
|
response = client.act(resolved, slot=slot)
|
|
except HarnessError as exc:
|
|
# `resolved` came straight from the engine's legal set, so a
|
|
# rejection here is a genuine protocol surprise, not stale
|
|
# expert intent — surface it and close the turn.
|
|
print(f" game {game_id} slot {slot}: legal action rejected "
|
|
f"({exc})", file=sys.stderr, flush=True)
|
|
break
|
|
ended, winner = _terminal_winner(_drain_events(client, response))
|
|
if ended:
|
|
return True, winner
|
|
|
|
# Scripted chains do not carry `end_turn` (Stage 6.1.6 diagnostic) —
|
|
# close the turn explicitly and record the boundary so BC learns it.
|
|
# The `duel-v3-s7` failure (30/30 episodes stuck at turn 12) was a
|
|
# policy that never learned to end its turn; the dataset must not
|
|
# share that blind spot.
|
|
view = client.view(slot=slot)
|
|
mask, idx_to_action = encode_legal_actions(view)
|
|
if mask[0]: # end_turn is index 0
|
|
writer.add(encode_observation(view), 0, mask, slot, game_id)
|
|
response = client.end_turn(slot=slot)
|
|
ended, winner = _terminal_winner(_drain_events(client, response))
|
|
return (True, winner) if ended else (False, None)
|
|
|
|
|
|
def _play_game(
|
|
game_id: int,
|
|
cfg: HarnessConfig,
|
|
max_turns: int,
|
|
writer: _ShardWriter,
|
|
drops: _DropStats,
|
|
) -> int | None:
|
|
"""Record one full scripted-vs-scripted duel. Returns the winner
|
|
slot, or None if the game hit the turn cap undecided."""
|
|
slots = cfg.effective_player_slots
|
|
client = HarnessClient(cfg)
|
|
winner: int | None = None
|
|
try:
|
|
for _turn in range(max_turns):
|
|
view = client.view(slot=slots[0])
|
|
if int(view.get("turn", 0)) >= max_turns:
|
|
break
|
|
game_ended = False
|
|
for slot in slots:
|
|
ended, w = _record_slot_turn(client, slot, game_id, writer, drops)
|
|
if ended:
|
|
game_ended = True
|
|
if w is not None:
|
|
winner = w
|
|
elif len(slots) == 2:
|
|
# Elimination with no explicit winner — the
|
|
# surviving slot took it.
|
|
winner = slots[1] if slot == slots[0] else slots[0]
|
|
break
|
|
if game_ended:
|
|
break
|
|
finally:
|
|
client.shutdown()
|
|
return winner
|
|
|
|
|
|
def _build_argparser() -> argparse.ArgumentParser:
|
|
p = argparse.ArgumentParser(description="Record scripted-AI expert games for BC")
|
|
p.add_argument("--games", type=int, default=1000,
|
|
help="Number of duels to record (default: 1000).")
|
|
p.add_argument("--out-dir", default=str(THIS_DIR / "expert" / "duel"),
|
|
help="Directory for the sharded .npz output.")
|
|
p.add_argument("--shard-games", type=int, default=100,
|
|
help="Games per .npz shard (default: 100).")
|
|
p.add_argument("--max-turns", type=int, default=600,
|
|
help="Per-game turn cap (default: 600).")
|
|
p.add_argument("--seed", type=int, default=20000,
|
|
help="Base seed; game N uses seed+N (default: 20000).")
|
|
p.add_argument("--map-size", default="duel",
|
|
help="MapGenerator size key (default: duel).")
|
|
return p
|
|
|
|
|
|
def main() -> int:
|
|
args = _build_argparser().parse_args()
|
|
out_dir = Path(args.out_dir)
|
|
writer = _ShardWriter(out_dir, args.shard_games)
|
|
drops = _DropStats()
|
|
wins = {0: 0, 1: 0}
|
|
undecided = 0
|
|
t0 = time.time()
|
|
|
|
print(f"recording {args.games} expert games → {out_dir}", flush=True)
|
|
for game_id in range(args.games):
|
|
cfg = HarnessConfig(
|
|
seed=args.seed + game_id,
|
|
players=2,
|
|
player_slots=(0, 1),
|
|
map_size=args.map_size,
|
|
map_type="continents",
|
|
victory_mode="domination",
|
|
)
|
|
winner = _play_game(game_id, cfg, args.max_turns, writer, drops)
|
|
writer.finish_game(game_id, winner)
|
|
if winner is None:
|
|
undecided += 1
|
|
else:
|
|
wins[winner] = wins.get(winner, 0) + 1
|
|
if (game_id + 1) % 10 == 0 or game_id == 0:
|
|
elapsed = time.time() - t0
|
|
rate = (game_id + 1) / elapsed
|
|
print(f" game {game_id + 1}/{args.games} "
|
|
f"({rate:.2f} games/s, drop_rate={drops.drop_rate:.1%})",
|
|
flush=True)
|
|
writer.flush()
|
|
|
|
summary = {
|
|
"games": args.games,
|
|
"triples": drops.recorded,
|
|
"suggested_actions": drops.suggested,
|
|
"dropped_actions": drops.dropped,
|
|
"drop_rate": round(drops.drop_rate, 4),
|
|
"move_projected": drops.move_projected,
|
|
"nonmove_matched": drops.nonmove_matched,
|
|
"move_unresolved": drops.move_unresolved,
|
|
"nonmove_unmatched": drops.nonmove_unmatched,
|
|
"wins_slot0": wins.get(0, 0),
|
|
"wins_slot1": wins.get(1, 0),
|
|
"undecided": undecided,
|
|
"elapsed_sec": round(time.time() - t0, 1),
|
|
}
|
|
print("RECORD_SUMMARY " + json.dumps(summary), flush=True)
|
|
if drops.drop_rate > 0.30:
|
|
print(f"WARNING: drop_rate {drops.drop_rate:.1%} exceeds 30% — the "
|
|
f"322-index action space is missing a large slice of expert "
|
|
f"play; BC will train on a biased subset.", file=sys.stderr,
|
|
flush=True)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|