462 lines
17 KiB
Python
462 lines
17 KiB
Python
#!/usr/bin/env python3
|
||
"""Sprite generation pipeline CLI for Magic Civilization.
|
||
|
||
Usage:
|
||
python3 tools/sprite-generation/cli.py scan
|
||
python3 tools/sprite-generation/cli.py status
|
||
python3 tools/sprite-generation/cli.py generate --category terrain --variants 8
|
||
python3 tools/sprite-generation/cli.py poll --watch
|
||
python3 tools/sprite-generation/cli.py review --port 5801
|
||
python3 tools/sprite-generation/cli.py install
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
TOOL_DIR = Path(__file__).resolve().parent
|
||
PROJECT = TOOL_DIR.parent.parent
|
||
sys.path.insert(0, str(TOOL_DIR))
|
||
|
||
CONFIG_PATH = TOOL_DIR / "sprite-config.json"
|
||
DB_PATH = TOOL_DIR / "sprites.db"
|
||
RAW_DIR = TOOL_DIR / "raw"
|
||
VARIANTS_DIR = TOOL_DIR / "variants"
|
||
HEX_MASK_PATH = TOOL_DIR / "hex_mask.png"
|
||
|
||
LOCAL_DATA = PROJECT / "games" / "age-of-dwarves" / "data"
|
||
DEMO_DATA = TOOL_DIR / "demo-data"
|
||
ASSETS_DIR = PROJECT / "games" / "age-of-dwarves" / "assets"
|
||
|
||
|
||
def _data_dir(args: argparse.Namespace) -> Path:
|
||
if hasattr(args, "data_dir") and args.data_dir:
|
||
return Path(args.data_dir)
|
||
if hasattr(args, "demo") and args.demo:
|
||
return DEMO_DATA
|
||
return LOCAL_DATA
|
||
|
||
|
||
def _load_config() -> dict:
|
||
return json.loads(CONFIG_PATH.read_text())
|
||
|
||
|
||
def _registry():
|
||
from engine.registry import SpriteRegistry
|
||
return SpriteRegistry(DB_PATH)
|
||
|
||
|
||
def cmd_scan(args: argparse.Namespace) -> None:
|
||
from engine.scanner import SpriteScanner
|
||
|
||
reg = _registry()
|
||
data = _data_dir(args)
|
||
print(f"Scanning data from: {data}")
|
||
demo = hasattr(args, "demo") and args.demo
|
||
scanner = SpriteScanner(data_dir=data, assets_dir=ASSETS_DIR, registry=reg)
|
||
if demo:
|
||
report = scanner.scan_all(skip_biome_grid=True, skip_ui=True)
|
||
else:
|
||
report = scanner.scan_all()
|
||
|
||
print(f"\nBy category:")
|
||
for cat, count in sorted(report.categories.items()):
|
||
print(f" {cat:20s} {count:4d} sprites")
|
||
print(f"\n {'TOTAL':20s} {report.new_sprites:4d} new, {report.existing_sprites} existing")
|
||
print(f" {'DIMENSIONS':20s} {report.new_dimensions:4d} new, {report.existing_dimensions} existing")
|
||
|
||
|
||
def cmd_status(args: argparse.Namespace) -> None:
|
||
reg = _registry()
|
||
stats = reg.get_stats()
|
||
|
||
if not stats["by_category"]:
|
||
print("No sprites in registry. Run 'scan' first.")
|
||
return
|
||
|
||
print(f"\n{'Category':<20s} {'needed':>8s} {'generating':>12s} {'review':>8s} {'approved':>10s} {'installed':>10s} {'rejected':>10s} {'skip':>6s} {'TOTAL':>8s}")
|
||
print("─" * 102)
|
||
|
||
for cat in sorted(stats["by_category"]):
|
||
row = stats["by_category"][cat]
|
||
total = sum(row.values())
|
||
print(
|
||
f"{cat:<20s} {row.get('needed', 0):>8d} {row.get('generating', 0):>12d} "
|
||
f"{row.get('review', 0):>8d} {row.get('approved', 0):>10d} {row.get('installed', 0):>10d} "
|
||
f"{row.get('rejected', 0):>10d} {row.get('skip', 0):>6d} {total:>8d}"
|
||
)
|
||
|
||
total_row = stats["total"]
|
||
grand = sum(total_row.values())
|
||
print("─" * 102)
|
||
print(
|
||
f"{'TOTAL':<20s} {total_row.get('needed', 0):>8d} {total_row.get('generating', 0):>12d} "
|
||
f"{total_row.get('review', 0):>8d} {total_row.get('approved', 0):>10d} {total_row.get('installed', 0):>10d} "
|
||
f"{total_row.get('rejected', 0):>10d} {total_row.get('skip', 0):>6d} {grand:>8d}"
|
||
)
|
||
|
||
|
||
def cmd_generate(args: argparse.Namespace) -> None:
|
||
import asyncio
|
||
from engine.generator import SpriteGenerator
|
||
|
||
reg = _registry()
|
||
config = _load_config()
|
||
gen = SpriteGenerator(config=config, registry=reg, raw_dir=RAW_DIR)
|
||
|
||
sprites = reg.get_sprites(
|
||
category=args.category,
|
||
status="needed",
|
||
limit=args.max or 10000,
|
||
)
|
||
|
||
if args.sprite:
|
||
sprites = [s for s in sprites if s["id"] == args.sprite]
|
||
|
||
if not sprites:
|
||
print("No sprites in 'needed' status matching filters.")
|
||
return
|
||
|
||
sprite_ids = [s["id"] for s in sprites]
|
||
|
||
if args.dry_run:
|
||
print(f"Would generate {len(sprite_ids)} sprites × {args.variants} variants = {len(sprite_ids) * args.variants} jobs\n")
|
||
for sid in sprite_ids[:50]:
|
||
print(f" {sid}")
|
||
if len(sprite_ids) > 50:
|
||
print(f" ... and {len(sprite_ids) - 50} more")
|
||
return
|
||
|
||
pose_ref = Path(args.pose_ref) if args.pose_ref else None
|
||
completed = asyncio.run(gen.generate_batch(
|
||
sprite_ids=sprite_ids,
|
||
variants_per=args.variants,
|
||
priority=args.priority,
|
||
pose_reference=pose_ref,
|
||
img2img_strength=args.strength,
|
||
))
|
||
print(f"\nGenerated {completed} images. Run 'rank' to score them.")
|
||
|
||
|
||
def cmd_rank(args: argparse.Namespace) -> None:
|
||
import asyncio
|
||
from engine.ranker import SpriteRanker
|
||
|
||
reg = _registry()
|
||
ranker = SpriteRanker(registry=reg, raw_dir=RAW_DIR)
|
||
|
||
if args.sprite:
|
||
result = asyncio.run(ranker.rank_and_filter(args.sprite))
|
||
print(f"\n{args.sprite}: {result['good_count']}/{len(result['ranked'])} good variants")
|
||
for r in result["ranked"]:
|
||
scores = r["scores"]
|
||
flag = "✓" if r["confidence"] >= 0.7 else "✗"
|
||
dims = " ".join(f"{k}={v:.2f}" for k, v in scores.items())
|
||
print(f" {flag} variant {r['variant_id']} (seed {r['seed']}): confidence={r['confidence']:.2f} {dims}")
|
||
if result["needs_regen"]:
|
||
print(f"\n Needs {result['deficit']} more good variants — re-generate with higher priority")
|
||
else:
|
||
print(f"Ranking all sprites in 'review' status...")
|
||
summary = asyncio.run(ranker.rank_all_review(category=args.category))
|
||
print(f"\nSummary: {summary['ready']} ready, {summary['need_regen']} need re-generation (of {summary['total']} total)")
|
||
|
||
|
||
def cmd_install(args: argparse.Namespace) -> None:
|
||
from engine.installer import SpriteInstaller
|
||
|
||
reg = _registry()
|
||
installer = SpriteInstaller(assets_dir=ASSETS_DIR, registry=reg)
|
||
count = installer.install_approved(category=args.category, dry_run=args.dry_run)
|
||
print(f"\n{'Would install' if args.dry_run else 'Installed'}: {count} sprites")
|
||
|
||
|
||
def cmd_approve(args: argparse.Namespace) -> None:
|
||
from engine.pipeline import SpritePipeline
|
||
|
||
reg = _registry()
|
||
pipeline = SpritePipeline(
|
||
registry=reg,
|
||
raw_dir=RAW_DIR,
|
||
variants_dir=VARIANTS_DIR,
|
||
assets_dir=ASSETS_DIR,
|
||
game_db_path=LOCAL_DATA / "sprites.db",
|
||
)
|
||
result = pipeline.approve_and_install(args.variant, alt_name=args.alt)
|
||
if result:
|
||
print(f"\nShipped: {result}")
|
||
else:
|
||
print("\nFailed to install.")
|
||
|
||
|
||
def cmd_reset(args: argparse.Namespace) -> None:
|
||
reg = _registry()
|
||
if not args.sprite:
|
||
print("--sprite is required for reset")
|
||
return
|
||
reg.update_sprite_status(args.sprite, "needed")
|
||
print(f"Reset {args.sprite} to 'needed'")
|
||
|
||
|
||
def cmd_review(args: argparse.Namespace) -> None:
|
||
print(f"Starting review GUI server on port {args.port}...")
|
||
print(f"Open http://localhost:{args.port} in your browser")
|
||
try:
|
||
from server import create_app
|
||
import uvicorn
|
||
app = create_app(
|
||
registry=_registry(),
|
||
raw_dir=RAW_DIR,
|
||
variants_dir=VARIANTS_DIR,
|
||
)
|
||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||
except ImportError as e:
|
||
print(f"Error: {e}")
|
||
print("Install dependencies: pip install fastapi uvicorn")
|
||
|
||
|
||
def cmd_run(args: argparse.Namespace) -> None:
|
||
"""Orchestrate the full generate → rank → regen loop continuously.
|
||
|
||
Runs until all sprites are in review (ranked) or installed.
|
||
The human reviews and approves via the GUI at localhost:5850.
|
||
"""
|
||
import asyncio
|
||
import time
|
||
from engine.generator import SpriteGenerator
|
||
from engine.ranker import SpriteRanker
|
||
|
||
reg = _registry()
|
||
config = _load_config()
|
||
gen = SpriteGenerator(config=config, registry=reg, raw_dir=RAW_DIR)
|
||
ranker = SpriteRanker(registry=reg, raw_dir=RAW_DIR)
|
||
|
||
pose_ref = Path(args.pose_ref) if args.pose_ref else None
|
||
batch_size = args.batch or 4
|
||
variants_per = args.variants
|
||
|
||
# Start GUI server in background thread
|
||
import threading
|
||
def _serve():
|
||
from server import create_app
|
||
import uvicorn
|
||
app = create_app(registry=reg, raw_dir=RAW_DIR, variants_dir=VARIANTS_DIR)
|
||
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="warning")
|
||
|
||
server_thread = threading.Thread(target=_serve, daemon=True)
|
||
server_thread.start()
|
||
print(f"Review GUI: http://localhost:{args.port}/?spriteTheater=true")
|
||
|
||
MAX_REGEN_ATTEMPTS = 5 # Max times a sprite gets regenerated before skipping
|
||
regen_counts: dict[str, int] = {}
|
||
|
||
loop_count = 0
|
||
while True:
|
||
loop_count += 1
|
||
|
||
# --- Phase 1: Generate ONE sprite (tight feedback loop) ---
|
||
needed = reg.get_sprites(
|
||
category=args.category,
|
||
status="needed",
|
||
limit=10,
|
||
)
|
||
|
||
# Skip sprites that have hit max regen attempts
|
||
sprite = None
|
||
for candidate in (needed or []):
|
||
sid = candidate["id"]
|
||
if regen_counts.get(sid, 0) < MAX_REGEN_ATTEMPTS:
|
||
sprite = candidate
|
||
break
|
||
else:
|
||
# Skip this sprite, move to review so human can pick from what exists
|
||
reg.update_sprite_status(sid, "review")
|
||
print(f" {sid}: max regen attempts reached, moving to review")
|
||
|
||
if sprite:
|
||
sprite_id = sprite["id"]
|
||
attempt = regen_counts.get(sprite_id, 0) + 1
|
||
regen_counts[sprite_id] = attempt
|
||
print(f"\n[loop {loop_count}] Generating {sprite_id} × {variants_per} variants (attempt {attempt}/{MAX_REGEN_ATTEMPTS})...")
|
||
completed = asyncio.run(gen.generate_batch(
|
||
sprite_ids=[sprite_id],
|
||
variants_per=variants_per,
|
||
priority="high",
|
||
pose_reference=pose_ref,
|
||
img2img_strength=args.strength,
|
||
))
|
||
print(f" Generated {completed} images for {sprite_id}")
|
||
|
||
# Immediately rank this sprite
|
||
print(f" Ranking {sprite_id}...", end=" ", flush=True)
|
||
result = asyncio.run(ranker.rank_and_filter(sprite_id))
|
||
good = result["good_count"]
|
||
total = len(result["ranked"])
|
||
if result["needs_regen"]:
|
||
if attempt >= MAX_REGEN_ATTEMPTS:
|
||
reg.update_sprite_status(sprite_id, "review")
|
||
print(f"{good}/{total} good — max attempts, moving to review anyway")
|
||
else:
|
||
reg.update_sprite_status(sprite_id, "needed")
|
||
print(f"{good}/{total} good — needs regen (attempt {attempt})")
|
||
else:
|
||
print(f"{good}/{total} good — ready for review")
|
||
# Print score breakdown for best variant
|
||
if result["ranked"]:
|
||
best = result["ranked"][0]
|
||
dims = " ".join(f"{k}={v:.0%}" for k, v in best["scores"].items())
|
||
print(f" Best #{best['variant_id']}: {best['confidence']:.0%} — {dims}")
|
||
|
||
# --- Phase 2: Rank any remaining unranked sprites ---
|
||
unranked = reg.conn.execute("""
|
||
SELECT DISTINCT s.id FROM sprites s
|
||
WHERE s.status = 'review'
|
||
AND NOT EXISTS (
|
||
SELECT 1 FROM variants v
|
||
WHERE v.sprite_id = s.id
|
||
AND v.job_status = 'completed'
|
||
AND v.notes IS NOT NULL
|
||
)
|
||
LIMIT ?
|
||
""", (batch_size,)).fetchall()
|
||
|
||
if unranked:
|
||
print(f"\n[loop {loop_count}] Ranking {len(unranked)} unranked sprites...")
|
||
for row in unranked:
|
||
sid = row["id"]
|
||
result = asyncio.run(ranker.rank_and_filter(sid))
|
||
good = result["good_count"]
|
||
total = len(result["ranked"])
|
||
if result["needs_regen"]:
|
||
reg.update_sprite_status(sid, "needed")
|
||
print(f" {sid}: {good}/{total} good — needs regen")
|
||
else:
|
||
print(f" {sid}: {good}/{total} good — ready for review")
|
||
|
||
# --- Status ---
|
||
stats = reg.get_stats()
|
||
total_row = stats["total"]
|
||
needed_count = total_row.get("needed", 0)
|
||
review_count = total_row.get("review", 0)
|
||
approved_count = total_row.get("approved", 0)
|
||
installed_count = total_row.get("installed", 0)
|
||
total_count = sum(total_row.values())
|
||
|
||
done_count = approved_count + installed_count
|
||
print(f"\n[loop {loop_count}] Status: {needed_count} needed, {review_count} review, {done_count} done / {total_count} total")
|
||
|
||
if needed_count == 0 and review_count == 0:
|
||
print("\nAll sprites processed. Waiting for new work...")
|
||
|
||
# --- Pause between loops ---
|
||
pause = 5 if sprite else 30
|
||
print(f" Next loop in {pause}s... (Ctrl+C to stop)")
|
||
try:
|
||
time.sleep(pause)
|
||
except KeyboardInterrupt:
|
||
print("\nStopping pipeline.")
|
||
break
|
||
|
||
|
||
def cmd_export(args: argparse.Namespace) -> None:
|
||
import csv
|
||
import io
|
||
|
||
reg = _registry()
|
||
sprites = reg.get_sprites(limit=100000)
|
||
|
||
if args.format == "json":
|
||
print(json.dumps(sprites, indent=2))
|
||
else:
|
||
if not sprites:
|
||
print("No sprites in registry.")
|
||
return
|
||
writer = csv.DictWriter(sys.stdout, fieldnames=sprites[0].keys())
|
||
writer.writeheader()
|
||
for s in sprites:
|
||
writer.writerow(s)
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(
|
||
description="Sprite generation pipeline for Magic Civilization",
|
||
prog="sprite-gen",
|
||
)
|
||
parser.add_argument("--data-dir", type=str, help="Override game data directory path")
|
||
parser.add_argument("--demo", action="store_true", help="Use minimal demo data (1 per domain)")
|
||
sub = parser.add_subparsers(dest="command")
|
||
|
||
# run — full pipeline orchestrator
|
||
p = sub.add_parser("run", help="Run full pipeline: generate → rank → regen loop + GUI")
|
||
p.add_argument("--port", type=int, default=5850, help="GUI server port (default: 5850)")
|
||
p.add_argument("--category", type=str, help="Only process one category")
|
||
p.add_argument("--variants", type=int, default=4, help="Variants per sprite (default: 4)")
|
||
p.add_argument("--batch", type=int, default=4, help="Sprites per batch (default: 4)")
|
||
p.add_argument("--pose-ref", type=str, help="Pose reference image for img2img")
|
||
p.add_argument("--strength", type=float, default=0.6, help="img2img strength (default: 0.6)")
|
||
p.set_defaults(func=cmd_run)
|
||
|
||
# start — launch GUI server only
|
||
p = sub.add_parser("start", help="Launch review GUI server")
|
||
p.add_argument("--port", type=int, default=5850, help="Server port (default: 5850)")
|
||
p.set_defaults(func=cmd_review)
|
||
|
||
# scan
|
||
p = sub.add_parser("scan", help="Scan game data, populate sprite registry")
|
||
p.set_defaults(func=cmd_scan)
|
||
|
||
# status
|
||
p = sub.add_parser("status", help="Show sprite counts by category and status")
|
||
p.set_defaults(func=cmd_status)
|
||
|
||
# generate
|
||
p = sub.add_parser("generate", help="Submit generation jobs to model-boss")
|
||
p.add_argument("--category", type=str, help="Only generate one category")
|
||
p.add_argument("--sprite", type=str, help="Generate single sprite by ID")
|
||
p.add_argument("--variants", type=int, default=8, help="Variants per sprite (default: 8)")
|
||
p.add_argument("--priority", choices=["low", "normal", "high"], default="normal")
|
||
p.add_argument("--max", type=int, help="Max sprites to generate")
|
||
p.add_argument("--dry-run", action="store_true", help="Show what would be generated")
|
||
p.add_argument("--pose-ref", type=str, help="Path to pose reference image for img2img (consistent facing)")
|
||
p.add_argument("--strength", type=float, default=0.75, help="img2img denoising strength (0=copy, 1=ignore ref)")
|
||
p.set_defaults(func=cmd_generate)
|
||
|
||
# rank
|
||
p = sub.add_parser("rank", help="AI-rank variants using Haiku vision")
|
||
p.add_argument("--category", type=str, help="Only rank one category")
|
||
p.add_argument("--sprite", type=str, help="Rank single sprite by ID")
|
||
p.set_defaults(func=cmd_rank)
|
||
|
||
# install
|
||
p = sub.add_parser("install", help="Copy approved sprites to game assets")
|
||
p.add_argument("--category", type=str, help="Only install one category")
|
||
p.add_argument("--dry-run", action="store_true")
|
||
p.set_defaults(func=cmd_install)
|
||
|
||
# approve
|
||
p = sub.add_parser("approve", help="Approve variant → process → install → manifest")
|
||
p.add_argument("variant", type=int, help="Variant ID to approve (e.g. 129)")
|
||
p.add_argument("--alt", type=str, help="Install as alternate (e.g. --alt front)")
|
||
p.set_defaults(func=cmd_approve)
|
||
|
||
# reset
|
||
p = sub.add_parser("reset", help="Reset sprite back to needed status")
|
||
p.add_argument("--sprite", type=str, required=True, help="Sprite ID to reset")
|
||
p.set_defaults(func=cmd_reset)
|
||
|
||
# export
|
||
p = sub.add_parser("export", help="Export registry as CSV or JSON")
|
||
p.add_argument("--format", choices=["csv", "json"], default="csv")
|
||
p.set_defaults(func=cmd_export)
|
||
|
||
args = parser.parse_args()
|
||
if args.command is None:
|
||
args.port = 5850
|
||
cmd_review(args)
|
||
else:
|
||
args.func(args)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|