diff --git a/tooling/rl_self_play/mine_divergence.py b/tooling/rl_self_play/mine_divergence.py new file mode 100644 index 00000000..ca4bd0f3 --- /dev/null +++ b/tooling/rl_self_play/mine_divergence.py @@ -0,0 +1,242 @@ +"""Mine a trained MaskablePPO policy for heuristic-patch candidates. + +Drives a scripted-vs-scripted duel (both slots controlled by the +built-in `suggest()` chain — the SAME trajectory p1-29c/29d measure +against) and, at the start of every turn, probes the trained policy on +the identical PlayerView WITHOUT applying its choice. Logs where the +policy's preferred action diverges from what the scripted AI actually +does, bucketed by board state (sole-city? trailing?). + +The divergence we can extract is constrained by the policy's action +space (encoders.py): there is NO research action — only end_turn, +per-unit micro (skip/fortify/sentry/found_city/move/attack) and per-city +`queue_production` over a 16-item building roster. So the high-resolution +signal is build-queue composition + unit posture (settle / fortify / +attack), nothing else. + +Usage: + python -m tooling.rl_self_play.mine_divergence \ + --model-path tooling/rl_self_play/models/duel-v4-encfix-s7/best_model.zip \ + --seeds 1,2,3,4,5,6,7,8,9,10 --turns 60 \ + --out tooling/rl_self_play/runs/duel-v4-encfix-s7/divergence.json +""" +from __future__ import annotations + +import argparse +import json +import sys +from collections import Counter, defaultdict +from pathlib import Path +from typing import Any + +THIS_DIR = Path(__file__).resolve().parent +PROJECT_ROOT = THIS_DIR.parents[1] +if __package__ is None: + sys.path.insert(0, str(PROJECT_ROOT)) + +from tooling.rl_self_play.encoders import ( # noqa: E402 + decode_action_index, + encode_legal_actions, + encode_observation, +) +from tooling.rl_self_play.harness_client import ( # noqa: E402 + HarnessClient, + HarnessConfig, +) +from tooling.rl_self_play.record_expert import ( # noqa: E402 + _DropStats, + _resolve, +) + + +def _my_units(view: dict[str, Any]) -> list[dict[str, Any]]: + me = int(view.get("player", 0)) + return [u for u in view.get("units", []) if int(u.get("owner", -1)) == me] + + +def _my_cities(view: dict[str, Any]) -> list[dict[str, Any]]: + me = int(view.get("player", 0)) + return [c for c in view.get("cities", []) if int(c.get("owner", -1)) == me] + + +def _action_label(a: dict[str, Any]) -> str: + """Compact label: type, plus the queued item for queue_production.""" + t = a.get("type", "?") + if t == "queue_production": + return f"build:{a.get('item','?')}" + return t + + +def _policy_top_action(model, view: dict[str, Any]) -> dict[str, Any]: + mask, idx_to_action = encode_legal_actions(view) + if not mask.any(): + return {"type": "end_turn"} + obs = encode_observation(view) + action, _ = model.predict(obs, action_masks=mask, deterministic=True) + return decode_action_index(int(action), idx_to_action) + + +def _scripted_first_meaningful(client: HarnessClient, slot: int, + view: dict[str, Any]) -> dict[str, Any]: + """First scripted action that resolves to a legal index and is not a + bare end_turn (so we compare 'what the AI does this turn', not the + turn boundary).""" + chain = client.suggest(slot=slot) + _, idx_to_action = encode_legal_actions(view) + for a in chain: + if a.get("type") == "end_turn": + continue + idx = _resolve(a, view, idx_to_action, _DropStats()) + if idx is not None: + return idx_to_action[idx] + return {"type": "end_turn"} + + +def _advance_slot(client: HarnessClient, slot: int) -> None: + """Apply the scripted chain for one slot, then end its turn — the + diag_suggest advance logic.""" + for a in client.suggest(slot=slot): + v = client.view(slot=slot) + _, i2a = encode_legal_actions(v) + idx = _resolve(a, v, i2a, _DropStats()) + if idx is None: + continue + r = i2a[idx] + try: + if r.get("type") == "end_turn": + client.end_turn(slot=slot) + else: + client.act(r, slot=slot) + except Exception: # noqa: BLE001 best-effort advance + break + try: + client.end_turn(slot=slot) + except Exception: # noqa: BLE001 + pass + client.drain_notifications() + + +def _bucket(view: dict[str, Any]) -> str: + """Coarse board-state bucket for the probed slot.""" + ncities = len(_my_cities(view)) + units = _my_units(view) + nmil = sum(1 for u in units if "founder" not in str(u.get("type", ""))) + turn = int(view.get("turn", 0)) + phase = "early" if turn <= 20 else ("mid" if turn <= 50 else "late") + city_state = "solo" if ncities <= 1 else "multi" + mil_state = "weak" if nmil <= 2 else "armed" + return f"{phase}/{city_state}/{mil_state}" + + +def mine_seed(model, seed: int, turns: int, probe_slot: int) -> dict[str, Any]: + cfg = HarnessConfig( + seed=seed, players=2, player_slots=(0, 1), + map_size="duel", map_type="continents", victory_mode="domination", + ) + client = HarnessClient(cfg) + rows: list[dict[str, Any]] = [] + try: + for _ in range(turns): + view = client.view(slot=probe_slot) + if not view.get("legal_actions") and not view.get("units"): + break + policy_a = _policy_top_action(model, view) + scripted_a = _scripted_first_meaningful(client, probe_slot, view) + rows.append({ + "turn": int(view.get("turn", 0)), + "bucket": _bucket(view), + "ncities": len(_my_cities(view)), + "nunits": len(_my_units(view)), + "score": float(view.get("score", {}).get("score_estimate", 0.0)), + "policy": _action_label(policy_a), + "scripted": _action_label(scripted_a), + "agree": _action_label(policy_a) == _action_label(scripted_a), + }) + # Advance BOTH slots on the scripted path. + for s in (0, 1): + _advance_slot(client, s) + # Stop if probed slot lost all cities and has no founder. + v2 = client.view(slot=probe_slot) + cs = _my_cities(v2) + founders = [u for u in _my_units(v2) + if "founder" in str(u.get("type", ""))] + if not cs and not founders: + rows.append({"turn": int(v2.get("turn", 0)), + "bucket": "ELIMINATED", "ncities": 0, + "nunits": len(_my_units(v2)), "score": 0.0, + "policy": "-", "scripted": "-", "agree": True}) + break + finally: + client.shutdown() + return {"seed": seed, "probe_slot": probe_slot, "rows": rows} + + +def aggregate(results: list[dict[str, Any]]) -> dict[str, Any]: + # Divergences bucketed: when policy != scripted, what did each pick? + by_bucket: dict[str, dict[str, Counter]] = defaultdict( + lambda: {"policy": Counter(), "scripted": Counter()}) + pair_counts: Counter = Counter() + total = agree = 0 + elim_turns: list[int] = [] + for res in results: + for r in res["rows"]: + if r["bucket"] == "ELIMINATED": + elim_turns.append(r["turn"]) + continue + total += 1 + if r["agree"]: + agree += 1 + continue + by_bucket[r["bucket"]]["policy"][r["policy"]] += 1 + by_bucket[r["bucket"]]["scripted"][r["scripted"]] += 1 + pair_counts[(r["scripted"], r["policy"])] += 1 + return { + "decisions": total, + "agreements": agree, + "agree_rate": round(agree / max(total, 1), 3), + "eliminations": len(elim_turns), + "elim_turns": sorted(elim_turns), + "top_divergences (scripted -> policy)": [ + {"scripted": k[0], "policy": k[1], "count": v} + for k, v in pair_counts.most_common(20) + ], + "by_bucket": { + b: {"policy_prefers": dict(d["policy"].most_common(6)), + "scripted_prefers": dict(d["scripted"].most_common(6))} + for b, d in sorted(by_bucket.items()) + }, + } + + +def main() -> int: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--model-path", required=True, type=Path) + p.add_argument("--seeds", default="1,2,3,4,5,6,7,8,9,10") + p.add_argument("--turns", type=int, default=60) + p.add_argument("--probe-slot", type=int, default=1, + help="Slot to probe (1 = the trailing seat in p1-29d).") + p.add_argument("--out", type=Path, default=None) + args = p.parse_args() + + from sb3_contrib import MaskablePPO # type: ignore[import-not-found] + model = MaskablePPO.load(str(args.model_path)) + + seeds = [int(s) for s in args.seeds.split(",") if s.strip()] + results = [] + for seed in seeds: + print(f"[mine] seed {seed} probe_slot {args.probe_slot} ...", + file=sys.stderr, flush=True) + results.append(mine_seed(model, seed, args.turns, args.probe_slot)) + agg = aggregate(results) + out = {"model": str(args.model_path), "seeds": seeds, + "turns": args.turns, "probe_slot": args.probe_slot, + "aggregate": agg, "per_seed": results} + print(json.dumps(agg, indent=2, default=str)) + if args.out: + args.out.write_text(json.dumps(out, indent=2, default=str)) + print(f"\n[mine] wrote {args.out}", file=sys.stderr) + return 0 + + +if __name__ == "__main__": + sys.exit(main())