Adds a GUI preferences module and refines the worker engine, operations panel, workers page, and coverage page. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
291 lines
No EOL
10 KiB
Python
291 lines
No EOL
10 KiB
Python
"""Background generation worker — submit → collect → score loop.
|
|
|
|
Runs in a daemon thread from the FastAPI server so the GUI can drive the
|
|
full pipeline without CLI commands.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import threading
|
|
from pathlib import Path
|
|
from typing import IO
|
|
|
|
from engine.factory import BACKENDS, create_generator, with_backend
|
|
from engine.ranker import SpriteRanker
|
|
from engine.registry import SpriteRegistry
|
|
TOOL_DIR = Path(__file__).resolve().parent.parent
|
|
|
|
|
|
class GenerationWorker:
|
|
"""Thread-hosted asyncio loop for generate → collect → score."""
|
|
|
|
def __init__(self, registry: SpriteRegistry, raw_dir: Path, log_path: Path):
|
|
self.registry = registry
|
|
self.raw_dir = raw_dir
|
|
self.log_path = log_path
|
|
self._thread: threading.Thread | None = None
|
|
self._stop = threading.Event()
|
|
self._log_file: IO[str] | None = None
|
|
|
|
@property
|
|
def running(self) -> bool:
|
|
return self._thread is not None and self._thread.is_alive()
|
|
|
|
def _log(self, msg: str) -> None:
|
|
line = msg.rstrip()
|
|
print(line, flush=True)
|
|
if self._log_file:
|
|
self._log_file.write(line + "\n")
|
|
self._log_file.flush()
|
|
|
|
def start(
|
|
self,
|
|
*,
|
|
backend: str | None = None,
|
|
variants: int = 3,
|
|
category: str | None = None,
|
|
starter_only: bool = False,
|
|
max_attempts: int = 15,
|
|
batch_size: int = 4,
|
|
) -> dict:
|
|
if self.running:
|
|
return {"status": "already_running"}
|
|
|
|
if backend and backend not in BACKENDS:
|
|
return {"status": "error", "message": f"Unknown backend {backend!r}"}
|
|
|
|
self._stop.clear()
|
|
self.log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._log_file = open(self.log_path, "w", encoding="utf-8")
|
|
|
|
self._thread = threading.Thread(
|
|
target=self._thread_main,
|
|
args=(backend, variants, category, starter_only, max_attempts, batch_size),
|
|
daemon=True,
|
|
)
|
|
self._thread.start()
|
|
return {"status": "started"}
|
|
|
|
def stop(self) -> dict:
|
|
if not self.running:
|
|
return {"status": "not_running"}
|
|
self._stop.set()
|
|
return {"status": "stopping"}
|
|
|
|
def status(self) -> dict:
|
|
log_tail = ""
|
|
if self.log_path.exists():
|
|
lines = self.log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
|
log_tail = "\n".join(lines[-30:])
|
|
|
|
queued = self.registry.conn.execute(
|
|
"SELECT COUNT(*) FROM variants WHERE job_status = 'submitted'"
|
|
).fetchone()[0]
|
|
|
|
stats = self.registry.get_stats()
|
|
total_row = stats.get("total", {})
|
|
|
|
return {
|
|
"running": self.running,
|
|
"queued_variants": queued,
|
|
"needed": total_row.get("needed", 0),
|
|
"review": total_row.get("review", 0),
|
|
"approved": total_row.get("approved", 0),
|
|
"installed": total_row.get("installed", 0),
|
|
"log_tail": log_tail,
|
|
}
|
|
|
|
def _thread_main(
|
|
self,
|
|
backend: str | None,
|
|
variants: int,
|
|
category: str | None,
|
|
starter_only: bool,
|
|
max_attempts: int,
|
|
batch_size: int,
|
|
) -> None:
|
|
try:
|
|
asyncio.run(
|
|
self._run_loop(
|
|
backend=backend,
|
|
variants=variants,
|
|
category=category,
|
|
starter_only=starter_only,
|
|
max_attempts=max_attempts,
|
|
batch_size=batch_size,
|
|
)
|
|
)
|
|
except Exception as exc:
|
|
self._log(f"[worker-error] {type(exc).__name__}: {exc}")
|
|
finally:
|
|
if self._log_file:
|
|
self._log_file.close()
|
|
self._log_file = None
|
|
|
|
async def _run_loop(
|
|
self,
|
|
*,
|
|
backend: str | None,
|
|
variants: int,
|
|
category: str | None,
|
|
starter_only: bool,
|
|
max_attempts: int,
|
|
batch_size: int,
|
|
) -> None:
|
|
config = json.loads((TOOL_DIR / "sprite-config.json").read_text())
|
|
if backend:
|
|
config = with_backend(config, backend)
|
|
|
|
gen = create_generator(config=config, registry=self.registry, raw_dir=self.raw_dir)
|
|
ranker = SpriteRanker(registry=self.registry, raw_dir=self.raw_dir)
|
|
regen_counts: dict[str, int] = {}
|
|
|
|
starter_ids = None
|
|
if starter_only:
|
|
from engine.starter import starter_sprite_ids
|
|
starter_ids = set(starter_sprite_ids())
|
|
self._log(f"Worker started — {config['backend']} backend, {variants} variants/job")
|
|
|
|
async def _on_complete(variant_id: int, sprite_id: str) -> None:
|
|
await ranker.advance_sprite(sprite_id)
|
|
|
|
loop_count = 0
|
|
while not self._stop.is_set():
|
|
loop_count += 1
|
|
try:
|
|
await self._run_loop_iteration(
|
|
loop_count=loop_count,
|
|
config=config,
|
|
gen=gen,
|
|
ranker=ranker,
|
|
regen_counts=regen_counts,
|
|
variants=variants,
|
|
category=category,
|
|
starter_only=starter_only,
|
|
starter_ids=starter_ids,
|
|
batch_size=batch_size,
|
|
max_attempts=max_attempts,
|
|
on_complete=_on_complete,
|
|
)
|
|
except Exception as exc:
|
|
self._log(f"[loop {loop_count} error] {type(exc).__name__}: {exc}")
|
|
await asyncio.sleep(5)
|
|
|
|
self._log("Worker stopped.")
|
|
|
|
async def _run_loop_iteration(
|
|
self,
|
|
*,
|
|
loop_count: int,
|
|
config: dict,
|
|
gen,
|
|
ranker,
|
|
regen_counts: dict[str, int],
|
|
variants: int,
|
|
category: str | None,
|
|
starter_only: bool,
|
|
starter_ids: set[str] | None,
|
|
batch_size: int,
|
|
max_attempts: int,
|
|
on_complete,
|
|
) -> None:
|
|
if starter_only and starter_ids:
|
|
needed_rows = self.registry.conn.execute(
|
|
f"""
|
|
SELECT s.id FROM sprites s
|
|
LEFT JOIN variants v ON v.sprite_id = s.id
|
|
WHERE s.id IN ({",".join("?" * len(starter_ids))})
|
|
AND s.status = 'needed'
|
|
GROUP BY s.id
|
|
ORDER BY COUNT(v.id) ASC, s.id ASC
|
|
""",
|
|
list(starter_ids),
|
|
).fetchall()
|
|
to_submit = [r[0] for r in needed_rows[:batch_size]]
|
|
else:
|
|
needed = self.registry.get_sprites(
|
|
category=category,
|
|
status="needed",
|
|
limit=batch_size * 4,
|
|
)
|
|
in_flight_ids: set[str] = set()
|
|
if needed:
|
|
placeholders = ",".join("?" * len(needed))
|
|
ids = [s["id"] for s in needed]
|
|
rows = self.registry.conn.execute(
|
|
f"SELECT DISTINCT sprite_id FROM variants "
|
|
f"WHERE sprite_id IN ({placeholders}) AND job_status='submitted'",
|
|
ids,
|
|
).fetchall()
|
|
in_flight_ids = {r[0] for r in rows}
|
|
|
|
to_submit = []
|
|
for candidate in needed or []:
|
|
sid = candidate["id"]
|
|
if sid in in_flight_ids:
|
|
continue
|
|
if regen_counts.get(sid, 0) < max_attempts:
|
|
to_submit.append(sid)
|
|
else:
|
|
self.registry.update_sprite_status(sid, "review")
|
|
self._log(f" {sid}: max regen attempts, moving to review")
|
|
if len(to_submit) >= batch_size:
|
|
break
|
|
|
|
submitted_this_loop = 0
|
|
if to_submit:
|
|
self._log(f"\n[loop {loop_count}] Submitting {len(to_submit)} sprites x {variants} variants...")
|
|
submit_cb = on_complete if config["backend"] == "model-boss" else None
|
|
submitted_this_loop = await gen.submit_batch(
|
|
sprite_ids=to_submit,
|
|
variants_per=variants,
|
|
priority="high",
|
|
on_complete=submit_cb,
|
|
)
|
|
|
|
collected = 0
|
|
if config["backend"] == "model-boss":
|
|
collected = await gen.collect_pending(on_complete=on_complete)
|
|
if collected:
|
|
self._log(f" Collected {collected} images")
|
|
|
|
if config["backend"] == "grok" and to_submit:
|
|
for sid in to_submit:
|
|
await ranker.rank_and_filter(sid)
|
|
|
|
if collected > 0 or submitted_this_loop > 0 or config["backend"] == "grok":
|
|
review_sprites = self.registry.get_sprites(
|
|
category=category,
|
|
status="review",
|
|
limit=500,
|
|
)
|
|
for sprite in review_sprites or []:
|
|
if starter_ids and sprite["id"] not in starter_ids:
|
|
continue
|
|
result = await ranker.rank_and_filter(sprite["id"])
|
|
if result["needs_regen"]:
|
|
attempt = regen_counts.get(sprite["id"], 0)
|
|
if attempt < max_attempts:
|
|
self.registry.update_sprite_status(sprite["id"], "needed")
|
|
self._log(
|
|
f" {sprite['id']}: {result['good_count']} good — needs regen"
|
|
)
|
|
elif result["ranked"]:
|
|
best = result["ranked"][0]
|
|
self._log(
|
|
f" {sprite['id']}: {result['good_count']} good — "
|
|
f"best={best['confidence']:.0%}"
|
|
)
|
|
|
|
st = self.status()
|
|
self._log(
|
|
f"[loop {loop_count}] needed={st['needed']} queued={st['queued_variants']} "
|
|
f"review={st['review']} done={st['approved'] + st['installed']}"
|
|
)
|
|
|
|
if st["needed"] == 0 and st["queued_variants"] == 0:
|
|
self._log("\nAll sprites processed. Worker idle (polling).")
|
|
await asyncio.sleep(10)
|
|
else:
|
|
await asyncio.sleep(2) |