feat(@projects/@magic-civilization): add gpu-optimized abstract rollout action path

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
Natalie 2026-05-04 16:57:43 -04:00
parent 467d7ee951
commit a4827d5257
5 changed files with 418 additions and 6 deletions

View file

@ -522,6 +522,162 @@ impl GdMcTreeController {
r#"{{"action":"{action_str}","win_rate":{win_rate:.4},"root_idle":{visits_idle},"root_found":{visits_found},"root_spawn":{visits_spawn}}}"#
))
}
/// p0-20 Phase A v3 — choose an action via the abstract `GameRolloutState`
/// path. Coexists with `choose_action` / `choose_action_with_stats`
/// (which still drive `Tree<McSnapshot>`), giving the GPU MCTS path its
/// own callable entry point without disturbing the live-game flow.
///
/// Pipeline:
/// 1. Parse `game_state_json` → `mc_turn::GameState`.
/// 2. Project to `AbstractRolloutState` via
/// `mc_turn::abstract_projection::to_abstract_rollout_state`.
/// 3. Build a `Tree<GameRolloutState>` rooted at `(pod, priors_per_player)`,
/// with `priors_per_player` decoded from `priors_json` — a JSON array
/// of up to `MAX_PLAYERS` `PersonalityPriors` objects (missing slots
/// default).
/// 4. Outer loop: call `tree.iterate_gpu_batched(batch_size, …)` until the
/// rollout budget is met OR the wall-clock budget (`budget_ms`) expires.
/// Routing GPU vs CPU is decided by the boot-probed
/// [`mc_ai::backend::AiBackend`].
/// 5. Return the highest-visit-count action via
/// `Tree::most_visited_action_at_root` as the `ActionKind` debug name
/// (`"Build"`, `"Settle"`, etc.). Empty trees return `"Idle"`.
///
/// Determinism: same `(game_state_json, priors_json, root_player, base_seed,
/// rollout_budget)` produces the same action across runs and across
/// platforms with the same probed backend (CPU and GPU paths are
/// byte-equivalent — see `tests/gpu_rollout_parity.rs`).
#[func]
fn choose_action_via_abstract(
&self,
game_state_json: GString,
priors_json: GString,
root_player: i64,
base_seed: i64,
) -> GString {
use mc_ai::abstract_state::MAX_PLAYERS;
use mc_ai::policy::{ActionKind, PersonalityPriors};
use mc_ai::rollout::GameRolloutState;
use mc_turn::abstract_projection::to_abstract_rollout_state;
// 1. Parse GameState.
let state: GameState = match serde_json::from_str(&game_state_json.to_string()) {
Ok(s) => s,
Err(e) => {
godot_error!(
"GdMcTreeController::choose_action_via_abstract parse error: {}", e
);
return GString::from("Idle");
}
};
// 2. Project to abstract POD.
let pod = to_abstract_rollout_state(&state);
// 3. Decode priors into a fixed-size array, defaulting missing slots.
let priors_arr: [PersonalityPriors; MAX_PLAYERS] = {
let mut out = [PersonalityPriors::default(); MAX_PLAYERS];
let raw = priors_json.to_string();
if !raw.trim().is_empty() && raw.trim() != "null" {
match serde_json::from_str::<Vec<PersonalityPriors>>(&raw) {
Ok(v) => {
for (i, p) in v.into_iter().take(MAX_PLAYERS).enumerate() {
out[i] = p;
}
}
Err(e) => {
godot_error!(
"GdMcTreeController::choose_action_via_abstract priors parse error: {}",
e
);
}
}
}
out
};
// 4. Build the rollout-state tree. The root carries the abstract POD
// + per-player priors; downstream `iterate_gpu_batched` clones the
// POD into the GPU/CPU batch path on every leaf expansion.
let root_state = GameRolloutState::new(pod, priors_arr);
let mut tree = Tree::new(root_state);
tree.use_priors = self.priors_enabled;
let pi = root_player.max(0) as usize;
tree.root_player = pi.min(MAX_PLAYERS - 1) as u8;
// Backend + outer N-rollout loop. AiBackend is process-static — probe
// once and cache so successive calls don't pay the adapter probe cost.
static BACKEND: OnceLock<AiBackend> = OnceLock::new();
let backend = BACKEND.get_or_init(AiBackend::probe);
// Tunable: 32 leaves per dispatch matches the GPU shader workgroup
// size and is the same number the existing `gpu_tree_integration`
// tests use. Smaller batches lose throughput; larger batches risk
// overshooting the rollout budget on the last dispatch.
const BATCH_SIZE: usize = 32;
let total_budget = self.rollout_budget as usize;
let wall_budget = if self.budget_ms > 0 {
Some(self.budget_ms)
} else {
None
};
let start = Instant::now();
let mut completed: usize = 0;
let bs = base_seed as u64;
while completed < total_budget {
if let Some(b) = wall_budget {
if start.elapsed() >= Duration::from_millis(b) {
break;
}
}
let remaining = total_budget - completed;
let this_batch = remaining.min(BATCH_SIZE);
// Each batch carries its own seed offset so identical inputs
// produce identical visit counts across runs.
let dispatched = tree.iterate_gpu_batched(
this_batch,
bs.wrapping_add(completed as u64),
wall_budget,
backend,
);
if dispatched == 0 {
// Either terminal-at-root, batch_size=0, or backend Err.
// Surface and stop — no silent fallback (see
// `iterate_gpu_batched` docstring).
break;
}
completed += dispatched;
}
// 5. Read the highest-visit child action.
let action_name = match tree.most_visited_action_at_root() {
Some(a) => action_kind_name(a),
None => "Idle",
};
GString::from(action_name)
}
}
/// Stable, ordinal-frozen lower-CamelCase debug name for an
/// [`mc_ai::policy::ActionKind`]. Matches the variant identifier so GDScript
/// can switch on the exact string without consulting a lookup table.
fn action_kind_name(kind: mc_ai::policy::ActionKind) -> &'static str {
use mc_ai::policy::ActionKind;
match kind {
ActionKind::Build => "Build",
ActionKind::Attack => "Attack",
ActionKind::Settle => "Settle",
ActionKind::Research => "Research",
ActionKind::Defend => "Defend",
ActionKind::Trade => "Trade",
ActionKind::ContinueWar => "ContinueWar",
ActionKind::MakePeace => "MakePeace",
ActionKind::Idle => "Idle",
ActionKind::CommandFormation => "CommandFormation",
ActionKind::SetRallyPoint => "SetRallyPoint",
}
}
// ── GdAiController ───────────────────────────────────────────────────────────

View file

@ -127,6 +127,22 @@ impl AiBackend {
}
}
/// True iff the active backend dispatches through the GPU compute path.
///
/// Universal — callable from non-`gpu`-feature builds where the `Gpu`
/// variant doesn't exist (always returns `false` there). Used by
/// [`crate::mcts_tree::Tree::iterate_gpu_batched`] to bump
/// `gpu_batch_count` only when the GPU path actually ran, without
/// requiring `#[cfg(feature = "gpu")]` at the call site (p0-20 v3).
#[must_use]
pub fn is_gpu(&self) -> bool {
match self {
#[cfg(feature = "gpu")]
AiBackend::Gpu(_) => true,
AiBackend::Cpu => false,
}
}
/// Dispatch a batched rollout through the active backend.
///
/// Returns one `f32` terminal score in `[0, 1]` per batch entry, in

View file

@ -118,9 +118,11 @@ pub struct Tree<S: TreeState> {
/// Count of successful batched GPU dispatches performed by this tree.
/// Observable from tests to confirm the GPU path actually ran. Bumped
/// once per non-empty `iterate_gpu_batched` call where the active
/// [`crate::backend::AiBackend`] is `Gpu` AND dispatch returned `Ok`.
/// CPU-backend dispatches and dispatch errors do NOT bump this.
#[cfg(feature = "gpu")]
/// [`crate::backend::AiBackend`] reports `is_gpu() == true` AND dispatch
/// returned `Ok`. CPU-backend dispatches and dispatch errors do NOT bump
/// this. Universal across `gpu`/no-`gpu` builds (p0-20 v3) — under
/// `cfg(not(feature = "gpu"))` the counter exists but always reads 0
/// because `AiBackend::is_gpu()` is constant `false`.
pub gpu_batch_count: u32,
}
@ -134,7 +136,6 @@ impl<S: TreeState> Tree<S> {
rollout_horizon: crate::rollout::DEFAULT_ROLLOUT_HORIZON,
rollout_temperature: crate::rollout::DEFAULT_ROLLOUT_TEMPERATURE,
root_player: 0,
#[cfg(feature = "gpu")]
gpu_batch_count: 0,
}
}
@ -339,7 +340,12 @@ impl<S: TreeState> Tree<S> {
/// [`crate::backend::AiBackend::batch_simulate`] operates on
/// `AbstractRolloutState` — the projection is well-defined only for
/// `GameRolloutState`, not arbitrary `S: TreeState`.
#[cfg(feature = "gpu")]
///
/// Universal across `gpu`/no-`gpu` builds (p0-20 v3): the impl block is no
/// longer cfg-gated. Mobile-CPU-only builds compile this method and route
/// dispatch through `AiBackend::Cpu` (which is universal); the `gpu/` module
/// itself stays feature-gated, so `AiBackend::Gpu(_)` only exists under
/// `feature = "gpu"`.
impl Tree<crate::rollout::GameRolloutState> {
/// Run one batched MCTS iteration: select + expand `batch_size` leaves,
/// dispatch their rollouts through `backend.batch_simulate` (which is
@ -413,7 +419,7 @@ impl Tree<crate::rollout::GameRolloutState> {
}
};
if matches!(backend, crate::backend::AiBackend::Gpu(_)) {
if backend.is_gpu() {
self.gpu_batch_count = self.gpu_batch_count.saturating_add(1);
}
@ -425,6 +431,54 @@ impl Tree<crate::rollout::GameRolloutState> {
}
states.len()
}
/// Return the [`crate::policy::ActionKind`] of the root child with the
/// highest visit count. Ties are broken by lowest `ActionKind` ordinal
/// (the position in [`crate::policy::ActionKind::ALL`]) so identical
/// rollout budgets and seeds produce identical action picks across
/// platforms.
///
/// Returns `None` when the root has no expanded children (zero-budget
/// search or terminal-at-root). p0-20 v3 — used by api-gdext's
/// `choose_action_via_abstract`.
#[must_use]
pub fn most_visited_action_at_root(&self) -> Option<crate::policy::ActionKind> {
let root = &self.nodes[0];
if root.children.is_empty() {
return None;
}
// Build (ordinal, visits, action) triples and pick by (-visits, ordinal).
let mut best: Option<(u32, usize, crate::policy::ActionKind)> = None;
for &ci in &root.children {
let n = &self.nodes[ci];
let action = match n.action {
Some(a) => a,
None => continue,
};
// Ordinal = index in ActionKind::ALL; falls back to a sentinel
// (=ALL.len()) for non-rollout actions (CommandFormation,
// SetRallyPoint), which sort last on tie-break.
let ordinal = crate::policy::ActionKind::ALL
.iter()
.position(|k| *k == action)
.unwrap_or(crate::policy::ActionKind::ALL.len());
let candidate = (n.visits, ordinal, action);
best = Some(match best {
None => candidate,
Some(prev) => {
// Higher visits wins; on tie, lower ordinal wins.
if candidate.0 > prev.0
|| (candidate.0 == prev.0 && candidate.1 < prev.1)
{
candidate
} else {
prev
}
}
});
}
best.map(|(_, _, a)| a)
}
}
// ── Rollout helpers for McSnapshot ──────────────────────────────────────────

View file

@ -123,6 +123,179 @@ pub struct SearchActionResult {
pub path: String,
}
// ── Abstract-rollout IPC mirror (p0-20 Phase A v3) ──────────────────────────
//
// `AbstractJobState` is a serde mirror of `mc_ai::abstract_state::AbstractRolloutState`.
// The POD itself is `#[repr(C)] + bytemuck::Pod` (no serde) because the GPU
// shader reads it as a raw byte buffer and we don't want serde to cross the
// memory-layout contract. The mirror is a flat serde-friendly form that
// round-trips through the IPC frame; converters at the boundary copy field-by-
// field. Phase A ships **schema only** — there is no `run_search_action_via_abstract`
// runner yet (Phase C).
/// Per-player serde mirror of `mc_ai::abstract_state::AbstractPlayerState`.
///
/// Field set is the full 14-logical-field surface (excluding internal
/// `_padN` slots, which serde reconstructs as zero via `Default`). Mirrors
/// the per-byte layout of `AbstractPlayerState` for unambiguous conversion.
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct AbstractPlayerStateMirror {
/// `gold` — treasury, signed.
pub gold: i32,
/// `science` — science accumulator toward current tech.
pub science: i32,
/// `pop_total` — sum of city populations.
pub pop_total: u32,
/// `city_count` — number of cities owned.
pub city_count: u16,
/// `tech_index` — normalized tech progress in `[0, 100]`.
pub tech_index: u16,
/// `unit_counts[T1, T2]` — counts at tier 1 and tier 2.
pub unit_counts: [u8; 2],
/// `formation_count` — number of active formations (saturated u8).
pub formation_count: u8,
/// `happiness_pool` — signed; negative under unrest.
pub happiness_pool: i16,
/// `force_rel[opp]` — relative military force vs each opponent slot.
pub force_rel: [u16; 4],
/// `axes` — strategic axes via `mc_ai::game_state::axes_to_flat`.
pub axes: [u8; 8],
/// `relations[opp]` — diplomatic relation per opponent (-1/0/+1).
pub relations: [i8; 4],
/// `formation_strength[T1..T4]` — mean strength per tier bucket, 0255.
pub formation_strength: [u8; 4],
/// `rng_state` — per-player SplitMix64 state.
pub rng_state: u64,
/// `turn` — game turn number.
pub turn: u32,
}
impl AbstractPlayerStateMirror {
/// Project from the POD into a serde-friendly mirror. Padding fields
/// are dropped — they reconstruct as zero via `Default`.
#[must_use]
pub fn from_pod(pod: &mc_ai::abstract_state::AbstractPlayerState) -> Self {
Self {
gold: pod.gold,
science: pod.science,
pop_total: pod.pop_total,
city_count: pod.city_count,
tech_index: pod.tech_index,
unit_counts: pod.unit_counts,
formation_count: pod.formation_count,
happiness_pool: pod.happiness_pool,
force_rel: pod.force_rel,
axes: pod.axes,
relations: pod.relations,
formation_strength: pod.formation_strength,
rng_state: pod.rng_state,
turn: pod.turn,
}
}
/// Reconstruct the POD from the mirror. All `_padN` fields are zeroed
/// per the `bytemuck::Pod` contract (no uninitialised bytes).
#[must_use]
pub fn to_pod(&self) -> mc_ai::abstract_state::AbstractPlayerState {
mc_ai::abstract_state::AbstractPlayerState {
gold: self.gold,
science: self.science,
pop_total: self.pop_total,
city_count: self.city_count,
tech_index: self.tech_index,
unit_counts: self.unit_counts,
formation_count: self.formation_count,
_pad_uc: 0,
happiness_pool: self.happiness_pool,
_pad0: 0,
force_rel: self.force_rel,
axes: self.axes,
relations: self.relations,
formation_strength: self.formation_strength,
rng_state: self.rng_state,
turn: self.turn,
_pad2: [0; 4],
}
}
}
/// Serde mirror of `mc_ai::abstract_state::AbstractRolloutState` — the
/// 256-byte 4-player POD that travels through the GPU/CPU rollout shader.
///
/// Used by [`SearchActionViaAbstractJob`] to ship a search request through
/// the IPC frame without forcing serde onto the `bytemuck::Pod` POD itself.
/// Field-by-field copy at the boundary keeps the WGSL layout pristine.
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct AbstractJobState {
/// Up to 4 player slots in slot order. Missing slots zero on deserialise
/// via `Default`.
pub players: Vec<AbstractPlayerStateMirror>,
}
impl AbstractJobState {
/// Project from the POD into the IPC mirror.
#[must_use]
pub fn from_pod(pod: &mc_ai::abstract_state::AbstractRolloutState) -> Self {
Self {
players: pod
.players
.iter()
.map(AbstractPlayerStateMirror::from_pod)
.collect(),
}
}
/// Reconstruct the POD from the mirror. Missing slots stay zero-initialised.
#[must_use]
pub fn to_pod(&self) -> mc_ai::abstract_state::AbstractRolloutState {
let mut out = mc_ai::abstract_state::AbstractRolloutState::zeroed();
for (i, slot) in self
.players
.iter()
.take(mc_ai::abstract_state::MAX_PLAYERS)
.enumerate()
{
out.players[i] = slot.to_pod();
}
out
}
}
/// Per-player personality priors mirror.
///
/// `mc_ai::policy::PersonalityPriors` already derives `Serialize` /
/// `Deserialize`, so this is a transparent re-alias rather than a
/// hand-written mirror — kept named for symmetry with `AbstractJobState`
/// and to give the IPC schema one self-contained module.
pub type PersonalityPriorsMirror = mc_ai::policy::PersonalityPriors;
/// p0-20 Phase A v3 — full MCTS tree search rooted at an
/// [`AbstractJobState`]. Routes through the GPU/CPU `iterate_gpu_batched`
/// path on the server side.
///
/// **Phase A ships the schema only**; the runner
/// (`run_search_action_via_abstract`) lands in Phase C alongside the rest of
/// the GPU MCTS service plumbing.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SearchActionViaAbstractJob {
/// Root rollout state — projected from `mc_turn::GameState` via
/// `mc_turn::abstract_projection::to_abstract_rollout_state`.
pub abstract_state: AbstractJobState,
/// Per-player priors aligned with `abstract_state.players`. Slot order
/// matches the POD; missing entries default.
pub priors: [PersonalityPriorsMirror; mc_ai::abstract_state::MAX_PLAYERS],
/// Player slot the MCTS is deciding for. Mirrors `Tree::root_player`.
pub root_player: u8,
/// Total rollout budget — the outer loop ends once `iterate_gpu_batched`
/// has dispatched this many leaves (or `budget_ms` expires).
pub rollout_budget: u32,
/// Base seed; per-batch seeds are derived as `base_seed + completed_count`.
pub base_seed: u64,
/// Optional per-decision wall-clock budget. `None` = unbounded.
/// Mirrors `Tree::iterate_gpu_batched`'s `budget_ms`.
pub budget_ms: Option<u64>,
}
/// Requests submitted to the MCTS service.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum Request {
@ -134,6 +307,11 @@ pub enum Request {
MctsBatch { jobs: Vec<MctsJob> },
/// Run a full MCTS tree search and return the best action (p1-27c).
SearchAction(SearchActionJob),
/// p0-20 Phase A v3 (schema only) — abstract-rollout MCTS search.
/// Server routing via `iterate_gpu_batched` + `most_visited_action_at_root`
/// lands in Phase C; this variant exists in Phase A so callers (api-gdext,
/// integration tests) can encode requests against a stable wire shape.
SearchActionViaAbstract(SearchActionViaAbstractJob),
}
/// Responses produced by the MCTS service.

View file

@ -97,6 +97,14 @@ fn dispatch(request: Request) -> Response {
Ok(result) => Response::SearchActionResult(result),
Err(msg) => Response::Error { message: msg },
},
// p0-20 Phase A v3 — schema-only stub. The runner lands in Phase C.
// Returning a structured error here lets callers exercise the wire
// shape (encode/decode round-trip) without claiming server support.
Request::SearchActionViaAbstract(_) => Response::Error {
message:
"SearchActionViaAbstract is schema-only in Phase A; runner ships in Phase C"
.to_owned(),
},
}
}