fix(api-gdext): 🐛 remove ai_gpu_rollout dependency

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
Natalie 2026-05-04 17:57:02 -04:00
parent 587d4c9934
commit d840855d81
8 changed files with 488 additions and 1234 deletions

View file

@ -130,7 +130,9 @@ static func _apply_mcts_strategic_override(player: RefCounted) -> void:
else MCTS_ROLLOUT_COUNT_EARLY)
ctrl.set_rollout_budget(budget)
ctrl.set_rollout_depth(MCTS_ROLLOUT_DEPTH)
ctrl.set_gpu_enabled(OS.get_environment("AI_GPU_ROLLOUT") in ["1", "true", "TRUE", "True"])
# p0-20 Phase C — backend (CPU vs GPU) is now decided once at boot by
# AiBackend::probe(); the AI_GPU_ROLLOUT env var was deleted alongside
# the McSnapshot strategic-MCTS path. MC_AI_BACKEND is the only knob.
ctrl.set_priors_enabled(
OS.get_environment("AI_MCTS_PRIORS") in ["1", "true", "TRUE", "True"]
)

File diff suppressed because it is too large Load diff

View file

@ -6,7 +6,9 @@ use tracing::instrument;
use crate::error::ServiceError;
use crate::framing::{read_frame, write_frame};
use crate::protocol::{MctsJob, MctsResult, Request, Response, SearchActionJob, SearchActionResult};
use crate::protocol::{
MctsJob, MctsResult, Request, Response, SearchActionResult, SearchActionViaAbstractJob,
};
/// A single-connection client that round-trips one [`Request`] to the service.
///
@ -81,22 +83,24 @@ pub async fn submit_batch(
}
}
/// Submit a full MCTS tree search job and return the best action with stats.
/// Submit a full abstract-rollout MCTS tree search and return the best action
/// with stats. p0-20 Phase C — only strategic-search entry point; the legacy
/// McSnapshot-shaped `submit_search_action` is gone.
///
/// Maps `Response::Error` to [`ServiceError::Remote`].
///
/// # Errors
///
/// Returns an error if the transport fails or the service responds with `Error`.
pub async fn submit_search_action(
pub async fn submit_search_action_via_abstract(
socket_path: impl AsRef<Path> + std::fmt::Debug,
job: SearchActionJob,
job: SearchActionViaAbstractJob,
) -> Result<SearchActionResult, ServiceError> {
match round_trip(socket_path, Request::SearchAction(job)).await? {
Response::SearchActionResult(r) => Ok(r),
match round_trip(socket_path, Request::SearchActionViaAbstract(job)).await? {
Response::SearchActionViaAbstractResult(r) => Ok(r),
Response::Error { message } => Err(ServiceError::Remote(message)),
other => Err(ServiceError::Remote(format!(
"unexpected response to SearchAction request: {other:?}"
"unexpected response to SearchActionViaAbstract request: {other:?}"
))),
}
}

View file

@ -18,10 +18,15 @@
//! Version 1 (p1-27b): `Request::Mcts(MctsJob)` / `Response::MctsResult(MctsResult)` and
//! `Request::MctsBatch` / `Response::MctsBatchResult`. State encoded as
//! `MctsJobState` JSON inside each job.
//! Version 2 (p1-27c): `Request::SearchAction(SearchActionJob)` /
//! `Response::SearchActionResult(SearchActionResult)`. Carries a full
//! `McSnapshot` JSON; server runs `Tree::simulate_parallel` and returns
//! `{ action, win_rate, n_rollouts, took_ms, path }`.
//! Version 3 (p0-20 Phase C):
//! `Request::SearchActionViaAbstract(SearchActionViaAbstractJob)` /
//! `Response::SearchActionViaAbstractResult(SearchActionResult)`. Carries an
//! `AbstractJobState` (mirror of `mc_ai::abstract_state::AbstractRolloutState`)
//! plus per-player `PersonalityPriors`; server runs
//! `Tree<GameRolloutState>::iterate_gpu_batched` (CPU or GPU per probed
//! `AiBackend`) and returns
//! `{ action, win_rate, n_rollouts, took_ms, path }`. The McSnapshot-shaped
//! `Request::SearchAction` from version 2 is **deleted**.
//!
//! ## Relationship to @model-boss
//!

View file

@ -83,43 +83,23 @@ pub struct MctsResult {
pub took_ms: u32,
}
/// Input for a full MCTS tree search producing an action decision.
/// Result of a full MCTS tree search rooted at an [`AbstractJobState`].
///
/// Carries a JSON-serialised [`mc_turn::snapshot::McSnapshot`] so the server
/// can run `Tree::simulate_parallel` over the exact same state type
/// `GdMcTreeController` uses locally, giving byte-compatible action selection.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SearchActionJob {
/// JSON-encoded `McSnapshot` (serde form of `mc_turn::snapshot::McSnapshot`).
pub snapshot_json: String,
/// Index of the player MCTS is deciding for (matches `McSnapshot::active_player`).
pub root_player: u8,
/// Number of MCTS iterations (`Tree::simulate_parallel` budget).
pub n_rollouts: u32,
/// Rollout horizon — turns per random playout.
pub depth: u32,
/// Base RNG seed for the tree simulation.
pub seed: u64,
/// When `true`, use PUCT selection with per-node priors (p0-38). When
/// `false`, fall back to classical UCB1. Mirrors `GdMcTreeController::priors_enabled`.
pub use_priors: bool,
/// Per-decision wall-clock budget in milliseconds (`0` = unbounded). Mirrors
/// `GdMcTreeController::budget_ms`.
pub budget_ms: u64,
}
/// Result of a [`Request::SearchAction`] tree search.
/// p0-20 Phase C — replaces the legacy McSnapshot-shaped `SearchActionResult`.
/// The action string is one of [`mc_ai::policy::ActionKind`]'s canonical
/// debug names (`"Build"`, `"Settle"`, `"Idle"`, …).
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SearchActionResult {
/// The chosen action: `"Idle"`, `"FoundCity"`, or `"SpawnUnit"`.
/// The chosen [`mc_ai::policy::ActionKind`] as a string.
pub action: String,
/// Win-rate estimate for the chosen action from the best child node.
pub win_rate: f32,
/// Total iterations completed inside `Tree::simulate_parallel`.
/// Total iterations completed inside `Tree::iterate_gpu_batched`.
pub n_rollouts: u32,
/// Wall-clock milliseconds the search took.
pub took_ms: u32,
/// Compute path used by the server (`"cpu"` for v1; `"gpu"` reserved).
/// Compute path used by the server: `"gpu"` when the boot-probed
/// [`mc_ai::backend::AiBackend`] reports `is_gpu()`, else `"cpu"`.
pub path: String,
}
@ -305,12 +285,9 @@ pub enum Request {
Mcts(MctsJob),
/// Run multiple flat-rollout jobs sequentially (p1-27b primitive).
MctsBatch { jobs: Vec<MctsJob> },
/// Run a full MCTS tree search and return the best action (p1-27c).
SearchAction(SearchActionJob),
/// p0-20 Phase A v3 (schema only) — abstract-rollout MCTS search.
/// Server routing via `iterate_gpu_batched` + `most_visited_action_at_root`
/// lands in Phase C; this variant exists in Phase A so callers (api-gdext,
/// integration tests) can encode requests against a stable wire shape.
/// p0-20 Phase C — abstract-rollout MCTS search. Routes through
/// `Tree<GameRolloutState>::iterate_gpu_batched` server-side, returning
/// the best [`mc_ai::policy::ActionKind`] as a string.
SearchActionViaAbstract(SearchActionViaAbstractJob),
}
@ -323,8 +300,8 @@ pub enum Response {
MctsResult(MctsResult),
/// Results for a [`Request::MctsBatch`], one entry per input job.
MctsBatchResult { results: Vec<MctsResult> },
/// Result for [`Request::SearchAction`].
SearchActionResult(SearchActionResult),
/// Result for [`Request::SearchActionViaAbstract`] (p0-20 Phase C).
SearchActionViaAbstractResult(SearchActionResult),
/// The service encountered an error processing the request.
Error { message: String },
}

View file

@ -1,16 +1,23 @@
/// Async Unix-socket server that processes [`Request`](crate::protocol::Request) frames.
use std::path::Path;
use std::sync::OnceLock;
use std::time::{Duration, Instant};
use tokio::net::{UnixListener, UnixStream};
use tracing::{error, info, instrument, warn};
use crate::error::ServiceError;
use crate::framing::{read_frame, write_frame};
use crate::protocol::{MctsJob, MctsJobState, MctsResult, Request, Response, SearchActionJob, SearchActionResult};
use crate::protocol::{
MctsJob, MctsJobState, MctsResult, Request, Response, SearchActionResult,
SearchActionViaAbstractJob,
};
/// Default socket path used when none is provided.
pub const DEFAULT_SOCKET_PATH: &str = "/tmp/mc-mcts.sock";
static BACKEND: OnceLock<mc_ai::backend::AiBackend> = OnceLock::new();
/// Run the MCTS service, listening on `socket_path`.
///
/// Removes a stale socket file before binding (handles unclean prior shutdown).
@ -26,13 +33,9 @@ pub async fn run(socket_path: impl AsRef<Path> + std::fmt::Debug) -> Result<(),
let _ = tokio::fs::remove_file(path).await;
let listener = UnixListener::bind(path).map_err(ServiceError::Bind)?;
// Phase 1 of p0-20: probe the AI backend at startup so the chosen path
// is observable in service logs. The strategic search call site below
// still uses CPU rollouts via `Tree::simulate_parallel` — Phase 2 wires
// the boot-probed backend into the search itself. This call exists so
// operators see "[mc-ai backend] Cpu (...)" in `mcts-server.log` and
// can confirm the deployed binary is on the expected backend.
let ai_backend = mc_ai::backend::AiBackend::probe();
// Probe the AI backend at startup. The strategic search runner below
// dispatches through this backend (Cpu or Gpu(...)).
let ai_backend = BACKEND.get_or_init(mc_ai::backend::AiBackend::probe);
info!(backend = %ai_backend.name(), "AiBackend probed");
info!("listening");
@ -93,18 +96,10 @@ fn dispatch(request: Request) -> Response {
}
Response::MctsBatchResult { results }
}
Request::SearchAction(job) => match run_search_action(&job) {
Ok(result) => Response::SearchActionResult(result),
Request::SearchActionViaAbstract(job) => match run_search_action_via_abstract(&job) {
Ok(result) => Response::SearchActionViaAbstractResult(result),
Err(msg) => Response::Error { message: msg },
},
// p0-20 Phase A v3 — schema-only stub. The runner lands in Phase C.
// Returning a structured error here lets callers exercise the wire
// shape (encode/decode round-trip) without claiming server support.
Request::SearchActionViaAbstract(_) => Response::Error {
message:
"SearchActionViaAbstract is schema-only in Phase A; runner ships in Phase C"
.to_owned(),
},
}
}
@ -165,78 +160,106 @@ fn run_job(job: &MctsJob) -> Result<MctsResult, String> {
})
}
/// Run a full MCTS tree search from a serialised [`McSnapshot`] and return
/// the best [`McAction`] as a string.
/// Run a full abstract-rollout MCTS search and return the best
/// [`mc_ai::policy::ActionKind`] as a [`SearchActionResult`].
///
/// Mirrors `GdMcTreeController::choose_action_with_stats` exactly — uses the
/// same `Tree::simulate_parallel` + rollout path so action quality is identical
/// to the local in-process path. GPU context is `None` for v1 (CPU only);
/// `path: "cpu"` is set in the response accordingly.
fn run_search_action(job: &SearchActionJob) -> Result<SearchActionResult, String> {
use mc_ai::mcts::XorShift64;
use mc_ai::mcts_tree::{rollout_snapshot, Tree};
use mc_turn::snapshot::{McAction, McSnapshot};
/// Pipeline:
/// 1. Rebuild `AbstractRolloutState` from the [`SearchActionViaAbstractJob`]'s
/// flat mirror.
/// 2. Construct `Tree<GameRolloutState>` with the supplied per-player priors.
/// 3. Loop `tree.iterate_gpu_batched(BATCH_SIZE, …)` until `rollout_budget`
/// is met OR `budget_ms` expires.
/// 4. Read `tree.most_visited_action_at_root()` — return as the canonical
/// debug-name string. Empty trees return `"Idle"`.
///
/// p0-20 Phase C — replaces the legacy `run_search_action` runner that drove
/// `Tree<McSnapshot>::simulate_parallel`. The McSnapshot strategic tree is
/// gone; the only path is abstract.
fn run_search_action_via_abstract(
job: &SearchActionViaAbstractJob,
) -> Result<SearchActionResult, String> {
use mc_ai::backend::AiBackend;
use mc_ai::mcts_tree::Tree;
use mc_ai::policy::ActionKind;
use mc_ai::rollout::GameRolloutState;
let mut snapshot: McSnapshot = serde_json::from_str(&job.snapshot_json)
.map_err(|e| format!("snapshot_json parse error: {e}"))?;
snapshot.active_player = job.root_player;
let pi = job.root_player as usize;
let depth = job.depth;
let base_seed = job.seed;
let n_rollouts = job.n_rollouts.max(1) as usize;
let budget = if job.budget_ms > 0 { Some(job.budget_ms) } else { None };
let pod = job.abstract_state.to_pod();
let priors = job.priors;
let mut tree = Tree::new(GameRolloutState::new(pod, priors));
tree.use_priors = true;
tree.root_player = job.root_player;
let mut tree = Tree::new(snapshot);
tree.use_priors = job.use_priors;
let backend = BACKEND.get_or_init(AiBackend::probe);
let rollout_fn = move |snap: &McSnapshot, rng: &mut XorShift64| -> f32 {
let step_fn = |s: &McSnapshot, _d: u32, rng: &mut XorShift64| {
let actions = s.legal_actions();
if actions.is_empty() {
return s.clone();
const BATCH_SIZE: usize = 1024;
let total_budget = job.rollout_budget as usize;
let wall_budget = job.budget_ms;
let start = Instant::now();
let mut completed: usize = 0;
while completed < total_budget {
if let Some(b) = wall_budget {
if start.elapsed() >= Duration::from_millis(b) {
break;
}
let idx = rng.next_u64() as usize % actions.len();
s.step(&actions[idx])
};
let score_fn = |s: &McSnapshot| -> f32 {
if let Some(winner) = s.winner() {
if winner == pi { 1.0 } else { 0.0 }
} else {
s.heuristic_value(pi.min(s.players.len().saturating_sub(1)))
}
};
rollout_snapshot(snap, rng, depth, &step_fn, &score_fn)
}
let remaining = total_budget - completed;
let this_batch = remaining.min(BATCH_SIZE);
let dispatched = tree.iterate_gpu_batched(
this_batch,
job.base_seed.wrapping_add(completed as u64),
wall_budget,
backend,
);
if dispatched == 0 {
break;
}
completed += dispatched;
}
let action = tree
.most_visited_action_at_root()
.unwrap_or(ActionKind::Idle);
// Win rate at the chosen child.
let mut chosen_visits = 0u32;
let mut chosen_wins = 0.0f32;
for &ci in &tree.root().children {
let n = &tree.nodes[ci];
if n.action == Some(action) {
chosen_visits = n.visits;
chosen_wins = n.wins;
break;
}
}
let win_rate = if chosen_visits > 0 {
chosen_wins / chosen_visits as f32
} else {
0.5
};
let start = std::time::Instant::now();
tree.simulate_parallel(n_rollouts, base_seed, rollout_fn, budget);
let took_ms = start.elapsed().as_millis().min(u32::MAX as u128) as u32;
// Robust child: highest visit count.
let root_children = tree.root().children.clone();
let best_child = root_children
.into_iter()
.max_by_key(|&ci| tree.nodes[ci].visits);
let (action, win_rate) = if let Some(ci) = best_child {
let n = &tree.nodes[ci];
let rate = if n.visits > 0 { n.wins / n.visits as f32 } else { 0.5 };
(n.action.clone().unwrap_or(McAction::Idle), rate)
} else {
(McAction::Idle, 0.5)
let action_name = match action {
ActionKind::Build => "Build",
ActionKind::Attack => "Attack",
ActionKind::Settle => "Settle",
ActionKind::Research => "Research",
ActionKind::Defend => "Defend",
ActionKind::Trade => "Trade",
ActionKind::ContinueWar => "ContinueWar",
ActionKind::MakePeace => "MakePeace",
ActionKind::Idle => "Idle",
ActionKind::CommandFormation => "CommandFormation",
ActionKind::SetRallyPoint => "SetRallyPoint",
};
let n_completed = tree.root().visits;
let path = if backend.is_gpu() { "gpu" } else { "cpu" }.to_string();
Ok(SearchActionResult {
action: match action {
McAction::Idle => "Idle".to_owned(),
McAction::FoundCity => "FoundCity".to_owned(),
McAction::SpawnUnit => "SpawnUnit".to_owned(),
},
action: action_name.to_string(),
win_rate,
n_rollouts: n_completed,
n_rollouts: tree.root().visits,
took_ms,
path: "cpu".to_owned(),
path,
})
}

View file

@ -30,7 +30,6 @@ pub mod game_state;
pub mod combat_event;
pub mod processor;
pub mod prologue;
pub mod snapshot;
pub mod spatial_index;
pub mod victory;
pub mod courier_resolver;
@ -65,6 +64,5 @@ pub use prologue::{
StartMode, Wanderer, WandererDirection, DEFAULT_LUCKY_INWARD_BIAS_PROB, LUCKY_MAX_BONUS_POP,
LUCKY_POP_PER_EXTRA_WANDERERS, MIN_WANDERERS_TO_FORM_TRIBE, TRIBE_CONVERGENCE_RADIUS,
};
pub use snapshot::{McAction, McSnapshot, PlayerSnap};
pub use spatial_index::LairIndexCsr;
pub use victory::{VictoryConfig, VictoryType};

View file

@ -1,378 +0,0 @@
//! Compact, cloneable game snapshot for MCTS rollouts.
//!
//! `McSnapshot` captures the fields `TurnProcessor::step` actually mutates
//! (economy, city count, unit count, turn counter) without carrying the full
//! `GridState` — keeping it cheap to clone across rayon threads.
//!
//! `McSnapshot::step` is byte-identical to `TurnProcessor::step` for every
//! field it tracks. Tests in this module assert that invariant.
use crate::game_state::{GameState, PlayerState};
use crate::processor::{LairCombatConfig, TurnProcessor};
use mc_ai::evaluator::ScoringWeights;
use mc_ai::mcts_tree::TreeState;
use serde::{Deserialize, Serialize};
// ── Action ──────────────────────────────────────────────────────────────────
/// An atomic AI decision that `McSnapshot::step` can apply.
///
/// Only economy-relevant choices are modelled here; spatial decisions
/// (unit movement, attack targeting) are out of scope until B-phase GPU work.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum McAction {
/// Do nothing — let the economy phase run with no strategic override.
Idle,
/// Attempt to found a new city this turn (burns expansion points).
FoundCity,
/// Produce a unit in the first city that can afford it.
SpawnUnit,
}
// ── Per-player compact state ─────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlayerSnap {
pub gold: i32,
pub city_count: u32,
pub unit_count: u32,
pub expansion_points: u32,
pub culture_total: i64,
/// Copied from `PlayerState::strategic_axes["wealth"]`.
pub wealth: u8,
/// Copied from `PlayerState::strategic_axes["expansion"]`.
pub expansion_axis: u8,
/// Copied from `PlayerState::strategic_axes["production"]`.
pub production_axis: u8,
pub scoring_weights: ScoringWeights,
}
impl PlayerSnap {
pub fn from_player(p: &PlayerState) -> Self {
Self {
gold: p.gold,
city_count: p.cities.len() as u32,
unit_count: p.units.len() as u32,
expansion_points: p.expansion_points,
culture_total: p.culture_total,
wealth: *p.strategic_axes.get("wealth").unwrap_or(&2),
expansion_axis: *p.strategic_axes.get("expansion").unwrap_or(&2),
production_axis: *p.strategic_axes.get("production").unwrap_or(&2),
scoring_weights: p.scoring_weights.clone(),
}
}
}
// ── Snapshot ─────────────────────────────────────────────────────────────────
/// Lightweight game snapshot. `Clone + Send` — safe to scatter across rayon threads.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McSnapshot {
pub turn: u32,
pub players: Vec<PlayerSnap>,
pub config: LairCombatConfig,
pub victory_city_count: u8,
/// Index of the player whose turn is being decided (set by the MCTS caller).
/// Used by `TreeState::action_prior` to look up that player's `scoring_weights`.
/// Defaults to 0; `ai.rs::choose_action` sets it to `player_index`.
pub active_player: u8,
}
impl McSnapshot {
pub fn from_game_state(state: &GameState, processor: &TurnProcessor) -> Self {
Self {
turn: state.turn,
players: state.players.iter().map(PlayerSnap::from_player).collect(),
config: processor.lair_combat_config.clone(),
victory_city_count: processor.victory_city_count,
active_player: 0,
}
}
/// Advance one turn deterministically. Economy + production + founding +
/// unit-spawn phases only (no spatial movement or combat — grid is absent).
///
/// Byte-identical to TurnProcessor::step for the fields McSnapshot tracks.
pub fn step(&self, action: &McAction) -> McSnapshot {
let mut next = self.clone();
next.turn += 1;
for p in &mut next.players {
// Phase 1: economy
let gold_in = p.wealth as i32
* p.city_count as i32
* self.config.gold_per_wealth_per_city;
p.gold = p.gold.saturating_add(gold_in);
// Phase 1b: culture (uses expansion_axis as culture proxy, matching processor)
let culture_per_turn = p.expansion_axis as i64 * p.city_count as i64 * 25;
p.culture_total += culture_per_turn;
// Phase 3: expansion points
p.expansion_points = p
.expansion_points
.saturating_add(self.config.expansion_per_axis_per_turn * p.expansion_axis as u32);
// Phase 3: city founding (only when action matches and player has enough points)
let max_cities = self.config.max_cities_per_player_base as u32
+ 3 * p.expansion_axis as u32;
if p.city_count < max_cities
&& p.expansion_points >= self.config.city_founding_cost
&& *action == McAction::FoundCity
{
p.city_count += 1;
p.expansion_points -= self.config.city_founding_cost;
}
// Phase 4: unit production
let prod_budget =
self.config.prod_per_axis_per_city * p.production_axis as u32 * p.city_count;
if prod_budget >= self.config.unit_spawn_cost && *action == McAction::SpawnUnit {
p.unit_count += 1;
}
}
next
}
/// Heuristic leaf-value for player `pi` in [0, 1].
///
/// Mirrors the `ScoringWeights` MCTS leaf evaluator fields. Normalized
/// against a soft maximum so values stay in a comparable range.
pub fn heuristic_value(&self, pi: usize) -> f32 {
let p = &self.players[pi];
let w = &p.scoring_weights;
let raw = w.city_expansion * p.city_count as f32
+ w.yield_gold * p.gold.max(0) as f32 * 0.01
+ w.yield_culture * p.culture_total as f32 * 0.0001
+ w.pop_value * p.unit_count as f32;
// Soft-normalize: sigmoid-like squash so value stays in (0, 1).
raw / (raw + 50.0)
}
/// True if any player has met the city-count victory condition.
pub fn is_terminal(&self) -> bool {
self.players
.iter()
.any(|p| p.city_count >= self.victory_city_count as u32)
}
/// Winner index if terminal, else None.
pub fn winner(&self) -> Option<usize> {
self.players
.iter()
.enumerate()
.find(|(_, p)| p.city_count >= self.victory_city_count as u32)
.map(|(i, _)| i)
}
/// All legal actions from this snapshot. Empty only if terminal.
pub fn legal_actions(&self) -> Vec<McAction> {
if self.is_terminal() {
return Vec::new();
}
vec![McAction::Idle, McAction::FoundCity, McAction::SpawnUnit]
}
}
// ── TreeState impl ───────────────────────────────────────────────────────────
impl TreeState for McSnapshot {
type Action = McAction;
fn legal_actions(&self) -> Vec<McAction> {
self.legal_actions()
}
fn apply(&self, action: &McAction) -> McSnapshot {
self.step(action)
}
fn is_terminal(&self) -> bool {
self.is_terminal()
}
/// Personality-weighted action prior for PUCT selection (p0-38).
///
/// Maps `McAction` variants to the `active_player`'s `ScoringWeights`:
/// - `SpawnUnit` → `military_base` (aggressive clans prefer early armies)
/// - `FoundCity` → `expansion` (expansionist clans prefer new cities)
/// - `Idle` → 1.0 baseline (everyone considers doing nothing)
///
/// Returns unnormalised weights — `Tree::best_puct_child` normalises via
/// the PUCT formula rather than requiring pre-normalised priors.
fn action_prior(&self, action: &McAction) -> f32 {
let Some(p) = self.players.get(self.active_player as usize) else {
return 1.0;
};
let w = &p.scoring_weights;
match action {
McAction::SpawnUnit => w.military_base.max(0.1),
McAction::FoundCity => w.expansion_base.max(0.1),
McAction::Idle => 1.0,
}
}
}
// ── Tests ────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use crate::game_state::{CityEcology, MapUnit, PlayerState};
use mc_ai::evaluator::ScoringWeights;
use mc_city::CityState;
use std::collections::BTreeMap;
fn make_player(gold: i32, cities: usize, units: usize) -> PlayerState {
let mut axes = BTreeMap::new();
axes.insert("wealth".into(), 3u8);
axes.insert("expansion".into(), 2u8);
axes.insert("production".into(), 2u8);
axes.insert("culture".into(), 2u8);
PlayerState {
player_index: 0,
gold,
cities: (0..cities).map(|_| CityState::default()).collect(),
unit_upkeep: vec![0; units],
strategic_axes: axes,
scoring_weights: ScoringWeights::default(),
expansion_points: 0,
city_buildings: vec![vec![]; cities],
city_improvements: Default::default(),
city_ecology: vec![CityEcology::default(); cities],
tech_state: None,
science_pool: 0,
player_tech: None,
science_yield: 0,
units: (0..units)
.map(|_| MapUnit {
col: 0,
row: 0,
hp: 10,
max_hp: 10,
attack: 5,
defense: 5,
is_fortified: false,
unit_id: "dwarf_warrior".into(),
held_resources: Vec::new(),
patrol_order: None,
..Default::default()
})
.collect(),
city_positions: vec![(0, 0); cities],
capital_position: Some((0, 0)),
culture_total: 0,
culture_pool: mc_culture::CulturePool::default(),
arcane_lore_pop_deducted: false,
traded_luxuries: Default::default(),
relations: Default::default(),
strategic_ledger: Default::default(),
wonders_built: Default::default(),
explored_deposits: Default::default(),
..Default::default()
}
}
fn make_state(cities: usize) -> GameState {
GameState {
turn: 0,
players: vec![make_player(100, cities, 2), make_player(80, cities, 1)],
grid: None,
pending_pvp_attacks: Default::default(),
..Default::default()
}
}
/// Economy formula: gold_in = wealth * city_count * gold_per_city.
/// Verify McSnapshot::step matches TurnProcessor::step for 10 varying states.
#[test]
fn step_economy_matches_turn_processor_for_10_seeds() {
let processor = TurnProcessor::new(300);
for seed in 0u32..10 {
let mut state = make_state((seed % 3 + 1) as usize);
state.players[0].gold = seed as i32 * 50;
state.players[1].gold = seed as i32 * 30;
let snap = McSnapshot::from_game_state(&state, &processor);
let next_snap = snap.step(&McAction::Idle);
// Run full processor step (no grid, so no movement/combat)
processor.step(&mut state);
for (pi, p) in state.players.iter().enumerate() {
assert_eq!(
next_snap.players[pi].gold,
p.gold,
"seed={seed} pi={pi}: snapshot gold {} != processor gold {}",
next_snap.players[pi].gold,
p.gold
);
assert_eq!(
next_snap.players[pi].city_count,
p.cities.len() as u32,
"seed={seed} pi={pi}: city count mismatch"
);
}
}
}
#[test]
fn step_idle_does_not_found_city() {
let processor = TurnProcessor::new(300);
let state = make_state(2);
let mut snap = McSnapshot::from_game_state(&state, &processor);
snap.players[0].expansion_points = 100;
let next = snap.step(&McAction::Idle);
assert_eq!(next.players[0].city_count, 2, "Idle must not found a city");
}
#[test]
fn step_found_city_spends_expansion_points() {
let processor = TurnProcessor::new(300);
let state = make_state(1);
let mut snap = McSnapshot::from_game_state(&state, &processor);
snap.players[0].expansion_points = 50;
let next = snap.step(&McAction::FoundCity);
assert_eq!(next.players[0].city_count, 2, "FoundCity must add a city");
let cost = snap.config.city_founding_cost;
let earned =
snap.config.expansion_per_axis_per_turn * snap.players[0].expansion_axis as u32;
assert_eq!(
next.players[0].expansion_points,
50 - cost + earned,
"expansion_points after founding"
);
}
#[test]
fn heuristic_value_returns_value_in_unit_interval() {
let processor = TurnProcessor::new(300);
let state = make_state(3);
let snap = McSnapshot::from_game_state(&state, &processor);
for pi in 0..snap.players.len() {
let v = snap.heuristic_value(pi);
assert!(v >= 0.0 && v < 1.0, "heuristic_value out of range: {v}");
}
}
#[test]
fn is_terminal_triggers_at_victory_threshold() {
let processor = TurnProcessor::new(300);
let state = make_state(1);
let mut snap = McSnapshot::from_game_state(&state, &processor);
snap.players[0].city_count = snap.victory_city_count as u32;
assert!(snap.is_terminal());
assert_eq!(snap.winner(), Some(0));
}
#[test]
fn legal_actions_empty_when_terminal() {
let processor = TurnProcessor::new(300);
let state = make_state(1);
let mut snap = McSnapshot::from_game_state(&state, &processor);
snap.players[0].city_count = snap.victory_city_count as u32;
assert!(snap.legal_actions().is_empty());
}
}