feat(@projects/@magic-civilization): ✨ add ai backend probe and dispatch system
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
cd9b92879a
commit
039c31a079
7 changed files with 356 additions and 306 deletions
172
src/simulator/crates/mc-ai/src/backend.rs
Normal file
172
src/simulator/crates/mc-ai/src/backend.rs
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
//! Boot-probed AI backend selector — single dispatch entry point for batched
|
||||
//! rollouts. Replaces the per-call silent GPU→CPU fallback that
|
||||
//! `gpu::inner::batch_simulate` previously implemented.
|
||||
//!
|
||||
//! # Contract
|
||||
//!
|
||||
//! - Probe runs **once** at boot via [`AiBackend::probe`] (env-overridable).
|
||||
//! - The chosen backend is **fixed for the session** — there is NO per-call
|
||||
//! fallback. If GPU dispatch fails mid-game,
|
||||
//! [`AiBackend::batch_simulate`] returns `Err`; callers do NOT silently
|
||||
//! degrade to CPU.
|
||||
//! - Algorithm parity is enforced by `tests/gpu_rollout_parity.rs`. The CPU
|
||||
//! path goes through [`crate::gpu::cpu_reference::batch_simulate_cpu`],
|
||||
//! which is a thin wrapper around the canonical `rollout::walk` — the same
|
||||
//! semantics the WGSL shader mirrors.
|
||||
//!
|
||||
//! # Env override
|
||||
//!
|
||||
//! - `MC_AI_BACKEND=cpu` — force `AiBackend::Cpu` regardless of probe.
|
||||
//! - `MC_AI_BACKEND=gpu` — require a working GPU; **panic** if probe fails.
|
||||
//! - Unset / any other value — probe normally (Gpu if available, else Cpu).
|
||||
|
||||
use crate::abstract_state::{AbstractRolloutState, MAX_PLAYERS};
|
||||
use crate::policy::PersonalityPriors;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::gpu::inner::{GpuContext, GpuError};
|
||||
|
||||
/// Active AI backend for batched rollout dispatch. Probed once at boot.
|
||||
///
|
||||
/// The `Gpu` variant is feature-gated behind `gpu`; under
|
||||
/// `cfg(not(feature = "gpu"))` the only variant is `Cpu` and probe always
|
||||
/// returns `Cpu`. Mobile / minimal builds that omit wgpu therefore see no
|
||||
/// dispatch-layer cfg branching — the gating is by variant existence, not
|
||||
/// by `cfg` at every call site.
|
||||
pub enum AiBackend {
|
||||
/// GPU compute via wgpu. Holds the process-wide cached
|
||||
/// [`GpuContext::shared`] — adapter probe + pipeline compile cost is
|
||||
/// paid exactly once at boot.
|
||||
#[cfg(feature = "gpu")]
|
||||
Gpu(&'static GpuContext),
|
||||
/// Canonical CPU rollout via
|
||||
/// [`crate::gpu::cpu_reference::batch_simulate_cpu`]. Algorithm-equivalent
|
||||
/// to the WGSL shader, byte-by-byte, on `AbstractRolloutState`.
|
||||
Cpu,
|
||||
}
|
||||
|
||||
/// Errors surfaced from [`AiBackend::batch_simulate`]. There is no fallback;
|
||||
/// dispatch failures propagate up.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum BackendError {
|
||||
/// The GPU pipeline failed to dispatch (queue submit, buffer map, device
|
||||
/// lost, etc.). The session does NOT fall back to CPU — callers must
|
||||
/// decide whether to retry, surface, or abort.
|
||||
#[error("GPU dispatch failed: {0}")]
|
||||
GpuDispatchFailed(String),
|
||||
/// Catch-all for caller misuse (length mismatches, etc.) and any future
|
||||
/// non-GPU error case.
|
||||
#[error("backend error: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl AiBackend {
|
||||
/// Probe at boot. Order: existing [`GpuContext::shared`] adapter probe
|
||||
/// (Vulkan / Metal / DX12 / WebGPU) → `Cpu`. The `MC_AI_BACKEND` env var
|
||||
/// overrides the probe — see crate-level docs.
|
||||
///
|
||||
/// Logs the chosen backend at info-level on stderr (same channel
|
||||
/// `mc-turn` / `mc-compute` use). Mobile users running CPU see a line
|
||||
/// like `[mc-ai backend] Cpu (no compute adapter)`; hosts with a working
|
||||
/// adapter see `[mc-ai backend] Gpu (Vulkan)`.
|
||||
#[must_use]
|
||||
pub fn probe() -> Self {
|
||||
let env = std::env::var("MC_AI_BACKEND").ok();
|
||||
let env_norm = env.as_deref().map(str::to_ascii_lowercase);
|
||||
|
||||
match env_norm.as_deref() {
|
||||
Some("cpu") => {
|
||||
eprintln!("[mc-ai backend] Cpu (forced via MC_AI_BACKEND=cpu)");
|
||||
AiBackend::Cpu
|
||||
}
|
||||
Some("gpu") => {
|
||||
#[cfg(feature = "gpu")]
|
||||
{
|
||||
match GpuContext::shared() {
|
||||
Some(ctx) => {
|
||||
eprintln!(
|
||||
"[mc-ai backend] Gpu ({}) (forced via MC_AI_BACKEND=gpu)",
|
||||
ctx.backend
|
||||
);
|
||||
AiBackend::Gpu(ctx)
|
||||
}
|
||||
None => panic!(
|
||||
"MC_AI_BACKEND=gpu requested but no compute adapter available"
|
||||
),
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "gpu"))]
|
||||
{
|
||||
panic!(
|
||||
"MC_AI_BACKEND=gpu requested but the `gpu` cargo feature is disabled"
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
#[cfg(feature = "gpu")]
|
||||
{
|
||||
if let Some(ctx) = GpuContext::shared() {
|
||||
eprintln!("[mc-ai backend] Gpu ({})", ctx.backend);
|
||||
return AiBackend::Gpu(ctx);
|
||||
}
|
||||
}
|
||||
eprintln!("[mc-ai backend] Cpu (no compute adapter)");
|
||||
AiBackend::Cpu
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stable, lower-case name for diagnostics / log lines. `"gpu"` or
|
||||
/// `"cpu"`.
|
||||
#[must_use]
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
#[cfg(feature = "gpu")]
|
||||
AiBackend::Gpu(_) => "gpu",
|
||||
AiBackend::Cpu => "cpu",
|
||||
}
|
||||
}
|
||||
|
||||
/// Dispatch a batched rollout through the active backend.
|
||||
///
|
||||
/// Returns one `f32` terminal score in `[0, 1]` per batch entry, in
|
||||
/// input order. The algorithm is identical across backends — see
|
||||
/// `tests/gpu_rollout_parity.rs` for the byte-equivalence contract.
|
||||
///
|
||||
/// Returns `Err(BackendError::Other)` on caller-side input length
|
||||
/// mismatch and `Err(BackendError::GpuDispatchFailed)` on GPU runtime
|
||||
/// failure. There is no silent fallback to CPU.
|
||||
pub fn batch_simulate(
|
||||
&self,
|
||||
inputs: &[AbstractRolloutState],
|
||||
priors_per_entry: &[[PersonalityPriors; MAX_PLAYERS]],
|
||||
seed: u64,
|
||||
horizon: u32,
|
||||
) -> Result<Vec<f32>, BackendError> {
|
||||
if inputs.len() != priors_per_entry.len() {
|
||||
return Err(BackendError::Other(format!(
|
||||
"inputs len {} != priors len {}",
|
||||
inputs.len(),
|
||||
priors_per_entry.len()
|
||||
)));
|
||||
}
|
||||
if inputs.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
match self {
|
||||
#[cfg(feature = "gpu")]
|
||||
AiBackend::Gpu(ctx) => ctx
|
||||
.batch_simulate(inputs, priors_per_entry, seed, horizon)
|
||||
.map_err(|e: GpuError| BackendError::GpuDispatchFailed(e.to_string())),
|
||||
AiBackend::Cpu => {
|
||||
let out = crate::gpu::cpu_reference::batch_simulate_cpu(
|
||||
inputs,
|
||||
priors_per_entry,
|
||||
seed,
|
||||
horizon,
|
||||
);
|
||||
Ok(out.into_iter().map(|(s, _)| s).collect())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -19,10 +19,23 @@ use wgpu::util::DeviceExt;
|
|||
|
||||
use crate::abstract_state::{AbstractRolloutState, MAX_PLAYERS};
|
||||
use crate::policy::PersonalityPriors;
|
||||
use crate::rollout::{DEFAULT_ROLLOUT_HORIZON, DEFAULT_ROLLOUT_TEMPERATURE};
|
||||
use crate::rollout::DEFAULT_ROLLOUT_TEMPERATURE;
|
||||
|
||||
use super::cpu_reference::batch_simulate_cpu;
|
||||
use super::RolloutPath;
|
||||
/// Runtime failure surfaced from [`GpuContext::batch_simulate`]. There is no
|
||||
/// per-call CPU fallback inside this module — backend identity is fixed for
|
||||
/// the session by [`crate::backend::AiBackend::probe`]. Callers translate
|
||||
/// this into [`crate::backend::BackendError::GpuDispatchFailed`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum GpuError {
|
||||
/// Caller passed mismatched `inputs` / `priors_per_entry` lengths.
|
||||
#[error("inputs len {inputs} != priors len {priors}")]
|
||||
LengthMismatch { inputs: usize, priors: usize },
|
||||
/// The wgpu pipeline dispatch failed (queue submit error, buffer map
|
||||
/// failure, device lost). The cause string is whatever the wgpu /
|
||||
/// `pollster` layer surfaced.
|
||||
#[error("dispatch failed: {0}")]
|
||||
DispatchFailed(String),
|
||||
}
|
||||
|
||||
/// WGSL kernel source, compiled into the binary at build time.
|
||||
const SHADER_SRC: &str = include_str!("rollout.wgsl");
|
||||
|
|
@ -396,27 +409,33 @@ impl GpuContext {
|
|||
|
||||
/// Dispatch a full rollout batch through the GPU pipeline.
|
||||
///
|
||||
/// Returns `RolloutPath::Gpu`-tagged results on success. On any runtime
|
||||
/// failure (dispatch error, map failure) falls back to the CPU reference
|
||||
/// and returns `RolloutPath::Cpu`-tagged results — the caller gets a
|
||||
/// valid answer either way.
|
||||
#[must_use]
|
||||
/// Returns scores on success. On dispatch failure returns
|
||||
/// `Err(GpuError::DispatchFailed)` — there is **no** silent CPU
|
||||
/// fallback. The boot-probed [`crate::backend::AiBackend`] decides
|
||||
/// backend identity once; runtime failures propagate up so callers
|
||||
/// know the dispatch did not produce GPU-quality answers.
|
||||
pub fn batch_simulate(
|
||||
&self,
|
||||
inputs: &[AbstractRolloutState],
|
||||
priors_per_entry: &[[PersonalityPriors; MAX_PLAYERS]],
|
||||
seed: u64,
|
||||
horizon: u32,
|
||||
) -> Vec<(f32, RolloutPath)> {
|
||||
) -> Result<Vec<f32>, GpuError> {
|
||||
if inputs.len() != priors_per_entry.len() {
|
||||
return Vec::new();
|
||||
return Err(GpuError::LengthMismatch {
|
||||
inputs: inputs.len(),
|
||||
priors: priors_per_entry.len(),
|
||||
});
|
||||
}
|
||||
if inputs.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
match self.dispatch_batch(inputs, priors_per_entry, seed, horizon) {
|
||||
Some(scores) => scores
|
||||
.into_iter()
|
||||
.map(|s| (s, RolloutPath::Gpu))
|
||||
.collect(),
|
||||
None => batch_simulate_cpu(inputs, priors_per_entry, seed, horizon),
|
||||
Some(scores) => Ok(scores),
|
||||
None => Err(GpuError::DispatchFailed(
|
||||
"wgpu pipeline dispatch returned None (queue submit / buffer map / device lost)"
|
||||
.to_owned(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -464,48 +483,6 @@ fn create_storage_rw(dev: &wgpu::Device, size_bytes: usize, label: &str) -> wgpu
|
|||
})
|
||||
}
|
||||
|
||||
/// Top-level GPU-or-CPU dispatch entry point.
|
||||
///
|
||||
/// Uses the process-wide cached [`GpuContext::shared`] — the adapter probe
|
||||
/// runs exactly once per process, not per call. On hosts with a working GPU
|
||||
/// adapter this dispatches to the shader; on headless hosts or hosts where
|
||||
/// the driver is wedged (see `TRY_INIT_TIMEOUT_MS`) it falls through to the
|
||||
/// CPU reference silently. Result types are identical; only the
|
||||
/// [`RolloutPath`] tag differs.
|
||||
///
|
||||
/// For hot loops that dispatch many batches, consider holding a
|
||||
/// `&GpuContext` directly via `GpuContext::shared()` to skip the `OnceLock`
|
||||
/// atomic load per call.
|
||||
#[must_use]
|
||||
pub fn batch_simulate(
|
||||
inputs: &[AbstractRolloutState],
|
||||
priors_per_entry: &[[PersonalityPriors; MAX_PLAYERS]],
|
||||
seed: u64,
|
||||
horizon: u32,
|
||||
) -> Vec<(f32, RolloutPath)> {
|
||||
// Zero-length inputs never touch the GPU cache — fast path.
|
||||
if inputs.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
if inputs.len() != priors_per_entry.len() {
|
||||
return Vec::new();
|
||||
}
|
||||
if let Some(ctx) = GpuContext::shared() {
|
||||
return ctx.batch_simulate(inputs, priors_per_entry, seed, horizon);
|
||||
}
|
||||
batch_simulate_cpu(inputs, priors_per_entry, seed, horizon)
|
||||
}
|
||||
|
||||
/// Convenience: `batch_simulate` with `DEFAULT_ROLLOUT_HORIZON`.
|
||||
#[must_use]
|
||||
pub fn batch_simulate_default_horizon(
|
||||
inputs: &[AbstractRolloutState],
|
||||
priors_per_entry: &[[PersonalityPriors; MAX_PLAYERS]],
|
||||
seed: u64,
|
||||
) -> Vec<(f32, RolloutPath)> {
|
||||
batch_simulate(inputs, priors_per_entry, seed, DEFAULT_ROLLOUT_HORIZON)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -558,43 +535,42 @@ mod tests {
|
|||
|
||||
use std::time::Instant;
|
||||
|
||||
// ── Tests that do NOT touch the GPU adapter ──────────────────────────
|
||||
// ── Adapter-gated guard tests ────────────────────────────────────────
|
||||
//
|
||||
// These rely solely on the pre-dispatch guards in `batch_simulate` (empty
|
||||
// input / mismatched lens) OR on `GpuContext::shared()` returning None
|
||||
// after the one-time probe. Either way no single test is responsible for
|
||||
// the probe cost — the first test that *does* need GPU state pays it
|
||||
// once and caches.
|
||||
// The boot-probed `AiBackend` no longer has a per-call dispatch shim
|
||||
// here, so empty / mismatched-length tests hit `GpuContext::batch_simulate`
|
||||
// directly. Both must short-circuit before any wgpu work; without an
|
||||
// adapter we skip via `shared()` returning None.
|
||||
|
||||
#[test]
|
||||
fn batch_simulate_empty_bypasses_gpu_probe() {
|
||||
// Empty input returns Vec::new() before GpuContext::shared() is ever
|
||||
// consulted. Must complete in microseconds even on a wedged-adapter
|
||||
// host; assert a 100ms upper bound with generous slack for CI jitter.
|
||||
fn ctx_batch_simulate_empty_returns_ok_empty() {
|
||||
let Some(ctx) = GpuContext::shared() else {
|
||||
eprintln!("[rollout-gpu] no adapter — skipping ctx_batch_simulate_empty_returns_ok_empty");
|
||||
return;
|
||||
};
|
||||
let start = Instant::now();
|
||||
let out = batch_simulate(&[], &[], 42, 20);
|
||||
let out = ctx.batch_simulate(&[], &[], 42, 20).expect("empty must succeed");
|
||||
let elapsed = start.elapsed();
|
||||
assert!(out.is_empty());
|
||||
assert!(
|
||||
elapsed < Duration::from_millis(100),
|
||||
"empty input must bypass GPU probe; took {:?}",
|
||||
"empty input must short-circuit; took {:?}",
|
||||
elapsed
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_simulate_mismatched_lengths_bypasses_gpu_probe() {
|
||||
fn ctx_batch_simulate_mismatched_lengths_errors() {
|
||||
let Some(ctx) = GpuContext::shared() else {
|
||||
eprintln!("[rollout-gpu] no adapter — skipping ctx_batch_simulate_mismatched_lengths_errors");
|
||||
return;
|
||||
};
|
||||
let pods = vec![make_entry()];
|
||||
let priors: Vec<[PersonalityPriors; MAX_PLAYERS]> = vec![iron_vs_bh(), iron_vs_bh()];
|
||||
let start = Instant::now();
|
||||
let out = batch_simulate(&pods, &priors, 1, 20);
|
||||
let elapsed = start.elapsed();
|
||||
assert!(out.is_empty());
|
||||
assert!(
|
||||
elapsed < Duration::from_millis(100),
|
||||
"length-mismatch must bypass GPU probe; took {:?}",
|
||||
elapsed
|
||||
);
|
||||
let err = ctx
|
||||
.batch_simulate(&pods, &priors, 1, 20)
|
||||
.expect_err("mismatched lens must Err");
|
||||
assert!(matches!(err, GpuError::LengthMismatch { .. }));
|
||||
}
|
||||
|
||||
// ── Timeout contract ─────────────────────────────────────────────────
|
||||
|
|
@ -712,86 +688,36 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
// ── Scored-path tests ────────────────────────────────────────────────
|
||||
// ── Adapter-gated dispatch tests ─────────────────────────────────────
|
||||
//
|
||||
// These exercise the full dispatch pipeline. The first one to run pays
|
||||
// the probe cost (bounded by timeout). On a wedged-driver host all of
|
||||
// these silently route through CPU and still produce valid results.
|
||||
// Only run when a working adapter is actually present. On hosts without
|
||||
// a compute adapter `shared()` returns None and these skip — no hang,
|
||||
// no panic. The boot-probed `AiBackend` covers the no-adapter path; the
|
||||
// tests in `tests/backend_probe.rs` and `tests/gpu_rollout_parity.rs`
|
||||
// exercise the full dispatch surface.
|
||||
|
||||
#[test]
|
||||
fn batch_simulate_produces_unit_interval_scores() {
|
||||
let pods = vec![make_entry(); 4];
|
||||
let priors = vec![iron_vs_bh(); 4];
|
||||
let out = batch_simulate(&pods, &priors, 7, 20);
|
||||
assert_eq!(out.len(), 4);
|
||||
for (score, path) in &out {
|
||||
assert!((0.0..=1.0).contains(score), "score {score} out of [0,1] on {path}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fallback_returns_valid_path_tag() {
|
||||
let pods = vec![make_entry()];
|
||||
let priors = vec![iron_vs_bh()];
|
||||
let out = batch_simulate(&pods, &priors, 100, 20);
|
||||
assert_eq!(out.len(), 1);
|
||||
assert!(matches!(out[0].1, RolloutPath::Cpu | RolloutPath::Gpu));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_simulate_is_deterministic_across_repeated_calls() {
|
||||
let pods = vec![make_entry(); 3];
|
||||
let priors = vec![iron_vs_bh(); 3];
|
||||
let a = batch_simulate(&pods, &priors, 77, 20);
|
||||
let b = batch_simulate(&pods, &priors, 77, 20);
|
||||
assert_eq!(a.len(), b.len());
|
||||
for (ra, rb) in a.iter().zip(b.iter()) {
|
||||
assert_eq!(
|
||||
ra.0.to_bits(),
|
||||
rb.0.to_bits(),
|
||||
"same seed must produce bit-identical results on {}",
|
||||
ra.1
|
||||
);
|
||||
assert_eq!(ra.1, rb.1, "path tag must be stable across calls");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_horizon_helper_matches_explicit_call() {
|
||||
let pods = vec![make_entry()];
|
||||
let priors = vec![iron_vs_bh()];
|
||||
let a = batch_simulate(&pods, &priors, 55, DEFAULT_ROLLOUT_HORIZON);
|
||||
let b = batch_simulate_default_horizon(&pods, &priors, 55);
|
||||
assert_eq!(a[0].0.to_bits(), b[0].0.to_bits());
|
||||
assert_eq!(a[0].1, b[0].1);
|
||||
}
|
||||
|
||||
// ── Adapter-gated tests ──────────────────────────────────────────────
|
||||
//
|
||||
// Only run when a working adapter is actually present. On wedged
|
||||
// drivers `shared()` returns None and these skip — no hang, no panic.
|
||||
|
||||
#[test]
|
||||
fn gpu_path_tags_when_adapter_available() {
|
||||
fn gpu_returns_unit_interval_scores_when_adapter_available() {
|
||||
let Some(ctx) = GpuContext::shared() else {
|
||||
eprintln!("[rollout-gpu] no adapter — skipping gpu_path_tags_when_adapter_available");
|
||||
eprintln!("[rollout-gpu] no adapter — skipping gpu_returns_unit_interval_scores_when_adapter_available");
|
||||
return;
|
||||
};
|
||||
let pods = vec![make_entry()];
|
||||
let priors = vec![iron_vs_bh()];
|
||||
let out = ctx.batch_simulate(&pods, &priors, 123, 20);
|
||||
assert_eq!(out.len(), 1);
|
||||
assert_eq!(
|
||||
out[0].1,
|
||||
RolloutPath::Gpu,
|
||||
"with adapter present, batch_simulate must tag Gpu (backend: {})",
|
||||
ctx.backend
|
||||
);
|
||||
assert!((0.0..=1.0).contains(&out[0].0));
|
||||
let pods = vec![make_entry(); 4];
|
||||
let priors = vec![iron_vs_bh(); 4];
|
||||
let out = ctx
|
||||
.batch_simulate(&pods, &priors, 7, 20)
|
||||
.expect("dispatch should succeed on a working adapter");
|
||||
assert_eq!(out.len(), 4);
|
||||
for score in &out {
|
||||
assert!(
|
||||
(0.0..=1.0).contains(score),
|
||||
"score {score} out of [0,1] (backend: {})",
|
||||
ctx.backend
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Dispatch determinism when an adapter is present. Skips on headless or
|
||||
/// wedged-driver hosts via the `shared()` None return.
|
||||
/// Dispatch determinism when an adapter is present.
|
||||
#[test]
|
||||
fn gpu_dispatch_is_deterministic_when_adapter_available() {
|
||||
let Some(ctx) = GpuContext::shared() else {
|
||||
|
|
@ -800,12 +726,16 @@ mod tests {
|
|||
};
|
||||
let pods = vec![make_entry(); 4];
|
||||
let priors = vec![iron_vs_bh(); 4];
|
||||
let a = ctx.batch_simulate(&pods, &priors, 42, 20);
|
||||
let b = ctx.batch_simulate(&pods, &priors, 42, 20);
|
||||
let a = ctx
|
||||
.batch_simulate(&pods, &priors, 42, 20)
|
||||
.expect("first dispatch");
|
||||
let b = ctx
|
||||
.batch_simulate(&pods, &priors, 42, 20)
|
||||
.expect("second dispatch");
|
||||
for (ra, rb) in a.iter().zip(b.iter()) {
|
||||
assert_eq!(
|
||||
ra.0.to_bits(),
|
||||
rb.0.to_bits(),
|
||||
ra.to_bits(),
|
||||
rb.to_bits(),
|
||||
"GPU dispatch must be bit-deterministic on repeat (backend: {})",
|
||||
ctx.backend
|
||||
);
|
||||
|
|
|
|||
|
|
@ -20,40 +20,7 @@ pub mod inner;
|
|||
pub use cpu_reference::{batch_simulate_cpu, batch_simulate_cpu_default_horizon};
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub use inner::{batch_simulate, batch_simulate_default_horizon, GpuContext};
|
||||
|
||||
/// CPU-only fallback for `batch_simulate` when the `gpu` feature is disabled.
|
||||
///
|
||||
/// Present so callers can target a stable `batch_simulate` surface without
|
||||
/// cfg-gating at every dispatch site. With the feature on, this symbol is
|
||||
/// shadowed by [`inner::batch_simulate`] which attempts GPU first.
|
||||
#[cfg(not(feature = "gpu"))]
|
||||
#[must_use]
|
||||
pub fn batch_simulate(
|
||||
inputs: &[crate::abstract_state::AbstractRolloutState],
|
||||
priors_per_entry: &[[crate::policy::PersonalityPriors; crate::abstract_state::MAX_PLAYERS]],
|
||||
seed: u64,
|
||||
horizon: u32,
|
||||
) -> Vec<(f32, RolloutPath)> {
|
||||
batch_simulate_cpu(inputs, priors_per_entry, seed, horizon)
|
||||
}
|
||||
|
||||
/// CPU-only convenience fallback matching [`inner::batch_simulate_default_horizon`]
|
||||
/// when the `gpu` feature is disabled.
|
||||
#[cfg(not(feature = "gpu"))]
|
||||
#[must_use]
|
||||
pub fn batch_simulate_default_horizon(
|
||||
inputs: &[crate::abstract_state::AbstractRolloutState],
|
||||
priors_per_entry: &[[crate::policy::PersonalityPriors; crate::abstract_state::MAX_PLAYERS]],
|
||||
seed: u64,
|
||||
) -> Vec<(f32, RolloutPath)> {
|
||||
batch_simulate_cpu(
|
||||
inputs,
|
||||
priors_per_entry,
|
||||
seed,
|
||||
crate::rollout::DEFAULT_ROLLOUT_HORIZON,
|
||||
)
|
||||
}
|
||||
pub use inner::{GpuContext, GpuError};
|
||||
|
||||
/// Which execution path produced a rollout result.
|
||||
///
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
//! leaf-value evaluator used by the tournament-mode strategy search.
|
||||
|
||||
pub mod abstract_state;
|
||||
pub mod backend;
|
||||
pub mod diplomacy;
|
||||
pub mod evaluator;
|
||||
pub mod game_state;
|
||||
|
|
@ -17,15 +18,13 @@ pub mod rollout;
|
|||
pub mod tactical;
|
||||
|
||||
pub use abstract_state::{AbstractPlayerState, AbstractRolloutState, MAX_PLAYERS};
|
||||
pub use backend::{AiBackend, BackendError};
|
||||
pub use diplomacy::{
|
||||
evaluate_open_borders_accept, evaluate_open_borders_offer, evaluate_shared_map_accept,
|
||||
evaluate_shared_map_offer, DiploDecision, DiplomacyCtx,
|
||||
};
|
||||
pub use evaluator::{LoadError, PersonalityDef, ScoringWeights};
|
||||
pub use gpu::{
|
||||
batch_simulate, batch_simulate_cpu, batch_simulate_cpu_default_horizon,
|
||||
batch_simulate_default_horizon, RolloutPath,
|
||||
};
|
||||
pub use gpu::{batch_simulate_cpu, batch_simulate_cpu_default_horizon, RolloutPath};
|
||||
pub use policy::{
|
||||
decide_ransom_response, ransom_accept_probability, score_capture_postures, ActionKind,
|
||||
CombatBalance, PersonalityPriors, RansomDecision,
|
||||
|
|
|
|||
|
|
@ -9,9 +9,6 @@
|
|||
use crate::mcts::XorShift64;
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::gpu::GpuContext;
|
||||
|
||||
/// State + action interface the tree MCTS operates over.
|
||||
pub trait TreeState: Clone {
|
||||
type Action: Clone;
|
||||
|
|
@ -118,20 +115,11 @@ pub struct Tree<S: TreeState> {
|
|||
/// Index of the player MCTS is deciding for. Rewards in `simulate()`
|
||||
/// are evaluated from this player's perspective.
|
||||
pub root_player: u8,
|
||||
/// Optional process-wide GPU context. When `Some`, state types that
|
||||
/// project to [`crate::abstract_state::AbstractRolloutState`] (currently
|
||||
/// [`crate::rollout::GameRolloutState`]) can dispatch batched rollouts
|
||||
/// through `batch_simulate_gpu`. When `None` the tree runs the serial
|
||||
/// CPU path in `iterate` / `simulate`.
|
||||
///
|
||||
/// Feature-gated behind `gpu` to keep the non-wgpu build paths lean.
|
||||
#[cfg(feature = "gpu")]
|
||||
pub gpu_context: Option<&'static GpuContext>,
|
||||
/// Count of GPU batch dispatches performed by this tree. Observable from
|
||||
/// tests to confirm the GPU path actually ran instead of falling through
|
||||
/// to CPU. Incremented once per successful `iterate_gpu_batched` call
|
||||
/// that produced a non-empty batch — the CPU fallback inside that method
|
||||
/// does NOT bump this counter.
|
||||
/// Count of successful batched GPU dispatches performed by this tree.
|
||||
/// Observable from tests to confirm the GPU path actually ran. Bumped
|
||||
/// once per non-empty `iterate_gpu_batched` call where the active
|
||||
/// [`crate::backend::AiBackend`] is `Gpu` AND dispatch returned `Ok`.
|
||||
/// CPU-backend dispatches and dispatch errors do NOT bump this.
|
||||
#[cfg(feature = "gpu")]
|
||||
pub gpu_batch_count: u32,
|
||||
}
|
||||
|
|
@ -147,23 +135,10 @@ impl<S: TreeState> Tree<S> {
|
|||
rollout_temperature: crate::rollout::DEFAULT_ROLLOUT_TEMPERATURE,
|
||||
root_player: 0,
|
||||
#[cfg(feature = "gpu")]
|
||||
gpu_context: None,
|
||||
#[cfg(feature = "gpu")]
|
||||
gpu_batch_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Install a GPU context so batched rollouts dispatch through
|
||||
/// `mc_ai::gpu::batch_simulate`. Passing `None` restores the serial
|
||||
/// CPU path.
|
||||
///
|
||||
/// Typical call: `tree.with_gpu_context(GpuContext::shared())`.
|
||||
#[cfg(feature = "gpu")]
|
||||
pub fn with_gpu_context(mut self, ctx: Option<&'static GpuContext>) -> Self {
|
||||
self.gpu_context = ctx;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn root(&self) -> &Node<S> {
|
||||
&self.nodes[0]
|
||||
}
|
||||
|
|
@ -357,28 +332,31 @@ impl<S: TreeState> Tree<S> {
|
|||
}
|
||||
}
|
||||
|
||||
// ── GPU batched iteration for GameRolloutState ──────────────────────────────
|
||||
// ── Batched iteration for GameRolloutState ──────────────────────────────────
|
||||
|
||||
/// Batched GPU rollout dispatch for trees whose state is
|
||||
/// Batched rollout dispatch for trees whose state is
|
||||
/// [`crate::rollout::GameRolloutState`]. Kept as a separate impl block because
|
||||
/// [`crate::gpu::batch_simulate`] operates on `AbstractRolloutState` — the
|
||||
/// projection is well-defined only for `GameRolloutState`, not arbitrary
|
||||
/// `S: TreeState`.
|
||||
/// [`crate::backend::AiBackend::batch_simulate`] operates on
|
||||
/// `AbstractRolloutState` — the projection is well-defined only for
|
||||
/// `GameRolloutState`, not arbitrary `S: TreeState`.
|
||||
#[cfg(feature = "gpu")]
|
||||
impl Tree<crate::rollout::GameRolloutState> {
|
||||
/// Run one batched MCTS iteration: select + expand `batch_size` leaves,
|
||||
/// dispatch their rollouts through [`crate::gpu::batch_simulate`] (which
|
||||
/// routes to GPU when `gpu_context` is `Some`, else CPU), then
|
||||
/// backpropagate the rewards in canonical (batch-index) order so visit
|
||||
/// totals are seed-deterministic.
|
||||
/// dispatch their rollouts through `backend.batch_simulate` (which is
|
||||
/// either the GPU shader or the canonical CPU rollout depending on the
|
||||
/// boot-probed [`crate::backend::AiBackend`]), then backpropagate the
|
||||
/// rewards in canonical (batch-index) order so visit totals are
|
||||
/// seed-deterministic.
|
||||
///
|
||||
/// Returns the number of leaves actually rolled out. Returns `0` when
|
||||
/// `batch_size == 0` or the root is terminal with no expandable children.
|
||||
/// `batch_size == 0`, the root is terminal with no expandable children,
|
||||
/// OR the backend returns `Err` (per Phase-1 contract there is **no**
|
||||
/// silent CPU fallback — a runtime GPU dispatch failure surfaces and the
|
||||
/// caller decides what to do).
|
||||
///
|
||||
/// The `gpu_batch_count` counter bumps once per non-empty dispatch so
|
||||
/// tests can assert the GPU path was exercised. When `gpu_context` is
|
||||
/// `None` the dispatch silently uses the CPU reference — results are
|
||||
/// valid but the counter stays put.
|
||||
/// The `gpu_batch_count` counter bumps once per non-empty dispatch where
|
||||
/// the backend is `Gpu` AND dispatch returned `Ok`. CPU-backend
|
||||
/// dispatches and `Err` returns do NOT bump it.
|
||||
///
|
||||
/// `budget_ms` caps wall-clock time: the batch-collection loop exits early
|
||||
/// once `Instant::now() - start >= budget_ms`. Already-collected leaves are
|
||||
|
|
@ -389,14 +367,12 @@ impl Tree<crate::rollout::GameRolloutState> {
|
|||
batch_size: usize,
|
||||
base_seed: u64,
|
||||
budget_ms: Option<u64>,
|
||||
backend: &crate::backend::AiBackend,
|
||||
) -> usize {
|
||||
if batch_size == 0 {
|
||||
return 0;
|
||||
}
|
||||
// Collect up to `batch_size` distinct (target_idx, state) pairs.
|
||||
// Each iteration runs one select+expand walk off the root; duplicates
|
||||
// are allowed when a leaf node accepts multiple rollouts through the
|
||||
// same target child.
|
||||
let mut targets: Vec<usize> = Vec::with_capacity(batch_size);
|
||||
let mut states: Vec<crate::abstract_state::AbstractRolloutState> =
|
||||
Vec::with_capacity(batch_size);
|
||||
|
|
@ -422,26 +398,29 @@ impl Tree<crate::rollout::GameRolloutState> {
|
|||
return 0;
|
||||
}
|
||||
|
||||
// Prefer the explicit context stored on the tree; fall back to the
|
||||
// process-wide shared one via the top-level dispatch so callers
|
||||
// without a `with_gpu_context` setup still exercise GPU when present.
|
||||
let results = if let Some(ctx) = self.gpu_context {
|
||||
ctx.batch_simulate(&states, &priors, base_seed, self.rollout_horizon)
|
||||
} else {
|
||||
crate::gpu::batch_simulate(&states, &priors, base_seed, self.rollout_horizon)
|
||||
let results = match backend.batch_simulate(
|
||||
&states,
|
||||
&priors,
|
||||
base_seed,
|
||||
self.rollout_horizon,
|
||||
) {
|
||||
Ok(scores) => scores,
|
||||
Err(_e) => {
|
||||
// No silent fallback. Surface the failure as zero
|
||||
// rollouts dispatched for this batch; caller can inspect
|
||||
// `gpu_batch_count` and decide whether to retry / abort.
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
// Count a GPU dispatch only when at least one result carries the
|
||||
// `Gpu` tag. Falling through to CPU silently is allowed and must not
|
||||
// be counted.
|
||||
if results.iter().any(|(_, path)| *path == crate::gpu::RolloutPath::Gpu) {
|
||||
if matches!(backend, crate::backend::AiBackend::Gpu(_)) {
|
||||
self.gpu_batch_count = self.gpu_batch_count.saturating_add(1);
|
||||
}
|
||||
|
||||
// Backpropagate in canonical (batch-index) order so repeated runs
|
||||
// with the same seed produce bit-identical visit counts even when
|
||||
// the GPU reorders work internally.
|
||||
for (target, (reward, _)) in targets.iter().zip(results.iter()) {
|
||||
for (target, reward) in targets.iter().zip(results.iter()) {
|
||||
self.backpropagate(*target, *reward);
|
||||
}
|
||||
states.len()
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use mc_ai::abstract_state::{AbstractRolloutState, MAX_PLAYERS};
|
||||
use mc_ai::gpu::{batch_simulate_cpu, GpuContext, RolloutPath};
|
||||
use mc_ai::gpu::{batch_simulate_cpu, GpuContext};
|
||||
use mc_ai::mcts::XorShift64;
|
||||
use mc_ai::policy::PersonalityPriors;
|
||||
use mc_ai::rollout::DEFAULT_ROLLOUT_HORIZON;
|
||||
|
|
@ -215,15 +215,13 @@ fn gpu_rollout_parity_small_batch() {
|
|||
|
||||
let (states, priors) = fixture_batch(N, SEED);
|
||||
|
||||
let gpu_out = ctx.batch_simulate(&states, &priors, SEED, HORIZON);
|
||||
let gpu_out = ctx
|
||||
.batch_simulate(&states, &priors, SEED, HORIZON)
|
||||
.expect("dispatch should succeed on a working adapter");
|
||||
let cpu_out = batch_simulate_cpu(&states, &priors, SEED, HORIZON);
|
||||
|
||||
assert_eq!(gpu_out.len(), N);
|
||||
assert_eq!(cpu_out.len(), N);
|
||||
for (i, (g, c)) in gpu_out.iter().zip(cpu_out.iter()).enumerate() {
|
||||
assert_eq!(g.1, RolloutPath::Gpu, "entry {i}: GPU path must tag Gpu");
|
||||
assert_eq!(c.1, RolloutPath::Cpu, "entry {i}: CPU path must tag Cpu");
|
||||
}
|
||||
|
||||
report_agreement(&gpu_out, &cpu_out, "small_batch", ctx.backend.as_str());
|
||||
}
|
||||
|
|
@ -241,7 +239,9 @@ fn gpu_rollout_parity_multi_workgroup() {
|
|||
const SEED: u64 = 0xDEAD_C0DE_1234_5678_u64;
|
||||
|
||||
let (states, priors) = fixture_batch(N, SEED);
|
||||
let gpu_out = ctx.batch_simulate(&states, &priors, SEED, DEFAULT_ROLLOUT_HORIZON);
|
||||
let gpu_out = ctx
|
||||
.batch_simulate(&states, &priors, SEED, DEFAULT_ROLLOUT_HORIZON)
|
||||
.expect("dispatch should succeed on a working adapter");
|
||||
let cpu_out = batch_simulate_cpu(&states, &priors, SEED, DEFAULT_ROLLOUT_HORIZON);
|
||||
|
||||
assert_eq!(gpu_out.len(), N);
|
||||
|
|
|
|||
|
|
@ -1,24 +1,23 @@
|
|||
//! p0-20 integration — `batch_simulate_gpu` wired into
|
||||
//! p0-20 integration — `AiBackend::batch_simulate` wired into
|
||||
//! `mcts_tree::Tree<GameRolloutState>::iterate_gpu_batched`.
|
||||
//!
|
||||
//! Asserts that constructing a `Tree<GameRolloutState>` with a
|
||||
//! `GpuContext::shared()` and calling `iterate_gpu_batched` over 100
|
||||
//! rollouts actually exercises the GPU path (observable via
|
||||
//! `Tree::gpu_batch_count`) and backpropagates valid rewards into the root.
|
||||
//! Asserts that constructing a `Tree<GameRolloutState>` with an `AiBackend`
|
||||
//! and calling `iterate_gpu_batched` over 100 rollouts actually exercises
|
||||
//! the GPU path (observable via `Tree::gpu_batch_count`) when the backend
|
||||
//! probes Gpu, and falls through to the CPU rollout (no counter bump) when
|
||||
//! it probes Cpu.
|
||||
//!
|
||||
//! # Skip behavior
|
||||
//!
|
||||
//! On headless hosts / hosts without a working compute adapter,
|
||||
//! `GpuContext::shared()` returns `None`. In that case the test falls back
|
||||
//! to the CPU reference path (which is itself a thin wrapper around the
|
||||
//! canonical rollout walker) and asserts only the CPU-observable invariants
|
||||
//! (visit counts, reward in [0, 1]). No hang, no panic — matches the
|
||||
//! skip-path used by `tests/gpu_rollout_parity.rs`.
|
||||
//! `AiBackend::probe()` returns `Cpu`. The Gpu-path test in this file checks
|
||||
//! `matches!(backend, AiBackend::Gpu(_))` and skips otherwise — no hang, no
|
||||
//! panic. Matches the skip-path used by `tests/gpu_rollout_parity.rs`.
|
||||
|
||||
#![cfg(feature = "gpu")]
|
||||
|
||||
use mc_ai::abstract_state::{AbstractRolloutState, MAX_PLAYERS};
|
||||
use mc_ai::gpu::GpuContext;
|
||||
use mc_ai::backend::AiBackend;
|
||||
use mc_ai::mcts_tree::Tree;
|
||||
use mc_ai::policy::PersonalityPriors;
|
||||
use mc_ai::rollout::GameRolloutState;
|
||||
|
|
@ -83,21 +82,24 @@ fn make_root_state() -> GameRolloutState {
|
|||
|
||||
#[test]
|
||||
fn iterate_gpu_batched_exercises_gpu_path_when_adapter_available() {
|
||||
let Some(ctx) = GpuContext::shared() else {
|
||||
// Force-probe via env so we always end up on Gpu when an adapter exists,
|
||||
// skipping otherwise.
|
||||
let backend = AiBackend::probe();
|
||||
if !matches!(backend, AiBackend::Gpu(_)) {
|
||||
eprintln!(
|
||||
"[skip] no GPU adapter — iterate_gpu_batched_exercises_gpu_path_when_adapter_available \
|
||||
is a no-op on this host"
|
||||
);
|
||||
return;
|
||||
};
|
||||
}
|
||||
|
||||
let root = make_root_state();
|
||||
let mut tree = Tree::new(root).with_gpu_context(Some(ctx));
|
||||
let mut tree = Tree::new(root);
|
||||
|
||||
let mut rolled_out = 0_usize;
|
||||
let mut batch_idx: u64 = 0;
|
||||
while rolled_out < TOTAL_ROLLOUTS {
|
||||
let n = tree.iterate_gpu_batched(BATCH_SIZE, 1000 + batch_idx, None);
|
||||
let n = tree.iterate_gpu_batched(BATCH_SIZE, 1000 + batch_idx, None, &backend);
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
|
|
@ -105,17 +107,12 @@ fn iterate_gpu_batched_exercises_gpu_path_when_adapter_available() {
|
|||
batch_idx += 1;
|
||||
}
|
||||
|
||||
// Every dispatch tagged Gpu bumps the counter — with a real adapter all
|
||||
// dispatches should hit GPU, so the counter must be at least 1.
|
||||
assert!(
|
||||
tree.gpu_batch_count >= 1,
|
||||
"expected ≥1 GPU dispatch with adapter present, got {} (backend: {})",
|
||||
"expected ≥1 GPU dispatch with adapter present, got {}",
|
||||
tree.gpu_batch_count,
|
||||
ctx.backend
|
||||
);
|
||||
|
||||
// Root visits accumulate one-per-rollout. With batch size {BATCH_SIZE}
|
||||
// and {TOTAL_ROLLOUTS} target rollouts, root.visits >= TOTAL_ROLLOUTS.
|
||||
assert!(
|
||||
tree.root().visits as usize >= TOTAL_ROLLOUTS,
|
||||
"expected ≥{} root visits, got {}",
|
||||
|
|
@ -123,7 +120,6 @@ fn iterate_gpu_batched_exercises_gpu_path_when_adapter_available() {
|
|||
tree.root().visits
|
||||
);
|
||||
|
||||
// Root wins must be in [0, visits] because every reward is in [0, 1].
|
||||
assert!(
|
||||
tree.root().wins >= 0.0 && tree.root().wins <= tree.root().visits as f32,
|
||||
"wins {} out of [0, {}]",
|
||||
|
|
@ -133,23 +129,26 @@ fn iterate_gpu_batched_exercises_gpu_path_when_adapter_available() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn iterate_gpu_batched_cpu_fallback_without_context() {
|
||||
// No context installed → falls through to the top-level dispatch, which
|
||||
// itself consults `GpuContext::shared()`. On a host with no adapter this
|
||||
// lands on the CPU reference. Either way rewards must be valid and root
|
||||
// visits must accumulate; the GPU counter is allowed to stay at 0 when
|
||||
// the path resolves to CPU.
|
||||
fn iterate_gpu_batched_cpu_backend_does_not_bump_counter() {
|
||||
// Force-CPU backend regardless of host adapter — every dispatch returns
|
||||
// valid CPU-rollout rewards, but `gpu_batch_count` must stay at 0.
|
||||
let prev = std::env::var("MC_AI_BACKEND").ok();
|
||||
std::env::set_var("MC_AI_BACKEND", "cpu");
|
||||
let backend = AiBackend::probe();
|
||||
if let Some(p) = prev {
|
||||
std::env::set_var("MC_AI_BACKEND", p);
|
||||
} else {
|
||||
std::env::remove_var("MC_AI_BACKEND");
|
||||
}
|
||||
assert!(matches!(backend, AiBackend::Cpu));
|
||||
|
||||
let root = make_root_state();
|
||||
let mut tree = Tree::new(root);
|
||||
assert!(
|
||||
tree.gpu_context.is_none(),
|
||||
"constructor default for gpu_context must be None"
|
||||
);
|
||||
|
||||
let mut rolled_out = 0_usize;
|
||||
let mut batch_idx: u64 = 0;
|
||||
while rolled_out < TOTAL_ROLLOUTS {
|
||||
let n = tree.iterate_gpu_batched(BATCH_SIZE, 2000 + batch_idx, None);
|
||||
let n = tree.iterate_gpu_batched(BATCH_SIZE, 2000 + batch_idx, None, &backend);
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
|
|
@ -157,15 +156,19 @@ fn iterate_gpu_batched_cpu_fallback_without_context() {
|
|||
batch_idx += 1;
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
tree.gpu_batch_count, 0,
|
||||
"Cpu backend must never bump gpu_batch_count"
|
||||
);
|
||||
assert!(
|
||||
tree.root().visits as usize >= TOTAL_ROLLOUTS,
|
||||
"CPU fallback must still accumulate ≥{} visits, got {}",
|
||||
"CPU backend must still accumulate ≥{} visits, got {}",
|
||||
TOTAL_ROLLOUTS,
|
||||
tree.root().visits
|
||||
);
|
||||
assert!(
|
||||
tree.root().wins >= 0.0 && tree.root().wins <= tree.root().visits as f32,
|
||||
"wins {} out of [0, {}] on CPU fallback",
|
||||
"wins {} out of [0, {}] on CPU backend",
|
||||
tree.root().wins,
|
||||
tree.root().visits
|
||||
);
|
||||
|
|
@ -173,20 +176,20 @@ fn iterate_gpu_batched_cpu_fallback_without_context() {
|
|||
|
||||
#[test]
|
||||
fn iterate_gpu_batched_is_seed_deterministic() {
|
||||
// Same seed + same root state + same context installation policy → same
|
||||
// visit/wins totals across repeated runs. Backprop order is
|
||||
// batch-index-ordered inside `iterate_gpu_batched`, so parallelism in
|
||||
// the GPU dispatch cannot leak into the tree.
|
||||
let ctx = GpuContext::shared();
|
||||
// Same seed + same root state + same backend → same visit/wins totals
|
||||
// across repeated runs. Backprop order is batch-index-ordered inside
|
||||
// `iterate_gpu_batched`, so parallelism in the GPU dispatch cannot
|
||||
// leak into the tree.
|
||||
let backend = AiBackend::probe();
|
||||
let root_a = make_root_state();
|
||||
let root_b = make_root_state();
|
||||
|
||||
let mut tree_a = Tree::new(root_a).with_gpu_context(ctx);
|
||||
let mut tree_b = Tree::new(root_b).with_gpu_context(ctx);
|
||||
let mut tree_a = Tree::new(root_a);
|
||||
let mut tree_b = Tree::new(root_b);
|
||||
|
||||
for i in 0..2 {
|
||||
tree_a.iterate_gpu_batched(BATCH_SIZE, 7000 + i as u64, None);
|
||||
tree_b.iterate_gpu_batched(BATCH_SIZE, 7000 + i as u64, None);
|
||||
tree_a.iterate_gpu_batched(BATCH_SIZE, 7000 + i as u64, None, &backend);
|
||||
tree_b.iterate_gpu_batched(BATCH_SIZE, 7000 + i as u64, None, &backend);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
|
|
@ -204,9 +207,9 @@ fn iterate_gpu_batched_is_seed_deterministic() {
|
|||
|
||||
#[test]
|
||||
fn iterate_gpu_batched_zero_batch_is_noop() {
|
||||
let ctx = GpuContext::shared();
|
||||
let mut tree = Tree::new(make_root_state()).with_gpu_context(ctx);
|
||||
let n = tree.iterate_gpu_batched(0, 42, None);
|
||||
let backend = AiBackend::probe();
|
||||
let mut tree = Tree::new(make_root_state());
|
||||
let n = tree.iterate_gpu_batched(0, 42, None, &backend);
|
||||
assert_eq!(n, 0);
|
||||
assert_eq!(tree.root().visits, 0);
|
||||
assert_eq!(tree.gpu_batch_count, 0);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue