refactor(rl-self-play): ♻️ Optimize ONNX export script for RL self-play model (p1_29f) to improve compatibility and performance

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
autocommit 2026-06-02 22:59:04 -07:00
parent 7c33434676
commit 55935afbd2

View file

@ -0,0 +1,215 @@
"""Export the trained MaskablePPO policy net (obs -> action logits) to ONNX,
and capture parity fixtures for the in-engine (Rust) reimplementation.
p1-29f bullet 1: produce a runtime-loadable policy artifact the engine can
evaluate without a Python runtime. We export ONLY the action-logit head
(features_extractor -> policy_net -> action_net); the value head is unused at
inference. Masking + argmax/softmax are applied on the Rust side, exactly as
MaskablePPO.predict(deterministic=True) does on the Python side.
Outputs (paths are CLI args; build script places them):
- <onnx_out> ONNX graph: input "obs" [1,32] f32 -> output "logits" [1,322] f32
- <fixtures_out> JSON: list of {view, obs, mask, logits, masked_logits,
argmax, topk} captured from REAL PlayerViews driven
through the live harness, for Rust parity tests.
Run on apricot (RL env + system torch/sb3/onnx present):
python3 export_onnx.py --model <best_model.zip> --onnx-out <f.onnx> \
--fixtures-out <f.json> [--fixture-seeds 1 2 3] [--turns 25]
"""
from __future__ import annotations
import argparse
import json
import sys
import warnings
from pathlib import Path
warnings.filterwarnings("ignore")
import numpy as np
import torch
import torch.nn as nn
from sb3_contrib import MaskablePPO
THIS_DIR = Path(__file__).resolve().parent
# encoders / harness_client live in the rl_self_play package on apricot.
PROJECT_ROOT = Path("/var/home/lilith/Code/@projects/@magic-civilization")
sys.path.insert(0, str(PROJECT_ROOT))
from tooling.rl_self_play.encoders import ( # noqa: E402
ACTION_DIM,
OBS_DIM,
encode_legal_actions,
encode_observation,
)
from tooling.rl_self_play.harness_client import ( # noqa: E402
HarnessClient,
HarnessConfig,
)
class LogitHead(nn.Module):
"""obs -> action logits, mirroring MaskableActorCriticPolicy's actor path.
SB3 forward (actor): features = features_extractor(obs);
latent_pi = mlp_extractor.forward_actor(features);
logits = action_net(latent_pi).
"""
def __init__(self, policy) -> None:
super().__init__()
self.features_extractor = policy.features_extractor
self.mlp_extractor = policy.mlp_extractor
self.action_net = policy.action_net
def forward(self, obs: torch.Tensor) -> torch.Tensor:
features = self.features_extractor(obs)
latent_pi = self.mlp_extractor.forward_actor(features)
return self.action_net(latent_pi)
def export(model_path: str, onnx_out: str) -> LogitHead:
model = MaskablePPO.load(model_path, device="cpu")
policy = model.policy.to("cpu").eval()
head = LogitHead(policy).eval()
# Parity self-check. The raw `action_net` logits we export differ from
# `policy.get_distribution().logits` by a UNIFORM per-row offset (SB3's
# MaskableCategorical normalises its stored logits). A uniform offset is
# invariant under both argmax and softmax, so we verify argmax + softmax
# equivalence rather than raw-logit equality.
def _softmax_t(x: torch.Tensor) -> torch.Tensor:
z = x - x.max(dim=-1, keepdim=True).values
e = torch.exp(z)
return e / e.sum(dim=-1, keepdim=True)
torch.manual_seed(0)
with torch.no_grad():
max_err = 0.0
for _ in range(8):
probe = torch.randn((1, OBS_DIM), dtype=torch.float32) * 5.0
ref_logits = policy.get_distribution(probe).distribution.logits.detach()
head_logits = head(probe)
assert int(head_logits.argmax()) == int(ref_logits.argmax()), (
"LogitHead argmax diverges from policy"
)
sm_err = float((_softmax_t(head_logits) - _softmax_t(ref_logits)).abs().max())
max_err = max(max_err, sm_err)
assert max_err < 1e-5, f"LogitHead softmax diverges from policy: {max_err}"
dummy = torch.zeros((1, OBS_DIM), dtype=torch.float32)
Path(onnx_out).parent.mkdir(parents=True, exist_ok=True)
torch.onnx.export(
head,
dummy,
onnx_out,
input_names=["obs"],
output_names=["logits"],
dynamic_axes=None, # fixed batch=1 — engine evaluates one obs at a time
opset_version=13,
do_constant_folding=True,
dynamo=False, # legacy TorchScript exporter — avoids onnxscript dep (torch 2.9)
)
print(f"[export] wrote {onnx_out} (head self-check max_err={max_err:.2e})")
return model, head
def _masked_logits(logits: np.ndarray, mask: np.ndarray) -> np.ndarray:
out = logits.astype(np.float64).copy()
out[~mask] = -1e30 # MaskablePPO uses -inf; -1e30 is exact under argmax/softmax
return out
def _softmax(x: np.ndarray) -> np.ndarray:
z = x - np.max(x)
e = np.exp(z)
return e / np.sum(e)
def capture_fixtures(model, head: LogitHead, seeds, turns: int) -> list:
fixtures = []
for seed in seeds:
cfg = HarnessConfig(
seed=seed, players=2, player_slots=(0, 1),
map_size="duel", map_type="continents", victory_mode="domination",
)
client = HarnessClient(cfg)
try:
for _ in range(turns):
view = client.view(slot=0)
mask, idx_to_action = encode_legal_actions(view)
if not mask.any():
break
obs = encode_observation(view)
with torch.no_grad():
logits = head(torch.from_numpy(obs).unsqueeze(0)).numpy()[0]
masked = _masked_logits(logits, mask)
argmax = int(np.argmax(masked))
probs = _softmax(masked)
topk = [int(i) for i in np.argsort(-masked)[:5] if mask[int(i)]]
# SB3 reference predict (deterministic) — must equal our argmax.
ref_action, _ = model.predict(obs, action_masks=mask, deterministic=True)
fixtures.append({
"seed": seed,
"turn": int(view.get("turn", 0)),
"view": view,
"obs": [float(x) for x in obs],
"mask": [bool(b) for b in mask],
"logits": [float(x) for x in logits],
"masked_logits": [float(x) for x in masked],
"probs": [float(x) for x in probs],
"argmax": argmax,
"topk": topk,
"ref_predict": int(ref_action),
})
assert argmax == int(ref_action), (
f"seed {seed} turn {view.get('turn')}: our argmax {argmax} "
f"!= SB3 predict {int(ref_action)}"
)
# advance both slots via scripted suggest to progress the game
for s in (0, 1):
_advance_slot(client, s)
finally:
client.shutdown()
print(f"[fixtures] captured {len(fixtures)} fixtures (ACTION_DIM={ACTION_DIM})")
return fixtures
def _advance_slot(client: HarnessClient, slot: int) -> None:
try:
for a in client.suggest(slot=slot):
v = client.view(slot=slot)
t = a.get("type")
try:
if t == "end_turn":
client.end_turn(slot=slot)
else:
client.act(a, slot=slot)
except Exception: # noqa: BLE001
break
client.end_turn(slot=slot)
except Exception: # noqa: BLE001
pass
client.drain_notifications()
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--model", required=True)
ap.add_argument("--onnx-out", required=True)
ap.add_argument("--fixtures-out", required=True)
ap.add_argument("--fixture-seeds", type=int, nargs="+", default=[1, 2, 3])
ap.add_argument("--turns", type=int, default=25)
args = ap.parse_args()
model, head = export(args.model, args.onnx_out)
fixtures = capture_fixtures(model, head, args.fixture_seeds, args.turns)
Path(args.fixtures_out).parent.mkdir(parents=True, exist_ok=True)
with open(args.fixtures_out, "w") as f:
json.dump({"action_dim": ACTION_DIM, "obs_dim": OBS_DIM, "fixtures": fixtures}, f)
print(f"[fixtures] wrote {args.fixtures_out}")
if __name__ == "__main__":
main()