162 lines
5.9 KiB
Python
162 lines
5.9 KiB
Python
"""Train a MaskablePPO policy against the harness's built-in AI.
|
|
|
|
Usage:
|
|
cd tooling/rl-self-play
|
|
pip install -r requirements.txt
|
|
python -m tooling.rl-self-play.train --total-steps 1_000_000
|
|
|
|
Run via TensorBoard for live curves:
|
|
tensorboard --logdir tooling/rl-self-play/runs/
|
|
|
|
The training loop:
|
|
|
|
1. K parallel `MagicCivEnv` instances are spawned (each owns a Godot
|
|
harness subprocess; rule of thumb: K = min(physical cores // 2, 8)).
|
|
2. MaskablePPO collects on-policy rollouts across all K envs, learns
|
|
for `total_timesteps`.
|
|
3. Every `eval_freq` steps we run a held-out eval against the same
|
|
baseline and record win-rate. When win-rate crosses
|
|
`--target-win-rate` (default 0.55) we save the model as
|
|
`models/winner.zip` and exit.
|
|
|
|
This script is intentionally minimal — no curriculum, no
|
|
self-play-against-frozen-snapshots, no league. Those are reasonable
|
|
extensions once the basic policy starts winning at all (which itself
|
|
will take hours on apricot).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
THIS_DIR = Path(__file__).resolve().parent
|
|
PROJECT_ROOT = THIS_DIR.parents[1]
|
|
|
|
# Resolve module path so the script works whether invoked as a module
|
|
# (`python -m tooling.rl-self-play.train`) or as a plain script
|
|
# (`python train.py`). Both paths matter — the former is the canonical
|
|
# way; the latter helps quick iteration without re-installing.
|
|
if __package__ is None:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found]
|
|
|
|
|
|
def _build_argparser() -> argparse.ArgumentParser:
|
|
p = argparse.ArgumentParser(description="Train MaskablePPO on Magic Civilization")
|
|
p.add_argument("--total-steps", type=int, default=1_000_000,
|
|
help="Total environment steps (default: 1M).")
|
|
p.add_argument("--num-envs", type=int, default=4,
|
|
help="Parallel envs; each spawns its own harness (default: 4).")
|
|
p.add_argument("--max-turns", type=int, default=200,
|
|
help="Per-episode turn limit before truncation (default: 200).")
|
|
p.add_argument("--map-size", default="duel",
|
|
help="MapGenerator size key (default: duel).")
|
|
p.add_argument("--players", type=int, default=2,
|
|
help="Total player slots in each game (default: 2).")
|
|
p.add_argument("--eval-freq", type=int, default=20_000,
|
|
help="Run eval every N steps (default: 20k).")
|
|
p.add_argument("--eval-episodes", type=int, default=20,
|
|
help="Episodes per eval (default: 20).")
|
|
p.add_argument("--target-win-rate", type=float, default=0.55,
|
|
help="Stop training once eval win-rate exceeds this (default: 0.55).")
|
|
p.add_argument("--run-name", default="duel-v1",
|
|
help="Subdirectory under runs/ + models/ (default: duel-v1).")
|
|
p.add_argument("--seed", type=int, default=42,
|
|
help="Base RNG seed; per-env seeds offset from this (default: 42).")
|
|
return p
|
|
|
|
|
|
def _make_env_factory(args: argparse.Namespace, env_idx: int):
|
|
"""Return a thunk that constructs one MagicCivEnv. sb3 expects these
|
|
as factories so each subprocess builds its own env after fork."""
|
|
from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found]
|
|
|
|
def _make() -> MagicCivEnv:
|
|
cfg = HarnessConfig(
|
|
seed=args.seed + env_idx,
|
|
players=args.players,
|
|
player_slot=0,
|
|
map_size=args.map_size,
|
|
map_type="continents",
|
|
)
|
|
return MagicCivEnv(harness_config=cfg, max_turns=args.max_turns)
|
|
|
|
return _make
|
|
|
|
|
|
def main() -> int:
|
|
args = _build_argparser().parse_args()
|
|
|
|
# Lazy imports — sb3 + torch are heavy and only needed once we
|
|
# commit to running. Lets `--help` stay fast.
|
|
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
|
|
from sb3_contrib.common.maskable.callbacks import ( # type: ignore[import-not-found]
|
|
MaskableEvalCallback,
|
|
)
|
|
from sb3_contrib.common.maskable.utils import ( # type: ignore[import-not-found]
|
|
get_action_masks,
|
|
)
|
|
from stable_baselines3.common.vec_env import ( # type: ignore[import-not-found]
|
|
DummyVecEnv,
|
|
SubprocVecEnv,
|
|
)
|
|
|
|
run_dir = THIS_DIR / "runs" / args.run_name
|
|
model_dir = THIS_DIR / "models" / args.run_name
|
|
run_dir.mkdir(parents=True, exist_ok=True)
|
|
model_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
factories = [_make_env_factory(args, i) for i in range(args.num_envs)]
|
|
# SubprocVecEnv runs each env in its own process — necessary because
|
|
# each env owns a Godot subprocess (we don't want one harness's
|
|
# JSON-Lines pipe to block sibling envs). DummyVecEnv is the
|
|
# single-process fallback for debugging.
|
|
env_cls = SubprocVecEnv if args.num_envs > 1 else DummyVecEnv
|
|
train_env = env_cls(factories)
|
|
|
|
eval_env = DummyVecEnv([_make_env_factory(args, 1000)])
|
|
eval_callback = MaskableEvalCallback(
|
|
eval_env,
|
|
best_model_save_path=str(model_dir),
|
|
log_path=str(run_dir / "eval"),
|
|
eval_freq=max(args.eval_freq // args.num_envs, 1),
|
|
n_eval_episodes=args.eval_episodes,
|
|
deterministic=True,
|
|
render=False,
|
|
)
|
|
|
|
model = MaskablePPO(
|
|
"MlpPolicy",
|
|
train_env,
|
|
verbose=1,
|
|
tensorboard_log=str(run_dir),
|
|
seed=args.seed,
|
|
n_steps=512,
|
|
batch_size=128,
|
|
learning_rate=3e-4,
|
|
gamma=0.995,
|
|
gae_lambda=0.95,
|
|
ent_coef=0.01,
|
|
)
|
|
|
|
try:
|
|
model.learn(
|
|
total_timesteps=args.total_steps,
|
|
callback=eval_callback,
|
|
progress_bar=True,
|
|
)
|
|
finally:
|
|
train_env.close()
|
|
eval_env.close()
|
|
model.save(str(model_dir / "final.zip"))
|
|
print(f"training complete; model saved to {model_dir / 'final.zip'}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|