fix(api-gdext): 🐛 remove ai_gpu_rollout dependency
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
587d4c9934
commit
d840855d81
8 changed files with 488 additions and 1234 deletions
|
|
@ -130,7 +130,9 @@ static func _apply_mcts_strategic_override(player: RefCounted) -> void:
|
|||
else MCTS_ROLLOUT_COUNT_EARLY)
|
||||
ctrl.set_rollout_budget(budget)
|
||||
ctrl.set_rollout_depth(MCTS_ROLLOUT_DEPTH)
|
||||
ctrl.set_gpu_enabled(OS.get_environment("AI_GPU_ROLLOUT") in ["1", "true", "TRUE", "True"])
|
||||
# p0-20 Phase C — backend (CPU vs GPU) is now decided once at boot by
|
||||
# AiBackend::probe(); the AI_GPU_ROLLOUT env var was deleted alongside
|
||||
# the McSnapshot strategic-MCTS path. MC_AI_BACKEND is the only knob.
|
||||
ctrl.set_priors_enabled(
|
||||
OS.get_environment("AI_MCTS_PRIORS") in ["1", "true", "TRUE", "True"]
|
||||
)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -6,7 +6,9 @@ use tracing::instrument;
|
|||
|
||||
use crate::error::ServiceError;
|
||||
use crate::framing::{read_frame, write_frame};
|
||||
use crate::protocol::{MctsJob, MctsResult, Request, Response, SearchActionJob, SearchActionResult};
|
||||
use crate::protocol::{
|
||||
MctsJob, MctsResult, Request, Response, SearchActionResult, SearchActionViaAbstractJob,
|
||||
};
|
||||
|
||||
/// A single-connection client that round-trips one [`Request`] to the service.
|
||||
///
|
||||
|
|
@ -81,22 +83,24 @@ pub async fn submit_batch(
|
|||
}
|
||||
}
|
||||
|
||||
/// Submit a full MCTS tree search job and return the best action with stats.
|
||||
/// Submit a full abstract-rollout MCTS tree search and return the best action
|
||||
/// with stats. p0-20 Phase C — only strategic-search entry point; the legacy
|
||||
/// McSnapshot-shaped `submit_search_action` is gone.
|
||||
///
|
||||
/// Maps `Response::Error` to [`ServiceError::Remote`].
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the transport fails or the service responds with `Error`.
|
||||
pub async fn submit_search_action(
|
||||
pub async fn submit_search_action_via_abstract(
|
||||
socket_path: impl AsRef<Path> + std::fmt::Debug,
|
||||
job: SearchActionJob,
|
||||
job: SearchActionViaAbstractJob,
|
||||
) -> Result<SearchActionResult, ServiceError> {
|
||||
match round_trip(socket_path, Request::SearchAction(job)).await? {
|
||||
Response::SearchActionResult(r) => Ok(r),
|
||||
match round_trip(socket_path, Request::SearchActionViaAbstract(job)).await? {
|
||||
Response::SearchActionViaAbstractResult(r) => Ok(r),
|
||||
Response::Error { message } => Err(ServiceError::Remote(message)),
|
||||
other => Err(ServiceError::Remote(format!(
|
||||
"unexpected response to SearchAction request: {other:?}"
|
||||
"unexpected response to SearchActionViaAbstract request: {other:?}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,10 +18,15 @@
|
|||
//! Version 1 (p1-27b): `Request::Mcts(MctsJob)` / `Response::MctsResult(MctsResult)` and
|
||||
//! `Request::MctsBatch` / `Response::MctsBatchResult`. State encoded as
|
||||
//! `MctsJobState` JSON inside each job.
|
||||
//! Version 2 (p1-27c): `Request::SearchAction(SearchActionJob)` /
|
||||
//! `Response::SearchActionResult(SearchActionResult)`. Carries a full
|
||||
//! `McSnapshot` JSON; server runs `Tree::simulate_parallel` and returns
|
||||
//! `{ action, win_rate, n_rollouts, took_ms, path }`.
|
||||
//! Version 3 (p0-20 Phase C):
|
||||
//! `Request::SearchActionViaAbstract(SearchActionViaAbstractJob)` /
|
||||
//! `Response::SearchActionViaAbstractResult(SearchActionResult)`. Carries an
|
||||
//! `AbstractJobState` (mirror of `mc_ai::abstract_state::AbstractRolloutState`)
|
||||
//! plus per-player `PersonalityPriors`; server runs
|
||||
//! `Tree<GameRolloutState>::iterate_gpu_batched` (CPU or GPU per probed
|
||||
//! `AiBackend`) and returns
|
||||
//! `{ action, win_rate, n_rollouts, took_ms, path }`. The McSnapshot-shaped
|
||||
//! `Request::SearchAction` from version 2 is **deleted**.
|
||||
//!
|
||||
//! ## Relationship to @model-boss
|
||||
//!
|
||||
|
|
|
|||
|
|
@ -83,43 +83,23 @@ pub struct MctsResult {
|
|||
pub took_ms: u32,
|
||||
}
|
||||
|
||||
/// Input for a full MCTS tree search producing an action decision.
|
||||
/// Result of a full MCTS tree search rooted at an [`AbstractJobState`].
|
||||
///
|
||||
/// Carries a JSON-serialised [`mc_turn::snapshot::McSnapshot`] so the server
|
||||
/// can run `Tree::simulate_parallel` over the exact same state type
|
||||
/// `GdMcTreeController` uses locally, giving byte-compatible action selection.
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct SearchActionJob {
|
||||
/// JSON-encoded `McSnapshot` (serde form of `mc_turn::snapshot::McSnapshot`).
|
||||
pub snapshot_json: String,
|
||||
/// Index of the player MCTS is deciding for (matches `McSnapshot::active_player`).
|
||||
pub root_player: u8,
|
||||
/// Number of MCTS iterations (`Tree::simulate_parallel` budget).
|
||||
pub n_rollouts: u32,
|
||||
/// Rollout horizon — turns per random playout.
|
||||
pub depth: u32,
|
||||
/// Base RNG seed for the tree simulation.
|
||||
pub seed: u64,
|
||||
/// When `true`, use PUCT selection with per-node priors (p0-38). When
|
||||
/// `false`, fall back to classical UCB1. Mirrors `GdMcTreeController::priors_enabled`.
|
||||
pub use_priors: bool,
|
||||
/// Per-decision wall-clock budget in milliseconds (`0` = unbounded). Mirrors
|
||||
/// `GdMcTreeController::budget_ms`.
|
||||
pub budget_ms: u64,
|
||||
}
|
||||
|
||||
/// Result of a [`Request::SearchAction`] tree search.
|
||||
/// p0-20 Phase C — replaces the legacy McSnapshot-shaped `SearchActionResult`.
|
||||
/// The action string is one of [`mc_ai::policy::ActionKind`]'s canonical
|
||||
/// debug names (`"Build"`, `"Settle"`, `"Idle"`, …).
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct SearchActionResult {
|
||||
/// The chosen action: `"Idle"`, `"FoundCity"`, or `"SpawnUnit"`.
|
||||
/// The chosen [`mc_ai::policy::ActionKind`] as a string.
|
||||
pub action: String,
|
||||
/// Win-rate estimate for the chosen action from the best child node.
|
||||
pub win_rate: f32,
|
||||
/// Total iterations completed inside `Tree::simulate_parallel`.
|
||||
/// Total iterations completed inside `Tree::iterate_gpu_batched`.
|
||||
pub n_rollouts: u32,
|
||||
/// Wall-clock milliseconds the search took.
|
||||
pub took_ms: u32,
|
||||
/// Compute path used by the server (`"cpu"` for v1; `"gpu"` reserved).
|
||||
/// Compute path used by the server: `"gpu"` when the boot-probed
|
||||
/// [`mc_ai::backend::AiBackend`] reports `is_gpu()`, else `"cpu"`.
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
|
|
@ -305,12 +285,9 @@ pub enum Request {
|
|||
Mcts(MctsJob),
|
||||
/// Run multiple flat-rollout jobs sequentially (p1-27b primitive).
|
||||
MctsBatch { jobs: Vec<MctsJob> },
|
||||
/// Run a full MCTS tree search and return the best action (p1-27c).
|
||||
SearchAction(SearchActionJob),
|
||||
/// p0-20 Phase A v3 (schema only) — abstract-rollout MCTS search.
|
||||
/// Server routing via `iterate_gpu_batched` + `most_visited_action_at_root`
|
||||
/// lands in Phase C; this variant exists in Phase A so callers (api-gdext,
|
||||
/// integration tests) can encode requests against a stable wire shape.
|
||||
/// p0-20 Phase C — abstract-rollout MCTS search. Routes through
|
||||
/// `Tree<GameRolloutState>::iterate_gpu_batched` server-side, returning
|
||||
/// the best [`mc_ai::policy::ActionKind`] as a string.
|
||||
SearchActionViaAbstract(SearchActionViaAbstractJob),
|
||||
}
|
||||
|
||||
|
|
@ -323,8 +300,8 @@ pub enum Response {
|
|||
MctsResult(MctsResult),
|
||||
/// Results for a [`Request::MctsBatch`], one entry per input job.
|
||||
MctsBatchResult { results: Vec<MctsResult> },
|
||||
/// Result for [`Request::SearchAction`].
|
||||
SearchActionResult(SearchActionResult),
|
||||
/// Result for [`Request::SearchActionViaAbstract`] (p0-20 Phase C).
|
||||
SearchActionViaAbstractResult(SearchActionResult),
|
||||
/// The service encountered an error processing the request.
|
||||
Error { message: String },
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,16 +1,23 @@
|
|||
/// Async Unix-socket server that processes [`Request`](crate::protocol::Request) frames.
|
||||
use std::path::Path;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tracing::{error, info, instrument, warn};
|
||||
|
||||
use crate::error::ServiceError;
|
||||
use crate::framing::{read_frame, write_frame};
|
||||
use crate::protocol::{MctsJob, MctsJobState, MctsResult, Request, Response, SearchActionJob, SearchActionResult};
|
||||
use crate::protocol::{
|
||||
MctsJob, MctsJobState, MctsResult, Request, Response, SearchActionResult,
|
||||
SearchActionViaAbstractJob,
|
||||
};
|
||||
|
||||
/// Default socket path used when none is provided.
|
||||
pub const DEFAULT_SOCKET_PATH: &str = "/tmp/mc-mcts.sock";
|
||||
|
||||
static BACKEND: OnceLock<mc_ai::backend::AiBackend> = OnceLock::new();
|
||||
|
||||
/// Run the MCTS service, listening on `socket_path`.
|
||||
///
|
||||
/// Removes a stale socket file before binding (handles unclean prior shutdown).
|
||||
|
|
@ -26,13 +33,9 @@ pub async fn run(socket_path: impl AsRef<Path> + std::fmt::Debug) -> Result<(),
|
|||
let _ = tokio::fs::remove_file(path).await;
|
||||
let listener = UnixListener::bind(path).map_err(ServiceError::Bind)?;
|
||||
|
||||
// Phase 1 of p0-20: probe the AI backend at startup so the chosen path
|
||||
// is observable in service logs. The strategic search call site below
|
||||
// still uses CPU rollouts via `Tree::simulate_parallel` — Phase 2 wires
|
||||
// the boot-probed backend into the search itself. This call exists so
|
||||
// operators see "[mc-ai backend] Cpu (...)" in `mcts-server.log` and
|
||||
// can confirm the deployed binary is on the expected backend.
|
||||
let ai_backend = mc_ai::backend::AiBackend::probe();
|
||||
// Probe the AI backend at startup. The strategic search runner below
|
||||
// dispatches through this backend (Cpu or Gpu(...)).
|
||||
let ai_backend = BACKEND.get_or_init(mc_ai::backend::AiBackend::probe);
|
||||
info!(backend = %ai_backend.name(), "AiBackend probed");
|
||||
|
||||
info!("listening");
|
||||
|
|
@ -93,18 +96,10 @@ fn dispatch(request: Request) -> Response {
|
|||
}
|
||||
Response::MctsBatchResult { results }
|
||||
}
|
||||
Request::SearchAction(job) => match run_search_action(&job) {
|
||||
Ok(result) => Response::SearchActionResult(result),
|
||||
Request::SearchActionViaAbstract(job) => match run_search_action_via_abstract(&job) {
|
||||
Ok(result) => Response::SearchActionViaAbstractResult(result),
|
||||
Err(msg) => Response::Error { message: msg },
|
||||
},
|
||||
// p0-20 Phase A v3 — schema-only stub. The runner lands in Phase C.
|
||||
// Returning a structured error here lets callers exercise the wire
|
||||
// shape (encode/decode round-trip) without claiming server support.
|
||||
Request::SearchActionViaAbstract(_) => Response::Error {
|
||||
message:
|
||||
"SearchActionViaAbstract is schema-only in Phase A; runner ships in Phase C"
|
||||
.to_owned(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -165,78 +160,106 @@ fn run_job(job: &MctsJob) -> Result<MctsResult, String> {
|
|||
})
|
||||
}
|
||||
|
||||
/// Run a full MCTS tree search from a serialised [`McSnapshot`] and return
|
||||
/// the best [`McAction`] as a string.
|
||||
/// Run a full abstract-rollout MCTS search and return the best
|
||||
/// [`mc_ai::policy::ActionKind`] as a [`SearchActionResult`].
|
||||
///
|
||||
/// Mirrors `GdMcTreeController::choose_action_with_stats` exactly — uses the
|
||||
/// same `Tree::simulate_parallel` + rollout path so action quality is identical
|
||||
/// to the local in-process path. GPU context is `None` for v1 (CPU only);
|
||||
/// `path: "cpu"` is set in the response accordingly.
|
||||
fn run_search_action(job: &SearchActionJob) -> Result<SearchActionResult, String> {
|
||||
use mc_ai::mcts::XorShift64;
|
||||
use mc_ai::mcts_tree::{rollout_snapshot, Tree};
|
||||
use mc_turn::snapshot::{McAction, McSnapshot};
|
||||
/// Pipeline:
|
||||
/// 1. Rebuild `AbstractRolloutState` from the [`SearchActionViaAbstractJob`]'s
|
||||
/// flat mirror.
|
||||
/// 2. Construct `Tree<GameRolloutState>` with the supplied per-player priors.
|
||||
/// 3. Loop `tree.iterate_gpu_batched(BATCH_SIZE, …)` until `rollout_budget`
|
||||
/// is met OR `budget_ms` expires.
|
||||
/// 4. Read `tree.most_visited_action_at_root()` — return as the canonical
|
||||
/// debug-name string. Empty trees return `"Idle"`.
|
||||
///
|
||||
/// p0-20 Phase C — replaces the legacy `run_search_action` runner that drove
|
||||
/// `Tree<McSnapshot>::simulate_parallel`. The McSnapshot strategic tree is
|
||||
/// gone; the only path is abstract.
|
||||
fn run_search_action_via_abstract(
|
||||
job: &SearchActionViaAbstractJob,
|
||||
) -> Result<SearchActionResult, String> {
|
||||
use mc_ai::backend::AiBackend;
|
||||
use mc_ai::mcts_tree::Tree;
|
||||
use mc_ai::policy::ActionKind;
|
||||
use mc_ai::rollout::GameRolloutState;
|
||||
|
||||
let mut snapshot: McSnapshot = serde_json::from_str(&job.snapshot_json)
|
||||
.map_err(|e| format!("snapshot_json parse error: {e}"))?;
|
||||
snapshot.active_player = job.root_player;
|
||||
let pi = job.root_player as usize;
|
||||
let depth = job.depth;
|
||||
let base_seed = job.seed;
|
||||
let n_rollouts = job.n_rollouts.max(1) as usize;
|
||||
let budget = if job.budget_ms > 0 { Some(job.budget_ms) } else { None };
|
||||
let pod = job.abstract_state.to_pod();
|
||||
let priors = job.priors;
|
||||
let mut tree = Tree::new(GameRolloutState::new(pod, priors));
|
||||
tree.use_priors = true;
|
||||
tree.root_player = job.root_player;
|
||||
|
||||
let mut tree = Tree::new(snapshot);
|
||||
tree.use_priors = job.use_priors;
|
||||
let backend = BACKEND.get_or_init(AiBackend::probe);
|
||||
|
||||
let rollout_fn = move |snap: &McSnapshot, rng: &mut XorShift64| -> f32 {
|
||||
let step_fn = |s: &McSnapshot, _d: u32, rng: &mut XorShift64| {
|
||||
let actions = s.legal_actions();
|
||||
if actions.is_empty() {
|
||||
return s.clone();
|
||||
const BATCH_SIZE: usize = 1024;
|
||||
let total_budget = job.rollout_budget as usize;
|
||||
let wall_budget = job.budget_ms;
|
||||
|
||||
let start = Instant::now();
|
||||
let mut completed: usize = 0;
|
||||
while completed < total_budget {
|
||||
if let Some(b) = wall_budget {
|
||||
if start.elapsed() >= Duration::from_millis(b) {
|
||||
break;
|
||||
}
|
||||
let idx = rng.next_u64() as usize % actions.len();
|
||||
s.step(&actions[idx])
|
||||
};
|
||||
let score_fn = |s: &McSnapshot| -> f32 {
|
||||
if let Some(winner) = s.winner() {
|
||||
if winner == pi { 1.0 } else { 0.0 }
|
||||
} else {
|
||||
s.heuristic_value(pi.min(s.players.len().saturating_sub(1)))
|
||||
}
|
||||
};
|
||||
rollout_snapshot(snap, rng, depth, &step_fn, &score_fn)
|
||||
}
|
||||
let remaining = total_budget - completed;
|
||||
let this_batch = remaining.min(BATCH_SIZE);
|
||||
let dispatched = tree.iterate_gpu_batched(
|
||||
this_batch,
|
||||
job.base_seed.wrapping_add(completed as u64),
|
||||
wall_budget,
|
||||
backend,
|
||||
);
|
||||
if dispatched == 0 {
|
||||
break;
|
||||
}
|
||||
completed += dispatched;
|
||||
}
|
||||
|
||||
let action = tree
|
||||
.most_visited_action_at_root()
|
||||
.unwrap_or(ActionKind::Idle);
|
||||
|
||||
// Win rate at the chosen child.
|
||||
let mut chosen_visits = 0u32;
|
||||
let mut chosen_wins = 0.0f32;
|
||||
for &ci in &tree.root().children {
|
||||
let n = &tree.nodes[ci];
|
||||
if n.action == Some(action) {
|
||||
chosen_visits = n.visits;
|
||||
chosen_wins = n.wins;
|
||||
break;
|
||||
}
|
||||
}
|
||||
let win_rate = if chosen_visits > 0 {
|
||||
chosen_wins / chosen_visits as f32
|
||||
} else {
|
||||
0.5
|
||||
};
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
tree.simulate_parallel(n_rollouts, base_seed, rollout_fn, budget);
|
||||
let took_ms = start.elapsed().as_millis().min(u32::MAX as u128) as u32;
|
||||
|
||||
// Robust child: highest visit count.
|
||||
let root_children = tree.root().children.clone();
|
||||
let best_child = root_children
|
||||
.into_iter()
|
||||
.max_by_key(|&ci| tree.nodes[ci].visits);
|
||||
|
||||
let (action, win_rate) = if let Some(ci) = best_child {
|
||||
let n = &tree.nodes[ci];
|
||||
let rate = if n.visits > 0 { n.wins / n.visits as f32 } else { 0.5 };
|
||||
(n.action.clone().unwrap_or(McAction::Idle), rate)
|
||||
} else {
|
||||
(McAction::Idle, 0.5)
|
||||
let action_name = match action {
|
||||
ActionKind::Build => "Build",
|
||||
ActionKind::Attack => "Attack",
|
||||
ActionKind::Settle => "Settle",
|
||||
ActionKind::Research => "Research",
|
||||
ActionKind::Defend => "Defend",
|
||||
ActionKind::Trade => "Trade",
|
||||
ActionKind::ContinueWar => "ContinueWar",
|
||||
ActionKind::MakePeace => "MakePeace",
|
||||
ActionKind::Idle => "Idle",
|
||||
ActionKind::CommandFormation => "CommandFormation",
|
||||
ActionKind::SetRallyPoint => "SetRallyPoint",
|
||||
};
|
||||
|
||||
let n_completed = tree.root().visits;
|
||||
let path = if backend.is_gpu() { "gpu" } else { "cpu" }.to_string();
|
||||
|
||||
Ok(SearchActionResult {
|
||||
action: match action {
|
||||
McAction::Idle => "Idle".to_owned(),
|
||||
McAction::FoundCity => "FoundCity".to_owned(),
|
||||
McAction::SpawnUnit => "SpawnUnit".to_owned(),
|
||||
},
|
||||
action: action_name.to_string(),
|
||||
win_rate,
|
||||
n_rollouts: n_completed,
|
||||
n_rollouts: tree.root().visits,
|
||||
took_ms,
|
||||
path: "cpu".to_owned(),
|
||||
path,
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ pub mod game_state;
|
|||
pub mod combat_event;
|
||||
pub mod processor;
|
||||
pub mod prologue;
|
||||
pub mod snapshot;
|
||||
pub mod spatial_index;
|
||||
pub mod victory;
|
||||
pub mod courier_resolver;
|
||||
|
|
@ -65,6 +64,5 @@ pub use prologue::{
|
|||
StartMode, Wanderer, WandererDirection, DEFAULT_LUCKY_INWARD_BIAS_PROB, LUCKY_MAX_BONUS_POP,
|
||||
LUCKY_POP_PER_EXTRA_WANDERERS, MIN_WANDERERS_TO_FORM_TRIBE, TRIBE_CONVERGENCE_RADIUS,
|
||||
};
|
||||
pub use snapshot::{McAction, McSnapshot, PlayerSnap};
|
||||
pub use spatial_index::LairIndexCsr;
|
||||
pub use victory::{VictoryConfig, VictoryType};
|
||||
|
|
|
|||
|
|
@ -1,378 +0,0 @@
|
|||
//! Compact, cloneable game snapshot for MCTS rollouts.
|
||||
//!
|
||||
//! `McSnapshot` captures the fields `TurnProcessor::step` actually mutates
|
||||
//! (economy, city count, unit count, turn counter) without carrying the full
|
||||
//! `GridState` — keeping it cheap to clone across rayon threads.
|
||||
//!
|
||||
//! `McSnapshot::step` is byte-identical to `TurnProcessor::step` for every
|
||||
//! field it tracks. Tests in this module assert that invariant.
|
||||
|
||||
use crate::game_state::{GameState, PlayerState};
|
||||
use crate::processor::{LairCombatConfig, TurnProcessor};
|
||||
use mc_ai::evaluator::ScoringWeights;
|
||||
use mc_ai::mcts_tree::TreeState;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ── Action ──────────────────────────────────────────────────────────────────
|
||||
|
||||
/// An atomic AI decision that `McSnapshot::step` can apply.
|
||||
///
|
||||
/// Only economy-relevant choices are modelled here; spatial decisions
|
||||
/// (unit movement, attack targeting) are out of scope until B-phase GPU work.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum McAction {
|
||||
/// Do nothing — let the economy phase run with no strategic override.
|
||||
Idle,
|
||||
/// Attempt to found a new city this turn (burns expansion points).
|
||||
FoundCity,
|
||||
/// Produce a unit in the first city that can afford it.
|
||||
SpawnUnit,
|
||||
}
|
||||
|
||||
// ── Per-player compact state ─────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlayerSnap {
|
||||
pub gold: i32,
|
||||
pub city_count: u32,
|
||||
pub unit_count: u32,
|
||||
pub expansion_points: u32,
|
||||
pub culture_total: i64,
|
||||
/// Copied from `PlayerState::strategic_axes["wealth"]`.
|
||||
pub wealth: u8,
|
||||
/// Copied from `PlayerState::strategic_axes["expansion"]`.
|
||||
pub expansion_axis: u8,
|
||||
/// Copied from `PlayerState::strategic_axes["production"]`.
|
||||
pub production_axis: u8,
|
||||
pub scoring_weights: ScoringWeights,
|
||||
}
|
||||
|
||||
impl PlayerSnap {
|
||||
pub fn from_player(p: &PlayerState) -> Self {
|
||||
Self {
|
||||
gold: p.gold,
|
||||
city_count: p.cities.len() as u32,
|
||||
unit_count: p.units.len() as u32,
|
||||
expansion_points: p.expansion_points,
|
||||
culture_total: p.culture_total,
|
||||
wealth: *p.strategic_axes.get("wealth").unwrap_or(&2),
|
||||
expansion_axis: *p.strategic_axes.get("expansion").unwrap_or(&2),
|
||||
production_axis: *p.strategic_axes.get("production").unwrap_or(&2),
|
||||
scoring_weights: p.scoring_weights.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Snapshot ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Lightweight game snapshot. `Clone + Send` — safe to scatter across rayon threads.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McSnapshot {
|
||||
pub turn: u32,
|
||||
pub players: Vec<PlayerSnap>,
|
||||
pub config: LairCombatConfig,
|
||||
pub victory_city_count: u8,
|
||||
/// Index of the player whose turn is being decided (set by the MCTS caller).
|
||||
/// Used by `TreeState::action_prior` to look up that player's `scoring_weights`.
|
||||
/// Defaults to 0; `ai.rs::choose_action` sets it to `player_index`.
|
||||
pub active_player: u8,
|
||||
}
|
||||
|
||||
impl McSnapshot {
|
||||
pub fn from_game_state(state: &GameState, processor: &TurnProcessor) -> Self {
|
||||
Self {
|
||||
turn: state.turn,
|
||||
players: state.players.iter().map(PlayerSnap::from_player).collect(),
|
||||
config: processor.lair_combat_config.clone(),
|
||||
victory_city_count: processor.victory_city_count,
|
||||
active_player: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Advance one turn deterministically. Economy + production + founding +
|
||||
/// unit-spawn phases only (no spatial movement or combat — grid is absent).
|
||||
///
|
||||
/// Byte-identical to TurnProcessor::step for the fields McSnapshot tracks.
|
||||
pub fn step(&self, action: &McAction) -> McSnapshot {
|
||||
let mut next = self.clone();
|
||||
next.turn += 1;
|
||||
|
||||
for p in &mut next.players {
|
||||
// Phase 1: economy
|
||||
let gold_in = p.wealth as i32
|
||||
* p.city_count as i32
|
||||
* self.config.gold_per_wealth_per_city;
|
||||
p.gold = p.gold.saturating_add(gold_in);
|
||||
|
||||
// Phase 1b: culture (uses expansion_axis as culture proxy, matching processor)
|
||||
let culture_per_turn = p.expansion_axis as i64 * p.city_count as i64 * 25;
|
||||
p.culture_total += culture_per_turn;
|
||||
|
||||
// Phase 3: expansion points
|
||||
p.expansion_points = p
|
||||
.expansion_points
|
||||
.saturating_add(self.config.expansion_per_axis_per_turn * p.expansion_axis as u32);
|
||||
|
||||
// Phase 3: city founding (only when action matches and player has enough points)
|
||||
let max_cities = self.config.max_cities_per_player_base as u32
|
||||
+ 3 * p.expansion_axis as u32;
|
||||
if p.city_count < max_cities
|
||||
&& p.expansion_points >= self.config.city_founding_cost
|
||||
&& *action == McAction::FoundCity
|
||||
{
|
||||
p.city_count += 1;
|
||||
p.expansion_points -= self.config.city_founding_cost;
|
||||
}
|
||||
|
||||
// Phase 4: unit production
|
||||
let prod_budget =
|
||||
self.config.prod_per_axis_per_city * p.production_axis as u32 * p.city_count;
|
||||
if prod_budget >= self.config.unit_spawn_cost && *action == McAction::SpawnUnit {
|
||||
p.unit_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
next
|
||||
}
|
||||
|
||||
/// Heuristic leaf-value for player `pi` in [0, 1].
|
||||
///
|
||||
/// Mirrors the `ScoringWeights` MCTS leaf evaluator fields. Normalized
|
||||
/// against a soft maximum so values stay in a comparable range.
|
||||
pub fn heuristic_value(&self, pi: usize) -> f32 {
|
||||
let p = &self.players[pi];
|
||||
let w = &p.scoring_weights;
|
||||
let raw = w.city_expansion * p.city_count as f32
|
||||
+ w.yield_gold * p.gold.max(0) as f32 * 0.01
|
||||
+ w.yield_culture * p.culture_total as f32 * 0.0001
|
||||
+ w.pop_value * p.unit_count as f32;
|
||||
// Soft-normalize: sigmoid-like squash so value stays in (0, 1).
|
||||
raw / (raw + 50.0)
|
||||
}
|
||||
|
||||
/// True if any player has met the city-count victory condition.
|
||||
pub fn is_terminal(&self) -> bool {
|
||||
self.players
|
||||
.iter()
|
||||
.any(|p| p.city_count >= self.victory_city_count as u32)
|
||||
}
|
||||
|
||||
/// Winner index if terminal, else None.
|
||||
pub fn winner(&self) -> Option<usize> {
|
||||
self.players
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, p)| p.city_count >= self.victory_city_count as u32)
|
||||
.map(|(i, _)| i)
|
||||
}
|
||||
|
||||
/// All legal actions from this snapshot. Empty only if terminal.
|
||||
pub fn legal_actions(&self) -> Vec<McAction> {
|
||||
if self.is_terminal() {
|
||||
return Vec::new();
|
||||
}
|
||||
vec![McAction::Idle, McAction::FoundCity, McAction::SpawnUnit]
|
||||
}
|
||||
}
|
||||
|
||||
// ── TreeState impl ───────────────────────────────────────────────────────────
|
||||
|
||||
impl TreeState for McSnapshot {
|
||||
type Action = McAction;
|
||||
|
||||
fn legal_actions(&self) -> Vec<McAction> {
|
||||
self.legal_actions()
|
||||
}
|
||||
|
||||
fn apply(&self, action: &McAction) -> McSnapshot {
|
||||
self.step(action)
|
||||
}
|
||||
|
||||
fn is_terminal(&self) -> bool {
|
||||
self.is_terminal()
|
||||
}
|
||||
|
||||
/// Personality-weighted action prior for PUCT selection (p0-38).
|
||||
///
|
||||
/// Maps `McAction` variants to the `active_player`'s `ScoringWeights`:
|
||||
/// - `SpawnUnit` → `military_base` (aggressive clans prefer early armies)
|
||||
/// - `FoundCity` → `expansion` (expansionist clans prefer new cities)
|
||||
/// - `Idle` → 1.0 baseline (everyone considers doing nothing)
|
||||
///
|
||||
/// Returns unnormalised weights — `Tree::best_puct_child` normalises via
|
||||
/// the PUCT formula rather than requiring pre-normalised priors.
|
||||
fn action_prior(&self, action: &McAction) -> f32 {
|
||||
let Some(p) = self.players.get(self.active_player as usize) else {
|
||||
return 1.0;
|
||||
};
|
||||
let w = &p.scoring_weights;
|
||||
match action {
|
||||
McAction::SpawnUnit => w.military_base.max(0.1),
|
||||
McAction::FoundCity => w.expansion_base.max(0.1),
|
||||
McAction::Idle => 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::game_state::{CityEcology, MapUnit, PlayerState};
|
||||
use mc_ai::evaluator::ScoringWeights;
|
||||
use mc_city::CityState;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
fn make_player(gold: i32, cities: usize, units: usize) -> PlayerState {
|
||||
let mut axes = BTreeMap::new();
|
||||
axes.insert("wealth".into(), 3u8);
|
||||
axes.insert("expansion".into(), 2u8);
|
||||
axes.insert("production".into(), 2u8);
|
||||
axes.insert("culture".into(), 2u8);
|
||||
PlayerState {
|
||||
player_index: 0,
|
||||
gold,
|
||||
cities: (0..cities).map(|_| CityState::default()).collect(),
|
||||
unit_upkeep: vec![0; units],
|
||||
strategic_axes: axes,
|
||||
scoring_weights: ScoringWeights::default(),
|
||||
expansion_points: 0,
|
||||
city_buildings: vec![vec![]; cities],
|
||||
city_improvements: Default::default(),
|
||||
city_ecology: vec![CityEcology::default(); cities],
|
||||
tech_state: None,
|
||||
science_pool: 0,
|
||||
player_tech: None,
|
||||
science_yield: 0,
|
||||
units: (0..units)
|
||||
.map(|_| MapUnit {
|
||||
col: 0,
|
||||
row: 0,
|
||||
hp: 10,
|
||||
max_hp: 10,
|
||||
attack: 5,
|
||||
defense: 5,
|
||||
is_fortified: false,
|
||||
unit_id: "dwarf_warrior".into(),
|
||||
held_resources: Vec::new(),
|
||||
patrol_order: None,
|
||||
..Default::default()
|
||||
})
|
||||
.collect(),
|
||||
city_positions: vec![(0, 0); cities],
|
||||
capital_position: Some((0, 0)),
|
||||
culture_total: 0,
|
||||
culture_pool: mc_culture::CulturePool::default(),
|
||||
arcane_lore_pop_deducted: false,
|
||||
traded_luxuries: Default::default(),
|
||||
relations: Default::default(),
|
||||
strategic_ledger: Default::default(),
|
||||
wonders_built: Default::default(),
|
||||
explored_deposits: Default::default(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn make_state(cities: usize) -> GameState {
|
||||
GameState {
|
||||
turn: 0,
|
||||
players: vec![make_player(100, cities, 2), make_player(80, cities, 1)],
|
||||
grid: None,
|
||||
pending_pvp_attacks: Default::default(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Economy formula: gold_in = wealth * city_count * gold_per_city.
|
||||
/// Verify McSnapshot::step matches TurnProcessor::step for 10 varying states.
|
||||
#[test]
|
||||
fn step_economy_matches_turn_processor_for_10_seeds() {
|
||||
let processor = TurnProcessor::new(300);
|
||||
|
||||
for seed in 0u32..10 {
|
||||
let mut state = make_state((seed % 3 + 1) as usize);
|
||||
state.players[0].gold = seed as i32 * 50;
|
||||
state.players[1].gold = seed as i32 * 30;
|
||||
|
||||
let snap = McSnapshot::from_game_state(&state, &processor);
|
||||
let next_snap = snap.step(&McAction::Idle);
|
||||
|
||||
// Run full processor step (no grid, so no movement/combat)
|
||||
processor.step(&mut state);
|
||||
|
||||
for (pi, p) in state.players.iter().enumerate() {
|
||||
assert_eq!(
|
||||
next_snap.players[pi].gold,
|
||||
p.gold,
|
||||
"seed={seed} pi={pi}: snapshot gold {} != processor gold {}",
|
||||
next_snap.players[pi].gold,
|
||||
p.gold
|
||||
);
|
||||
assert_eq!(
|
||||
next_snap.players[pi].city_count,
|
||||
p.cities.len() as u32,
|
||||
"seed={seed} pi={pi}: city count mismatch"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn step_idle_does_not_found_city() {
|
||||
let processor = TurnProcessor::new(300);
|
||||
let state = make_state(2);
|
||||
let mut snap = McSnapshot::from_game_state(&state, &processor);
|
||||
snap.players[0].expansion_points = 100;
|
||||
let next = snap.step(&McAction::Idle);
|
||||
assert_eq!(next.players[0].city_count, 2, "Idle must not found a city");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn step_found_city_spends_expansion_points() {
|
||||
let processor = TurnProcessor::new(300);
|
||||
let state = make_state(1);
|
||||
let mut snap = McSnapshot::from_game_state(&state, &processor);
|
||||
snap.players[0].expansion_points = 50;
|
||||
let next = snap.step(&McAction::FoundCity);
|
||||
assert_eq!(next.players[0].city_count, 2, "FoundCity must add a city");
|
||||
let cost = snap.config.city_founding_cost;
|
||||
let earned =
|
||||
snap.config.expansion_per_axis_per_turn * snap.players[0].expansion_axis as u32;
|
||||
assert_eq!(
|
||||
next.players[0].expansion_points,
|
||||
50 - cost + earned,
|
||||
"expansion_points after founding"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heuristic_value_returns_value_in_unit_interval() {
|
||||
let processor = TurnProcessor::new(300);
|
||||
let state = make_state(3);
|
||||
let snap = McSnapshot::from_game_state(&state, &processor);
|
||||
for pi in 0..snap.players.len() {
|
||||
let v = snap.heuristic_value(pi);
|
||||
assert!(v >= 0.0 && v < 1.0, "heuristic_value out of range: {v}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_terminal_triggers_at_victory_threshold() {
|
||||
let processor = TurnProcessor::new(300);
|
||||
let state = make_state(1);
|
||||
let mut snap = McSnapshot::from_game_state(&state, &processor);
|
||||
snap.players[0].city_count = snap.victory_city_count as u32;
|
||||
assert!(snap.is_terminal());
|
||||
assert_eq!(snap.winner(), Some(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn legal_actions_empty_when_terminal() {
|
||||
let processor = TurnProcessor::new(300);
|
||||
let state = make_state(1);
|
||||
let mut snap = McSnapshot::from_game_state(&state, &processor);
|
||||
snap.players[0].city_count = snap.victory_city_count as u32;
|
||||
assert!(snap.legal_actions().is_empty());
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue