magicciv/tooling/rl_self_play/tests/test_bc_pretrain.py

96 lines
3.4 KiB
Python

"""Unit tests for `tooling.rl_self_play.bc_pretrain._load_dataset`.
The training loop itself needs torch + sb3 and a Godot harness; those
are exercised by `bc_pretrain` on the apricot box. What we can prove
headless is the dataset loader contract — it concatenates shards
recursively, filters by `winners_only`, and rejects shape drift
between recorder and encoder (the load-bearing safety check from
commit e103928d2 that catches recorder/encoder skew before BC starts
chasing impossible logits).
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
import pytest
from tooling.rl_self_play.bc_pretrain import _load_dataset
from tooling.rl_self_play.encoders import ACTION_DIM, OBS_DIM
def _write_shard(
path: Path,
*,
n: int,
outcome: int,
action: int = 0,
) -> None:
obs = np.zeros((n, OBS_DIM), dtype=np.float32)
actions = np.full(n, action, dtype=np.int64)
masks = np.zeros((n, ACTION_DIM), dtype=bool)
masks[:, action] = True
outcomes = np.full(n, outcome, dtype=np.int8)
path.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(
path,
obs=obs,
actions=actions,
masks=masks,
slots=np.zeros(n, dtype=np.int8),
games=np.zeros(n, dtype=np.int32),
outcomes=outcomes,
)
def test_load_dataset_concatenates_shards(tmp_path: Path) -> None:
_write_shard(tmp_path / "expert_0000.npz", n=3, outcome=1)
_write_shard(tmp_path / "expert_0001.npz", n=2, outcome=0)
obs, actions, masks = _load_dataset(tmp_path, winners_only=False)
assert obs.shape == (5, OBS_DIM)
assert actions.shape == (5,)
assert masks.shape == (5, ACTION_DIM)
assert obs.dtype == np.float32
assert actions.dtype == np.int64
assert masks.dtype == bool
def test_load_dataset_recursive_discovery(tmp_path: Path) -> None:
"""Parallel recording workers each write into their own subdir.
`_load_dataset` must rglob, not glob."""
_write_shard(tmp_path / "worker_a" / "expert_0000.npz", n=1, outcome=1)
_write_shard(tmp_path / "worker_b" / "expert_0000.npz", n=2, outcome=0)
obs, _, _ = _load_dataset(tmp_path, winners_only=False)
assert obs.shape == (3, OBS_DIM)
def test_load_dataset_winners_only_filters(tmp_path: Path) -> None:
_write_shard(tmp_path / "expert_0000.npz", n=4, outcome=1)
_write_shard(tmp_path / "expert_0001.npz", n=6, outcome=0)
_write_shard(tmp_path / "expert_0002.npz", n=2, outcome=-1) # undecided
obs, _, _ = _load_dataset(tmp_path, winners_only=True)
assert obs.shape == (4, OBS_DIM)
def test_load_dataset_missing_shards_raises(tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError):
_load_dataset(tmp_path, winners_only=False)
def test_load_dataset_shape_mismatch_raises(tmp_path: Path) -> None:
"""If the recorder and encoder drift apart (different OBS_DIM or
ACTION_DIM), `_load_dataset` must reject the dataset before BC
starts training on impossible logits."""
path = tmp_path / "expert_0000.npz"
n = 2
np.savez_compressed(
path,
obs=np.zeros((n, OBS_DIM + 1), dtype=np.float32),
actions=np.zeros(n, dtype=np.int64),
masks=np.zeros((n, ACTION_DIM), dtype=bool),
slots=np.zeros(n, dtype=np.int8),
games=np.zeros(n, dtype=np.int32),
outcomes=np.ones(n, dtype=np.int8),
)
with pytest.raises(ValueError, match="shape mismatch"):
_load_dataset(tmp_path, winners_only=False)