247 lines
8 KiB
Python
247 lines
8 KiB
Python
"""Calibration tool: compare local SigLIP2 scores against Sonnet answer key.
|
|
|
|
Scores all labeled images with the local scorer, then generates a comparison
|
|
report showing per-dimension correlation, systematic biases, and disagreements.
|
|
|
|
Usage:
|
|
python3 -m engine.calibration [--device cuda:0] [--limit 20]
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import sqlite3
|
|
import statistics
|
|
from pathlib import Path
|
|
|
|
from engine.local_scorer import LocalSpriteScorer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TOOL_DIR = Path(__file__).resolve().parent.parent
|
|
DB_PATH = TOOL_DIR / "spritegen.db"
|
|
|
|
UNIT_DIMS = (
|
|
"camera_angle", "facing_direction", "composition", "subject_type",
|
|
"race_accuracy", "gender_accuracy", "equipment_accuracy",
|
|
"pose_quality", "background_compliance", "art_style",
|
|
)
|
|
|
|
|
|
def _parse_entity_metadata(entity_id: str) -> dict[str, str]:
|
|
"""Extract race and gender from entity_id like 'bowmen_dwarves_f'."""
|
|
parts = entity_id
|
|
meta: dict[str, str] = {}
|
|
|
|
if parts.endswith("_m"):
|
|
meta["gender"] = "male"
|
|
parts = parts[:-2]
|
|
elif parts.endswith("_f"):
|
|
meta["gender"] = "female"
|
|
parts = parts[:-2]
|
|
|
|
for race in ("dwarves", "humans", "high_elves", "orcs"):
|
|
suffix = f"_{race}"
|
|
if parts.endswith(suffix):
|
|
meta["race"] = race
|
|
parts = parts[:-len(suffix)]
|
|
break
|
|
|
|
meta["base_unit"] = parts
|
|
return meta
|
|
|
|
|
|
def _pearson_r(xs: list[float], ys: list[float]) -> float:
|
|
"""Pearson correlation coefficient between two lists."""
|
|
n = len(xs)
|
|
if n < 3:
|
|
return 0.0
|
|
mean_x = sum(xs) / n
|
|
mean_y = sum(ys) / n
|
|
dx = [x - mean_x for x in xs]
|
|
dy = [y - mean_y for y in ys]
|
|
num = sum(a * b for a, b in zip(dx, dy))
|
|
den_x = sum(a * a for a in dx) ** 0.5
|
|
den_y = sum(a * a for a in dy) ** 0.5
|
|
if den_x == 0 or den_y == 0:
|
|
return 0.0
|
|
return num / (den_x * den_y)
|
|
|
|
|
|
def load_sonnet_scores(limit: int | None = None) -> list[dict]:
|
|
"""Load all 10-dim Sonnet-scored unit variants from sprites.db."""
|
|
conn = sqlite3.connect(str(DB_PATH))
|
|
conn.row_factory = sqlite3.Row
|
|
|
|
query = """
|
|
SELECT v.id as variant_id, v.raw_path, v.notes, s.entity_id, s.prompt, s.category
|
|
FROM variants v
|
|
JOIN sprites s ON v.sprite_id = s.id
|
|
WHERE v.notes IS NOT NULL
|
|
AND v.rating IS NOT NULL AND v.rating != -1
|
|
AND v.raw_path IS NOT NULL
|
|
"""
|
|
if limit:
|
|
query += f" LIMIT {limit}"
|
|
|
|
rows = conn.execute(query).fetchall()
|
|
results = []
|
|
for r in rows:
|
|
path = Path(r["raw_path"])
|
|
if not path.exists():
|
|
continue
|
|
sonnet_scores = json.loads(r["notes"])
|
|
meta = _parse_entity_metadata(r["entity_id"])
|
|
results.append({
|
|
"variant_id": r["variant_id"],
|
|
"raw_path": str(path),
|
|
"entity_id": r["entity_id"],
|
|
"prompt": r["prompt"],
|
|
"category": r["category"],
|
|
"sonnet_scores": sonnet_scores,
|
|
"sonnet_dims": tuple(sorted(sonnet_scores.keys())),
|
|
"race": meta.get("race", ""),
|
|
"gender": meta.get("gender", ""),
|
|
"base_unit": meta.get("base_unit", ""),
|
|
})
|
|
conn.close()
|
|
return results
|
|
|
|
|
|
def run_calibration(device: str = "cuda:0", limit: int | None = None) -> dict:
|
|
"""Score all labeled images with local scorer and compare to Sonnet.
|
|
|
|
Returns a calibration report dict with per-dimension stats.
|
|
"""
|
|
labeled = load_sonnet_scores(limit=limit)
|
|
if not labeled:
|
|
print("No labeled data found.")
|
|
return {}
|
|
|
|
print(f"Loaded {len(labeled)} Sonnet-scored variants with images on disk.")
|
|
|
|
scorer = LocalSpriteScorer(device=device)
|
|
scorer.load_sync(device=device)
|
|
|
|
# Collect all unique Sonnet dimensions across the dataset
|
|
all_sonnet_dims: set[str] = set()
|
|
for item in labeled:
|
|
all_sonnet_dims.update(item["sonnet_scores"].keys())
|
|
all_dims = sorted(all_sonnet_dims)
|
|
|
|
per_dim_sonnet: dict[str, list[float]] = {d: [] for d in all_dims}
|
|
per_dim_local: dict[str, list[float]] = {d: [] for d in all_dims}
|
|
disagreements: list[dict] = []
|
|
|
|
# Map old 4-dim names to our 10-dim local scorer queries
|
|
# Old dims: perspective, composition, subject_accuracy, production_quality
|
|
# These map approximately to our dimensions
|
|
OLD_TO_LOCAL_MAP = {
|
|
"perspective": "camera_angle",
|
|
"composition": "composition",
|
|
"subject_accuracy": "subject_type",
|
|
"production_quality": "art_style",
|
|
}
|
|
|
|
for i, item in enumerate(labeled):
|
|
is_unit = item["category"] == "units"
|
|
|
|
if is_unit:
|
|
result = scorer.score_image(
|
|
image_path=item["raw_path"],
|
|
entity_id=item["entity_id"],
|
|
race=item["race"],
|
|
gender=item["gender"],
|
|
entity_description=item["base_unit"].replace("_", " "),
|
|
)
|
|
local_scores = result.scores
|
|
else:
|
|
# Non-unit categories — score with generic queries
|
|
result = scorer.score_image(
|
|
image_path=item["raw_path"],
|
|
entity_id=item["entity_id"],
|
|
)
|
|
local_scores = result.scores
|
|
|
|
sonnet_scores = item["sonnet_scores"]
|
|
|
|
for dim in sonnet_scores:
|
|
s_val = sonnet_scores[dim]
|
|
|
|
# Find corresponding local dimension
|
|
local_dim = OLD_TO_LOCAL_MAP.get(dim, dim)
|
|
l_val = local_scores.get(local_dim, 0.5)
|
|
|
|
per_dim_sonnet[dim].append(s_val)
|
|
per_dim_local[dim].append(l_val)
|
|
|
|
diff = abs(s_val - l_val)
|
|
if diff > 0.3:
|
|
disagreements.append({
|
|
"variant_id": item["variant_id"],
|
|
"entity_id": item["entity_id"],
|
|
"dimension": dim,
|
|
"local_dimension": local_dim,
|
|
"sonnet": s_val,
|
|
"local": l_val,
|
|
"diff": diff,
|
|
})
|
|
|
|
if (i + 1) % 10 == 0:
|
|
print(f" Scored {i + 1}/{len(labeled)}...")
|
|
|
|
# Compute per-dimension statistics
|
|
report: dict[str, dict] = {}
|
|
print(f"\n{'='*80}")
|
|
print(f"CALIBRATION REPORT — {len(labeled)} variants")
|
|
print(f"{'='*80}")
|
|
print(f"\n{'Dimension':<25} {'Pearson r':>10} {'Sonnet avg':>12} {'Local avg':>12} {'Bias':>8}")
|
|
print("-" * 70)
|
|
|
|
for dim in all_dims:
|
|
s_vals = per_dim_sonnet[dim]
|
|
l_vals = per_dim_local[dim]
|
|
r = _pearson_r(s_vals, l_vals)
|
|
s_avg = statistics.mean(s_vals)
|
|
l_avg = statistics.mean(l_vals)
|
|
bias = l_avg - s_avg
|
|
|
|
report[dim] = {
|
|
"pearson_r": round(r, 3),
|
|
"sonnet_avg": round(s_avg, 3),
|
|
"local_avg": round(l_avg, 3),
|
|
"bias": round(bias, 3),
|
|
}
|
|
|
|
r_indicator = "✓" if r > 0.7 else "~" if r > 0.4 else "✗"
|
|
print(f"{r_indicator} {dim:<23} {r:>10.3f} {s_avg:>12.3f} {l_avg:>12.3f} {bias:>+8.3f}")
|
|
|
|
# Top disagreements
|
|
disagreements.sort(key=lambda d: d["diff"], reverse=True)
|
|
print(f"\n{'='*80}")
|
|
print(f"TOP DISAGREEMENTS (|diff| > 0.3) — {len(disagreements)} total")
|
|
print(f"{'='*80}")
|
|
for d in disagreements[:20]:
|
|
print(
|
|
f" #{d['variant_id']} {d['entity_id']:30s} {d['dimension']:25s} "
|
|
f"sonnet={d['sonnet']:.2f} local={d['local']:.2f} diff={d['diff']:+.2f}"
|
|
)
|
|
|
|
return {
|
|
"n_variants": len(labeled),
|
|
"per_dimension": report,
|
|
"n_disagreements": len(disagreements),
|
|
"top_disagreements": disagreements[:20],
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
parser = argparse.ArgumentParser(description="Calibrate local scorer against Sonnet")
|
|
parser.add_argument("--device", default="cuda:0", help="GPU device (default: cuda:0)")
|
|
parser.add_argument("--limit", type=int, help="Max variants to score (for quick testing)")
|
|
args = parser.parse_args()
|
|
|
|
run_calibration(device=args.device, limit=args.limit)
|