magicciv/tooling/rl_self_play/train.py
2026-05-26 02:21:11 -07:00

195 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Train a MaskablePPO policy against the harness's built-in AI.
Usage:
cd tooling/rl-self-play
pip install -r requirements.txt
python -m tooling.rl-self-play.train --total-steps 1_000_000
Run via TensorBoard for live curves:
tensorboard --logdir tooling/rl-self-play/runs/
The training loop:
1. K parallel `MagicCivEnv` instances are spawned (each owns a Godot
harness subprocess; rule of thumb: K = min(physical cores // 2, 8)).
2. MaskablePPO collects on-policy rollouts across all K envs, learns
for `total_timesteps`.
3. Every `eval_freq` steps we run a held-out eval against the same
baseline and record win-rate. When win-rate crosses
`--target-win-rate` (default 0.55) we save the model as
`models/winner.zip` and exit.
This script is intentionally minimal — no curriculum, no
self-play-against-frozen-snapshots, no league. Those are reasonable
extensions once the basic policy starts winning at all (which itself
will take hours on apricot).
"""
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
import numpy as np
THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parents[1]
# Resolve module path so the script works whether invoked as a module
# (`python -m tooling.rl-self-play.train`) or as a plain script
# (`python train.py`). Both paths matter — the former is the canonical
# way; the latter helps quick iteration without re-installing.
if __package__ is None:
sys.path.insert(0, str(PROJECT_ROOT))
from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found]
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Train MaskablePPO on Magic Civilization")
p.add_argument("--total-steps", type=int, default=1_000_000,
help="Total environment steps (default: 1M).")
p.add_argument("--num-envs", type=int, default=4,
help="Parallel envs; each spawns its own harness (default: 4).")
p.add_argument("--max-turns", type=int, default=200,
help="Per-episode turn limit before truncation (default: 200).")
p.add_argument("--map-size", default="duel",
help="MapGenerator size key (default: duel).")
p.add_argument("--players", type=int, default=2,
help="Total player slots in each game (default: 2).")
p.add_argument("--eval-freq", type=int, default=20_000,
help="Run eval every N steps (default: 20k).")
p.add_argument("--eval-episodes", type=int, default=20,
help="Episodes per eval (default: 20).")
p.add_argument("--target-win-rate", type=float, default=0.55,
help="Stop training once eval win-rate exceeds this (default: 0.55).")
p.add_argument("--run-name", default="duel-v1",
help="Subdirectory under runs/ + models/ (default: duel-v1).")
p.add_argument("--seed", type=int, default=42,
help="Base RNG seed; per-env seeds offset from this (default: 42).")
p.add_argument("--device", default="auto",
help=("Torch device for the policy net: 'auto' (default — "
"picks cuda if available, else cpu), 'cuda', "
"'cuda:1' (second GPU), 'mps' (Apple Silicon), or "
"'cpu'. On apricot, prefer 'cuda:1' so cuda:0 stays "
"free for model-boss / MCTS rollouts."))
return p
def _make_env_factory(args: argparse.Namespace, env_idx: int):
"""Return a thunk that constructs one MagicCivEnv. sb3 expects these
as factories so each subprocess builds its own env after fork."""
from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found]
def _make() -> MagicCivEnv:
cfg = HarnessConfig(
seed=args.seed + env_idx,
players=args.players,
player_slot=0,
map_size=args.map_size,
map_type="continents",
)
return MagicCivEnv(harness_config=cfg, max_turns=args.max_turns)
return _make
def main() -> int:
args = _build_argparser().parse_args()
# Lazy imports — sb3 + torch are heavy and only needed once we
# commit to running. Lets `--help` stay fast.
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
from sb3_contrib.common.maskable.callbacks import ( # type: ignore[import-not-found]
MaskableEvalCallback,
)
from sb3_contrib.common.maskable.utils import ( # type: ignore[import-not-found]
get_action_masks,
)
from stable_baselines3.common.vec_env import ( # type: ignore[import-not-found]
DummyVecEnv,
SubprocVecEnv,
)
run_dir = THIS_DIR / "runs" / args.run_name
model_dir = THIS_DIR / "models" / args.run_name
run_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
factories = [_make_env_factory(args, i) for i in range(args.num_envs)]
# SubprocVecEnv runs each env in its own process — necessary because
# each env owns a Godot subprocess (we don't want one harness's
# JSON-Lines pipe to block sibling envs). DummyVecEnv is the
# single-process fallback for debugging.
env_cls = SubprocVecEnv if args.num_envs > 1 else DummyVecEnv
train_env = env_cls(factories)
eval_env = DummyVecEnv([_make_env_factory(args, 1000)])
eval_callback = MaskableEvalCallback(
eval_env,
best_model_save_path=str(model_dir),
log_path=str(run_dir / "eval"),
eval_freq=max(args.eval_freq // args.num_envs, 1),
n_eval_episodes=args.eval_episodes,
# Stochastic eval: a barely-trained net's argmax over the
# 322-dim action head has ~zero chance of being end_turn (idx 0),
# so deterministic eval episodes never advance past turn 0 and
# all 10 hit step_cap with reward 0. Sampling from the masked
# softmax keeps end_turn reachable until the policy has
# consolidated enough mass on a real strategy.
deterministic=False,
render=False,
)
# Resolve `--device` for logging clarity — sb3 accepts 'auto' but we
# want to print exactly which device the rollouts will land on so a
# multi-GPU box (apricot has 2× RTX 3090) can be confirmed at a glance.
import torch # type: ignore[import-not-found]
if args.device == "auto":
if torch.cuda.is_available():
resolved_device = "cuda"
elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
resolved_device = "mps"
else:
resolved_device = "cpu"
else:
resolved_device = args.device
print(
f"policy device: {resolved_device} "
f"(cuda_available={torch.cuda.is_available()}, "
f"cuda_devices={torch.cuda.device_count() if torch.cuda.is_available() else 0})"
)
model = MaskablePPO(
"MlpPolicy",
train_env,
verbose=1,
tensorboard_log=str(run_dir),
seed=args.seed,
device=resolved_device,
n_steps=512,
batch_size=128,
learning_rate=3e-4,
gamma=0.995,
gae_lambda=0.95,
ent_coef=0.01,
)
try:
model.learn(
total_timesteps=args.total_steps,
callback=eval_callback,
progress_bar=True,
)
finally:
train_env.close()
eval_env.close()
model.save(str(model_dir / "final.zip"))
print(f"training complete; model saved to {model_dir / 'final.zip'}")
return 0
if __name__ == "__main__":
sys.exit(main())