242 lines
9.1 KiB
Python
242 lines
9.1 KiB
Python
"""Mine a trained MaskablePPO policy for heuristic-patch candidates.
|
|
|
|
Drives a scripted-vs-scripted duel (both slots controlled by the
|
|
built-in `suggest()` chain — the SAME trajectory p1-29c/29d measure
|
|
against) and, at the start of every turn, probes the trained policy on
|
|
the identical PlayerView WITHOUT applying its choice. Logs where the
|
|
policy's preferred action diverges from what the scripted AI actually
|
|
does, bucketed by board state (sole-city? trailing?).
|
|
|
|
The divergence we can extract is constrained by the policy's action
|
|
space (encoders.py): there is NO research action — only end_turn,
|
|
per-unit micro (skip/fortify/sentry/found_city/move/attack) and per-city
|
|
`queue_production` over a 16-item building roster. So the high-resolution
|
|
signal is build-queue composition + unit posture (settle / fortify /
|
|
attack), nothing else.
|
|
|
|
Usage:
|
|
python -m tooling.rl_self_play.mine_divergence \
|
|
--model-path tooling/rl_self_play/models/duel-v4-encfix-s7/best_model.zip \
|
|
--seeds 1,2,3,4,5,6,7,8,9,10 --turns 60 \
|
|
--out tooling/rl_self_play/runs/duel-v4-encfix-s7/divergence.json
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from collections import Counter, defaultdict
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
THIS_DIR = Path(__file__).resolve().parent
|
|
PROJECT_ROOT = THIS_DIR.parents[1]
|
|
if __package__ is None:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from tooling.rl_self_play.encoders import ( # noqa: E402
|
|
decode_action_index,
|
|
encode_legal_actions,
|
|
encode_observation,
|
|
)
|
|
from tooling.rl_self_play.harness_client import ( # noqa: E402
|
|
HarnessClient,
|
|
HarnessConfig,
|
|
)
|
|
from tooling.rl_self_play.record_expert import ( # noqa: E402
|
|
_DropStats,
|
|
_resolve,
|
|
)
|
|
|
|
|
|
def _my_units(view: dict[str, Any]) -> list[dict[str, Any]]:
|
|
me = int(view.get("player", 0))
|
|
return [u for u in view.get("units", []) if int(u.get("owner", -1)) == me]
|
|
|
|
|
|
def _my_cities(view: dict[str, Any]) -> list[dict[str, Any]]:
|
|
me = int(view.get("player", 0))
|
|
return [c for c in view.get("cities", []) if int(c.get("owner", -1)) == me]
|
|
|
|
|
|
def _action_label(a: dict[str, Any]) -> str:
|
|
"""Compact label: type, plus the queued item for queue_production."""
|
|
t = a.get("type", "?")
|
|
if t == "queue_production":
|
|
return f"build:{a.get('item','?')}"
|
|
return t
|
|
|
|
|
|
def _policy_top_action(model, view: dict[str, Any]) -> dict[str, Any]:
|
|
mask, idx_to_action = encode_legal_actions(view)
|
|
if not mask.any():
|
|
return {"type": "end_turn"}
|
|
obs = encode_observation(view)
|
|
action, _ = model.predict(obs, action_masks=mask, deterministic=True)
|
|
return decode_action_index(int(action), idx_to_action)
|
|
|
|
|
|
def _scripted_first_meaningful(client: HarnessClient, slot: int,
|
|
view: dict[str, Any]) -> dict[str, Any]:
|
|
"""First scripted action that resolves to a legal index and is not a
|
|
bare end_turn (so we compare 'what the AI does this turn', not the
|
|
turn boundary)."""
|
|
chain = client.suggest(slot=slot)
|
|
_, idx_to_action = encode_legal_actions(view)
|
|
for a in chain:
|
|
if a.get("type") == "end_turn":
|
|
continue
|
|
idx = _resolve(a, view, idx_to_action, _DropStats())
|
|
if idx is not None:
|
|
return idx_to_action[idx]
|
|
return {"type": "end_turn"}
|
|
|
|
|
|
def _advance_slot(client: HarnessClient, slot: int) -> None:
|
|
"""Apply the scripted chain for one slot, then end its turn — the
|
|
diag_suggest advance logic."""
|
|
for a in client.suggest(slot=slot):
|
|
v = client.view(slot=slot)
|
|
_, i2a = encode_legal_actions(v)
|
|
idx = _resolve(a, v, i2a, _DropStats())
|
|
if idx is None:
|
|
continue
|
|
r = i2a[idx]
|
|
try:
|
|
if r.get("type") == "end_turn":
|
|
client.end_turn(slot=slot)
|
|
else:
|
|
client.act(r, slot=slot)
|
|
except Exception: # noqa: BLE001 best-effort advance
|
|
break
|
|
try:
|
|
client.end_turn(slot=slot)
|
|
except Exception: # noqa: BLE001
|
|
pass
|
|
client.drain_notifications()
|
|
|
|
|
|
def _bucket(view: dict[str, Any]) -> str:
|
|
"""Coarse board-state bucket for the probed slot."""
|
|
ncities = len(_my_cities(view))
|
|
units = _my_units(view)
|
|
nmil = sum(1 for u in units if "founder" not in str(u.get("type", "")))
|
|
turn = int(view.get("turn", 0))
|
|
phase = "early" if turn <= 20 else ("mid" if turn <= 50 else "late")
|
|
city_state = "solo" if ncities <= 1 else "multi"
|
|
mil_state = "weak" if nmil <= 2 else "armed"
|
|
return f"{phase}/{city_state}/{mil_state}"
|
|
|
|
|
|
def mine_seed(model, seed: int, turns: int, probe_slot: int) -> dict[str, Any]:
|
|
cfg = HarnessConfig(
|
|
seed=seed, players=2, player_slots=(0, 1),
|
|
map_size="duel", map_type="continents", victory_mode="domination",
|
|
)
|
|
client = HarnessClient(cfg)
|
|
rows: list[dict[str, Any]] = []
|
|
try:
|
|
for _ in range(turns):
|
|
view = client.view(slot=probe_slot)
|
|
if not view.get("legal_actions") and not view.get("units"):
|
|
break
|
|
policy_a = _policy_top_action(model, view)
|
|
scripted_a = _scripted_first_meaningful(client, probe_slot, view)
|
|
rows.append({
|
|
"turn": int(view.get("turn", 0)),
|
|
"bucket": _bucket(view),
|
|
"ncities": len(_my_cities(view)),
|
|
"nunits": len(_my_units(view)),
|
|
"score": float(view.get("score", {}).get("score_estimate", 0.0)),
|
|
"policy": _action_label(policy_a),
|
|
"scripted": _action_label(scripted_a),
|
|
"agree": _action_label(policy_a) == _action_label(scripted_a),
|
|
})
|
|
# Advance BOTH slots on the scripted path.
|
|
for s in (0, 1):
|
|
_advance_slot(client, s)
|
|
# Stop if probed slot lost all cities and has no founder.
|
|
v2 = client.view(slot=probe_slot)
|
|
cs = _my_cities(v2)
|
|
founders = [u for u in _my_units(v2)
|
|
if "founder" in str(u.get("type", ""))]
|
|
if not cs and not founders:
|
|
rows.append({"turn": int(v2.get("turn", 0)),
|
|
"bucket": "ELIMINATED", "ncities": 0,
|
|
"nunits": len(_my_units(v2)), "score": 0.0,
|
|
"policy": "-", "scripted": "-", "agree": True})
|
|
break
|
|
finally:
|
|
client.shutdown()
|
|
return {"seed": seed, "probe_slot": probe_slot, "rows": rows}
|
|
|
|
|
|
def aggregate(results: list[dict[str, Any]]) -> dict[str, Any]:
|
|
# Divergences bucketed: when policy != scripted, what did each pick?
|
|
by_bucket: dict[str, dict[str, Counter]] = defaultdict(
|
|
lambda: {"policy": Counter(), "scripted": Counter()})
|
|
pair_counts: Counter = Counter()
|
|
total = agree = 0
|
|
elim_turns: list[int] = []
|
|
for res in results:
|
|
for r in res["rows"]:
|
|
if r["bucket"] == "ELIMINATED":
|
|
elim_turns.append(r["turn"])
|
|
continue
|
|
total += 1
|
|
if r["agree"]:
|
|
agree += 1
|
|
continue
|
|
by_bucket[r["bucket"]]["policy"][r["policy"]] += 1
|
|
by_bucket[r["bucket"]]["scripted"][r["scripted"]] += 1
|
|
pair_counts[(r["scripted"], r["policy"])] += 1
|
|
return {
|
|
"decisions": total,
|
|
"agreements": agree,
|
|
"agree_rate": round(agree / max(total, 1), 3),
|
|
"eliminations": len(elim_turns),
|
|
"elim_turns": sorted(elim_turns),
|
|
"top_divergences (scripted -> policy)": [
|
|
{"scripted": k[0], "policy": k[1], "count": v}
|
|
for k, v in pair_counts.most_common(20)
|
|
],
|
|
"by_bucket": {
|
|
b: {"policy_prefers": dict(d["policy"].most_common(6)),
|
|
"scripted_prefers": dict(d["scripted"].most_common(6))}
|
|
for b, d in sorted(by_bucket.items())
|
|
},
|
|
}
|
|
|
|
|
|
def main() -> int:
|
|
p = argparse.ArgumentParser(description=__doc__)
|
|
p.add_argument("--model-path", required=True, type=Path)
|
|
p.add_argument("--seeds", default="1,2,3,4,5,6,7,8,9,10")
|
|
p.add_argument("--turns", type=int, default=60)
|
|
p.add_argument("--probe-slot", type=int, default=1,
|
|
help="Slot to probe (1 = the trailing seat in p1-29d).")
|
|
p.add_argument("--out", type=Path, default=None)
|
|
args = p.parse_args()
|
|
|
|
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
|
|
model = MaskablePPO.load(str(args.model_path))
|
|
|
|
seeds = [int(s) for s in args.seeds.split(",") if s.strip()]
|
|
results = []
|
|
for seed in seeds:
|
|
print(f"[mine] seed {seed} probe_slot {args.probe_slot} ...",
|
|
file=sys.stderr, flush=True)
|
|
results.append(mine_seed(model, seed, args.turns, args.probe_slot))
|
|
agg = aggregate(results)
|
|
out = {"model": str(args.model_path), "seeds": seeds,
|
|
"turns": args.turns, "probe_slot": args.probe_slot,
|
|
"aggregate": agg, "per_seed": results}
|
|
print(json.dumps(agg, indent=2, default=str))
|
|
if args.out:
|
|
args.out.write_text(json.dumps(out, indent=2, default=str))
|
|
print(f"\n[mine] wrote {args.out}", file=sys.stderr)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|