feat(@projects/@magic-civilization): add ai backend probe and dispatch system

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
Natalie 2026-05-04 16:24:19 -04:00
parent cd9b92879a
commit 039c31a079
7 changed files with 356 additions and 306 deletions

View file

@ -0,0 +1,172 @@
//! Boot-probed AI backend selector — single dispatch entry point for batched
//! rollouts. Replaces the per-call silent GPU→CPU fallback that
//! `gpu::inner::batch_simulate` previously implemented.
//!
//! # Contract
//!
//! - Probe runs **once** at boot via [`AiBackend::probe`] (env-overridable).
//! - The chosen backend is **fixed for the session** — there is NO per-call
//! fallback. If GPU dispatch fails mid-game,
//! [`AiBackend::batch_simulate`] returns `Err`; callers do NOT silently
//! degrade to CPU.
//! - Algorithm parity is enforced by `tests/gpu_rollout_parity.rs`. The CPU
//! path goes through [`crate::gpu::cpu_reference::batch_simulate_cpu`],
//! which is a thin wrapper around the canonical `rollout::walk` — the same
//! semantics the WGSL shader mirrors.
//!
//! # Env override
//!
//! - `MC_AI_BACKEND=cpu` — force `AiBackend::Cpu` regardless of probe.
//! - `MC_AI_BACKEND=gpu` — require a working GPU; **panic** if probe fails.
//! - Unset / any other value — probe normally (Gpu if available, else Cpu).
use crate::abstract_state::{AbstractRolloutState, MAX_PLAYERS};
use crate::policy::PersonalityPriors;
#[cfg(feature = "gpu")]
use crate::gpu::inner::{GpuContext, GpuError};
/// Active AI backend for batched rollout dispatch. Probed once at boot.
///
/// The `Gpu` variant is feature-gated behind `gpu`; under
/// `cfg(not(feature = "gpu"))` the only variant is `Cpu` and probe always
/// returns `Cpu`. Mobile / minimal builds that omit wgpu therefore see no
/// dispatch-layer cfg branching — the gating is by variant existence, not
/// by `cfg` at every call site.
pub enum AiBackend {
/// GPU compute via wgpu. Holds the process-wide cached
/// [`GpuContext::shared`] — adapter probe + pipeline compile cost is
/// paid exactly once at boot.
#[cfg(feature = "gpu")]
Gpu(&'static GpuContext),
/// Canonical CPU rollout via
/// [`crate::gpu::cpu_reference::batch_simulate_cpu`]. Algorithm-equivalent
/// to the WGSL shader, byte-by-byte, on `AbstractRolloutState`.
Cpu,
}
/// Errors surfaced from [`AiBackend::batch_simulate`]. There is no fallback;
/// dispatch failures propagate up.
#[derive(Debug, thiserror::Error)]
pub enum BackendError {
/// The GPU pipeline failed to dispatch (queue submit, buffer map, device
/// lost, etc.). The session does NOT fall back to CPU — callers must
/// decide whether to retry, surface, or abort.
#[error("GPU dispatch failed: {0}")]
GpuDispatchFailed(String),
/// Catch-all for caller misuse (length mismatches, etc.) and any future
/// non-GPU error case.
#[error("backend error: {0}")]
Other(String),
}
impl AiBackend {
/// Probe at boot. Order: existing [`GpuContext::shared`] adapter probe
/// (Vulkan / Metal / DX12 / WebGPU) → `Cpu`. The `MC_AI_BACKEND` env var
/// overrides the probe — see crate-level docs.
///
/// Logs the chosen backend at info-level on stderr (same channel
/// `mc-turn` / `mc-compute` use). Mobile users running CPU see a line
/// like `[mc-ai backend] Cpu (no compute adapter)`; hosts with a working
/// adapter see `[mc-ai backend] Gpu (Vulkan)`.
#[must_use]
pub fn probe() -> Self {
let env = std::env::var("MC_AI_BACKEND").ok();
let env_norm = env.as_deref().map(str::to_ascii_lowercase);
match env_norm.as_deref() {
Some("cpu") => {
eprintln!("[mc-ai backend] Cpu (forced via MC_AI_BACKEND=cpu)");
AiBackend::Cpu
}
Some("gpu") => {
#[cfg(feature = "gpu")]
{
match GpuContext::shared() {
Some(ctx) => {
eprintln!(
"[mc-ai backend] Gpu ({}) (forced via MC_AI_BACKEND=gpu)",
ctx.backend
);
AiBackend::Gpu(ctx)
}
None => panic!(
"MC_AI_BACKEND=gpu requested but no compute adapter available"
),
}
}
#[cfg(not(feature = "gpu"))]
{
panic!(
"MC_AI_BACKEND=gpu requested but the `gpu` cargo feature is disabled"
);
}
}
_ => {
#[cfg(feature = "gpu")]
{
if let Some(ctx) = GpuContext::shared() {
eprintln!("[mc-ai backend] Gpu ({})", ctx.backend);
return AiBackend::Gpu(ctx);
}
}
eprintln!("[mc-ai backend] Cpu (no compute adapter)");
AiBackend::Cpu
}
}
}
/// Stable, lower-case name for diagnostics / log lines. `"gpu"` or
/// `"cpu"`.
#[must_use]
pub fn name(&self) -> &'static str {
match self {
#[cfg(feature = "gpu")]
AiBackend::Gpu(_) => "gpu",
AiBackend::Cpu => "cpu",
}
}
/// Dispatch a batched rollout through the active backend.
///
/// Returns one `f32` terminal score in `[0, 1]` per batch entry, in
/// input order. The algorithm is identical across backends — see
/// `tests/gpu_rollout_parity.rs` for the byte-equivalence contract.
///
/// Returns `Err(BackendError::Other)` on caller-side input length
/// mismatch and `Err(BackendError::GpuDispatchFailed)` on GPU runtime
/// failure. There is no silent fallback to CPU.
pub fn batch_simulate(
&self,
inputs: &[AbstractRolloutState],
priors_per_entry: &[[PersonalityPriors; MAX_PLAYERS]],
seed: u64,
horizon: u32,
) -> Result<Vec<f32>, BackendError> {
if inputs.len() != priors_per_entry.len() {
return Err(BackendError::Other(format!(
"inputs len {} != priors len {}",
inputs.len(),
priors_per_entry.len()
)));
}
if inputs.is_empty() {
return Ok(Vec::new());
}
match self {
#[cfg(feature = "gpu")]
AiBackend::Gpu(ctx) => ctx
.batch_simulate(inputs, priors_per_entry, seed, horizon)
.map_err(|e: GpuError| BackendError::GpuDispatchFailed(e.to_string())),
AiBackend::Cpu => {
let out = crate::gpu::cpu_reference::batch_simulate_cpu(
inputs,
priors_per_entry,
seed,
horizon,
);
Ok(out.into_iter().map(|(s, _)| s).collect())
}
}
}
}

View file

@ -19,10 +19,23 @@ use wgpu::util::DeviceExt;
use crate::abstract_state::{AbstractRolloutState, MAX_PLAYERS};
use crate::policy::PersonalityPriors;
use crate::rollout::{DEFAULT_ROLLOUT_HORIZON, DEFAULT_ROLLOUT_TEMPERATURE};
use crate::rollout::DEFAULT_ROLLOUT_TEMPERATURE;
use super::cpu_reference::batch_simulate_cpu;
use super::RolloutPath;
/// Runtime failure surfaced from [`GpuContext::batch_simulate`]. There is no
/// per-call CPU fallback inside this module — backend identity is fixed for
/// the session by [`crate::backend::AiBackend::probe`]. Callers translate
/// this into [`crate::backend::BackendError::GpuDispatchFailed`].
#[derive(Debug, thiserror::Error)]
pub enum GpuError {
/// Caller passed mismatched `inputs` / `priors_per_entry` lengths.
#[error("inputs len {inputs} != priors len {priors}")]
LengthMismatch { inputs: usize, priors: usize },
/// The wgpu pipeline dispatch failed (queue submit error, buffer map
/// failure, device lost). The cause string is whatever the wgpu /
/// `pollster` layer surfaced.
#[error("dispatch failed: {0}")]
DispatchFailed(String),
}
/// WGSL kernel source, compiled into the binary at build time.
const SHADER_SRC: &str = include_str!("rollout.wgsl");
@ -396,27 +409,33 @@ impl GpuContext {
/// Dispatch a full rollout batch through the GPU pipeline.
///
/// Returns `RolloutPath::Gpu`-tagged results on success. On any runtime
/// failure (dispatch error, map failure) falls back to the CPU reference
/// and returns `RolloutPath::Cpu`-tagged results — the caller gets a
/// valid answer either way.
#[must_use]
/// Returns scores on success. On dispatch failure returns
/// `Err(GpuError::DispatchFailed)` — there is **no** silent CPU
/// fallback. The boot-probed [`crate::backend::AiBackend`] decides
/// backend identity once; runtime failures propagate up so callers
/// know the dispatch did not produce GPU-quality answers.
pub fn batch_simulate(
&self,
inputs: &[AbstractRolloutState],
priors_per_entry: &[[PersonalityPriors; MAX_PLAYERS]],
seed: u64,
horizon: u32,
) -> Vec<(f32, RolloutPath)> {
) -> Result<Vec<f32>, GpuError> {
if inputs.len() != priors_per_entry.len() {
return Vec::new();
return Err(GpuError::LengthMismatch {
inputs: inputs.len(),
priors: priors_per_entry.len(),
});
}
if inputs.is_empty() {
return Ok(Vec::new());
}
match self.dispatch_batch(inputs, priors_per_entry, seed, horizon) {
Some(scores) => scores
.into_iter()
.map(|s| (s, RolloutPath::Gpu))
.collect(),
None => batch_simulate_cpu(inputs, priors_per_entry, seed, horizon),
Some(scores) => Ok(scores),
None => Err(GpuError::DispatchFailed(
"wgpu pipeline dispatch returned None (queue submit / buffer map / device lost)"
.to_owned(),
)),
}
}
}
@ -464,48 +483,6 @@ fn create_storage_rw(dev: &wgpu::Device, size_bytes: usize, label: &str) -> wgpu
})
}
/// Top-level GPU-or-CPU dispatch entry point.
///
/// Uses the process-wide cached [`GpuContext::shared`] — the adapter probe
/// runs exactly once per process, not per call. On hosts with a working GPU
/// adapter this dispatches to the shader; on headless hosts or hosts where
/// the driver is wedged (see `TRY_INIT_TIMEOUT_MS`) it falls through to the
/// CPU reference silently. Result types are identical; only the
/// [`RolloutPath`] tag differs.
///
/// For hot loops that dispatch many batches, consider holding a
/// `&GpuContext` directly via `GpuContext::shared()` to skip the `OnceLock`
/// atomic load per call.
#[must_use]
pub fn batch_simulate(
inputs: &[AbstractRolloutState],
priors_per_entry: &[[PersonalityPriors; MAX_PLAYERS]],
seed: u64,
horizon: u32,
) -> Vec<(f32, RolloutPath)> {
// Zero-length inputs never touch the GPU cache — fast path.
if inputs.is_empty() {
return Vec::new();
}
if inputs.len() != priors_per_entry.len() {
return Vec::new();
}
if let Some(ctx) = GpuContext::shared() {
return ctx.batch_simulate(inputs, priors_per_entry, seed, horizon);
}
batch_simulate_cpu(inputs, priors_per_entry, seed, horizon)
}
/// Convenience: `batch_simulate` with `DEFAULT_ROLLOUT_HORIZON`.
#[must_use]
pub fn batch_simulate_default_horizon(
inputs: &[AbstractRolloutState],
priors_per_entry: &[[PersonalityPriors; MAX_PLAYERS]],
seed: u64,
) -> Vec<(f32, RolloutPath)> {
batch_simulate(inputs, priors_per_entry, seed, DEFAULT_ROLLOUT_HORIZON)
}
#[cfg(test)]
mod tests {
use super::*;
@ -558,43 +535,42 @@ mod tests {
use std::time::Instant;
// ── Tests that do NOT touch the GPU adapter ──────────────────────────
// ── Adapter-gated guard tests ────────────────────────────────────────
//
// These rely solely on the pre-dispatch guards in `batch_simulate` (empty
// input / mismatched lens) OR on `GpuContext::shared()` returning None
// after the one-time probe. Either way no single test is responsible for
// the probe cost — the first test that *does* need GPU state pays it
// once and caches.
// The boot-probed `AiBackend` no longer has a per-call dispatch shim
// here, so empty / mismatched-length tests hit `GpuContext::batch_simulate`
// directly. Both must short-circuit before any wgpu work; without an
// adapter we skip via `shared()` returning None.
#[test]
fn batch_simulate_empty_bypasses_gpu_probe() {
// Empty input returns Vec::new() before GpuContext::shared() is ever
// consulted. Must complete in microseconds even on a wedged-adapter
// host; assert a 100ms upper bound with generous slack for CI jitter.
fn ctx_batch_simulate_empty_returns_ok_empty() {
let Some(ctx) = GpuContext::shared() else {
eprintln!("[rollout-gpu] no adapter — skipping ctx_batch_simulate_empty_returns_ok_empty");
return;
};
let start = Instant::now();
let out = batch_simulate(&[], &[], 42, 20);
let out = ctx.batch_simulate(&[], &[], 42, 20).expect("empty must succeed");
let elapsed = start.elapsed();
assert!(out.is_empty());
assert!(
elapsed < Duration::from_millis(100),
"empty input must bypass GPU probe; took {:?}",
"empty input must short-circuit; took {:?}",
elapsed
);
}
#[test]
fn batch_simulate_mismatched_lengths_bypasses_gpu_probe() {
fn ctx_batch_simulate_mismatched_lengths_errors() {
let Some(ctx) = GpuContext::shared() else {
eprintln!("[rollout-gpu] no adapter — skipping ctx_batch_simulate_mismatched_lengths_errors");
return;
};
let pods = vec![make_entry()];
let priors: Vec<[PersonalityPriors; MAX_PLAYERS]> = vec![iron_vs_bh(), iron_vs_bh()];
let start = Instant::now();
let out = batch_simulate(&pods, &priors, 1, 20);
let elapsed = start.elapsed();
assert!(out.is_empty());
assert!(
elapsed < Duration::from_millis(100),
"length-mismatch must bypass GPU probe; took {:?}",
elapsed
);
let err = ctx
.batch_simulate(&pods, &priors, 1, 20)
.expect_err("mismatched lens must Err");
assert!(matches!(err, GpuError::LengthMismatch { .. }));
}
// ── Timeout contract ─────────────────────────────────────────────────
@ -712,86 +688,36 @@ mod tests {
);
}
// ── Scored-path tests ────────────────────────────────────────────────
// ── Adapter-gated dispatch tests ─────────────────────────────────────
//
// These exercise the full dispatch pipeline. The first one to run pays
// the probe cost (bounded by timeout). On a wedged-driver host all of
// these silently route through CPU and still produce valid results.
// Only run when a working adapter is actually present. On hosts without
// a compute adapter `shared()` returns None and these skip — no hang,
// no panic. The boot-probed `AiBackend` covers the no-adapter path; the
// tests in `tests/backend_probe.rs` and `tests/gpu_rollout_parity.rs`
// exercise the full dispatch surface.
#[test]
fn batch_simulate_produces_unit_interval_scores() {
let pods = vec![make_entry(); 4];
let priors = vec![iron_vs_bh(); 4];
let out = batch_simulate(&pods, &priors, 7, 20);
assert_eq!(out.len(), 4);
for (score, path) in &out {
assert!((0.0..=1.0).contains(score), "score {score} out of [0,1] on {path}");
}
}
#[test]
fn fallback_returns_valid_path_tag() {
let pods = vec![make_entry()];
let priors = vec![iron_vs_bh()];
let out = batch_simulate(&pods, &priors, 100, 20);
assert_eq!(out.len(), 1);
assert!(matches!(out[0].1, RolloutPath::Cpu | RolloutPath::Gpu));
}
#[test]
fn batch_simulate_is_deterministic_across_repeated_calls() {
let pods = vec![make_entry(); 3];
let priors = vec![iron_vs_bh(); 3];
let a = batch_simulate(&pods, &priors, 77, 20);
let b = batch_simulate(&pods, &priors, 77, 20);
assert_eq!(a.len(), b.len());
for (ra, rb) in a.iter().zip(b.iter()) {
assert_eq!(
ra.0.to_bits(),
rb.0.to_bits(),
"same seed must produce bit-identical results on {}",
ra.1
);
assert_eq!(ra.1, rb.1, "path tag must be stable across calls");
}
}
#[test]
fn default_horizon_helper_matches_explicit_call() {
let pods = vec![make_entry()];
let priors = vec![iron_vs_bh()];
let a = batch_simulate(&pods, &priors, 55, DEFAULT_ROLLOUT_HORIZON);
let b = batch_simulate_default_horizon(&pods, &priors, 55);
assert_eq!(a[0].0.to_bits(), b[0].0.to_bits());
assert_eq!(a[0].1, b[0].1);
}
// ── Adapter-gated tests ──────────────────────────────────────────────
//
// Only run when a working adapter is actually present. On wedged
// drivers `shared()` returns None and these skip — no hang, no panic.
#[test]
fn gpu_path_tags_when_adapter_available() {
fn gpu_returns_unit_interval_scores_when_adapter_available() {
let Some(ctx) = GpuContext::shared() else {
eprintln!("[rollout-gpu] no adapter — skipping gpu_path_tags_when_adapter_available");
eprintln!("[rollout-gpu] no adapter — skipping gpu_returns_unit_interval_scores_when_adapter_available");
return;
};
let pods = vec![make_entry()];
let priors = vec![iron_vs_bh()];
let out = ctx.batch_simulate(&pods, &priors, 123, 20);
assert_eq!(out.len(), 1);
assert_eq!(
out[0].1,
RolloutPath::Gpu,
"with adapter present, batch_simulate must tag Gpu (backend: {})",
ctx.backend
);
assert!((0.0..=1.0).contains(&out[0].0));
let pods = vec![make_entry(); 4];
let priors = vec![iron_vs_bh(); 4];
let out = ctx
.batch_simulate(&pods, &priors, 7, 20)
.expect("dispatch should succeed on a working adapter");
assert_eq!(out.len(), 4);
for score in &out {
assert!(
(0.0..=1.0).contains(score),
"score {score} out of [0,1] (backend: {})",
ctx.backend
);
}
}
/// Dispatch determinism when an adapter is present. Skips on headless or
/// wedged-driver hosts via the `shared()` None return.
/// Dispatch determinism when an adapter is present.
#[test]
fn gpu_dispatch_is_deterministic_when_adapter_available() {
let Some(ctx) = GpuContext::shared() else {
@ -800,12 +726,16 @@ mod tests {
};
let pods = vec![make_entry(); 4];
let priors = vec![iron_vs_bh(); 4];
let a = ctx.batch_simulate(&pods, &priors, 42, 20);
let b = ctx.batch_simulate(&pods, &priors, 42, 20);
let a = ctx
.batch_simulate(&pods, &priors, 42, 20)
.expect("first dispatch");
let b = ctx
.batch_simulate(&pods, &priors, 42, 20)
.expect("second dispatch");
for (ra, rb) in a.iter().zip(b.iter()) {
assert_eq!(
ra.0.to_bits(),
rb.0.to_bits(),
ra.to_bits(),
rb.to_bits(),
"GPU dispatch must be bit-deterministic on repeat (backend: {})",
ctx.backend
);

View file

@ -20,40 +20,7 @@ pub mod inner;
pub use cpu_reference::{batch_simulate_cpu, batch_simulate_cpu_default_horizon};
#[cfg(feature = "gpu")]
pub use inner::{batch_simulate, batch_simulate_default_horizon, GpuContext};
/// CPU-only fallback for `batch_simulate` when the `gpu` feature is disabled.
///
/// Present so callers can target a stable `batch_simulate` surface without
/// cfg-gating at every dispatch site. With the feature on, this symbol is
/// shadowed by [`inner::batch_simulate`] which attempts GPU first.
#[cfg(not(feature = "gpu"))]
#[must_use]
pub fn batch_simulate(
inputs: &[crate::abstract_state::AbstractRolloutState],
priors_per_entry: &[[crate::policy::PersonalityPriors; crate::abstract_state::MAX_PLAYERS]],
seed: u64,
horizon: u32,
) -> Vec<(f32, RolloutPath)> {
batch_simulate_cpu(inputs, priors_per_entry, seed, horizon)
}
/// CPU-only convenience fallback matching [`inner::batch_simulate_default_horizon`]
/// when the `gpu` feature is disabled.
#[cfg(not(feature = "gpu"))]
#[must_use]
pub fn batch_simulate_default_horizon(
inputs: &[crate::abstract_state::AbstractRolloutState],
priors_per_entry: &[[crate::policy::PersonalityPriors; crate::abstract_state::MAX_PLAYERS]],
seed: u64,
) -> Vec<(f32, RolloutPath)> {
batch_simulate_cpu(
inputs,
priors_per_entry,
seed,
crate::rollout::DEFAULT_ROLLOUT_HORIZON,
)
}
pub use inner::{GpuContext, GpuError};
/// Which execution path produced a rollout result.
///

View file

@ -6,6 +6,7 @@
//! leaf-value evaluator used by the tournament-mode strategy search.
pub mod abstract_state;
pub mod backend;
pub mod diplomacy;
pub mod evaluator;
pub mod game_state;
@ -17,15 +18,13 @@ pub mod rollout;
pub mod tactical;
pub use abstract_state::{AbstractPlayerState, AbstractRolloutState, MAX_PLAYERS};
pub use backend::{AiBackend, BackendError};
pub use diplomacy::{
evaluate_open_borders_accept, evaluate_open_borders_offer, evaluate_shared_map_accept,
evaluate_shared_map_offer, DiploDecision, DiplomacyCtx,
};
pub use evaluator::{LoadError, PersonalityDef, ScoringWeights};
pub use gpu::{
batch_simulate, batch_simulate_cpu, batch_simulate_cpu_default_horizon,
batch_simulate_default_horizon, RolloutPath,
};
pub use gpu::{batch_simulate_cpu, batch_simulate_cpu_default_horizon, RolloutPath};
pub use policy::{
decide_ransom_response, ransom_accept_probability, score_capture_postures, ActionKind,
CombatBalance, PersonalityPriors, RansomDecision,

View file

@ -9,9 +9,6 @@
use crate::mcts::XorShift64;
use rayon::prelude::*;
#[cfg(feature = "gpu")]
use crate::gpu::GpuContext;
/// State + action interface the tree MCTS operates over.
pub trait TreeState: Clone {
type Action: Clone;
@ -118,20 +115,11 @@ pub struct Tree<S: TreeState> {
/// Index of the player MCTS is deciding for. Rewards in `simulate()`
/// are evaluated from this player's perspective.
pub root_player: u8,
/// Optional process-wide GPU context. When `Some`, state types that
/// project to [`crate::abstract_state::AbstractRolloutState`] (currently
/// [`crate::rollout::GameRolloutState`]) can dispatch batched rollouts
/// through `batch_simulate_gpu`. When `None` the tree runs the serial
/// CPU path in `iterate` / `simulate`.
///
/// Feature-gated behind `gpu` to keep the non-wgpu build paths lean.
#[cfg(feature = "gpu")]
pub gpu_context: Option<&'static GpuContext>,
/// Count of GPU batch dispatches performed by this tree. Observable from
/// tests to confirm the GPU path actually ran instead of falling through
/// to CPU. Incremented once per successful `iterate_gpu_batched` call
/// that produced a non-empty batch — the CPU fallback inside that method
/// does NOT bump this counter.
/// Count of successful batched GPU dispatches performed by this tree.
/// Observable from tests to confirm the GPU path actually ran. Bumped
/// once per non-empty `iterate_gpu_batched` call where the active
/// [`crate::backend::AiBackend`] is `Gpu` AND dispatch returned `Ok`.
/// CPU-backend dispatches and dispatch errors do NOT bump this.
#[cfg(feature = "gpu")]
pub gpu_batch_count: u32,
}
@ -147,23 +135,10 @@ impl<S: TreeState> Tree<S> {
rollout_temperature: crate::rollout::DEFAULT_ROLLOUT_TEMPERATURE,
root_player: 0,
#[cfg(feature = "gpu")]
gpu_context: None,
#[cfg(feature = "gpu")]
gpu_batch_count: 0,
}
}
/// Install a GPU context so batched rollouts dispatch through
/// `mc_ai::gpu::batch_simulate`. Passing `None` restores the serial
/// CPU path.
///
/// Typical call: `tree.with_gpu_context(GpuContext::shared())`.
#[cfg(feature = "gpu")]
pub fn with_gpu_context(mut self, ctx: Option<&'static GpuContext>) -> Self {
self.gpu_context = ctx;
self
}
pub fn root(&self) -> &Node<S> {
&self.nodes[0]
}
@ -357,28 +332,31 @@ impl<S: TreeState> Tree<S> {
}
}
// ── GPU batched iteration for GameRolloutState ──────────────────────────────
// ── Batched iteration for GameRolloutState ──────────────────────────────────
/// Batched GPU rollout dispatch for trees whose state is
/// Batched rollout dispatch for trees whose state is
/// [`crate::rollout::GameRolloutState`]. Kept as a separate impl block because
/// [`crate::gpu::batch_simulate`] operates on `AbstractRolloutState` — the
/// projection is well-defined only for `GameRolloutState`, not arbitrary
/// `S: TreeState`.
/// [`crate::backend::AiBackend::batch_simulate`] operates on
/// `AbstractRolloutState` — the projection is well-defined only for
/// `GameRolloutState`, not arbitrary `S: TreeState`.
#[cfg(feature = "gpu")]
impl Tree<crate::rollout::GameRolloutState> {
/// Run one batched MCTS iteration: select + expand `batch_size` leaves,
/// dispatch their rollouts through [`crate::gpu::batch_simulate`] (which
/// routes to GPU when `gpu_context` is `Some`, else CPU), then
/// backpropagate the rewards in canonical (batch-index) order so visit
/// totals are seed-deterministic.
/// dispatch their rollouts through `backend.batch_simulate` (which is
/// either the GPU shader or the canonical CPU rollout depending on the
/// boot-probed [`crate::backend::AiBackend`]), then backpropagate the
/// rewards in canonical (batch-index) order so visit totals are
/// seed-deterministic.
///
/// Returns the number of leaves actually rolled out. Returns `0` when
/// `batch_size == 0` or the root is terminal with no expandable children.
/// `batch_size == 0`, the root is terminal with no expandable children,
/// OR the backend returns `Err` (per Phase-1 contract there is **no**
/// silent CPU fallback — a runtime GPU dispatch failure surfaces and the
/// caller decides what to do).
///
/// The `gpu_batch_count` counter bumps once per non-empty dispatch so
/// tests can assert the GPU path was exercised. When `gpu_context` is
/// `None` the dispatch silently uses the CPU reference — results are
/// valid but the counter stays put.
/// The `gpu_batch_count` counter bumps once per non-empty dispatch where
/// the backend is `Gpu` AND dispatch returned `Ok`. CPU-backend
/// dispatches and `Err` returns do NOT bump it.
///
/// `budget_ms` caps wall-clock time: the batch-collection loop exits early
/// once `Instant::now() - start >= budget_ms`. Already-collected leaves are
@ -389,14 +367,12 @@ impl Tree<crate::rollout::GameRolloutState> {
batch_size: usize,
base_seed: u64,
budget_ms: Option<u64>,
backend: &crate::backend::AiBackend,
) -> usize {
if batch_size == 0 {
return 0;
}
// Collect up to `batch_size` distinct (target_idx, state) pairs.
// Each iteration runs one select+expand walk off the root; duplicates
// are allowed when a leaf node accepts multiple rollouts through the
// same target child.
let mut targets: Vec<usize> = Vec::with_capacity(batch_size);
let mut states: Vec<crate::abstract_state::AbstractRolloutState> =
Vec::with_capacity(batch_size);
@ -422,26 +398,29 @@ impl Tree<crate::rollout::GameRolloutState> {
return 0;
}
// Prefer the explicit context stored on the tree; fall back to the
// process-wide shared one via the top-level dispatch so callers
// without a `with_gpu_context` setup still exercise GPU when present.
let results = if let Some(ctx) = self.gpu_context {
ctx.batch_simulate(&states, &priors, base_seed, self.rollout_horizon)
} else {
crate::gpu::batch_simulate(&states, &priors, base_seed, self.rollout_horizon)
let results = match backend.batch_simulate(
&states,
&priors,
base_seed,
self.rollout_horizon,
) {
Ok(scores) => scores,
Err(_e) => {
// No silent fallback. Surface the failure as zero
// rollouts dispatched for this batch; caller can inspect
// `gpu_batch_count` and decide whether to retry / abort.
return 0;
}
};
// Count a GPU dispatch only when at least one result carries the
// `Gpu` tag. Falling through to CPU silently is allowed and must not
// be counted.
if results.iter().any(|(_, path)| *path == crate::gpu::RolloutPath::Gpu) {
if matches!(backend, crate::backend::AiBackend::Gpu(_)) {
self.gpu_batch_count = self.gpu_batch_count.saturating_add(1);
}
// Backpropagate in canonical (batch-index) order so repeated runs
// with the same seed produce bit-identical visit counts even when
// the GPU reorders work internally.
for (target, (reward, _)) in targets.iter().zip(results.iter()) {
for (target, reward) in targets.iter().zip(results.iter()) {
self.backpropagate(*target, *reward);
}
states.len()

View file

@ -32,7 +32,7 @@
use std::collections::HashMap;
use mc_ai::abstract_state::{AbstractRolloutState, MAX_PLAYERS};
use mc_ai::gpu::{batch_simulate_cpu, GpuContext, RolloutPath};
use mc_ai::gpu::{batch_simulate_cpu, GpuContext};
use mc_ai::mcts::XorShift64;
use mc_ai::policy::PersonalityPriors;
use mc_ai::rollout::DEFAULT_ROLLOUT_HORIZON;
@ -215,15 +215,13 @@ fn gpu_rollout_parity_small_batch() {
let (states, priors) = fixture_batch(N, SEED);
let gpu_out = ctx.batch_simulate(&states, &priors, SEED, HORIZON);
let gpu_out = ctx
.batch_simulate(&states, &priors, SEED, HORIZON)
.expect("dispatch should succeed on a working adapter");
let cpu_out = batch_simulate_cpu(&states, &priors, SEED, HORIZON);
assert_eq!(gpu_out.len(), N);
assert_eq!(cpu_out.len(), N);
for (i, (g, c)) in gpu_out.iter().zip(cpu_out.iter()).enumerate() {
assert_eq!(g.1, RolloutPath::Gpu, "entry {i}: GPU path must tag Gpu");
assert_eq!(c.1, RolloutPath::Cpu, "entry {i}: CPU path must tag Cpu");
}
report_agreement(&gpu_out, &cpu_out, "small_batch", ctx.backend.as_str());
}
@ -241,7 +239,9 @@ fn gpu_rollout_parity_multi_workgroup() {
const SEED: u64 = 0xDEAD_C0DE_1234_5678_u64;
let (states, priors) = fixture_batch(N, SEED);
let gpu_out = ctx.batch_simulate(&states, &priors, SEED, DEFAULT_ROLLOUT_HORIZON);
let gpu_out = ctx
.batch_simulate(&states, &priors, SEED, DEFAULT_ROLLOUT_HORIZON)
.expect("dispatch should succeed on a working adapter");
let cpu_out = batch_simulate_cpu(&states, &priors, SEED, DEFAULT_ROLLOUT_HORIZON);
assert_eq!(gpu_out.len(), N);

View file

@ -1,24 +1,23 @@
//! p0-20 integration — `batch_simulate_gpu` wired into
//! p0-20 integration — `AiBackend::batch_simulate` wired into
//! `mcts_tree::Tree<GameRolloutState>::iterate_gpu_batched`.
//!
//! Asserts that constructing a `Tree<GameRolloutState>` with a
//! `GpuContext::shared()` and calling `iterate_gpu_batched` over 100
//! rollouts actually exercises the GPU path (observable via
//! `Tree::gpu_batch_count`) and backpropagates valid rewards into the root.
//! Asserts that constructing a `Tree<GameRolloutState>` with an `AiBackend`
//! and calling `iterate_gpu_batched` over 100 rollouts actually exercises
//! the GPU path (observable via `Tree::gpu_batch_count`) when the backend
//! probes Gpu, and falls through to the CPU rollout (no counter bump) when
//! it probes Cpu.
//!
//! # Skip behavior
//!
//! On headless hosts / hosts without a working compute adapter,
//! `GpuContext::shared()` returns `None`. In that case the test falls back
//! to the CPU reference path (which is itself a thin wrapper around the
//! canonical rollout walker) and asserts only the CPU-observable invariants
//! (visit counts, reward in [0, 1]). No hang, no panic — matches the
//! skip-path used by `tests/gpu_rollout_parity.rs`.
//! `AiBackend::probe()` returns `Cpu`. The Gpu-path test in this file checks
//! `matches!(backend, AiBackend::Gpu(_))` and skips otherwise — no hang, no
//! panic. Matches the skip-path used by `tests/gpu_rollout_parity.rs`.
#![cfg(feature = "gpu")]
use mc_ai::abstract_state::{AbstractRolloutState, MAX_PLAYERS};
use mc_ai::gpu::GpuContext;
use mc_ai::backend::AiBackend;
use mc_ai::mcts_tree::Tree;
use mc_ai::policy::PersonalityPriors;
use mc_ai::rollout::GameRolloutState;
@ -83,21 +82,24 @@ fn make_root_state() -> GameRolloutState {
#[test]
fn iterate_gpu_batched_exercises_gpu_path_when_adapter_available() {
let Some(ctx) = GpuContext::shared() else {
// Force-probe via env so we always end up on Gpu when an adapter exists,
// skipping otherwise.
let backend = AiBackend::probe();
if !matches!(backend, AiBackend::Gpu(_)) {
eprintln!(
"[skip] no GPU adapter — iterate_gpu_batched_exercises_gpu_path_when_adapter_available \
is a no-op on this host"
);
return;
};
}
let root = make_root_state();
let mut tree = Tree::new(root).with_gpu_context(Some(ctx));
let mut tree = Tree::new(root);
let mut rolled_out = 0_usize;
let mut batch_idx: u64 = 0;
while rolled_out < TOTAL_ROLLOUTS {
let n = tree.iterate_gpu_batched(BATCH_SIZE, 1000 + batch_idx, None);
let n = tree.iterate_gpu_batched(BATCH_SIZE, 1000 + batch_idx, None, &backend);
if n == 0 {
break;
}
@ -105,17 +107,12 @@ fn iterate_gpu_batched_exercises_gpu_path_when_adapter_available() {
batch_idx += 1;
}
// Every dispatch tagged Gpu bumps the counter — with a real adapter all
// dispatches should hit GPU, so the counter must be at least 1.
assert!(
tree.gpu_batch_count >= 1,
"expected ≥1 GPU dispatch with adapter present, got {} (backend: {})",
"expected ≥1 GPU dispatch with adapter present, got {}",
tree.gpu_batch_count,
ctx.backend
);
// Root visits accumulate one-per-rollout. With batch size {BATCH_SIZE}
// and {TOTAL_ROLLOUTS} target rollouts, root.visits >= TOTAL_ROLLOUTS.
assert!(
tree.root().visits as usize >= TOTAL_ROLLOUTS,
"expected ≥{} root visits, got {}",
@ -123,7 +120,6 @@ fn iterate_gpu_batched_exercises_gpu_path_when_adapter_available() {
tree.root().visits
);
// Root wins must be in [0, visits] because every reward is in [0, 1].
assert!(
tree.root().wins >= 0.0 && tree.root().wins <= tree.root().visits as f32,
"wins {} out of [0, {}]",
@ -133,23 +129,26 @@ fn iterate_gpu_batched_exercises_gpu_path_when_adapter_available() {
}
#[test]
fn iterate_gpu_batched_cpu_fallback_without_context() {
// No context installed → falls through to the top-level dispatch, which
// itself consults `GpuContext::shared()`. On a host with no adapter this
// lands on the CPU reference. Either way rewards must be valid and root
// visits must accumulate; the GPU counter is allowed to stay at 0 when
// the path resolves to CPU.
fn iterate_gpu_batched_cpu_backend_does_not_bump_counter() {
// Force-CPU backend regardless of host adapter — every dispatch returns
// valid CPU-rollout rewards, but `gpu_batch_count` must stay at 0.
let prev = std::env::var("MC_AI_BACKEND").ok();
std::env::set_var("MC_AI_BACKEND", "cpu");
let backend = AiBackend::probe();
if let Some(p) = prev {
std::env::set_var("MC_AI_BACKEND", p);
} else {
std::env::remove_var("MC_AI_BACKEND");
}
assert!(matches!(backend, AiBackend::Cpu));
let root = make_root_state();
let mut tree = Tree::new(root);
assert!(
tree.gpu_context.is_none(),
"constructor default for gpu_context must be None"
);
let mut rolled_out = 0_usize;
let mut batch_idx: u64 = 0;
while rolled_out < TOTAL_ROLLOUTS {
let n = tree.iterate_gpu_batched(BATCH_SIZE, 2000 + batch_idx, None);
let n = tree.iterate_gpu_batched(BATCH_SIZE, 2000 + batch_idx, None, &backend);
if n == 0 {
break;
}
@ -157,15 +156,19 @@ fn iterate_gpu_batched_cpu_fallback_without_context() {
batch_idx += 1;
}
assert_eq!(
tree.gpu_batch_count, 0,
"Cpu backend must never bump gpu_batch_count"
);
assert!(
tree.root().visits as usize >= TOTAL_ROLLOUTS,
"CPU fallback must still accumulate ≥{} visits, got {}",
"CPU backend must still accumulate ≥{} visits, got {}",
TOTAL_ROLLOUTS,
tree.root().visits
);
assert!(
tree.root().wins >= 0.0 && tree.root().wins <= tree.root().visits as f32,
"wins {} out of [0, {}] on CPU fallback",
"wins {} out of [0, {}] on CPU backend",
tree.root().wins,
tree.root().visits
);
@ -173,20 +176,20 @@ fn iterate_gpu_batched_cpu_fallback_without_context() {
#[test]
fn iterate_gpu_batched_is_seed_deterministic() {
// Same seed + same root state + same context installation policy → same
// visit/wins totals across repeated runs. Backprop order is
// batch-index-ordered inside `iterate_gpu_batched`, so parallelism in
// the GPU dispatch cannot leak into the tree.
let ctx = GpuContext::shared();
// Same seed + same root state + same backend → same visit/wins totals
// across repeated runs. Backprop order is batch-index-ordered inside
// `iterate_gpu_batched`, so parallelism in the GPU dispatch cannot
// leak into the tree.
let backend = AiBackend::probe();
let root_a = make_root_state();
let root_b = make_root_state();
let mut tree_a = Tree::new(root_a).with_gpu_context(ctx);
let mut tree_b = Tree::new(root_b).with_gpu_context(ctx);
let mut tree_a = Tree::new(root_a);
let mut tree_b = Tree::new(root_b);
for i in 0..2 {
tree_a.iterate_gpu_batched(BATCH_SIZE, 7000 + i as u64, None);
tree_b.iterate_gpu_batched(BATCH_SIZE, 7000 + i as u64, None);
tree_a.iterate_gpu_batched(BATCH_SIZE, 7000 + i as u64, None, &backend);
tree_b.iterate_gpu_batched(BATCH_SIZE, 7000 + i as u64, None, &backend);
}
assert_eq!(
@ -204,9 +207,9 @@ fn iterate_gpu_batched_is_seed_deterministic() {
#[test]
fn iterate_gpu_batched_zero_batch_is_noop() {
let ctx = GpuContext::shared();
let mut tree = Tree::new(make_root_state()).with_gpu_context(ctx);
let n = tree.iterate_gpu_batched(0, 42, None);
let backend = AiBackend::probe();
let mut tree = Tree::new(make_root_state());
let n = tree.iterate_gpu_batched(0, 42, None, &backend);
assert_eq!(n, 0);
assert_eq!(tree.root().visits, 0);
assert_eq!(tree.gpu_batch_count, 0);