magicciv/tooling/rl_self_play/train.py
Natalie b7891991a4 feat(@projects/@magic-civilization): add rl_self_play tooling for self-play training
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
2026-05-17 03:54:40 -07:00

162 lines
5.9 KiB
Python

"""Train a MaskablePPO policy against the harness's built-in AI.
Usage:
cd tooling/rl-self-play
pip install -r requirements.txt
python -m tooling.rl-self-play.train --total-steps 1_000_000
Run via TensorBoard for live curves:
tensorboard --logdir tooling/rl-self-play/runs/
The training loop:
1. K parallel `MagicCivEnv` instances are spawned (each owns a Godot
harness subprocess; rule of thumb: K = min(physical cores // 2, 8)).
2. MaskablePPO collects on-policy rollouts across all K envs, learns
for `total_timesteps`.
3. Every `eval_freq` steps we run a held-out eval against the same
baseline and record win-rate. When win-rate crosses
`--target-win-rate` (default 0.55) we save the model as
`models/winner.zip` and exit.
This script is intentionally minimal — no curriculum, no
self-play-against-frozen-snapshots, no league. Those are reasonable
extensions once the basic policy starts winning at all (which itself
will take hours on apricot).
"""
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
import numpy as np
THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parents[1]
# Resolve module path so the script works whether invoked as a module
# (`python -m tooling.rl-self-play.train`) or as a plain script
# (`python train.py`). Both paths matter — the former is the canonical
# way; the latter helps quick iteration without re-installing.
if __package__ is None:
sys.path.insert(0, str(PROJECT_ROOT))
from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found]
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Train MaskablePPO on Magic Civilization")
p.add_argument("--total-steps", type=int, default=1_000_000,
help="Total environment steps (default: 1M).")
p.add_argument("--num-envs", type=int, default=4,
help="Parallel envs; each spawns its own harness (default: 4).")
p.add_argument("--max-turns", type=int, default=200,
help="Per-episode turn limit before truncation (default: 200).")
p.add_argument("--map-size", default="duel",
help="MapGenerator size key (default: duel).")
p.add_argument("--players", type=int, default=2,
help="Total player slots in each game (default: 2).")
p.add_argument("--eval-freq", type=int, default=20_000,
help="Run eval every N steps (default: 20k).")
p.add_argument("--eval-episodes", type=int, default=20,
help="Episodes per eval (default: 20).")
p.add_argument("--target-win-rate", type=float, default=0.55,
help="Stop training once eval win-rate exceeds this (default: 0.55).")
p.add_argument("--run-name", default="duel-v1",
help="Subdirectory under runs/ + models/ (default: duel-v1).")
p.add_argument("--seed", type=int, default=42,
help="Base RNG seed; per-env seeds offset from this (default: 42).")
return p
def _make_env_factory(args: argparse.Namespace, env_idx: int):
"""Return a thunk that constructs one MagicCivEnv. sb3 expects these
as factories so each subprocess builds its own env after fork."""
from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found]
def _make() -> MagicCivEnv:
cfg = HarnessConfig(
seed=args.seed + env_idx,
players=args.players,
player_slot=0,
map_size=args.map_size,
map_type="continents",
)
return MagicCivEnv(harness_config=cfg, max_turns=args.max_turns)
return _make
def main() -> int:
args = _build_argparser().parse_args()
# Lazy imports — sb3 + torch are heavy and only needed once we
# commit to running. Lets `--help` stay fast.
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
from sb3_contrib.common.maskable.callbacks import ( # type: ignore[import-not-found]
MaskableEvalCallback,
)
from sb3_contrib.common.maskable.utils import ( # type: ignore[import-not-found]
get_action_masks,
)
from stable_baselines3.common.vec_env import ( # type: ignore[import-not-found]
DummyVecEnv,
SubprocVecEnv,
)
run_dir = THIS_DIR / "runs" / args.run_name
model_dir = THIS_DIR / "models" / args.run_name
run_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
factories = [_make_env_factory(args, i) for i in range(args.num_envs)]
# SubprocVecEnv runs each env in its own process — necessary because
# each env owns a Godot subprocess (we don't want one harness's
# JSON-Lines pipe to block sibling envs). DummyVecEnv is the
# single-process fallback for debugging.
env_cls = SubprocVecEnv if args.num_envs > 1 else DummyVecEnv
train_env = env_cls(factories)
eval_env = DummyVecEnv([_make_env_factory(args, 1000)])
eval_callback = MaskableEvalCallback(
eval_env,
best_model_save_path=str(model_dir),
log_path=str(run_dir / "eval"),
eval_freq=max(args.eval_freq // args.num_envs, 1),
n_eval_episodes=args.eval_episodes,
deterministic=True,
render=False,
)
model = MaskablePPO(
"MlpPolicy",
train_env,
verbose=1,
tensorboard_log=str(run_dir),
seed=args.seed,
n_steps=512,
batch_size=128,
learning_rate=3e-4,
gamma=0.995,
gae_lambda=0.95,
ent_coef=0.01,
)
try:
model.learn(
total_timesteps=args.total_steps,
callback=eval_callback,
progress_bar=True,
)
finally:
train_env.close()
eval_env.close()
model.save(str(model_dir / "final.zip"))
print(f"training complete; model saved to {model_dir / 'final.zip'}")
return 0
if __name__ == "__main__":
sys.exit(main())