magicciv/tooling/rl_self_play/bc_pretrain.py

237 lines
9.2 KiB
Python

"""Behavioural-cloning pre-trainer — Stage 6.1.6 step B.
Consumes the `(obs, action, mask)` triples recorded by
`record_expert.py` and supervised-trains a MaskablePPO policy to imitate
the scripted controller. Cross-entropy on the masked action
distribution: minimise `-log pi(expert_action | obs, legal_mask)`.
The output is a full MaskablePPO checkpoint (`.zip`), architecturally
identical to what `train.py` builds — so `train.py --init-from
<checkpoint>` loads it and continues with PPO. The policy *starts* from
~scripted-strength play; PPO then refines rather than discovers.
Only the policy (action) head is supervised here. The value head is
left at init — PPO's first rollouts correct it within a few thousand
steps, which is cheap relative to the exploration problem BC solves.
Usage:
python -m tooling.rl_self_play.bc_pretrain \\
--data tooling/rl_self_play/expert/duel \\
--out tooling/rl_self_play/models/bc-duel/bc.zip \\
--epochs 8 --device cuda:1
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
import numpy as np
THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parents[1]
if __package__ is None:
sys.path.insert(0, str(PROJECT_ROOT))
from tooling.rl_self_play.encoders import ACTION_DIM, OBS_DIM # type: ignore[import-not-found]
from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found]
def _load_dataset(
data_dir: Path, winners_only: bool
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Concatenate every `expert_*.npz` shard, searching recursively so a
parallel recording run (one worker subdir per process) merges into a
single dataset. Returns (obs, actions, masks)."""
shards = sorted(data_dir.rglob("expert_*.npz"))
if not shards:
raise FileNotFoundError(f"no expert_*.npz shards under {data_dir}")
obs_parts: list[np.ndarray] = []
act_parts: list[np.ndarray] = []
mask_parts: list[np.ndarray] = []
for shard in shards:
with np.load(shard) as z:
obs, actions, masks = z["obs"], z["actions"], z["masks"]
outcomes = z["outcomes"]
if winners_only:
keep = outcomes == 1
obs, actions, masks = obs[keep], actions[keep], masks[keep]
obs_parts.append(obs)
act_parts.append(actions)
mask_parts.append(masks)
obs = np.concatenate(obs_parts).astype(np.float32)
actions = np.concatenate(act_parts).astype(np.int64)
masks = np.concatenate(mask_parts).astype(bool)
if obs.shape[1] != OBS_DIM or masks.shape[1] != ACTION_DIM:
raise ValueError(
f"shard shape mismatch: obs {obs.shape}, masks {masks.shape} "
f"vs OBS_DIM={OBS_DIM}, ACTION_DIM={ACTION_DIM}"
)
return obs, actions, masks
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Behavioural-cloning pre-train for Magic Civilization")
p.add_argument("--data", required=True,
help="Directory holding expert_*.npz shards from record_expert.py.")
p.add_argument("--out", required=True,
help="Output path for the BC MaskablePPO checkpoint (.zip).")
p.add_argument("--epochs", type=int, default=8,
help="Supervised passes over the dataset (default: 8).")
p.add_argument("--batch-size", type=int, default=256,
help="Minibatch size for the CE loss (default: 256).")
p.add_argument("--lr", type=float, default=1e-3,
help="Adam learning rate for BC (default: 1e-3).")
p.add_argument("--val-frac", type=float, default=0.05,
help="Fraction of triples held out for validation (default: 0.05).")
p.add_argument("--winners-only", action="store_true",
help="Train only on trajectories whose acting slot won.")
p.add_argument("--seed", type=int, default=42,
help="RNG seed for the split + shuffles (default: 42).")
p.add_argument("--device", default="auto",
help="Torch device: auto / cpu / cuda / cuda:1 / mps.")
return p
def _resolve_device(requested: str) -> str:
import torch # type: ignore[import-not-found]
if requested != "auto":
return requested
if torch.cuda.is_available():
return "cuda"
mps = getattr(torch.backends, "mps", None)
if mps is not None and mps.is_available():
return "mps"
return "cpu"
def main() -> int:
args = _build_argparser().parse_args()
rng = np.random.default_rng(args.seed)
data_dir = Path(args.data)
obs, actions, masks = _load_dataset(data_dir, args.winners_only)
n = obs.shape[0]
if n == 0:
print("no triples to train on — aborting", file=sys.stderr)
return 1
# Sanity: every expert action must be legal under its own mask, else
# the recorder/encoder disagree and CE would chase impossible logits.
legal = masks[np.arange(n), actions]
if not legal.all():
bad = int((~legal).sum())
raise ValueError(
f"{bad}/{n} recorded actions are illegal under their own mask "
f"— record_expert.py / encoders.py are out of sync"
)
perm = rng.permutation(n)
n_val = max(1, int(n * args.val_frac))
val_idx, train_idx = perm[:n_val], perm[n_val:]
print(f"dataset: {n} triples ({len(train_idx)} train / {n_val} val)", flush=True)
import torch # type: ignore[import-not-found]
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
from stable_baselines3.common.vec_env import DummyVecEnv # type: ignore[import-not-found]
from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found]
device = _resolve_device(args.device)
print(f"BC device: {device}", flush=True)
# MagicCivEnv only spawns a Godot harness on reset(); constructing it
# (and the DummyVecEnv around it) just exposes observation/action
# spaces. We never reset or step it — BC is offline. This gives
# MaskablePPO a real, spec-correct env without any subprocess.
spec_cfg = HarnessConfig(players=2, player_slot=0)
space_env = DummyVecEnv([lambda: MagicCivEnv(harness_config=spec_cfg)])
# Hyperparameters mirror train.py exactly so the saved .zip is a
# drop-in checkpoint for `train.py --init-from`.
model = MaskablePPO(
"MlpPolicy",
space_env,
verbose=0,
seed=args.seed,
device=device,
n_steps=512,
batch_size=128,
learning_rate=3e-4,
gamma=0.995,
gae_lambda=0.95,
ent_coef=0.01,
)
policy = model.policy
policy.set_training_mode(True)
optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr)
obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device)
act_t = torch.as_tensor(actions, dtype=torch.int64, device=device)
mask_t = torch.as_tensor(masks, dtype=torch.bool, device=device)
def _epoch_pass(idx: np.ndarray, train: bool) -> tuple[float, float]:
"""One pass over `idx`. Returns (mean CE loss, top-1 accuracy)."""
if train:
order = rng.permutation(idx)
else:
order = idx
total_loss = 0.0
total_correct = 0
total = 0
for start in range(0, len(order), args.batch_size):
batch = order[start:start + args.batch_size]
b = torch.as_tensor(batch, dtype=torch.int64, device=device)
o, a, m = obs_t[b], act_t[b], mask_t[b]
dist = policy.get_distribution(o, action_masks=m.cpu().numpy())
log_prob = dist.log_prob(a)
loss = -log_prob.mean()
if train:
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
optimizer.step()
with torch.no_grad():
pred = dist.distribution.probs.argmax(dim=-1)
total_correct += int((pred == a).sum().item())
total_loss += float(loss.item()) * len(batch)
total += len(batch)
return total_loss / total, total_correct / total
t0 = time.time()
best_val = float("inf")
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
for epoch in range(args.epochs):
train_loss, train_acc = _epoch_pass(train_idx, train=True)
policy.set_training_mode(False)
with torch.no_grad():
val_loss, val_acc = _epoch_pass(val_idx, train=False)
policy.set_training_mode(True)
print(
f"epoch {epoch + 1}/{args.epochs} "
f"train_ce={train_loss:.4f} train_acc={train_acc:.3f} "
f"val_ce={val_loss:.4f} val_acc={val_acc:.3f}",
flush=True,
)
if val_loss < best_val:
best_val = val_loss
model.save(str(out_path))
space_env.close()
summary = {
"triples": n,
"epochs": args.epochs,
"best_val_ce": round(best_val, 4),
"checkpoint": str(out_path),
"elapsed_sec": round(time.time() - t0, 1),
}
print("BC_SUMMARY " + json.dumps(summary), flush=True)
return 0
if __name__ == "__main__":
sys.exit(main())