magicciv/tools/sprite-generation/engine/worker.py
Natalie e12307b43d feat(@projects/@magic-civilization): 🖥️ sprite-gen worker preferences + operations/coverage GUI updates
Adds a GUI preferences module and refines the worker engine, operations panel,
workers page, and coverage page.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-21 07:59:40 -05:00

291 lines
No EOL
10 KiB
Python

"""Background generation worker — submit → collect → score loop.
Runs in a daemon thread from the FastAPI server so the GUI can drive the
full pipeline without CLI commands.
"""
from __future__ import annotations
import asyncio
import json
import threading
from pathlib import Path
from typing import IO
from engine.factory import BACKENDS, create_generator, with_backend
from engine.ranker import SpriteRanker
from engine.registry import SpriteRegistry
TOOL_DIR = Path(__file__).resolve().parent.parent
class GenerationWorker:
"""Thread-hosted asyncio loop for generate → collect → score."""
def __init__(self, registry: SpriteRegistry, raw_dir: Path, log_path: Path):
self.registry = registry
self.raw_dir = raw_dir
self.log_path = log_path
self._thread: threading.Thread | None = None
self._stop = threading.Event()
self._log_file: IO[str] | None = None
@property
def running(self) -> bool:
return self._thread is not None and self._thread.is_alive()
def _log(self, msg: str) -> None:
line = msg.rstrip()
print(line, flush=True)
if self._log_file:
self._log_file.write(line + "\n")
self._log_file.flush()
def start(
self,
*,
backend: str | None = None,
variants: int = 3,
category: str | None = None,
starter_only: bool = False,
max_attempts: int = 15,
batch_size: int = 4,
) -> dict:
if self.running:
return {"status": "already_running"}
if backend and backend not in BACKENDS:
return {"status": "error", "message": f"Unknown backend {backend!r}"}
self._stop.clear()
self.log_path.parent.mkdir(parents=True, exist_ok=True)
self._log_file = open(self.log_path, "w", encoding="utf-8")
self._thread = threading.Thread(
target=self._thread_main,
args=(backend, variants, category, starter_only, max_attempts, batch_size),
daemon=True,
)
self._thread.start()
return {"status": "started"}
def stop(self) -> dict:
if not self.running:
return {"status": "not_running"}
self._stop.set()
return {"status": "stopping"}
def status(self) -> dict:
log_tail = ""
if self.log_path.exists():
lines = self.log_path.read_text(encoding="utf-8", errors="replace").splitlines()
log_tail = "\n".join(lines[-30:])
queued = self.registry.conn.execute(
"SELECT COUNT(*) FROM variants WHERE job_status = 'submitted'"
).fetchone()[0]
stats = self.registry.get_stats()
total_row = stats.get("total", {})
return {
"running": self.running,
"queued_variants": queued,
"needed": total_row.get("needed", 0),
"review": total_row.get("review", 0),
"approved": total_row.get("approved", 0),
"installed": total_row.get("installed", 0),
"log_tail": log_tail,
}
def _thread_main(
self,
backend: str | None,
variants: int,
category: str | None,
starter_only: bool,
max_attempts: int,
batch_size: int,
) -> None:
try:
asyncio.run(
self._run_loop(
backend=backend,
variants=variants,
category=category,
starter_only=starter_only,
max_attempts=max_attempts,
batch_size=batch_size,
)
)
except Exception as exc:
self._log(f"[worker-error] {type(exc).__name__}: {exc}")
finally:
if self._log_file:
self._log_file.close()
self._log_file = None
async def _run_loop(
self,
*,
backend: str | None,
variants: int,
category: str | None,
starter_only: bool,
max_attempts: int,
batch_size: int,
) -> None:
config = json.loads((TOOL_DIR / "sprite-config.json").read_text())
if backend:
config = with_backend(config, backend)
gen = create_generator(config=config, registry=self.registry, raw_dir=self.raw_dir)
ranker = SpriteRanker(registry=self.registry, raw_dir=self.raw_dir)
regen_counts: dict[str, int] = {}
starter_ids = None
if starter_only:
from engine.starter import starter_sprite_ids
starter_ids = set(starter_sprite_ids())
self._log(f"Worker started — {config['backend']} backend, {variants} variants/job")
async def _on_complete(variant_id: int, sprite_id: str) -> None:
await ranker.advance_sprite(sprite_id)
loop_count = 0
while not self._stop.is_set():
loop_count += 1
try:
await self._run_loop_iteration(
loop_count=loop_count,
config=config,
gen=gen,
ranker=ranker,
regen_counts=regen_counts,
variants=variants,
category=category,
starter_only=starter_only,
starter_ids=starter_ids,
batch_size=batch_size,
max_attempts=max_attempts,
on_complete=_on_complete,
)
except Exception as exc:
self._log(f"[loop {loop_count} error] {type(exc).__name__}: {exc}")
await asyncio.sleep(5)
self._log("Worker stopped.")
async def _run_loop_iteration(
self,
*,
loop_count: int,
config: dict,
gen,
ranker,
regen_counts: dict[str, int],
variants: int,
category: str | None,
starter_only: bool,
starter_ids: set[str] | None,
batch_size: int,
max_attempts: int,
on_complete,
) -> None:
if starter_only and starter_ids:
needed_rows = self.registry.conn.execute(
f"""
SELECT s.id FROM sprites s
LEFT JOIN variants v ON v.sprite_id = s.id
WHERE s.id IN ({",".join("?" * len(starter_ids))})
AND s.status = 'needed'
GROUP BY s.id
ORDER BY COUNT(v.id) ASC, s.id ASC
""",
list(starter_ids),
).fetchall()
to_submit = [r[0] for r in needed_rows[:batch_size]]
else:
needed = self.registry.get_sprites(
category=category,
status="needed",
limit=batch_size * 4,
)
in_flight_ids: set[str] = set()
if needed:
placeholders = ",".join("?" * len(needed))
ids = [s["id"] for s in needed]
rows = self.registry.conn.execute(
f"SELECT DISTINCT sprite_id FROM variants "
f"WHERE sprite_id IN ({placeholders}) AND job_status='submitted'",
ids,
).fetchall()
in_flight_ids = {r[0] for r in rows}
to_submit = []
for candidate in needed or []:
sid = candidate["id"]
if sid in in_flight_ids:
continue
if regen_counts.get(sid, 0) < max_attempts:
to_submit.append(sid)
else:
self.registry.update_sprite_status(sid, "review")
self._log(f" {sid}: max regen attempts, moving to review")
if len(to_submit) >= batch_size:
break
submitted_this_loop = 0
if to_submit:
self._log(f"\n[loop {loop_count}] Submitting {len(to_submit)} sprites x {variants} variants...")
submit_cb = on_complete if config["backend"] == "model-boss" else None
submitted_this_loop = await gen.submit_batch(
sprite_ids=to_submit,
variants_per=variants,
priority="high",
on_complete=submit_cb,
)
collected = 0
if config["backend"] == "model-boss":
collected = await gen.collect_pending(on_complete=on_complete)
if collected:
self._log(f" Collected {collected} images")
if config["backend"] == "grok" and to_submit:
for sid in to_submit:
await ranker.rank_and_filter(sid)
if collected > 0 or submitted_this_loop > 0 or config["backend"] == "grok":
review_sprites = self.registry.get_sprites(
category=category,
status="review",
limit=500,
)
for sprite in review_sprites or []:
if starter_ids and sprite["id"] not in starter_ids:
continue
result = await ranker.rank_and_filter(sprite["id"])
if result["needs_regen"]:
attempt = regen_counts.get(sprite["id"], 0)
if attempt < max_attempts:
self.registry.update_sprite_status(sprite["id"], "needed")
self._log(
f" {sprite['id']}: {result['good_count']} good — needs regen"
)
elif result["ranked"]:
best = result["ranked"][0]
self._log(
f" {sprite['id']}: {result['good_count']} good — "
f"best={best['confidence']:.0%}"
)
st = self.status()
self._log(
f"[loop {loop_count}] needed={st['needed']} queued={st['queued_variants']} "
f"review={st['review']} done={st['approved'] + st['installed']}"
)
if st["needed"] == 0 and st["queued_variants"] == 0:
self._log("\nAll sprites processed. Worker idle (polling).")
await asyncio.sleep(10)
else:
await asyncio.sleep(2)