feat(rl-self-play): ✨ Add mine divergence metric for evaluating strategy differences in RL self-play
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
2d7357550e
commit
bb15503079
1 changed files with 242 additions and 0 deletions
242
tooling/rl_self_play/mine_divergence.py
Normal file
242
tooling/rl_self_play/mine_divergence.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
"""Mine a trained MaskablePPO policy for heuristic-patch candidates.
|
||||
|
||||
Drives a scripted-vs-scripted duel (both slots controlled by the
|
||||
built-in `suggest()` chain — the SAME trajectory p1-29c/29d measure
|
||||
against) and, at the start of every turn, probes the trained policy on
|
||||
the identical PlayerView WITHOUT applying its choice. Logs where the
|
||||
policy's preferred action diverges from what the scripted AI actually
|
||||
does, bucketed by board state (sole-city? trailing?).
|
||||
|
||||
The divergence we can extract is constrained by the policy's action
|
||||
space (encoders.py): there is NO research action — only end_turn,
|
||||
per-unit micro (skip/fortify/sentry/found_city/move/attack) and per-city
|
||||
`queue_production` over a 16-item building roster. So the high-resolution
|
||||
signal is build-queue composition + unit posture (settle / fortify /
|
||||
attack), nothing else.
|
||||
|
||||
Usage:
|
||||
python -m tooling.rl_self_play.mine_divergence \
|
||||
--model-path tooling/rl_self_play/models/duel-v4-encfix-s7/best_model.zip \
|
||||
--seeds 1,2,3,4,5,6,7,8,9,10 --turns 60 \
|
||||
--out tooling/rl_self_play/runs/duel-v4-encfix-s7/divergence.json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter, defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
THIS_DIR = Path(__file__).resolve().parent
|
||||
PROJECT_ROOT = THIS_DIR.parents[1]
|
||||
if __package__ is None:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from tooling.rl_self_play.encoders import ( # noqa: E402
|
||||
decode_action_index,
|
||||
encode_legal_actions,
|
||||
encode_observation,
|
||||
)
|
||||
from tooling.rl_self_play.harness_client import ( # noqa: E402
|
||||
HarnessClient,
|
||||
HarnessConfig,
|
||||
)
|
||||
from tooling.rl_self_play.record_expert import ( # noqa: E402
|
||||
_DropStats,
|
||||
_resolve,
|
||||
)
|
||||
|
||||
|
||||
def _my_units(view: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
me = int(view.get("player", 0))
|
||||
return [u for u in view.get("units", []) if int(u.get("owner", -1)) == me]
|
||||
|
||||
|
||||
def _my_cities(view: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
me = int(view.get("player", 0))
|
||||
return [c for c in view.get("cities", []) if int(c.get("owner", -1)) == me]
|
||||
|
||||
|
||||
def _action_label(a: dict[str, Any]) -> str:
|
||||
"""Compact label: type, plus the queued item for queue_production."""
|
||||
t = a.get("type", "?")
|
||||
if t == "queue_production":
|
||||
return f"build:{a.get('item','?')}"
|
||||
return t
|
||||
|
||||
|
||||
def _policy_top_action(model, view: dict[str, Any]) -> dict[str, Any]:
|
||||
mask, idx_to_action = encode_legal_actions(view)
|
||||
if not mask.any():
|
||||
return {"type": "end_turn"}
|
||||
obs = encode_observation(view)
|
||||
action, _ = model.predict(obs, action_masks=mask, deterministic=True)
|
||||
return decode_action_index(int(action), idx_to_action)
|
||||
|
||||
|
||||
def _scripted_first_meaningful(client: HarnessClient, slot: int,
|
||||
view: dict[str, Any]) -> dict[str, Any]:
|
||||
"""First scripted action that resolves to a legal index and is not a
|
||||
bare end_turn (so we compare 'what the AI does this turn', not the
|
||||
turn boundary)."""
|
||||
chain = client.suggest(slot=slot)
|
||||
_, idx_to_action = encode_legal_actions(view)
|
||||
for a in chain:
|
||||
if a.get("type") == "end_turn":
|
||||
continue
|
||||
idx = _resolve(a, view, idx_to_action, _DropStats())
|
||||
if idx is not None:
|
||||
return idx_to_action[idx]
|
||||
return {"type": "end_turn"}
|
||||
|
||||
|
||||
def _advance_slot(client: HarnessClient, slot: int) -> None:
|
||||
"""Apply the scripted chain for one slot, then end its turn — the
|
||||
diag_suggest advance logic."""
|
||||
for a in client.suggest(slot=slot):
|
||||
v = client.view(slot=slot)
|
||||
_, i2a = encode_legal_actions(v)
|
||||
idx = _resolve(a, v, i2a, _DropStats())
|
||||
if idx is None:
|
||||
continue
|
||||
r = i2a[idx]
|
||||
try:
|
||||
if r.get("type") == "end_turn":
|
||||
client.end_turn(slot=slot)
|
||||
else:
|
||||
client.act(r, slot=slot)
|
||||
except Exception: # noqa: BLE001 best-effort advance
|
||||
break
|
||||
try:
|
||||
client.end_turn(slot=slot)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
client.drain_notifications()
|
||||
|
||||
|
||||
def _bucket(view: dict[str, Any]) -> str:
|
||||
"""Coarse board-state bucket for the probed slot."""
|
||||
ncities = len(_my_cities(view))
|
||||
units = _my_units(view)
|
||||
nmil = sum(1 for u in units if "founder" not in str(u.get("type", "")))
|
||||
turn = int(view.get("turn", 0))
|
||||
phase = "early" if turn <= 20 else ("mid" if turn <= 50 else "late")
|
||||
city_state = "solo" if ncities <= 1 else "multi"
|
||||
mil_state = "weak" if nmil <= 2 else "armed"
|
||||
return f"{phase}/{city_state}/{mil_state}"
|
||||
|
||||
|
||||
def mine_seed(model, seed: int, turns: int, probe_slot: int) -> dict[str, Any]:
|
||||
cfg = HarnessConfig(
|
||||
seed=seed, players=2, player_slots=(0, 1),
|
||||
map_size="duel", map_type="continents", victory_mode="domination",
|
||||
)
|
||||
client = HarnessClient(cfg)
|
||||
rows: list[dict[str, Any]] = []
|
||||
try:
|
||||
for _ in range(turns):
|
||||
view = client.view(slot=probe_slot)
|
||||
if not view.get("legal_actions") and not view.get("units"):
|
||||
break
|
||||
policy_a = _policy_top_action(model, view)
|
||||
scripted_a = _scripted_first_meaningful(client, probe_slot, view)
|
||||
rows.append({
|
||||
"turn": int(view.get("turn", 0)),
|
||||
"bucket": _bucket(view),
|
||||
"ncities": len(_my_cities(view)),
|
||||
"nunits": len(_my_units(view)),
|
||||
"score": float(view.get("score", {}).get("score_estimate", 0.0)),
|
||||
"policy": _action_label(policy_a),
|
||||
"scripted": _action_label(scripted_a),
|
||||
"agree": _action_label(policy_a) == _action_label(scripted_a),
|
||||
})
|
||||
# Advance BOTH slots on the scripted path.
|
||||
for s in (0, 1):
|
||||
_advance_slot(client, s)
|
||||
# Stop if probed slot lost all cities and has no founder.
|
||||
v2 = client.view(slot=probe_slot)
|
||||
cs = _my_cities(v2)
|
||||
founders = [u for u in _my_units(v2)
|
||||
if "founder" in str(u.get("type", ""))]
|
||||
if not cs and not founders:
|
||||
rows.append({"turn": int(v2.get("turn", 0)),
|
||||
"bucket": "ELIMINATED", "ncities": 0,
|
||||
"nunits": len(_my_units(v2)), "score": 0.0,
|
||||
"policy": "-", "scripted": "-", "agree": True})
|
||||
break
|
||||
finally:
|
||||
client.shutdown()
|
||||
return {"seed": seed, "probe_slot": probe_slot, "rows": rows}
|
||||
|
||||
|
||||
def aggregate(results: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
# Divergences bucketed: when policy != scripted, what did each pick?
|
||||
by_bucket: dict[str, dict[str, Counter]] = defaultdict(
|
||||
lambda: {"policy": Counter(), "scripted": Counter()})
|
||||
pair_counts: Counter = Counter()
|
||||
total = agree = 0
|
||||
elim_turns: list[int] = []
|
||||
for res in results:
|
||||
for r in res["rows"]:
|
||||
if r["bucket"] == "ELIMINATED":
|
||||
elim_turns.append(r["turn"])
|
||||
continue
|
||||
total += 1
|
||||
if r["agree"]:
|
||||
agree += 1
|
||||
continue
|
||||
by_bucket[r["bucket"]]["policy"][r["policy"]] += 1
|
||||
by_bucket[r["bucket"]]["scripted"][r["scripted"]] += 1
|
||||
pair_counts[(r["scripted"], r["policy"])] += 1
|
||||
return {
|
||||
"decisions": total,
|
||||
"agreements": agree,
|
||||
"agree_rate": round(agree / max(total, 1), 3),
|
||||
"eliminations": len(elim_turns),
|
||||
"elim_turns": sorted(elim_turns),
|
||||
"top_divergences (scripted -> policy)": [
|
||||
{"scripted": k[0], "policy": k[1], "count": v}
|
||||
for k, v in pair_counts.most_common(20)
|
||||
],
|
||||
"by_bucket": {
|
||||
b: {"policy_prefers": dict(d["policy"].most_common(6)),
|
||||
"scripted_prefers": dict(d["scripted"].most_common(6))}
|
||||
for b, d in sorted(by_bucket.items())
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
p = argparse.ArgumentParser(description=__doc__)
|
||||
p.add_argument("--model-path", required=True, type=Path)
|
||||
p.add_argument("--seeds", default="1,2,3,4,5,6,7,8,9,10")
|
||||
p.add_argument("--turns", type=int, default=60)
|
||||
p.add_argument("--probe-slot", type=int, default=1,
|
||||
help="Slot to probe (1 = the trailing seat in p1-29d).")
|
||||
p.add_argument("--out", type=Path, default=None)
|
||||
args = p.parse_args()
|
||||
|
||||
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
|
||||
model = MaskablePPO.load(str(args.model_path))
|
||||
|
||||
seeds = [int(s) for s in args.seeds.split(",") if s.strip()]
|
||||
results = []
|
||||
for seed in seeds:
|
||||
print(f"[mine] seed {seed} probe_slot {args.probe_slot} ...",
|
||||
file=sys.stderr, flush=True)
|
||||
results.append(mine_seed(model, seed, args.turns, args.probe_slot))
|
||||
agg = aggregate(results)
|
||||
out = {"model": str(args.model_path), "seeds": seeds,
|
||||
"turns": args.turns, "probe_slot": args.probe_slot,
|
||||
"aggregate": agg, "per_seed": results}
|
||||
print(json.dumps(agg, indent=2, default=str))
|
||||
if args.out:
|
||||
args.out.write_text(json.dumps(out, indent=2, default=str))
|
||||
print(f"\n[mine] wrote {args.out}", file=sys.stderr)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Loading…
Add table
Reference in a new issue