refactor(rl-self-play): ♻️ Optimize ONNX export script for RL self-play model (p1_29f) to improve compatibility and performance
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
7c33434676
commit
55935afbd2
1 changed files with 215 additions and 0 deletions
215
tooling/rl_self_play/_export_onnx_p1_29f.py
Normal file
215
tooling/rl_self_play/_export_onnx_p1_29f.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
"""Export the trained MaskablePPO policy net (obs -> action logits) to ONNX,
|
||||
and capture parity fixtures for the in-engine (Rust) reimplementation.
|
||||
|
||||
p1-29f bullet 1: produce a runtime-loadable policy artifact the engine can
|
||||
evaluate without a Python runtime. We export ONLY the action-logit head
|
||||
(features_extractor -> policy_net -> action_net); the value head is unused at
|
||||
inference. Masking + argmax/softmax are applied on the Rust side, exactly as
|
||||
MaskablePPO.predict(deterministic=True) does on the Python side.
|
||||
|
||||
Outputs (paths are CLI args; build script places them):
|
||||
- <onnx_out> ONNX graph: input "obs" [1,32] f32 -> output "logits" [1,322] f32
|
||||
- <fixtures_out> JSON: list of {view, obs, mask, logits, masked_logits,
|
||||
argmax, topk} captured from REAL PlayerViews driven
|
||||
through the live harness, for Rust parity tests.
|
||||
|
||||
Run on apricot (RL env + system torch/sb3/onnx present):
|
||||
python3 export_onnx.py --model <best_model.zip> --onnx-out <f.onnx> \
|
||||
--fixtures-out <f.json> [--fixture-seeds 1 2 3] [--turns 25]
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from sb3_contrib import MaskablePPO
|
||||
|
||||
THIS_DIR = Path(__file__).resolve().parent
|
||||
# encoders / harness_client live in the rl_self_play package on apricot.
|
||||
PROJECT_ROOT = Path("/var/home/lilith/Code/@projects/@magic-civilization")
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from tooling.rl_self_play.encoders import ( # noqa: E402
|
||||
ACTION_DIM,
|
||||
OBS_DIM,
|
||||
encode_legal_actions,
|
||||
encode_observation,
|
||||
)
|
||||
from tooling.rl_self_play.harness_client import ( # noqa: E402
|
||||
HarnessClient,
|
||||
HarnessConfig,
|
||||
)
|
||||
|
||||
|
||||
class LogitHead(nn.Module):
|
||||
"""obs -> action logits, mirroring MaskableActorCriticPolicy's actor path.
|
||||
|
||||
SB3 forward (actor): features = features_extractor(obs);
|
||||
latent_pi = mlp_extractor.forward_actor(features);
|
||||
logits = action_net(latent_pi).
|
||||
"""
|
||||
|
||||
def __init__(self, policy) -> None:
|
||||
super().__init__()
|
||||
self.features_extractor = policy.features_extractor
|
||||
self.mlp_extractor = policy.mlp_extractor
|
||||
self.action_net = policy.action_net
|
||||
|
||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||
features = self.features_extractor(obs)
|
||||
latent_pi = self.mlp_extractor.forward_actor(features)
|
||||
return self.action_net(latent_pi)
|
||||
|
||||
|
||||
def export(model_path: str, onnx_out: str) -> LogitHead:
|
||||
model = MaskablePPO.load(model_path, device="cpu")
|
||||
policy = model.policy.to("cpu").eval()
|
||||
head = LogitHead(policy).eval()
|
||||
|
||||
# Parity self-check. The raw `action_net` logits we export differ from
|
||||
# `policy.get_distribution().logits` by a UNIFORM per-row offset (SB3's
|
||||
# MaskableCategorical normalises its stored logits). A uniform offset is
|
||||
# invariant under both argmax and softmax, so we verify argmax + softmax
|
||||
# equivalence rather than raw-logit equality.
|
||||
def _softmax_t(x: torch.Tensor) -> torch.Tensor:
|
||||
z = x - x.max(dim=-1, keepdim=True).values
|
||||
e = torch.exp(z)
|
||||
return e / e.sum(dim=-1, keepdim=True)
|
||||
|
||||
torch.manual_seed(0)
|
||||
with torch.no_grad():
|
||||
max_err = 0.0
|
||||
for _ in range(8):
|
||||
probe = torch.randn((1, OBS_DIM), dtype=torch.float32) * 5.0
|
||||
ref_logits = policy.get_distribution(probe).distribution.logits.detach()
|
||||
head_logits = head(probe)
|
||||
assert int(head_logits.argmax()) == int(ref_logits.argmax()), (
|
||||
"LogitHead argmax diverges from policy"
|
||||
)
|
||||
sm_err = float((_softmax_t(head_logits) - _softmax_t(ref_logits)).abs().max())
|
||||
max_err = max(max_err, sm_err)
|
||||
assert max_err < 1e-5, f"LogitHead softmax diverges from policy: {max_err}"
|
||||
|
||||
dummy = torch.zeros((1, OBS_DIM), dtype=torch.float32)
|
||||
Path(onnx_out).parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.onnx.export(
|
||||
head,
|
||||
dummy,
|
||||
onnx_out,
|
||||
input_names=["obs"],
|
||||
output_names=["logits"],
|
||||
dynamic_axes=None, # fixed batch=1 — engine evaluates one obs at a time
|
||||
opset_version=13,
|
||||
do_constant_folding=True,
|
||||
dynamo=False, # legacy TorchScript exporter — avoids onnxscript dep (torch 2.9)
|
||||
)
|
||||
print(f"[export] wrote {onnx_out} (head self-check max_err={max_err:.2e})")
|
||||
return model, head
|
||||
|
||||
|
||||
def _masked_logits(logits: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
||||
out = logits.astype(np.float64).copy()
|
||||
out[~mask] = -1e30 # MaskablePPO uses -inf; -1e30 is exact under argmax/softmax
|
||||
return out
|
||||
|
||||
|
||||
def _softmax(x: np.ndarray) -> np.ndarray:
|
||||
z = x - np.max(x)
|
||||
e = np.exp(z)
|
||||
return e / np.sum(e)
|
||||
|
||||
|
||||
def capture_fixtures(model, head: LogitHead, seeds, turns: int) -> list:
|
||||
fixtures = []
|
||||
for seed in seeds:
|
||||
cfg = HarnessConfig(
|
||||
seed=seed, players=2, player_slots=(0, 1),
|
||||
map_size="duel", map_type="continents", victory_mode="domination",
|
||||
)
|
||||
client = HarnessClient(cfg)
|
||||
try:
|
||||
for _ in range(turns):
|
||||
view = client.view(slot=0)
|
||||
mask, idx_to_action = encode_legal_actions(view)
|
||||
if not mask.any():
|
||||
break
|
||||
obs = encode_observation(view)
|
||||
with torch.no_grad():
|
||||
logits = head(torch.from_numpy(obs).unsqueeze(0)).numpy()[0]
|
||||
masked = _masked_logits(logits, mask)
|
||||
argmax = int(np.argmax(masked))
|
||||
probs = _softmax(masked)
|
||||
topk = [int(i) for i in np.argsort(-masked)[:5] if mask[int(i)]]
|
||||
# SB3 reference predict (deterministic) — must equal our argmax.
|
||||
ref_action, _ = model.predict(obs, action_masks=mask, deterministic=True)
|
||||
fixtures.append({
|
||||
"seed": seed,
|
||||
"turn": int(view.get("turn", 0)),
|
||||
"view": view,
|
||||
"obs": [float(x) for x in obs],
|
||||
"mask": [bool(b) for b in mask],
|
||||
"logits": [float(x) for x in logits],
|
||||
"masked_logits": [float(x) for x in masked],
|
||||
"probs": [float(x) for x in probs],
|
||||
"argmax": argmax,
|
||||
"topk": topk,
|
||||
"ref_predict": int(ref_action),
|
||||
})
|
||||
assert argmax == int(ref_action), (
|
||||
f"seed {seed} turn {view.get('turn')}: our argmax {argmax} "
|
||||
f"!= SB3 predict {int(ref_action)}"
|
||||
)
|
||||
# advance both slots via scripted suggest to progress the game
|
||||
for s in (0, 1):
|
||||
_advance_slot(client, s)
|
||||
finally:
|
||||
client.shutdown()
|
||||
print(f"[fixtures] captured {len(fixtures)} fixtures (ACTION_DIM={ACTION_DIM})")
|
||||
return fixtures
|
||||
|
||||
|
||||
def _advance_slot(client: HarnessClient, slot: int) -> None:
|
||||
try:
|
||||
for a in client.suggest(slot=slot):
|
||||
v = client.view(slot=slot)
|
||||
t = a.get("type")
|
||||
try:
|
||||
if t == "end_turn":
|
||||
client.end_turn(slot=slot)
|
||||
else:
|
||||
client.act(a, slot=slot)
|
||||
except Exception: # noqa: BLE001
|
||||
break
|
||||
client.end_turn(slot=slot)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
client.drain_notifications()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True)
|
||||
ap.add_argument("--onnx-out", required=True)
|
||||
ap.add_argument("--fixtures-out", required=True)
|
||||
ap.add_argument("--fixture-seeds", type=int, nargs="+", default=[1, 2, 3])
|
||||
ap.add_argument("--turns", type=int, default=25)
|
||||
args = ap.parse_args()
|
||||
|
||||
model, head = export(args.model, args.onnx_out)
|
||||
fixtures = capture_fixtures(model, head, args.fixture_seeds, args.turns)
|
||||
Path(args.fixtures_out).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(args.fixtures_out, "w") as f:
|
||||
json.dump({"action_dim": ACTION_DIM, "obs_dim": OBS_DIM, "fixtures": fixtures}, f)
|
||||
print(f"[fixtures] wrote {args.fixtures_out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Reference in a new issue