237 lines
9.2 KiB
Python
237 lines
9.2 KiB
Python
"""Behavioural-cloning pre-trainer — Stage 6.1.6 step B.
|
|
|
|
Consumes the `(obs, action, mask)` triples recorded by
|
|
`record_expert.py` and supervised-trains a MaskablePPO policy to imitate
|
|
the scripted controller. Cross-entropy on the masked action
|
|
distribution: minimise `-log pi(expert_action | obs, legal_mask)`.
|
|
|
|
The output is a full MaskablePPO checkpoint (`.zip`), architecturally
|
|
identical to what `train.py` builds — so `train.py --init-from
|
|
<checkpoint>` loads it and continues with PPO. The policy *starts* from
|
|
~scripted-strength play; PPO then refines rather than discovers.
|
|
|
|
Only the policy (action) head is supervised here. The value head is
|
|
left at init — PPO's first rollouts correct it within a few thousand
|
|
steps, which is cheap relative to the exploration problem BC solves.
|
|
|
|
Usage:
|
|
python -m tooling.rl_self_play.bc_pretrain \\
|
|
--data tooling/rl_self_play/expert/duel \\
|
|
--out tooling/rl_self_play/models/bc-duel/bc.zip \\
|
|
--epochs 8 --device cuda:1
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
THIS_DIR = Path(__file__).resolve().parent
|
|
PROJECT_ROOT = THIS_DIR.parents[1]
|
|
if __package__ is None:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from tooling.rl_self_play.encoders import ACTION_DIM, OBS_DIM # type: ignore[import-not-found]
|
|
from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found]
|
|
|
|
|
|
def _load_dataset(
|
|
data_dir: Path, winners_only: bool
|
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
"""Concatenate every `expert_*.npz` shard, searching recursively so a
|
|
parallel recording run (one worker subdir per process) merges into a
|
|
single dataset. Returns (obs, actions, masks)."""
|
|
shards = sorted(data_dir.rglob("expert_*.npz"))
|
|
if not shards:
|
|
raise FileNotFoundError(f"no expert_*.npz shards under {data_dir}")
|
|
obs_parts: list[np.ndarray] = []
|
|
act_parts: list[np.ndarray] = []
|
|
mask_parts: list[np.ndarray] = []
|
|
for shard in shards:
|
|
with np.load(shard) as z:
|
|
obs, actions, masks = z["obs"], z["actions"], z["masks"]
|
|
outcomes = z["outcomes"]
|
|
if winners_only:
|
|
keep = outcomes == 1
|
|
obs, actions, masks = obs[keep], actions[keep], masks[keep]
|
|
obs_parts.append(obs)
|
|
act_parts.append(actions)
|
|
mask_parts.append(masks)
|
|
obs = np.concatenate(obs_parts).astype(np.float32)
|
|
actions = np.concatenate(act_parts).astype(np.int64)
|
|
masks = np.concatenate(mask_parts).astype(bool)
|
|
if obs.shape[1] != OBS_DIM or masks.shape[1] != ACTION_DIM:
|
|
raise ValueError(
|
|
f"shard shape mismatch: obs {obs.shape}, masks {masks.shape} "
|
|
f"vs OBS_DIM={OBS_DIM}, ACTION_DIM={ACTION_DIM}"
|
|
)
|
|
return obs, actions, masks
|
|
|
|
|
|
def _build_argparser() -> argparse.ArgumentParser:
|
|
p = argparse.ArgumentParser(description="Behavioural-cloning pre-train for Magic Civilization")
|
|
p.add_argument("--data", required=True,
|
|
help="Directory holding expert_*.npz shards from record_expert.py.")
|
|
p.add_argument("--out", required=True,
|
|
help="Output path for the BC MaskablePPO checkpoint (.zip).")
|
|
p.add_argument("--epochs", type=int, default=8,
|
|
help="Supervised passes over the dataset (default: 8).")
|
|
p.add_argument("--batch-size", type=int, default=256,
|
|
help="Minibatch size for the CE loss (default: 256).")
|
|
p.add_argument("--lr", type=float, default=1e-3,
|
|
help="Adam learning rate for BC (default: 1e-3).")
|
|
p.add_argument("--val-frac", type=float, default=0.05,
|
|
help="Fraction of triples held out for validation (default: 0.05).")
|
|
p.add_argument("--winners-only", action="store_true",
|
|
help="Train only on trajectories whose acting slot won.")
|
|
p.add_argument("--seed", type=int, default=42,
|
|
help="RNG seed for the split + shuffles (default: 42).")
|
|
p.add_argument("--device", default="auto",
|
|
help="Torch device: auto / cpu / cuda / cuda:1 / mps.")
|
|
return p
|
|
|
|
|
|
def _resolve_device(requested: str) -> str:
|
|
import torch # type: ignore[import-not-found]
|
|
|
|
if requested != "auto":
|
|
return requested
|
|
if torch.cuda.is_available():
|
|
return "cuda"
|
|
mps = getattr(torch.backends, "mps", None)
|
|
if mps is not None and mps.is_available():
|
|
return "mps"
|
|
return "cpu"
|
|
|
|
|
|
def main() -> int:
|
|
args = _build_argparser().parse_args()
|
|
rng = np.random.default_rng(args.seed)
|
|
|
|
data_dir = Path(args.data)
|
|
obs, actions, masks = _load_dataset(data_dir, args.winners_only)
|
|
n = obs.shape[0]
|
|
if n == 0:
|
|
print("no triples to train on — aborting", file=sys.stderr)
|
|
return 1
|
|
|
|
# Sanity: every expert action must be legal under its own mask, else
|
|
# the recorder/encoder disagree and CE would chase impossible logits.
|
|
legal = masks[np.arange(n), actions]
|
|
if not legal.all():
|
|
bad = int((~legal).sum())
|
|
raise ValueError(
|
|
f"{bad}/{n} recorded actions are illegal under their own mask "
|
|
f"— record_expert.py / encoders.py are out of sync"
|
|
)
|
|
|
|
perm = rng.permutation(n)
|
|
n_val = max(1, int(n * args.val_frac))
|
|
val_idx, train_idx = perm[:n_val], perm[n_val:]
|
|
print(f"dataset: {n} triples ({len(train_idx)} train / {n_val} val)", flush=True)
|
|
|
|
import torch # type: ignore[import-not-found]
|
|
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
|
|
from stable_baselines3.common.vec_env import DummyVecEnv # type: ignore[import-not-found]
|
|
|
|
from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found]
|
|
|
|
device = _resolve_device(args.device)
|
|
print(f"BC device: {device}", flush=True)
|
|
|
|
# MagicCivEnv only spawns a Godot harness on reset(); constructing it
|
|
# (and the DummyVecEnv around it) just exposes observation/action
|
|
# spaces. We never reset or step it — BC is offline. This gives
|
|
# MaskablePPO a real, spec-correct env without any subprocess.
|
|
spec_cfg = HarnessConfig(players=2, player_slot=0)
|
|
space_env = DummyVecEnv([lambda: MagicCivEnv(harness_config=spec_cfg)])
|
|
|
|
# Hyperparameters mirror train.py exactly so the saved .zip is a
|
|
# drop-in checkpoint for `train.py --init-from`.
|
|
model = MaskablePPO(
|
|
"MlpPolicy",
|
|
space_env,
|
|
verbose=0,
|
|
seed=args.seed,
|
|
device=device,
|
|
n_steps=512,
|
|
batch_size=128,
|
|
learning_rate=3e-4,
|
|
gamma=0.995,
|
|
gae_lambda=0.95,
|
|
ent_coef=0.01,
|
|
)
|
|
policy = model.policy
|
|
policy.set_training_mode(True)
|
|
optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr)
|
|
|
|
obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device)
|
|
act_t = torch.as_tensor(actions, dtype=torch.int64, device=device)
|
|
mask_t = torch.as_tensor(masks, dtype=torch.bool, device=device)
|
|
|
|
def _epoch_pass(idx: np.ndarray, train: bool) -> tuple[float, float]:
|
|
"""One pass over `idx`. Returns (mean CE loss, top-1 accuracy)."""
|
|
if train:
|
|
order = rng.permutation(idx)
|
|
else:
|
|
order = idx
|
|
total_loss = 0.0
|
|
total_correct = 0
|
|
total = 0
|
|
for start in range(0, len(order), args.batch_size):
|
|
batch = order[start:start + args.batch_size]
|
|
b = torch.as_tensor(batch, dtype=torch.int64, device=device)
|
|
o, a, m = obs_t[b], act_t[b], mask_t[b]
|
|
dist = policy.get_distribution(o, action_masks=m.cpu().numpy())
|
|
log_prob = dist.log_prob(a)
|
|
loss = -log_prob.mean()
|
|
if train:
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
|
|
optimizer.step()
|
|
with torch.no_grad():
|
|
pred = dist.distribution.probs.argmax(dim=-1)
|
|
total_correct += int((pred == a).sum().item())
|
|
total_loss += float(loss.item()) * len(batch)
|
|
total += len(batch)
|
|
return total_loss / total, total_correct / total
|
|
|
|
t0 = time.time()
|
|
best_val = float("inf")
|
|
out_path = Path(args.out)
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
for epoch in range(args.epochs):
|
|
train_loss, train_acc = _epoch_pass(train_idx, train=True)
|
|
policy.set_training_mode(False)
|
|
with torch.no_grad():
|
|
val_loss, val_acc = _epoch_pass(val_idx, train=False)
|
|
policy.set_training_mode(True)
|
|
print(
|
|
f"epoch {epoch + 1}/{args.epochs} "
|
|
f"train_ce={train_loss:.4f} train_acc={train_acc:.3f} "
|
|
f"val_ce={val_loss:.4f} val_acc={val_acc:.3f}",
|
|
flush=True,
|
|
)
|
|
if val_loss < best_val:
|
|
best_val = val_loss
|
|
model.save(str(out_path))
|
|
|
|
space_env.close()
|
|
summary = {
|
|
"triples": n,
|
|
"epochs": args.epochs,
|
|
"best_val_ce": round(best_val, 4),
|
|
"checkpoint": str(out_path),
|
|
"elapsed_sec": round(time.time() - t0, 1),
|
|
}
|
|
print("BC_SUMMARY " + json.dumps(summary), flush=True)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|