feat(@projects/@magic-civilization): ✨ add gpu combat resolve kernel support
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
314086bc47
commit
720947087a
3 changed files with 641 additions and 7 deletions
630
src/simulator/crates/mc-turn/src/gpu/combat_resolve.rs
Normal file
630
src/simulator/crates/mc-turn/src/gpu/combat_resolve.rs
Normal file
|
|
@ -0,0 +1,630 @@
|
|||
//! GPU compute adapter for the combat_resolve kernel.
|
||||
//!
|
||||
//! Wraps buffer upload, dispatch, and readback for the WGSL
|
||||
//! `combat_resolve.wgsl` kernel. One workgroup invocation per combat;
|
||||
//! workgroup_size=64 — no shared RNG state, fully parallel.
|
||||
//!
|
||||
//! Falls back silently (returns `None`) when `GpuContext::try_init` finds
|
||||
//! no suitable adapter so tests pass in headless CI.
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub mod inner {
|
||||
use super::super::GpuContext;
|
||||
use mc_combat::keywords::Keyword;
|
||||
use mc_combat::resolver::{CombatOutcome, CombatParams, CombatResult, CombatType};
|
||||
use wgpu::util::DeviceExt;
|
||||
|
||||
fn keywords_to_mask(keywords: &[Keyword]) -> u32 {
|
||||
keywords.iter().fold(0u32, |acc, kw| acc | (1u32 << (*kw as u32)))
|
||||
}
|
||||
|
||||
const SHADER_SRC: &str = include_str!("combat_resolve.wgsl");
|
||||
|
||||
// ── Base XP constant ──────────────────────────────────────────────────────
|
||||
//
|
||||
/// Base XP granted per combat engagement — matches resolver.rs BASE_COMBAT_XP.
|
||||
const BASE_COMBAT_XP: u32 = 5;
|
||||
|
||||
// ── GPU-friendly flat combat descriptor ───────────────────────────────────
|
||||
//
|
||||
// Fields ordered to match the WGSL `GpuCombat` struct layout (std430).
|
||||
// All i32/u32/f32 — no pointers, no strings.
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
|
||||
pub struct GpuCombat {
|
||||
// Attacker stats
|
||||
atk_hp: i32,
|
||||
atk_max_hp: i32,
|
||||
atk_attack: i32,
|
||||
atk_defense: i32,
|
||||
atk_ranged_attack: i32,
|
||||
// Defender stats
|
||||
def_hp: i32,
|
||||
def_max_hp: i32,
|
||||
def_attack: i32,
|
||||
def_defense: i32,
|
||||
// Attack bonuses
|
||||
atk_flanking_allies: i32,
|
||||
atk_support_units: i32,
|
||||
atk_terrain_defense: f32,
|
||||
atk_fortification: f32,
|
||||
atk_river_crossing: u32,
|
||||
// Defense bonuses
|
||||
def_flanking_allies: i32,
|
||||
def_support_units: i32,
|
||||
def_terrain_defense: f32,
|
||||
def_fortification: f32,
|
||||
def_city_wall_bonus: f32,
|
||||
def_river_crossing: u32,
|
||||
// Keywords
|
||||
atk_keywords: u32,
|
||||
def_keywords: u32,
|
||||
// Metadata
|
||||
combat_type: u32,
|
||||
city_hp: i32,
|
||||
city_wall_tier: u32,
|
||||
city_has_garrison:u32,
|
||||
attacker_is_siege:u32,
|
||||
}
|
||||
|
||||
impl GpuCombat {
|
||||
/// Convert a CPU `CombatParams` to the GPU flat struct.
|
||||
pub fn from_params(p: &CombatParams) -> Self {
|
||||
let combat_type = match p.combat_type {
|
||||
CombatType::Melee => 0,
|
||||
CombatType::Ranged => 1,
|
||||
CombatType::Siege => 2,
|
||||
};
|
||||
Self {
|
||||
atk_hp: p.attacker.hp,
|
||||
atk_max_hp: p.attacker.max_hp,
|
||||
atk_attack: p.attacker.attack,
|
||||
atk_defense: p.attacker.defense,
|
||||
atk_ranged_attack: p.attacker.ranged_attack,
|
||||
|
||||
def_hp: p.defender.hp,
|
||||
def_max_hp:p.defender.max_hp,
|
||||
def_attack: p.defender.attack,
|
||||
def_defense:p.defender.defense,
|
||||
|
||||
atk_flanking_allies: p.attacker_bonuses.flanking_allies,
|
||||
atk_support_units: p.attacker_bonuses.support_units,
|
||||
atk_terrain_defense: p.attacker_bonuses.terrain_defense,
|
||||
atk_fortification: p.attacker_bonuses.fortification,
|
||||
atk_river_crossing: p.attacker_bonuses.river_crossing as u32,
|
||||
|
||||
def_flanking_allies: p.defender_bonuses.flanking_allies,
|
||||
def_support_units: p.defender_bonuses.support_units,
|
||||
def_terrain_defense: p.defender_bonuses.terrain_defense,
|
||||
def_fortification: p.defender_bonuses.fortification,
|
||||
def_city_wall_bonus: p.defender_bonuses.city_wall_bonus,
|
||||
def_river_crossing: p.defender_bonuses.river_crossing as u32,
|
||||
|
||||
atk_keywords: keywords_to_mask(&p.attacker_keywords),
|
||||
def_keywords: keywords_to_mask(&p.defender_keywords),
|
||||
|
||||
combat_type,
|
||||
city_hp: p.city_hp.unwrap_or(0),
|
||||
city_wall_tier: p.city_wall_tier as u32,
|
||||
city_has_garrison: p.city_has_garrison as u32,
|
||||
attacker_is_siege: p.attacker_is_siege as u32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── GPU result flat struct ────────────────────────────────────────────────
|
||||
//
|
||||
// Must match WGSL `GpuResult` layout exactly.
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Default, bytemuck::Pod, bytemuck::Zeroable)]
|
||||
pub struct GpuResultRaw {
|
||||
pub defender_damage: i32,
|
||||
pub attacker_damage: i32,
|
||||
pub attacker_outcome: u32,
|
||||
pub defender_outcome: u32,
|
||||
pub attacker_hp: i32,
|
||||
pub defender_hp: i32,
|
||||
pub city_damage: i32,
|
||||
pub city_hp_remaining: i32,
|
||||
pub life_drain_heal: i32,
|
||||
pub attacker_xp: i32,
|
||||
pub defender_xp: i32,
|
||||
pub _pad: u32,
|
||||
}
|
||||
|
||||
impl GpuResultRaw {
|
||||
/// Convert GPU result back to the CPU `CombatResult` type.
|
||||
pub fn to_combat_result(self, original_city_hp: Option<i32>) -> CombatResult {
|
||||
CombatResult {
|
||||
defender_damage: self.defender_damage,
|
||||
attacker_damage: self.attacker_damage,
|
||||
attacker_outcome: if self.attacker_outcome == 0 { CombatOutcome::Survived } else { CombatOutcome::Killed },
|
||||
defender_outcome: if self.defender_outcome == 0 { CombatOutcome::Survived } else { CombatOutcome::Killed },
|
||||
attacker_hp: self.attacker_hp,
|
||||
defender_hp: self.defender_hp,
|
||||
city_damage: self.city_damage,
|
||||
city_hp_remaining:if original_city_hp.is_some() { self.city_hp_remaining } else { 0 },
|
||||
life_drain_heal: self.life_drain_heal,
|
||||
attacker_xp: self.attacker_xp,
|
||||
defender_xp: self.defender_xp,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Uniforms block ────────────────────────────────────────────────────────
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
|
||||
struct GpuUniforms {
|
||||
n_combats: u32,
|
||||
/// Base XP per engagement — matches resolver.rs BASE_COMBAT_XP = 5.
|
||||
base_xp: u32,
|
||||
}
|
||||
|
||||
// ── Bind-group layout helpers (shared pattern from fauna_encounter) ───────
|
||||
|
||||
fn bgl_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn upload_ro(dev: &wgpu::Device, data: &[u8], label: &str) -> wgpu::Buffer {
|
||||
dev.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||
label: Some(label),
|
||||
contents: data,
|
||||
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
|
||||
})
|
||||
}
|
||||
|
||||
fn create_rw(dev: &wgpu::Device, size_bytes: usize, label: &str) -> wgpu::Buffer {
|
||||
dev.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some(label),
|
||||
size: size_bytes as u64,
|
||||
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
|
||||
mapped_at_creation: false,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Public dispatch API ───────────────────────────────────────────────────
|
||||
|
||||
/// Dispatch the combat_resolve kernel over `combats`, returning one
|
||||
/// `GpuResultRaw` per entry. Returns `None` if the GPU context is unavailable.
|
||||
pub fn dispatch_combat_batch(
|
||||
ctx: &GpuContext,
|
||||
combats: &[GpuCombat],
|
||||
) -> Vec<GpuResultRaw> {
|
||||
if combats.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let dev = &ctx.device;
|
||||
let n = combats.len() as u32;
|
||||
|
||||
let uniforms = GpuUniforms { n_combats: n, base_xp: BASE_COMBAT_XP };
|
||||
|
||||
let buf_combats = upload_ro(dev, bytemuck::cast_slice(combats), "combats");
|
||||
let buf_results = create_rw(dev, combats.len() * std::mem::size_of::<GpuResultRaw>(), "results");
|
||||
let buf_uniforms = upload_ro(dev, bytemuck::bytes_of(&uniforms), "combat_uniforms");
|
||||
|
||||
let bgl = dev.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||
label: Some("combat_bgl"),
|
||||
entries: &[
|
||||
bgl_entry(0, true), // combats
|
||||
bgl_entry(1, false), // results (read_write)
|
||||
bgl_entry(2, true), // uniforms
|
||||
],
|
||||
});
|
||||
|
||||
let pipeline_layout = dev.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||
label: Some("combat_pl"),
|
||||
bind_group_layouts: &[&bgl],
|
||||
push_constant_ranges: &[],
|
||||
});
|
||||
|
||||
let shader = dev.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||
label: Some("combat_resolve"),
|
||||
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
|
||||
});
|
||||
|
||||
let pipeline = dev.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: Some("combat_resolve"),
|
||||
layout: Some(&pipeline_layout),
|
||||
module: &shader,
|
||||
entry_point: "main",
|
||||
compilation_options: wgpu::PipelineCompilationOptions::default(),
|
||||
});
|
||||
|
||||
let bind_group = dev.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
label: Some("combat_bg"),
|
||||
layout: &bgl,
|
||||
entries: &[
|
||||
wgpu::BindGroupEntry { binding: 0, resource: buf_combats.as_entire_binding() },
|
||||
wgpu::BindGroupEntry { binding: 1, resource: buf_results.as_entire_binding() },
|
||||
wgpu::BindGroupEntry { binding: 2, resource: buf_uniforms.as_entire_binding() },
|
||||
],
|
||||
});
|
||||
|
||||
let mut encoder = dev.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("combat_enc"),
|
||||
});
|
||||
{
|
||||
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: Some("combat_pass"),
|
||||
timestamp_writes: None,
|
||||
});
|
||||
pass.set_pipeline(&pipeline);
|
||||
pass.set_bind_group(0, &bind_group, &[]);
|
||||
// workgroup_size=64; ceil(n/64) groups covers all combats.
|
||||
let groups = n.div_ceil(64);
|
||||
pass.dispatch_workgroups(groups, 1, 1);
|
||||
}
|
||||
|
||||
let staging_size = (combats.len() * std::mem::size_of::<GpuResultRaw>()) as u64;
|
||||
let staging = dev.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("combat_staging"),
|
||||
size: staging_size,
|
||||
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
encoder.copy_buffer_to_buffer(&buf_results, 0, &staging, 0, staging_size);
|
||||
ctx.queue.submit(std::iter::once(encoder.finish()));
|
||||
|
||||
let slice = staging.slice(..);
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
slice.map_async(wgpu::MapMode::Read, move |r| { let _ = tx.send(r); });
|
||||
dev.poll(wgpu::Maintain::Wait);
|
||||
rx.recv().unwrap().unwrap();
|
||||
|
||||
let mapped = slice.get_mapped_range();
|
||||
let out: Vec<GpuResultRaw> = bytemuck::cast_slice(&mapped).to_vec();
|
||||
drop(mapped);
|
||||
staging.unmap();
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mc_combat::resolver::{CombatParams, CombatResolver, CombatType, UnitStats};
|
||||
use mc_combat::keywords::Keyword;
|
||||
use mc_combat::bonuses::CombatBonuses;
|
||||
|
||||
fn warrior() -> UnitStats {
|
||||
UnitStats { hp: 60, max_hp: 60, attack: 12, defense: 1, ranged_attack: 0, range: 0, movement: 2 }
|
||||
}
|
||||
|
||||
fn crossbow() -> UnitStats {
|
||||
UnitStats { hp: 40, max_hp: 40, attack: 5, defense: 1, ranged_attack: 18, range: 2, movement: 2 }
|
||||
}
|
||||
|
||||
fn catapult() -> UnitStats {
|
||||
UnitStats { hp: 50, max_hp: 50, attack: 6, defense: 0, ranged_attack: 20, range: 3, movement: 1 }
|
||||
}
|
||||
|
||||
fn strong() -> UnitStats {
|
||||
UnitStats { hp: 100, max_hp: 100, attack: 40, defense: 5, ranged_attack: 0, range: 0, movement: 2 }
|
||||
}
|
||||
|
||||
fn weak() -> UnitStats {
|
||||
UnitStats { hp: 20, max_hp: 20, attack: 10, defense: 0, ranged_attack: 0, range: 0, movement: 2 }
|
||||
}
|
||||
|
||||
fn half_hp_warrior() -> UnitStats {
|
||||
UnitStats { hp: 30, ..warrior() }
|
||||
}
|
||||
|
||||
/// Build the 1000-scenario suite covering all main code paths.
|
||||
fn build_test_suite() -> Vec<CombatParams> {
|
||||
let mut suite = Vec::with_capacity(1024);
|
||||
|
||||
// Basic melee variants — equal, flanked, fortified, terrain
|
||||
for flanking in [0, 2, 4] {
|
||||
for fort in [0.0f32, 0.25, 0.50] {
|
||||
for terrain in [0.0f32, 0.25, 0.50] {
|
||||
suite.push(CombatParams {
|
||||
attacker: warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
attacker_bonuses: CombatBonuses { flanking_allies: flanking, ..Default::default() },
|
||||
defender_bonuses: CombatBonuses { fortification: fort, terrain_defense: terrain, ..Default::default() },
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ranged — no retaliation
|
||||
suite.push(CombatParams {
|
||||
attacker: crossbow(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Ranged,
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Siege vs city — various wall tiers
|
||||
for wall_tier in [0, 1, 2] {
|
||||
for garrison in [false, true] {
|
||||
suite.push(CombatParams {
|
||||
attacker: catapult(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Ranged,
|
||||
city_hp: Some(250),
|
||||
city_wall_tier: wall_tier,
|
||||
city_has_garrison: garrison,
|
||||
attacker_is_siege: true,
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Melee vs walled city
|
||||
for wall_tier in [0, 1, 2] {
|
||||
suite.push(CombatParams {
|
||||
attacker: warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
city_hp: Some(200),
|
||||
city_wall_tier: wall_tier,
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
// Keyword combos: FirstStrike, LifeDrain, Charge, Brace, Skirmish
|
||||
suite.push(CombatParams {
|
||||
attacker: strong(),
|
||||
defender: weak(),
|
||||
combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::FirstStrike],
|
||||
..Default::default()
|
||||
});
|
||||
suite.push(CombatParams {
|
||||
attacker: warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::LifeDrain],
|
||||
..Default::default()
|
||||
});
|
||||
suite.push(CombatParams {
|
||||
attacker: warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::Charge],
|
||||
..Default::default()
|
||||
});
|
||||
suite.push(CombatParams {
|
||||
attacker: warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
defender_keywords: vec![Keyword::Brace],
|
||||
..Default::default()
|
||||
});
|
||||
suite.push(CombatParams {
|
||||
attacker: warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::Skirmish],
|
||||
..Default::default()
|
||||
});
|
||||
suite.push(CombatParams {
|
||||
attacker: warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
defender_keywords: vec![Keyword::NoMeleeRetaliation],
|
||||
..Default::default()
|
||||
});
|
||||
suite.push(CombatParams {
|
||||
attacker: warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::Flying],
|
||||
..Default::default()
|
||||
});
|
||||
suite.push(CombatParams {
|
||||
attacker: warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
defender_keywords: vec![Keyword::Flying],
|
||||
..Default::default()
|
||||
});
|
||||
// BonusVsFortified with fortified defender
|
||||
suite.push(CombatParams {
|
||||
attacker: warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::BonusVsFortified],
|
||||
defender_bonuses: CombatBonuses { fortification: 0.50, ..Default::default() },
|
||||
..Default::default()
|
||||
});
|
||||
// PackTactics
|
||||
suite.push(CombatParams {
|
||||
attacker: warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::PackTactics],
|
||||
attacker_bonuses: CombatBonuses { flanking_allies: 3, ..Default::default() },
|
||||
..Default::default()
|
||||
});
|
||||
// IgnoreTerrainDefense
|
||||
suite.push(CombatParams {
|
||||
attacker: warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::IgnoreTerrainDefense],
|
||||
defender_bonuses: CombatBonuses { terrain_defense: 0.50, ..Default::default() },
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Damaged attacker
|
||||
suite.push(CombatParams {
|
||||
attacker: half_hp_warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// River crossing penalty
|
||||
suite.push(CombatParams {
|
||||
attacker: warrior(),
|
||||
defender: warrior(),
|
||||
combat_type: CombatType::Melee,
|
||||
attacker_bonuses: CombatBonuses { river_crossing: true, ..Default::default() },
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Pad to 1000 by repeating with varied HP levels
|
||||
let base_len = suite.len();
|
||||
let mut hp = 60i32;
|
||||
while suite.len() < 1000 {
|
||||
hp = (hp - 5).max(10);
|
||||
let i = suite.len() % base_len;
|
||||
let mut p = suite[i].clone();
|
||||
p.attacker.hp = hp.min(p.attacker.max_hp);
|
||||
suite.push(p);
|
||||
}
|
||||
|
||||
suite
|
||||
}
|
||||
|
||||
/// 1000-combat parity test: GPU and CPU must produce identical results
|
||||
/// across all fields for the full test suite.
|
||||
#[test]
|
||||
fn combat_resolve_matches_cpu_1000() {
|
||||
let ctx = match GpuContext::try_init() {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
eprintln!("[gpu-test] No GPU adapter — skipping combat_resolve_matches_cpu_1000");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let suite = build_test_suite();
|
||||
assert_eq!(suite.len(), 1000, "test suite must be exactly 1000 combats");
|
||||
|
||||
// CPU reference
|
||||
let cpu_results: Vec<CombatResult> = suite.iter().map(CombatResolver::resolve).collect();
|
||||
|
||||
// GPU
|
||||
let gpu_flat: Vec<GpuCombat> = suite.iter().map(GpuCombat::from_params).collect();
|
||||
let gpu_raw = dispatch_combat_batch(&ctx, &gpu_flat);
|
||||
assert_eq!(gpu_raw.len(), 1000);
|
||||
|
||||
let gpu_results: Vec<CombatResult> = gpu_raw.iter().zip(suite.iter())
|
||||
.map(|(r, p)| r.to_combat_result(p.city_hp))
|
||||
.collect();
|
||||
|
||||
for (i, (cpu, gpu)) in cpu_results.iter().zip(gpu_results.iter()).enumerate() {
|
||||
assert_eq!(cpu.defender_damage, gpu.defender_damage, "combat {i}: defender_damage mismatch");
|
||||
assert_eq!(cpu.attacker_damage, gpu.attacker_damage, "combat {i}: attacker_damage mismatch");
|
||||
assert_eq!(cpu.attacker_outcome, gpu.attacker_outcome, "combat {i}: attacker_outcome mismatch");
|
||||
assert_eq!(cpu.defender_outcome, gpu.defender_outcome, "combat {i}: defender_outcome mismatch");
|
||||
assert_eq!(cpu.attacker_hp, gpu.attacker_hp, "combat {i}: attacker_hp mismatch");
|
||||
assert_eq!(cpu.defender_hp, gpu.defender_hp, "combat {i}: defender_hp mismatch");
|
||||
assert_eq!(cpu.city_damage, gpu.city_damage, "combat {i}: city_damage mismatch");
|
||||
assert_eq!(cpu.city_hp_remaining, gpu.city_hp_remaining, "combat {i}: city_hp_remaining mismatch");
|
||||
assert_eq!(cpu.life_drain_heal, gpu.life_drain_heal, "combat {i}: life_drain_heal mismatch");
|
||||
assert_eq!(cpu.attacker_xp, gpu.attacker_xp, "combat {i}: attacker_xp mismatch");
|
||||
assert_eq!(cpu.defender_xp, gpu.defender_xp, "combat {i}: defender_xp mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
/// Scalar keyword parity: verify each keyword path individually so
|
||||
/// a single bitmask bug is caught before the integration test.
|
||||
#[test]
|
||||
fn combat_keyword_parity_scalar() {
|
||||
let ctx = match GpuContext::try_init() {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
eprintln!("[gpu-test] No GPU adapter — skipping combat_keyword_parity_scalar");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let cases: Vec<(&str, CombatParams)> = vec![
|
||||
("first_strike_kills_no_ret", CombatParams {
|
||||
attacker: strong(), defender: weak(), combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::FirstStrike], ..Default::default()
|
||||
}),
|
||||
("life_drain", CombatParams {
|
||||
attacker: warrior(), defender: warrior(), combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::LifeDrain], ..Default::default()
|
||||
}),
|
||||
("charge", CombatParams {
|
||||
attacker: warrior(), defender: warrior(), combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::Charge], ..Default::default()
|
||||
}),
|
||||
("brace", CombatParams {
|
||||
attacker: warrior(), defender: warrior(), combat_type: CombatType::Melee,
|
||||
defender_keywords: vec![Keyword::Brace], ..Default::default()
|
||||
}),
|
||||
("skirmish_no_ret", CombatParams {
|
||||
attacker: warrior(), defender: warrior(), combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::Skirmish], ..Default::default()
|
||||
}),
|
||||
("no_melee_ret", CombatParams {
|
||||
attacker: warrior(), defender: warrior(), combat_type: CombatType::Melee,
|
||||
defender_keywords: vec![Keyword::NoMeleeRetaliation], ..Default::default()
|
||||
}),
|
||||
("flying_def_vs_ground", CombatParams {
|
||||
attacker: warrior(), defender: warrior(), combat_type: CombatType::Melee,
|
||||
defender_keywords: vec![Keyword::Flying], ..Default::default()
|
||||
}),
|
||||
("bonus_vs_fortified", CombatParams {
|
||||
attacker: warrior(), defender: warrior(), combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::BonusVsFortified],
|
||||
defender_bonuses: CombatBonuses { fortification: 0.50, ..Default::default() },
|
||||
..Default::default()
|
||||
}),
|
||||
("pack_tactics_3_allies", CombatParams {
|
||||
attacker: warrior(), defender: warrior(), combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::PackTactics],
|
||||
attacker_bonuses: CombatBonuses { flanking_allies: 3, ..Default::default() },
|
||||
..Default::default()
|
||||
}),
|
||||
("ignore_terrain", CombatParams {
|
||||
attacker: warrior(), defender: warrior(), combat_type: CombatType::Melee,
|
||||
attacker_keywords: vec![Keyword::IgnoreTerrainDefense],
|
||||
defender_bonuses: CombatBonuses { terrain_defense: 0.50, ..Default::default() },
|
||||
..Default::default()
|
||||
}),
|
||||
];
|
||||
|
||||
let params: Vec<&CombatParams> = cases.iter().map(|(_, p)| p).collect();
|
||||
let cpu_results: Vec<CombatResult> = params.iter().map(|p| CombatResolver::resolve(p)).collect();
|
||||
|
||||
let gpu_flat: Vec<GpuCombat> = params.iter().map(|p| GpuCombat::from_params(p)).collect();
|
||||
let gpu_raw = dispatch_combat_batch(&ctx, &gpu_flat);
|
||||
|
||||
for (i, (label, p)) in cases.iter().enumerate() {
|
||||
let cpu = &cpu_results[i];
|
||||
let gpu = gpu_raw[i].to_combat_result(p.city_hp);
|
||||
assert_eq!(cpu.defender_damage, gpu.defender_damage, "[{label}] defender_damage");
|
||||
assert_eq!(cpu.attacker_damage, gpu.attacker_damage, "[{label}] attacker_damage");
|
||||
assert_eq!(cpu.attacker_outcome, gpu.attacker_outcome, "[{label}] attacker_outcome");
|
||||
assert_eq!(cpu.defender_outcome, gpu.defender_outcome, "[{label}] defender_outcome");
|
||||
assert_eq!(cpu.life_drain_heal, gpu.life_drain_heal, "[{label}] life_drain_heal");
|
||||
assert_eq!(cpu.attacker_xp, gpu.attacker_xp, "[{label}] attacker_xp");
|
||||
}
|
||||
}
|
||||
|
||||
/// Graceful fallback: try_init must not panic even without a GPU.
|
||||
#[test]
|
||||
fn graceful_when_no_gpu() {
|
||||
let result = std::panic::catch_unwind(GpuContext::try_init);
|
||||
assert!(result.is_ok(), "try_init must not panic even without GPU");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub use inner::{dispatch_combat_batch, GpuCombat, GpuResultRaw};
|
||||
|
|
@ -125,8 +125,7 @@ struct GpuCombat {
|
|||
def_terrain_defense: f32,
|
||||
def_fortification: f32,
|
||||
def_city_wall_bonus: f32,
|
||||
def_city_defense_pct: f32,
|
||||
def_river_crossing: u32, // unused but pad for alignment
|
||||
def_river_crossing: u32,
|
||||
|
||||
// Keywords (bitmask)
|
||||
atk_keywords: u32,
|
||||
|
|
@ -178,8 +177,7 @@ fn total_defense_modifier(c: GpuCombat, ignore_terrain: bool) -> f32 {
|
|||
let fort = min(c.def_fortification, FORTIFICATION_CAP);
|
||||
let walls = c.def_city_wall_bonus;
|
||||
let flanking = min(f32(c.def_flanking_allies) * FLANKING_BONUS_PER_UNIT, FLANKING_BONUS_CAP);
|
||||
let city_def = c.def_city_defense_pct;
|
||||
return terrain + fort + walls + flanking + city_def;
|
||||
return terrain + fort + walls + flanking;
|
||||
}
|
||||
|
||||
// ── Helper: keyword attack bonus (mirrors keywords::keyword_attack_bonus) ─────
|
||||
|
|
@ -221,6 +219,10 @@ fn keyword_defense_bonus(
|
|||
}
|
||||
|
||||
// ── Helper: XP from combat (mirrors promotions::xp_from_combat) ───────────────
|
||||
//
|
||||
// Uses floor(x + 0.5) instead of round() to match Rust's half-away-from-zero
|
||||
// rounding. WGSL round() uses banker's rounding (half-to-even), which diverges
|
||||
// at exactly 0.5 boundaries (e.g. 2.5 → 2 in WGSL, 3 in Rust).
|
||||
fn xp_from_combat(base_xp: i32, strength_ratio: f32) -> i32 {
|
||||
var multiplier: f32;
|
||||
if strength_ratio > 1.0 {
|
||||
|
|
@ -228,7 +230,7 @@ fn xp_from_combat(base_xp: i32, strength_ratio: f32) -> i32 {
|
|||
} else {
|
||||
multiplier = max(0.5, strength_ratio);
|
||||
}
|
||||
return i32(round(f32(base_xp) * multiplier));
|
||||
return i32(floor(f32(base_xp) * multiplier + 0.5));
|
||||
}
|
||||
|
||||
// ── Main kernel ───────────────────────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -49,8 +49,8 @@ mod inner {
|
|||
|
||||
/// Lazy-initialized wgpu context. Created once, reused across dispatches.
|
||||
pub struct GpuContext {
|
||||
device: wgpu::Device,
|
||||
queue: wgpu::Queue,
|
||||
pub(crate) device: wgpu::Device,
|
||||
pub(crate) queue: wgpu::Queue,
|
||||
pipeline: wgpu::ComputePipeline,
|
||||
bind_group_layout: wgpu::BindGroupLayout,
|
||||
pub backend: String,
|
||||
|
|
@ -546,3 +546,5 @@ mod inner {
|
|||
pub use inner::GpuContext;
|
||||
#[cfg(feature = "gpu")]
|
||||
pub use inner::GpuUnit;
|
||||
|
||||
pub mod combat_resolve;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue