feat(rl-self-play): ✨ Add learned opponent policy evaluation options to RL self-play evaluation script
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
bb15503079
commit
e2e578cdab
1 changed files with 28 additions and 1 deletions
|
|
@ -41,6 +41,16 @@ def _build_argparser() -> argparse.ArgumentParser:
|
|||
help="Eval episode seeds = offset + episode_idx; avoids overlap with train seeds")
|
||||
p.add_argument("--players", type=int, default=2)
|
||||
p.add_argument("--map-size", default="duel")
|
||||
p.add_argument("--opponent-model", default=None, type=Path,
|
||||
help=("Frozen MaskablePPO snapshot (.zip) to use as the "
|
||||
"opponent instead of the harness MCTS. Set to the "
|
||||
"graduated snapshot to measure win-rate against a "
|
||||
"learned policy (self-play curriculum gate)."))
|
||||
p.add_argument("--opponent-slots", default="1",
|
||||
help="Comma-separated opponent slot indices (default '1').")
|
||||
p.add_argument("--opponent-device", default="cpu")
|
||||
p.add_argument("--opponent-deterministic", action="store_true",
|
||||
help="Argmax opponent actions (default: stochastic sampling).")
|
||||
return p
|
||||
|
||||
|
||||
|
|
@ -75,9 +85,24 @@ def main() -> int:
|
|||
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
|
||||
from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found]
|
||||
from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found]
|
||||
from tooling.rl_self_play.opponent import ModelOpponent # type: ignore[import-not-found]
|
||||
|
||||
model = MaskablePPO.load(str(args.model_path))
|
||||
|
||||
opp_slots: tuple[int, ...] = tuple(
|
||||
int(s) for s in str(args.opponent_slots).split(",") if s.strip()
|
||||
)
|
||||
|
||||
def _make_opponent() -> ModelOpponent | None:
|
||||
if not args.opponent_model:
|
||||
return None
|
||||
return ModelOpponent(
|
||||
model_path=str(args.opponent_model),
|
||||
slots=opp_slots,
|
||||
device=args.opponent_device,
|
||||
deterministic=args.opponent_deterministic,
|
||||
)
|
||||
|
||||
wins = losses = draws = step_caps = turn_caps = 0
|
||||
turns_per_episode: list[int] = []
|
||||
for episode in range(args.episodes):
|
||||
|
|
@ -87,7 +112,9 @@ def main() -> int:
|
|||
player_slot=0,
|
||||
map_size=args.map_size,
|
||||
)
|
||||
env = MagicCivEnv(harness_config=cfg, max_turns=args.max_turns)
|
||||
env = MagicCivEnv(
|
||||
harness_config=cfg, max_turns=args.max_turns, opponent=_make_opponent()
|
||||
)
|
||||
try:
|
||||
obs, info = env.reset()
|
||||
done = False
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue