magicciv/tools/determinism-compare.py

214 lines
7.5 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""determinism-compare.py — Per-seed parity check for CPU↔GPU AI batches.
Companion to `tools/determinism-audit.sh` Scenario 2. Reads two batch output
directories (each containing `game_<stamp>_seed<N>/turn_stats.jsonl` dirs),
matches seeds, then:
- Integer fields (all non-float values in turn_stats schema) MUST match
byte-for-byte across turns.
- Float fields MUST match within `--float-tol` (default 1e-4 absolute).
Exits:
0 all seeds parity-green
1 at least one seed diverged (integer mismatch OR float beyond tolerance)
2 usage / I/O error
stdlib only no pip installs. Mirrors the stdlib-only policy of autoplay-report.py.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any
INT_FIELDS_REQUIRED_EQUAL: set[str] = {
# Top-level of each turn_stats line
"turn", "winner_index",
# aggregate.*
"total_combats", "total_cities_founded", "total_cities_captured",
"turn_first_combat", "turn_first_city_captured",
# player_stats.<pid>.*
"pop", "pop_peak", "mil",
"cities", "cities_captured", "cities_lost",
"gold", "gold_peak",
"techs", "tiles", "buildings",
"happiness",
"food_total", "production_total",
"kills", "units_lost",
"turn_first_pop_3", "turn_first_pop_4",
}
# Strings must match exactly (no tolerance makes sense)
STR_FIELDS_REQUIRED_EQUAL: set[str] = {"outcome", "victory_type", "winner_personality"}
# Fields that legitimately vary across runs (wall-clock timings, boot stamps, etc).
# Not a determinism signal — excluded from the parity check.
EXCLUDE_FIELDS: set[str] = {
"wall_clock_sec", # per-turn wall-clock time
"start_stamp", # ISO-8601 boot timestamp in meta.json
"finished_at", # audit completion timestamp
}
def find_seed_dirs(root: Path) -> dict[int, Path]:
"""Return {seed: most-recent-dir} for `game_<stamp>_seed<N>` dirs under root."""
by_seed: dict[int, list[Path]] = {}
for d in root.iterdir():
if not d.is_dir() or not d.name.startswith("game_"):
continue
parts = d.name.rsplit("_seed", 1)
if len(parts) != 2 or not parts[1].isdigit():
continue
by_seed.setdefault(int(parts[1]), []).append(d)
return {seed: sorted(dirs)[-1] for seed, dirs in by_seed.items()}
def load_jsonl(path: Path) -> list[dict[str, Any]]:
if not path.exists():
return []
out: list[dict[str, Any]] = []
for raw in path.read_text().splitlines():
raw = raw.strip()
if not raw:
continue
try:
out.append(json.loads(raw))
except json.JSONDecodeError:
pass
return out
def walk_compare(
a: Any, b: Any, path: str, tol: float, diffs: list[str]
) -> None:
"""Recursively compare two JSON-ish values. Append human-readable diffs."""
if type(a) is not type(b):
# int vs float is permitted only if the field is not in the required-int set.
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
pass # fall through to numeric compare
else:
diffs.append(f"{path}: type mismatch {type(a).__name__} vs {type(b).__name__}")
return
if isinstance(a, dict):
keys = set(a) | set(b)
for k in sorted(keys):
if k in EXCLUDE_FIELDS:
continue
if k not in a:
diffs.append(f"{path}.{k}: missing on cpu side")
continue
if k not in b:
diffs.append(f"{path}.{k}: missing on gpu side")
continue
walk_compare(a[k], b[k], f"{path}.{k}" if path else k, tol, diffs)
return
if isinstance(a, list):
if len(a) != len(b):
diffs.append(f"{path}: list length {len(a)} vs {len(b)}")
return
for i, (x, y) in enumerate(zip(a, b)):
walk_compare(x, y, f"{path}[{i}]", tol, diffs)
return
# Leaf: compare values.
leaf_name = path.rsplit(".", 1)[-1].rsplit("[", 1)[0]
if isinstance(a, str) and isinstance(b, str):
if a != b:
diffs.append(f"{path}: string '{a}' != '{b}'")
return
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
# Integer-required field: byte-equal.
if leaf_name in INT_FIELDS_REQUIRED_EQUAL:
if a != b:
diffs.append(f"{path}: int-required {a} != {b} (must be byte-equal)")
return
# Otherwise treat as float with tolerance.
if abs(float(a) - float(b)) > tol:
diffs.append(f"{path}: float {a} != {b} (tol={tol})")
return
if a != b:
diffs.append(f"{path}: {a!r} != {b!r}")
def compare_seed(
dir_a: Path, dir_b: Path, seed: int, tol: float
) -> list[str]:
ts_a = load_jsonl(dir_a / "turn_stats.jsonl")
ts_b = load_jsonl(dir_b / "turn_stats.jsonl")
if len(ts_a) != len(ts_b):
return [f"seed {seed}: turn count differs ({len(ts_a)} vs {len(ts_b)})"]
diffs: list[str] = []
for i, (line_a, line_b) in enumerate(zip(ts_a, ts_b)):
walk_compare(line_a, line_b, f"seed{seed}.line{i}", tol, diffs)
return diffs
def main(argv: list[str]) -> int:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("dir_a", type=Path, help="First batch dir (e.g. CPU run)")
ap.add_argument("dir_b", type=Path, help="Second batch dir (e.g. GPU run)")
ap.add_argument("--float-tol", type=float, default=1e-4,
help="Absolute tolerance for float fields (default 1e-4)")
ap.add_argument("--max-diff-lines", type=int, default=20,
help="Max diff lines to print per seed before truncating")
args = ap.parse_args(argv[1:])
if not args.dir_a.is_dir():
print(f"ERROR: {args.dir_a} is not a directory", file=sys.stderr)
return 2
if not args.dir_b.is_dir():
print(f"ERROR: {args.dir_b} is not a directory", file=sys.stderr)
return 2
seeds_a = find_seed_dirs(args.dir_a)
seeds_b = find_seed_dirs(args.dir_b)
common = sorted(set(seeds_a) & set(seeds_b))
only_a = sorted(set(seeds_a) - set(seeds_b))
only_b = sorted(set(seeds_b) - set(seeds_a))
if not common:
print(f"ERROR: no overlapping seeds between {args.dir_a} and {args.dir_b}", file=sys.stderr)
return 2
if only_a:
print(f"WARN: seeds only in {args.dir_a.name}: {only_a}", file=sys.stderr)
if only_b:
print(f"WARN: seeds only in {args.dir_b.name}: {only_b}", file=sys.stderr)
total_diffs = 0
failing_seeds: list[int] = []
for seed in common:
diffs = compare_seed(seeds_a[seed], seeds_b[seed], seed, args.float_tol)
if diffs:
failing_seeds.append(seed)
total_diffs += len(diffs)
print(f"seed {seed}: {len(diffs)} divergence(s)")
for line in diffs[: args.max_diff_lines]:
print(f" {line}")
if len(diffs) > args.max_diff_lines:
print(f" ... ({len(diffs) - args.max_diff_lines} more)")
else:
print(f"seed {seed}: OK")
print(f"\n=== summary ===")
print(f"seeds compared: {len(common)}")
print(f"seeds passing: {len(common) - len(failing_seeds)}")
print(f"seeds failing: {len(failing_seeds)} ({failing_seeds})")
print(f"total divergences: {total_diffs}")
print(f"float tolerance: {args.float_tol}")
return 0 if not failing_seeds else 1
if __name__ == "__main__":
sys.exit(main(sys.argv))