96 lines
3.4 KiB
Python
96 lines
3.4 KiB
Python
"""Unit tests for `tooling.rl_self_play.bc_pretrain._load_dataset`.
|
|
|
|
The training loop itself needs torch + sb3 and a Godot harness; those
|
|
are exercised by `bc_pretrain` on the apricot box. What we can prove
|
|
headless is the dataset loader contract — it concatenates shards
|
|
recursively, filters by `winners_only`, and rejects shape drift
|
|
between recorder and encoder (the load-bearing safety check from
|
|
commit e103928d2 that catches recorder/encoder skew before BC starts
|
|
chasing impossible logits).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from tooling.rl_self_play.bc_pretrain import _load_dataset
|
|
from tooling.rl_self_play.encoders import ACTION_DIM, OBS_DIM
|
|
|
|
|
|
def _write_shard(
|
|
path: Path,
|
|
*,
|
|
n: int,
|
|
outcome: int,
|
|
action: int = 0,
|
|
) -> None:
|
|
obs = np.zeros((n, OBS_DIM), dtype=np.float32)
|
|
actions = np.full(n, action, dtype=np.int64)
|
|
masks = np.zeros((n, ACTION_DIM), dtype=bool)
|
|
masks[:, action] = True
|
|
outcomes = np.full(n, outcome, dtype=np.int8)
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
np.savez_compressed(
|
|
path,
|
|
obs=obs,
|
|
actions=actions,
|
|
masks=masks,
|
|
slots=np.zeros(n, dtype=np.int8),
|
|
games=np.zeros(n, dtype=np.int32),
|
|
outcomes=outcomes,
|
|
)
|
|
|
|
|
|
def test_load_dataset_concatenates_shards(tmp_path: Path) -> None:
|
|
_write_shard(tmp_path / "expert_0000.npz", n=3, outcome=1)
|
|
_write_shard(tmp_path / "expert_0001.npz", n=2, outcome=0)
|
|
obs, actions, masks = _load_dataset(tmp_path, winners_only=False)
|
|
assert obs.shape == (5, OBS_DIM)
|
|
assert actions.shape == (5,)
|
|
assert masks.shape == (5, ACTION_DIM)
|
|
assert obs.dtype == np.float32
|
|
assert actions.dtype == np.int64
|
|
assert masks.dtype == bool
|
|
|
|
|
|
def test_load_dataset_recursive_discovery(tmp_path: Path) -> None:
|
|
"""Parallel recording workers each write into their own subdir.
|
|
`_load_dataset` must rglob, not glob."""
|
|
_write_shard(tmp_path / "worker_a" / "expert_0000.npz", n=1, outcome=1)
|
|
_write_shard(tmp_path / "worker_b" / "expert_0000.npz", n=2, outcome=0)
|
|
obs, _, _ = _load_dataset(tmp_path, winners_only=False)
|
|
assert obs.shape == (3, OBS_DIM)
|
|
|
|
|
|
def test_load_dataset_winners_only_filters(tmp_path: Path) -> None:
|
|
_write_shard(tmp_path / "expert_0000.npz", n=4, outcome=1)
|
|
_write_shard(tmp_path / "expert_0001.npz", n=6, outcome=0)
|
|
_write_shard(tmp_path / "expert_0002.npz", n=2, outcome=-1) # undecided
|
|
obs, _, _ = _load_dataset(tmp_path, winners_only=True)
|
|
assert obs.shape == (4, OBS_DIM)
|
|
|
|
|
|
def test_load_dataset_missing_shards_raises(tmp_path: Path) -> None:
|
|
with pytest.raises(FileNotFoundError):
|
|
_load_dataset(tmp_path, winners_only=False)
|
|
|
|
|
|
def test_load_dataset_shape_mismatch_raises(tmp_path: Path) -> None:
|
|
"""If the recorder and encoder drift apart (different OBS_DIM or
|
|
ACTION_DIM), `_load_dataset` must reject the dataset before BC
|
|
starts training on impossible logits."""
|
|
path = tmp_path / "expert_0000.npz"
|
|
n = 2
|
|
np.savez_compressed(
|
|
path,
|
|
obs=np.zeros((n, OBS_DIM + 1), dtype=np.float32),
|
|
actions=np.zeros(n, dtype=np.int64),
|
|
masks=np.zeros((n, ACTION_DIM), dtype=bool),
|
|
slots=np.zeros(n, dtype=np.int8),
|
|
games=np.zeros(n, dtype=np.int32),
|
|
outcomes=np.ones(n, dtype=np.int8),
|
|
)
|
|
with pytest.raises(ValueError, match="shape mismatch"):
|
|
_load_dataset(tmp_path, winners_only=False)
|