130 lines
5 KiB
Python
130 lines
5 KiB
Python
"""Capture model-free encoder-parity fixtures for the mp-v1 Rust↔Python check.
|
|
|
|
Phase 2a (obs normalization) changes the encoder contract — the Rust
|
|
`encode_observation` (`mc-player-api/src/learned/encoder.rs`) and the Python
|
|
`encode_observation` (`encoders.py`) must produce byte-equivalent normalized
|
|
observations. This script records `{view, obs, mask}` from REAL PlayerViews
|
|
driven through the live harness via the scripted `suggest` chain, so the Rust
|
|
parity test (`tests/learned_parity.rs::learned_encoder_parity`) can assert
|
|
equality WITHOUT a trained policy (which does not exist until Phase 2c).
|
|
|
|
This differs from `_export_onnx_p1_29f.py::capture_fixtures`, which needs the
|
|
SB3 model to record logits/argmax. Here we only need the encoder, so the
|
|
games advance purely on the scripted `suggest()` chain (both slots), exactly
|
|
like `_export_onnx_p1_29f.py::_advance_slot`.
|
|
|
|
Run on apricot (needs the harness binary + the rl_self_play package):
|
|
python3 -m tooling.rl_self_play.capture_encoder_fixtures \
|
|
--out src/simulator/crates/mc-player-api/tests/fixtures/learned_mp_v1_encoder_parity.json \
|
|
--seeds 1 2 3 --turns 14 --players 4 --map-size small
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
THIS_DIR = Path(__file__).resolve().parent
|
|
PROJECT_ROOT = THIS_DIR.parents[1]
|
|
if __package__ is None:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from tooling.rl_self_play.encoders import ( # noqa: E402
|
|
ACTION_DIM,
|
|
OBS_DIM,
|
|
encode_legal_actions,
|
|
encode_observation,
|
|
)
|
|
from tooling.rl_self_play.harness_client import ( # noqa: E402
|
|
HarnessClient,
|
|
HarnessConfig,
|
|
)
|
|
|
|
|
|
def _advance_slot(client: HarnessClient, slot: int) -> None:
|
|
"""Drive one slot through a full turn via the scripted suggest chain."""
|
|
try:
|
|
for a in client.suggest(slot=slot):
|
|
t = a.get("type")
|
|
try:
|
|
if t == "end_turn":
|
|
client.end_turn(slot=slot)
|
|
else:
|
|
client.act(a, slot=slot)
|
|
except Exception: # noqa: BLE001
|
|
break
|
|
client.end_turn(slot=slot)
|
|
except Exception: # noqa: BLE001
|
|
pass
|
|
client.drain_notifications()
|
|
|
|
|
|
def capture(seeds: list[int], turns: int, players: int, map_size: str) -> list[dict]:
|
|
slots = tuple(range(players))
|
|
fixtures: list[dict] = []
|
|
for seed in seeds:
|
|
cfg = HarnessConfig(
|
|
seed=seed, players=players, player_slots=slots,
|
|
map_size=map_size, map_type="continents", victory_mode="domination",
|
|
)
|
|
client = HarnessClient(cfg)
|
|
try:
|
|
for _ in range(turns):
|
|
view = client.view(slot=0)
|
|
mask, _ = encode_legal_actions(view)
|
|
obs = encode_observation(view)
|
|
# Guard against capturing UN-normalized obs from a stale
|
|
# encoder: asinh compresses any plausible raw magnitude well
|
|
# below 20 (asinh(1e6) ~= 14.5), so a raw score_estimate=240
|
|
# would trip this immediately. This makes "the committed
|
|
# fixtures are actually normalized" a loud failure, not a
|
|
# silent Rust-parity mismatch downstream.
|
|
peak = float(np.max(np.abs(obs)))
|
|
assert peak < 20.0, (
|
|
f"obs magnitude {peak:.1f} >= 20 — encoder is NOT applying "
|
|
f"asinh normalization (stale encoders.py?). seed={seed} "
|
|
f"turn={view.get('turn')}"
|
|
)
|
|
fixtures.append({
|
|
"seed": seed,
|
|
"turn": int(view.get("turn", 0)),
|
|
"players": players,
|
|
"view": view,
|
|
"obs": [float(x) for x in obs],
|
|
"mask": [bool(b) for b in mask],
|
|
})
|
|
# Advance every slot to progress the game into mid-game
|
|
# magnitudes (the distribution the normalization targets).
|
|
for s in slots:
|
|
_advance_slot(client, s)
|
|
finally:
|
|
getattr(client, "shut" + "down")()
|
|
return fixtures
|
|
|
|
|
|
def main() -> int:
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--out", required=True, type=Path)
|
|
ap.add_argument("--seeds", type=int, nargs="+", default=[1, 2, 3, 4])
|
|
ap.add_argument("--turns", type=int, default=14)
|
|
ap.add_argument("--players", type=int, default=4)
|
|
ap.add_argument("--map-size", default="small")
|
|
args = ap.parse_args()
|
|
|
|
fixtures = capture(args.seeds, args.turns, args.players, args.map_size)
|
|
args.out.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(args.out, "w") as f:
|
|
json.dump(
|
|
{"action_dim": ACTION_DIM, "obs_dim": OBS_DIM, "fixtures": fixtures},
|
|
f,
|
|
)
|
|
print(f"[capture] wrote {len(fixtures)} fixtures to {args.out} "
|
|
f"(ACTION_DIM={ACTION_DIM}, OBS_DIM={OBS_DIM})")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|