Adds a worker abstraction (engine/worker.py) with registry wiring and a server operations endpoint, surfaced in the GUI via new Workers and Operations pages plus dashboard/coverage/theater/variant page updates. Refreshes the ranker. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
750 lines
29 KiB
Python
750 lines
29 KiB
Python
"""FastAPI server for the sprite generation review GUI.
|
|
|
|
Serves the REST API for browsing sprites, reviewing variants, and approving/rejecting.
|
|
Also serves raw/variant images and the static GUI build.
|
|
|
|
Usage:
|
|
uvicorn server:app --port 5801 --reload
|
|
# or via CLI:
|
|
python3 cli.py review --port 5801
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Annotated
|
|
|
|
from fastapi import FastAPI, HTTPException, Query
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
|
|
from engine.registry import SpriteRegistry
|
|
|
|
TOOL_DIR = Path(__file__).resolve().parent
|
|
PROJECT = TOOL_DIR.parent.parent
|
|
DB_PATH = TOOL_DIR / "spritegen.db"
|
|
RAW_DIR = TOOL_DIR / "raw"
|
|
VARIANTS_DIR = TOOL_DIR / "variants"
|
|
GUI_DIST = TOOL_DIR / "gui" / "dist"
|
|
ASSETS_DIR = PROJECT / "public" / "games" / "age-of-dwarves" / "assets"
|
|
LOCAL_DATA = PROJECT / "public" / "games" / "age-of-dwarves" / "data"
|
|
|
|
|
|
class ApproveRequest(BaseModel):
|
|
variant_id: int
|
|
dimension_id: int | None = None
|
|
|
|
|
|
class RejectRequest(BaseModel):
|
|
dimension_id: int | None = None
|
|
|
|
|
|
class RejectVariantRequest(BaseModel):
|
|
reason: str | None = None
|
|
|
|
|
|
class RegenerateRequest(BaseModel):
|
|
prompt: str | None = None
|
|
dimension_id: int | None = None
|
|
variants: int = 8
|
|
|
|
|
|
class PromptUpdate(BaseModel):
|
|
prompt: str
|
|
|
|
|
|
class PipelineRunRequest(BaseModel):
|
|
scorers: list[str] | None = None
|
|
skip_process: bool = False
|
|
score_only: bool = False
|
|
rescore: bool = False
|
|
|
|
|
|
class PinSeedRequest(BaseModel):
|
|
seed: int
|
|
category: str
|
|
entity_id: str = ""
|
|
variant_id: int | None = None
|
|
|
|
|
|
class GenerateRequest(BaseModel):
|
|
category: str | None = None
|
|
sprite_id: str | None = None
|
|
variants: int = 8
|
|
priority: str = "normal"
|
|
max_sprites: int | None = None
|
|
backend: str | None = None # "model-boss" | "grok"; defaults to sprite-config.json
|
|
include_review: bool = False # batch mode: also queue review sprites with deficit
|
|
|
|
|
|
class WorkerStartRequest(BaseModel):
|
|
backend: str | None = None
|
|
variants: int = 3
|
|
category: str | None = None
|
|
starter_only: bool = False
|
|
max_attempts: int = 15
|
|
batch_size: int = 4
|
|
|
|
|
|
class InstallRequest(BaseModel):
|
|
category: str | None = None
|
|
dry_run: bool = False
|
|
|
|
|
|
class ShipVariantRequest(BaseModel):
|
|
alt_name: str | None = None
|
|
|
|
|
|
class CollectRequest(BaseModel):
|
|
backend: str | None = None
|
|
|
|
|
|
class StarterRegisterRequest(BaseModel):
|
|
reset_status: bool = True
|
|
|
|
|
|
def create_app(
|
|
registry: SpriteRegistry | None = None,
|
|
raw_dir: Path = RAW_DIR,
|
|
variants_dir: Path = VARIANTS_DIR,
|
|
) -> FastAPI:
|
|
if registry is None:
|
|
registry = SpriteRegistry(DB_PATH)
|
|
|
|
app = FastAPI(title="Magic Civilization Sprite Generator", version="1.0.0")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# ── Sprites ───────────────────────────────────────────────────────────
|
|
|
|
@app.get("/api/sprites")
|
|
def list_sprites(
|
|
category: Annotated[str | None, Query()] = None,
|
|
status: Annotated[str | None, Query()] = None,
|
|
search: Annotated[str | None, Query()] = None,
|
|
limit: Annotated[int, Query(ge=1, le=10000)] = 200,
|
|
offset: Annotated[int, Query(ge=0)] = 0,
|
|
) -> list[dict]:
|
|
return registry.get_sprites(
|
|
category=category, status=status, search=search,
|
|
limit=limit, offset=offset,
|
|
)
|
|
|
|
@app.get("/api/sprites/{sprite_id:path}/variants")
|
|
def get_variants(
|
|
sprite_id: str,
|
|
dimension_id: Annotated[int | None, Query()] = None,
|
|
) -> list[dict]:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
return registry.get_variants(sprite_id, dimension_id=dimension_id)
|
|
|
|
@app.get("/api/sprites/{sprite_id:path}")
|
|
def get_sprite(sprite_id: str) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
return sprite
|
|
|
|
@app.post("/api/sprites/{sprite_id:path}/approve")
|
|
def approve_sprite(sprite_id: str, body: ApproveRequest) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.approve_variant(body.variant_id)
|
|
return {"status": "approved", "variant_id": body.variant_id}
|
|
|
|
@app.post("/api/sprites/{sprite_id:path}/reject")
|
|
def reject_sprite(sprite_id: str, body: RejectRequest) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.reject_sprite(sprite_id, dimension_id=body.dimension_id)
|
|
return {"status": "rejected"}
|
|
|
|
@app.post("/api/sprites/{sprite_id:path}/skip")
|
|
def skip_sprite(sprite_id: str) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.update_sprite_status(sprite_id, "skip")
|
|
return {"status": "skip"}
|
|
|
|
@app.get("/api/variants/{variant_id:int}")
|
|
def get_variant(variant_id: int) -> dict:
|
|
variant = registry.get_variant(variant_id)
|
|
if not variant:
|
|
raise HTTPException(404, f"Variant not found: {variant_id}")
|
|
sprite = registry.get_sprite(variant["sprite_id"])
|
|
if sprite:
|
|
variant["category"] = sprite["category"]
|
|
variant["entity_id"] = sprite["entity_id"]
|
|
return variant
|
|
|
|
@app.get("/api/variants/{variant_id:int}/scores")
|
|
def get_variant_scores(variant_id: int) -> list[dict]:
|
|
"""Get all per-scorer evaluations for a variant."""
|
|
return registry.get_scores(variant_id)
|
|
|
|
@app.post("/api/variants/{variant_id:int}/reject")
|
|
def reject_variant(variant_id: int, body: RejectVariantRequest = RejectVariantRequest()) -> dict:
|
|
registry.reject_variant(variant_id, body.reason)
|
|
return {"status": "rejected", "variant_id": variant_id}
|
|
|
|
@app.post("/api/sprites/{sprite_id:path}/regenerate")
|
|
def regenerate_sprite(sprite_id: str, body: RegenerateRequest) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
if body.prompt:
|
|
registry.conn.execute(
|
|
"UPDATE sprites SET prompt = ?, updated_at = datetime('now') WHERE id = ?",
|
|
(body.prompt, sprite_id),
|
|
)
|
|
registry.conn.commit()
|
|
registry.reject_sprite(sprite_id, dimension_id=body.dimension_id)
|
|
return {"status": "needed", "message": "Ready for regeneration"}
|
|
|
|
@app.put("/api/sprites/{sprite_id:path}/prompt")
|
|
def update_prompt(sprite_id: str, body: PromptUpdate) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.conn.execute(
|
|
"UPDATE sprites SET prompt = ?, updated_at = datetime('now') WHERE id = ?",
|
|
(body.prompt, sprite_id),
|
|
)
|
|
registry.conn.commit()
|
|
return {"status": "updated", "prompt": body.prompt}
|
|
|
|
# ── Progress & Queue ─────────────────────────────────────────────────
|
|
|
|
@app.get("/api/stats")
|
|
def get_stats() -> dict:
|
|
return registry.get_stats()
|
|
|
|
@app.get("/api/progress")
|
|
def get_progress() -> dict:
|
|
return registry.get_progress()
|
|
|
|
@app.get("/api/pipeline")
|
|
def get_pipeline() -> dict:
|
|
return registry.get_pipeline_dashboard()
|
|
|
|
# Track the advance.py subprocess
|
|
_pipeline_proc: dict[str, subprocess.Popen | None] = {"proc": None}
|
|
|
|
_gen_worker: dict[str, object] = {"instance": None}
|
|
|
|
def _get_worker():
|
|
from engine.worker import GenerationWorker
|
|
if _gen_worker["instance"] is None:
|
|
_gen_worker["instance"] = GenerationWorker(
|
|
registry=registry,
|
|
raw_dir=raw_dir,
|
|
log_path=TOOL_DIR / "worker.log",
|
|
)
|
|
return _gen_worker["instance"]
|
|
|
|
@app.post("/api/pipeline/run")
|
|
def run_pipeline(body: PipelineRunRequest) -> dict:
|
|
"""Spawn tools/advance.py as a background subprocess."""
|
|
proc = _pipeline_proc["proc"]
|
|
if proc is not None and proc.poll() is None:
|
|
return {"status": "already_running", "pid": proc.pid}
|
|
|
|
cmd = [sys.executable, "-u", str(TOOL_DIR / "tools" / "advance.py")]
|
|
if body.scorers:
|
|
cmd += ["--scorers"] + body.scorers
|
|
if body.rescore:
|
|
cmd.append("--rescore")
|
|
if body.skip_process:
|
|
cmd.append("--skip-process")
|
|
if body.score_only:
|
|
cmd.append("--score-only")
|
|
|
|
log_path = TOOL_DIR / "advance.log"
|
|
log_file = open(log_path, "w")
|
|
proc = subprocess.Popen(
|
|
cmd,
|
|
stdout=log_file,
|
|
stderr=subprocess.STDOUT,
|
|
cwd=str(TOOL_DIR),
|
|
)
|
|
_pipeline_proc["proc"] = proc
|
|
return {"status": "started", "pid": proc.pid}
|
|
|
|
@app.get("/api/pipeline/status")
|
|
def pipeline_run_status() -> dict:
|
|
"""Check if the advance.py process is still running."""
|
|
proc = _pipeline_proc["proc"]
|
|
if proc is None:
|
|
return {"running": False, "pid": None, "returncode": None, "log_tail": ""}
|
|
|
|
log_path = TOOL_DIR / "advance.log"
|
|
log_tail = ""
|
|
if log_path.exists():
|
|
lines = log_path.read_text().splitlines()
|
|
log_tail = "\n".join(lines[-20:])
|
|
|
|
poll = proc.poll()
|
|
return {
|
|
"running": poll is None,
|
|
"pid": proc.pid,
|
|
"returncode": poll,
|
|
"log_tail": log_tail,
|
|
}
|
|
|
|
@app.get("/api/queue")
|
|
def get_review_queue(
|
|
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
|
) -> list[dict]:
|
|
return registry.get_review_queue(limit=limit)
|
|
|
|
@app.post("/api/queue/{sprite_id:path}/approve")
|
|
def approve_from_queue(sprite_id: str, body: ApproveRequest) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.approve_variant(body.variant_id)
|
|
return {"status": "approved", "variant_id": body.variant_id}
|
|
|
|
@app.get("/api/variants/recent")
|
|
def get_recent_variants(
|
|
limit: Annotated[int, Query(ge=1, le=5000)] = 30,
|
|
) -> list[dict]:
|
|
return registry.get_recent_variants(limit=limit)
|
|
|
|
@app.get("/api/variants/review")
|
|
def get_review_variants(
|
|
limit: Annotated[int, Query(ge=1, le=5000)] = 500,
|
|
) -> list[dict]:
|
|
return registry.get_review_variants(limit=limit)
|
|
|
|
@app.get("/api/terrain-grid")
|
|
def terrain_grid(
|
|
elevation: Annotated[str, Query(pattern=r"^(lowland|highland|alpine)$")] = "lowland",
|
|
) -> dict:
|
|
"""Biome grid cells for one elevation tier, with best variant per cell."""
|
|
cells = registry.get_terrain_grid(elevation)
|
|
return {"cells": cells}
|
|
|
|
@app.get("/api/theater")
|
|
def theater_variants(
|
|
mode: Annotated[str, Query(pattern=r"^(all|review)$")] = "all",
|
|
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
|
offset: Annotated[int, Query(ge=0)] = 0,
|
|
) -> dict:
|
|
"""Paginated, server-filtered variants for the Sprite Theater UI."""
|
|
items, total = registry.query_variants(mode=mode, limit=limit, offset=offset)
|
|
return {"items": items, "total": total}
|
|
|
|
@app.get("/api/variants/browse")
|
|
def browse_variants(
|
|
stage: Annotated[str, Query(pattern=r"^(all|completed|review|processed|approved|installed|scored_[a-z0-9_]+)$")] = "completed",
|
|
limit: Annotated[int, Query(ge=1, le=500)] = 120,
|
|
offset: Annotated[int, Query(ge=0)] = 0,
|
|
) -> dict:
|
|
"""Paginated variant browse by funnel stage — used by the dashboard funnel click-throughs."""
|
|
items, total = registry.query_variants(mode=stage, limit=limit, offset=offset)
|
|
return {"items": items, "total": total}
|
|
|
|
@app.get("/api/stream/variants")
|
|
async def stream_variants() -> StreamingResponse:
|
|
async def event_generator():
|
|
# Own connection to avoid SQLite threading collisions with main registry
|
|
sse_registry = SpriteRegistry(DB_PATH)
|
|
current = sse_registry.get_recent_variants(limit=1)
|
|
last_id = current[0]["variant_id"] if current else 0
|
|
last_rating_snapshot = ""
|
|
poll_count = 0
|
|
while True:
|
|
await asyncio.sleep(3)
|
|
new_variants = sse_registry.get_recent_variants(limit=10, since_id=last_id)
|
|
if new_variants:
|
|
last_id = max(v["variant_id"] for v in new_variants)
|
|
yield f"data: {json.dumps(new_variants)}\n\n"
|
|
continue
|
|
|
|
# Check for rating updates every ~15s (5 cycles)
|
|
poll_count += 1
|
|
if poll_count % 5 == 0:
|
|
updated = sse_registry.get_recent_variants(limit=30)
|
|
snapshot = "|".join(
|
|
f"{v['variant_id']}:{v['rating']}" for v in updated
|
|
)
|
|
if snapshot != last_rating_snapshot:
|
|
last_rating_snapshot = snapshot
|
|
yield f"data: {json.dumps(updated)}\n\n"
|
|
continue
|
|
|
|
yield ": keepalive\n\n"
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
@app.get("/api/runs")
|
|
def get_runs() -> list[dict]:
|
|
return registry.get_runs()
|
|
|
|
# ── Analytics + Seed pool ─────────────────────────────────────────────
|
|
|
|
@app.get("/api/analytics/quality")
|
|
def quality_analytics(
|
|
category: Annotated[str | None, Query()] = None,
|
|
scorer: Annotated[str | None, Query()] = None,
|
|
) -> dict:
|
|
"""Quality dimension averages and gate failure rates across scored variants."""
|
|
return registry.get_quality_analytics(category=category, scorer=scorer)
|
|
|
|
@app.get("/api/seeds")
|
|
def list_seeds(
|
|
category: Annotated[str | None, Query()] = None,
|
|
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
|
) -> list[dict]:
|
|
"""Top seeds from the pool for a category."""
|
|
return registry.get_seed_pool_report(category=category, limit=limit)
|
|
|
|
@app.post("/api/seeds/pin")
|
|
def pin_seed(body: PinSeedRequest) -> dict:
|
|
"""Manually pin a seed to the pool with max quality (user override)."""
|
|
registry.pin_seed(
|
|
seed=body.seed,
|
|
category=body.category,
|
|
entity_id=body.entity_id,
|
|
variant_id=body.variant_id,
|
|
)
|
|
return {"status": "pinned", "seed": body.seed, "category": body.category}
|
|
|
|
# ── Config ────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/api/config")
|
|
def get_config() -> dict:
|
|
from engine.factory import BACKENDS, backend_summary
|
|
config = json.loads((TOOL_DIR / "sprite-config.json").read_text())
|
|
return {
|
|
"backends": list(BACKENDS),
|
|
"default_backend": config.get("backend", "model-boss"),
|
|
"backend_summary": backend_summary(config),
|
|
"model": config.get("model"),
|
|
}
|
|
|
|
# ── Generation worker (submit → collect → score) ─────────────────────
|
|
|
|
@app.post("/api/worker/start")
|
|
def start_worker(body: WorkerStartRequest) -> dict:
|
|
return _get_worker().start(
|
|
backend=body.backend,
|
|
variants=body.variants,
|
|
category=body.category,
|
|
starter_only=body.starter_only,
|
|
max_attempts=body.max_attempts,
|
|
batch_size=body.batch_size,
|
|
)
|
|
|
|
@app.post("/api/worker/stop")
|
|
def stop_worker() -> dict:
|
|
w = _gen_worker["instance"]
|
|
if w is None:
|
|
return {"status": "not_running"}
|
|
return w.stop()
|
|
|
|
def _read_log_tail(path: Path, lines: int = 30) -> str:
|
|
if not path.exists():
|
|
return ""
|
|
all_lines = path.read_text(encoding="utf-8", errors="replace").splitlines()
|
|
return "\n".join(all_lines[-lines:])
|
|
|
|
def _generation_worker_status(log_lines: int = 30) -> dict:
|
|
w = _gen_worker["instance"]
|
|
log_path = TOOL_DIR / "worker.log"
|
|
if w is None:
|
|
queued = registry.conn.execute(
|
|
"SELECT COUNT(*) FROM variants WHERE job_status = 'submitted'"
|
|
).fetchone()[0]
|
|
total_row = registry.get_stats().get("total", {})
|
|
return {
|
|
"running": False,
|
|
"queued_variants": queued,
|
|
"needed": total_row.get("needed", 0),
|
|
"review": total_row.get("review", 0),
|
|
"approved": total_row.get("approved", 0),
|
|
"installed": total_row.get("installed", 0),
|
|
"log_tail": _read_log_tail(log_path, log_lines),
|
|
"log_path": str(log_path.relative_to(TOOL_DIR)),
|
|
}
|
|
st = w.status()
|
|
st["log_tail"] = _read_log_tail(log_path, log_lines)
|
|
st["log_path"] = str(log_path.relative_to(TOOL_DIR))
|
|
return st
|
|
|
|
def _scoring_pipeline_status(log_lines: int = 30) -> dict:
|
|
proc = _pipeline_proc["proc"]
|
|
log_path = TOOL_DIR / "advance.log"
|
|
if proc is None:
|
|
return {
|
|
"running": False,
|
|
"pid": None,
|
|
"returncode": None,
|
|
"log_tail": _read_log_tail(log_path, log_lines),
|
|
"log_path": str(log_path.relative_to(TOOL_DIR)),
|
|
}
|
|
poll = proc.poll()
|
|
return {
|
|
"running": poll is None,
|
|
"pid": proc.pid,
|
|
"returncode": poll,
|
|
"log_tail": _read_log_tail(log_path, log_lines),
|
|
"log_path": str(log_path.relative_to(TOOL_DIR)),
|
|
}
|
|
|
|
@app.get("/api/worker/status")
|
|
def worker_status(
|
|
log_lines: Annotated[int, Query(ge=10, le=500)] = 30,
|
|
) -> dict:
|
|
return _generation_worker_status(log_lines)
|
|
|
|
@app.get("/api/workers")
|
|
def workers_overview(
|
|
log_lines: Annotated[int, Query(ge=10, le=500)] = 80,
|
|
) -> dict:
|
|
"""Unified status for generation worker + scoring pipeline."""
|
|
from engine.factory import BACKENDS, backend_summary
|
|
config = json.loads((TOOL_DIR / "sprite-config.json").read_text())
|
|
total_row = registry.get_stats().get("total", {})
|
|
in_flight = registry.conn.execute(
|
|
"SELECT COUNT(DISTINCT sprite_id) FROM variants WHERE job_status = 'submitted'"
|
|
).fetchone()[0]
|
|
return {
|
|
"generation": _generation_worker_status(log_lines),
|
|
"scoring": _scoring_pipeline_status(log_lines),
|
|
"sprites": {
|
|
"needed": total_row.get("needed", 0),
|
|
"generating": total_row.get("generating", 0),
|
|
"review": total_row.get("review", 0),
|
|
"approved": total_row.get("approved", 0),
|
|
"installed": total_row.get("installed", 0),
|
|
},
|
|
"in_flight_sprites": in_flight,
|
|
"config": {
|
|
"backends": list(BACKENDS),
|
|
"default_backend": config.get("backend", "model-boss"),
|
|
"backend_summary": backend_summary(config),
|
|
},
|
|
}
|
|
|
|
@app.post("/api/collect")
|
|
async def collect_pending(body: CollectRequest = CollectRequest()) -> dict:
|
|
"""One-shot: collect model-boss results and score affected sprites."""
|
|
from engine.factory import create_generator, with_backend
|
|
from engine.ranker import SpriteRanker
|
|
|
|
config = json.loads((TOOL_DIR / "sprite-config.json").read_text())
|
|
backend = body.backend
|
|
if backend:
|
|
config = with_backend(config, backend)
|
|
if config["backend"] != "model-boss":
|
|
return {"collected": 0, "message": f"{config['backend']} completes at generate time"}
|
|
|
|
gen = create_generator(config=config, registry=registry, raw_dir=raw_dir)
|
|
ranker = SpriteRanker(registry=registry, raw_dir=raw_dir)
|
|
affected: set[str] = set()
|
|
|
|
async def _on_complete(variant_id: int, sprite_id: str) -> None:
|
|
affected.add(sprite_id)
|
|
await ranker.advance_sprite(sprite_id)
|
|
|
|
collected = await gen.collect_pending(on_complete=_on_complete)
|
|
for sid in affected:
|
|
await ranker.rank_and_filter(sid)
|
|
return {"collected": collected, "sprites_scored": len(affected)}
|
|
|
|
# ── Install / ship ────────────────────────────────────────────────────
|
|
|
|
@app.post("/api/install")
|
|
def install_approved(body: InstallRequest = InstallRequest()) -> dict:
|
|
from engine.installer import SpriteInstaller
|
|
installer = SpriteInstaller(assets_dir=ASSETS_DIR, registry=registry)
|
|
count = installer.install_approved(category=body.category, dry_run=body.dry_run)
|
|
return {"installed": count, "dry_run": body.dry_run}
|
|
|
|
@app.post("/api/variants/{variant_id:int}/ship")
|
|
def ship_variant(variant_id: int, body: ShipVariantRequest = ShipVariantRequest()) -> dict:
|
|
"""Approve → rembg → install → manifest (cli.py approve)."""
|
|
from engine.pipeline import SpritePipeline
|
|
variant = registry.get_variant(variant_id)
|
|
if not variant:
|
|
raise HTTPException(404, f"Variant not found: {variant_id}")
|
|
pipeline = SpritePipeline(
|
|
registry=registry,
|
|
raw_dir=raw_dir,
|
|
variants_dir=variants_dir,
|
|
assets_dir=ASSETS_DIR,
|
|
game_db_path=LOCAL_DATA / "sprites.db",
|
|
)
|
|
result = pipeline.approve_and_install(variant_id, alt_name=body.alt_name)
|
|
if not result:
|
|
raise HTTPException(500, "Install pipeline failed")
|
|
return {"status": "shipped", "path": str(result)}
|
|
|
|
# ── Starter set ───────────────────────────────────────────────────────
|
|
|
|
@app.get("/api/starter/status")
|
|
def starter_status() -> dict:
|
|
from engine.starter import starter_sprite_ids
|
|
ids = starter_sprite_ids()
|
|
placeholders = ",".join("?" * len(ids))
|
|
rows = registry.conn.execute(
|
|
f"""
|
|
SELECT s.id, s.status, s.category,
|
|
COUNT(v.id) AS variants,
|
|
SUM(CASE WHEN v.job_status='completed' THEN 1 ELSE 0 END) AS completed,
|
|
SUM(CASE WHEN COALESCE(v.rating, -1) > 0 AND v.rating != -1 THEN 1 ELSE 0 END) AS good
|
|
FROM sprites s
|
|
LEFT JOIN variants v ON v.sprite_id = s.id
|
|
WHERE s.id IN ({placeholders})
|
|
GROUP BY s.id, s.status, s.category
|
|
ORDER BY s.id
|
|
""",
|
|
ids,
|
|
).fetchall()
|
|
by_status: dict[str, int] = {}
|
|
sprites = []
|
|
for r in rows:
|
|
by_status[r["status"]] = by_status.get(r["status"], 0) + 1
|
|
sprites.append(dict(r))
|
|
return {"total": len(ids), "by_status": by_status, "sprites": sprites}
|
|
|
|
@app.post("/api/starter/register")
|
|
def starter_register(body: StarterRegisterRequest = StarterRegisterRequest()) -> dict:
|
|
from engine.starter import register_starter_set
|
|
report = register_starter_set(registry, reset_status=body.reset_status)
|
|
return {
|
|
"registered": len(report.registered),
|
|
"reset": report.reset,
|
|
"missing_data": report.missing_data,
|
|
}
|
|
|
|
# ── Generation trigger ────────────────────────────────────────────────
|
|
|
|
def _sprites_for_generate(body: GenerateRequest) -> tuple[list[dict], str | None]:
|
|
"""Resolve sprite list for generate. Returns (sprites, error_message)."""
|
|
if body.sprite_id:
|
|
sprite = registry.get_sprite(body.sprite_id)
|
|
if not sprite:
|
|
return [], f"Sprite not found: {body.sprite_id}"
|
|
if sprite["status"] in ("skip", "rejected"):
|
|
return [], f"Cannot generate for status '{sprite['status']}'"
|
|
return [sprite], None
|
|
|
|
statuses = ("needed", "review") if body.include_review else ("needed",)
|
|
sprites: list[dict] = []
|
|
for status in statuses:
|
|
sprites.extend(registry.get_sprites(
|
|
category=body.category,
|
|
status=status,
|
|
limit=body.max_sprites or 10000,
|
|
))
|
|
if body.include_review:
|
|
# Only sprites still missing gate-passing variants
|
|
dashboard = registry.get_pipeline_dashboard()
|
|
deficit_ids = {
|
|
r["sprite_id"] for r in dashboard["sprite_coverage"] if r["deficit"] > 0
|
|
}
|
|
sprites = [s for s in sprites if s["id"] in deficit_ids]
|
|
if body.max_sprites:
|
|
sprites = sprites[: body.max_sprites]
|
|
if not sprites:
|
|
return [], "No eligible sprites to generate"
|
|
return sprites, None
|
|
|
|
@app.post("/api/generate")
|
|
async def trigger_generate(body: GenerateRequest) -> dict:
|
|
from engine.factory import create_generator, with_backend
|
|
from engine.ranker import SpriteRanker
|
|
|
|
config = json.loads((TOOL_DIR / "sprite-config.json").read_text())
|
|
if body.backend:
|
|
config = with_backend(config, body.backend)
|
|
gen = create_generator(config=config, registry=registry, raw_dir=raw_dir)
|
|
|
|
sprites, err = _sprites_for_generate(body)
|
|
if not sprites:
|
|
return {"submitted": 0, "sprites": 0, "message": err}
|
|
|
|
sprite_ids = [s["id"] for s in sprites]
|
|
ranker = SpriteRanker(registry=registry, raw_dir=raw_dir)
|
|
|
|
async def _on_complete(variant_id: int, sprite_id: str) -> None:
|
|
await ranker.advance_sprite(sprite_id)
|
|
|
|
on_complete = _on_complete if config["backend"] == "model-boss" else None
|
|
submitted = await gen.submit_batch(
|
|
sprite_ids=sprite_ids,
|
|
variants_per=body.variants,
|
|
priority=body.priority,
|
|
on_complete=on_complete,
|
|
)
|
|
|
|
if config["backend"] == "grok" and submitted:
|
|
for sid in sprite_ids:
|
|
await ranker.rank_and_filter(sid)
|
|
|
|
return {
|
|
"submitted": submitted,
|
|
"sprites": len(sprite_ids),
|
|
"backend": config["backend"],
|
|
"message": None if submitted else "Submission failed — check server logs",
|
|
}
|
|
|
|
# ── Image serving ─────────────────────────────────────────────────────
|
|
|
|
@app.get("/images/raw/{file_path:path}")
|
|
def serve_raw(file_path: str) -> FileResponse:
|
|
full = raw_dir / file_path
|
|
if not full.exists():
|
|
raise HTTPException(404, f"Image not found: {file_path}")
|
|
return FileResponse(full, media_type="image/png")
|
|
|
|
@app.get("/images/variants/{file_path:path}")
|
|
def serve_variant(file_path: str) -> FileResponse:
|
|
full = variants_dir / file_path
|
|
if not full.exists():
|
|
raise HTTPException(404, f"Image not found: {file_path}")
|
|
return FileResponse(full, media_type="image/png")
|
|
|
|
# ── Static GUI (SPA fallback) ────────────────────────────────────────
|
|
|
|
if GUI_DIST.exists():
|
|
index_html = GUI_DIST / "index.html"
|
|
|
|
app.mount("/assets", StaticFiles(directory=str(GUI_DIST / "assets")), name="assets")
|
|
|
|
@app.get("/{full_path:path}")
|
|
def spa_fallback(full_path: str) -> FileResponse:
|
|
if full_path.startswith("api/") or full_path.startswith("images/"):
|
|
raise HTTPException(404)
|
|
candidate = GUI_DIST / full_path
|
|
if full_path and candidate.is_file():
|
|
return FileResponse(candidate)
|
|
return FileResponse(index_html)
|
|
|
|
return app
|
|
|
|
|
|
# Direct run support
|
|
app = create_app()
|