diff --git a/tooling/rl_self_play/_export_onnx_p1_29f.py b/tooling/rl_self_play/_export_onnx_p1_29f.py new file mode 100644 index 00000000..97a3edd7 --- /dev/null +++ b/tooling/rl_self_play/_export_onnx_p1_29f.py @@ -0,0 +1,215 @@ +"""Export the trained MaskablePPO policy net (obs -> action logits) to ONNX, +and capture parity fixtures for the in-engine (Rust) reimplementation. + +p1-29f bullet 1: produce a runtime-loadable policy artifact the engine can +evaluate without a Python runtime. We export ONLY the action-logit head +(features_extractor -> policy_net -> action_net); the value head is unused at +inference. Masking + argmax/softmax are applied on the Rust side, exactly as +MaskablePPO.predict(deterministic=True) does on the Python side. + +Outputs (paths are CLI args; build script places them): + - ONNX graph: input "obs" [1,32] f32 -> output "logits" [1,322] f32 + - JSON: list of {view, obs, mask, logits, masked_logits, + argmax, topk} captured from REAL PlayerViews driven + through the live harness, for Rust parity tests. + +Run on apricot (RL env + system torch/sb3/onnx present): + python3 export_onnx.py --model --onnx-out \ + --fixtures-out [--fixture-seeds 1 2 3] [--turns 25] +""" +from __future__ import annotations + +import argparse +import json +import sys +import warnings +from pathlib import Path + +warnings.filterwarnings("ignore") + +import numpy as np +import torch +import torch.nn as nn +from sb3_contrib import MaskablePPO + +THIS_DIR = Path(__file__).resolve().parent +# encoders / harness_client live in the rl_self_play package on apricot. +PROJECT_ROOT = Path("/var/home/lilith/Code/@projects/@magic-civilization") +sys.path.insert(0, str(PROJECT_ROOT)) + +from tooling.rl_self_play.encoders import ( # noqa: E402 + ACTION_DIM, + OBS_DIM, + encode_legal_actions, + encode_observation, +) +from tooling.rl_self_play.harness_client import ( # noqa: E402 + HarnessClient, + HarnessConfig, +) + + +class LogitHead(nn.Module): + """obs -> action logits, mirroring MaskableActorCriticPolicy's actor path. + + SB3 forward (actor): features = features_extractor(obs); + latent_pi = mlp_extractor.forward_actor(features); + logits = action_net(latent_pi). + """ + + def __init__(self, policy) -> None: + super().__init__() + self.features_extractor = policy.features_extractor + self.mlp_extractor = policy.mlp_extractor + self.action_net = policy.action_net + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + features = self.features_extractor(obs) + latent_pi = self.mlp_extractor.forward_actor(features) + return self.action_net(latent_pi) + + +def export(model_path: str, onnx_out: str) -> LogitHead: + model = MaskablePPO.load(model_path, device="cpu") + policy = model.policy.to("cpu").eval() + head = LogitHead(policy).eval() + + # Parity self-check. The raw `action_net` logits we export differ from + # `policy.get_distribution().logits` by a UNIFORM per-row offset (SB3's + # MaskableCategorical normalises its stored logits). A uniform offset is + # invariant under both argmax and softmax, so we verify argmax + softmax + # equivalence rather than raw-logit equality. + def _softmax_t(x: torch.Tensor) -> torch.Tensor: + z = x - x.max(dim=-1, keepdim=True).values + e = torch.exp(z) + return e / e.sum(dim=-1, keepdim=True) + + torch.manual_seed(0) + with torch.no_grad(): + max_err = 0.0 + for _ in range(8): + probe = torch.randn((1, OBS_DIM), dtype=torch.float32) * 5.0 + ref_logits = policy.get_distribution(probe).distribution.logits.detach() + head_logits = head(probe) + assert int(head_logits.argmax()) == int(ref_logits.argmax()), ( + "LogitHead argmax diverges from policy" + ) + sm_err = float((_softmax_t(head_logits) - _softmax_t(ref_logits)).abs().max()) + max_err = max(max_err, sm_err) + assert max_err < 1e-5, f"LogitHead softmax diverges from policy: {max_err}" + + dummy = torch.zeros((1, OBS_DIM), dtype=torch.float32) + Path(onnx_out).parent.mkdir(parents=True, exist_ok=True) + torch.onnx.export( + head, + dummy, + onnx_out, + input_names=["obs"], + output_names=["logits"], + dynamic_axes=None, # fixed batch=1 — engine evaluates one obs at a time + opset_version=13, + do_constant_folding=True, + dynamo=False, # legacy TorchScript exporter — avoids onnxscript dep (torch 2.9) + ) + print(f"[export] wrote {onnx_out} (head self-check max_err={max_err:.2e})") + return model, head + + +def _masked_logits(logits: np.ndarray, mask: np.ndarray) -> np.ndarray: + out = logits.astype(np.float64).copy() + out[~mask] = -1e30 # MaskablePPO uses -inf; -1e30 is exact under argmax/softmax + return out + + +def _softmax(x: np.ndarray) -> np.ndarray: + z = x - np.max(x) + e = np.exp(z) + return e / np.sum(e) + + +def capture_fixtures(model, head: LogitHead, seeds, turns: int) -> list: + fixtures = [] + for seed in seeds: + cfg = HarnessConfig( + seed=seed, players=2, player_slots=(0, 1), + map_size="duel", map_type="continents", victory_mode="domination", + ) + client = HarnessClient(cfg) + try: + for _ in range(turns): + view = client.view(slot=0) + mask, idx_to_action = encode_legal_actions(view) + if not mask.any(): + break + obs = encode_observation(view) + with torch.no_grad(): + logits = head(torch.from_numpy(obs).unsqueeze(0)).numpy()[0] + masked = _masked_logits(logits, mask) + argmax = int(np.argmax(masked)) + probs = _softmax(masked) + topk = [int(i) for i in np.argsort(-masked)[:5] if mask[int(i)]] + # SB3 reference predict (deterministic) — must equal our argmax. + ref_action, _ = model.predict(obs, action_masks=mask, deterministic=True) + fixtures.append({ + "seed": seed, + "turn": int(view.get("turn", 0)), + "view": view, + "obs": [float(x) for x in obs], + "mask": [bool(b) for b in mask], + "logits": [float(x) for x in logits], + "masked_logits": [float(x) for x in masked], + "probs": [float(x) for x in probs], + "argmax": argmax, + "topk": topk, + "ref_predict": int(ref_action), + }) + assert argmax == int(ref_action), ( + f"seed {seed} turn {view.get('turn')}: our argmax {argmax} " + f"!= SB3 predict {int(ref_action)}" + ) + # advance both slots via scripted suggest to progress the game + for s in (0, 1): + _advance_slot(client, s) + finally: + client.shutdown() + print(f"[fixtures] captured {len(fixtures)} fixtures (ACTION_DIM={ACTION_DIM})") + return fixtures + + +def _advance_slot(client: HarnessClient, slot: int) -> None: + try: + for a in client.suggest(slot=slot): + v = client.view(slot=slot) + t = a.get("type") + try: + if t == "end_turn": + client.end_turn(slot=slot) + else: + client.act(a, slot=slot) + except Exception: # noqa: BLE001 + break + client.end_turn(slot=slot) + except Exception: # noqa: BLE001 + pass + client.drain_notifications() + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--model", required=True) + ap.add_argument("--onnx-out", required=True) + ap.add_argument("--fixtures-out", required=True) + ap.add_argument("--fixture-seeds", type=int, nargs="+", default=[1, 2, 3]) + ap.add_argument("--turns", type=int, default=25) + args = ap.parse_args() + + model, head = export(args.model, args.onnx_out) + fixtures = capture_fixtures(model, head, args.fixture_seeds, args.turns) + Path(args.fixtures_out).parent.mkdir(parents=True, exist_ok=True) + with open(args.fixtures_out, "w") as f: + json.dump({"action_dim": ACTION_DIM, "obs_dim": OBS_DIM, "fixtures": fixtures}, f) + print(f"[fixtures] wrote {args.fixtures_out}") + + +if __name__ == "__main__": + main()