195 lines
7.6 KiB
Python
195 lines
7.6 KiB
Python
"""Train a MaskablePPO policy against the harness's built-in AI.
|
||
|
||
Usage:
|
||
cd tooling/rl-self-play
|
||
pip install -r requirements.txt
|
||
python -m tooling.rl-self-play.train --total-steps 1_000_000
|
||
|
||
Run via TensorBoard for live curves:
|
||
tensorboard --logdir tooling/rl-self-play/runs/
|
||
|
||
The training loop:
|
||
|
||
1. K parallel `MagicCivEnv` instances are spawned (each owns a Godot
|
||
harness subprocess; rule of thumb: K = min(physical cores // 2, 8)).
|
||
2. MaskablePPO collects on-policy rollouts across all K envs, learns
|
||
for `total_timesteps`.
|
||
3. Every `eval_freq` steps we run a held-out eval against the same
|
||
baseline and record win-rate. When win-rate crosses
|
||
`--target-win-rate` (default 0.55) we save the model as
|
||
`models/winner.zip` and exit.
|
||
|
||
This script is intentionally minimal — no curriculum, no
|
||
self-play-against-frozen-snapshots, no league. Those are reasonable
|
||
extensions once the basic policy starts winning at all (which itself
|
||
will take hours on apricot).
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
|
||
THIS_DIR = Path(__file__).resolve().parent
|
||
PROJECT_ROOT = THIS_DIR.parents[1]
|
||
|
||
# Resolve module path so the script works whether invoked as a module
|
||
# (`python -m tooling.rl-self-play.train`) or as a plain script
|
||
# (`python train.py`). Both paths matter — the former is the canonical
|
||
# way; the latter helps quick iteration without re-installing.
|
||
if __package__ is None:
|
||
sys.path.insert(0, str(PROJECT_ROOT))
|
||
|
||
from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found]
|
||
|
||
|
||
def _build_argparser() -> argparse.ArgumentParser:
|
||
p = argparse.ArgumentParser(description="Train MaskablePPO on Magic Civilization")
|
||
p.add_argument("--total-steps", type=int, default=1_000_000,
|
||
help="Total environment steps (default: 1M).")
|
||
p.add_argument("--num-envs", type=int, default=4,
|
||
help="Parallel envs; each spawns its own harness (default: 4).")
|
||
p.add_argument("--max-turns", type=int, default=200,
|
||
help="Per-episode turn limit before truncation (default: 200).")
|
||
p.add_argument("--map-size", default="duel",
|
||
help="MapGenerator size key (default: duel).")
|
||
p.add_argument("--players", type=int, default=2,
|
||
help="Total player slots in each game (default: 2).")
|
||
p.add_argument("--eval-freq", type=int, default=20_000,
|
||
help="Run eval every N steps (default: 20k).")
|
||
p.add_argument("--eval-episodes", type=int, default=20,
|
||
help="Episodes per eval (default: 20).")
|
||
p.add_argument("--target-win-rate", type=float, default=0.55,
|
||
help="Stop training once eval win-rate exceeds this (default: 0.55).")
|
||
p.add_argument("--run-name", default="duel-v1",
|
||
help="Subdirectory under runs/ + models/ (default: duel-v1).")
|
||
p.add_argument("--seed", type=int, default=42,
|
||
help="Base RNG seed; per-env seeds offset from this (default: 42).")
|
||
p.add_argument("--device", default="auto",
|
||
help=("Torch device for the policy net: 'auto' (default — "
|
||
"picks cuda if available, else cpu), 'cuda', "
|
||
"'cuda:1' (second GPU), 'mps' (Apple Silicon), or "
|
||
"'cpu'. On apricot, prefer 'cuda:1' so cuda:0 stays "
|
||
"free for model-boss / MCTS rollouts."))
|
||
return p
|
||
|
||
|
||
def _make_env_factory(args: argparse.Namespace, env_idx: int):
|
||
"""Return a thunk that constructs one MagicCivEnv. sb3 expects these
|
||
as factories so each subprocess builds its own env after fork."""
|
||
from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found]
|
||
|
||
def _make() -> MagicCivEnv:
|
||
cfg = HarnessConfig(
|
||
seed=args.seed + env_idx,
|
||
players=args.players,
|
||
player_slot=0,
|
||
map_size=args.map_size,
|
||
map_type="continents",
|
||
)
|
||
return MagicCivEnv(harness_config=cfg, max_turns=args.max_turns)
|
||
|
||
return _make
|
||
|
||
|
||
def main() -> int:
|
||
args = _build_argparser().parse_args()
|
||
|
||
# Lazy imports — sb3 + torch are heavy and only needed once we
|
||
# commit to running. Lets `--help` stay fast.
|
||
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
|
||
from sb3_contrib.common.maskable.callbacks import ( # type: ignore[import-not-found]
|
||
MaskableEvalCallback,
|
||
)
|
||
from sb3_contrib.common.maskable.utils import ( # type: ignore[import-not-found]
|
||
get_action_masks,
|
||
)
|
||
from stable_baselines3.common.vec_env import ( # type: ignore[import-not-found]
|
||
DummyVecEnv,
|
||
SubprocVecEnv,
|
||
)
|
||
|
||
run_dir = THIS_DIR / "runs" / args.run_name
|
||
model_dir = THIS_DIR / "models" / args.run_name
|
||
run_dir.mkdir(parents=True, exist_ok=True)
|
||
model_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
factories = [_make_env_factory(args, i) for i in range(args.num_envs)]
|
||
# SubprocVecEnv runs each env in its own process — necessary because
|
||
# each env owns a Godot subprocess (we don't want one harness's
|
||
# JSON-Lines pipe to block sibling envs). DummyVecEnv is the
|
||
# single-process fallback for debugging.
|
||
env_cls = SubprocVecEnv if args.num_envs > 1 else DummyVecEnv
|
||
train_env = env_cls(factories)
|
||
|
||
eval_env = DummyVecEnv([_make_env_factory(args, 1000)])
|
||
eval_callback = MaskableEvalCallback(
|
||
eval_env,
|
||
best_model_save_path=str(model_dir),
|
||
log_path=str(run_dir / "eval"),
|
||
eval_freq=max(args.eval_freq // args.num_envs, 1),
|
||
n_eval_episodes=args.eval_episodes,
|
||
# Stochastic eval: a barely-trained net's argmax over the
|
||
# 322-dim action head has ~zero chance of being end_turn (idx 0),
|
||
# so deterministic eval episodes never advance past turn 0 and
|
||
# all 10 hit step_cap with reward 0. Sampling from the masked
|
||
# softmax keeps end_turn reachable until the policy has
|
||
# consolidated enough mass on a real strategy.
|
||
deterministic=False,
|
||
render=False,
|
||
)
|
||
|
||
# Resolve `--device` for logging clarity — sb3 accepts 'auto' but we
|
||
# want to print exactly which device the rollouts will land on so a
|
||
# multi-GPU box (apricot has 2× RTX 3090) can be confirmed at a glance.
|
||
import torch # type: ignore[import-not-found]
|
||
|
||
if args.device == "auto":
|
||
if torch.cuda.is_available():
|
||
resolved_device = "cuda"
|
||
elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
|
||
resolved_device = "mps"
|
||
else:
|
||
resolved_device = "cpu"
|
||
else:
|
||
resolved_device = args.device
|
||
print(
|
||
f"policy device: {resolved_device} "
|
||
f"(cuda_available={torch.cuda.is_available()}, "
|
||
f"cuda_devices={torch.cuda.device_count() if torch.cuda.is_available() else 0})"
|
||
)
|
||
|
||
model = MaskablePPO(
|
||
"MlpPolicy",
|
||
train_env,
|
||
verbose=1,
|
||
tensorboard_log=str(run_dir),
|
||
seed=args.seed,
|
||
device=resolved_device,
|
||
n_steps=512,
|
||
batch_size=128,
|
||
learning_rate=3e-4,
|
||
gamma=0.995,
|
||
gae_lambda=0.95,
|
||
ent_coef=0.01,
|
||
)
|
||
|
||
try:
|
||
model.learn(
|
||
total_timesteps=args.total_steps,
|
||
callback=eval_callback,
|
||
progress_bar=True,
|
||
)
|
||
finally:
|
||
train_env.close()
|
||
eval_env.close()
|
||
model.save(str(model_dir / "final.zip"))
|
||
print(f"training complete; model saved to {model_dir / 'final.zip'}")
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|