454 lines
17 KiB
Python
454 lines
17 KiB
Python
"""FastAPI server for the sprite generation review GUI.
|
|
|
|
Serves the REST API for browsing sprites, reviewing variants, and approving/rejecting.
|
|
Also serves raw/variant images and the static GUI build.
|
|
|
|
Usage:
|
|
uvicorn server:app --port 5801 --reload
|
|
# or via CLI:
|
|
python3 cli.py review --port 5801
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Annotated
|
|
|
|
from fastapi import FastAPI, HTTPException, Query
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
|
|
from engine.registry import SpriteRegistry
|
|
|
|
TOOL_DIR = Path(__file__).resolve().parent
|
|
DB_PATH = TOOL_DIR / "spritegen.db"
|
|
RAW_DIR = TOOL_DIR / "raw"
|
|
VARIANTS_DIR = TOOL_DIR / "variants"
|
|
GUI_DIST = TOOL_DIR / "gui" / "dist"
|
|
|
|
|
|
class ApproveRequest(BaseModel):
|
|
variant_id: int
|
|
dimension_id: int | None = None
|
|
|
|
|
|
class RejectRequest(BaseModel):
|
|
dimension_id: int | None = None
|
|
|
|
|
|
class RejectVariantRequest(BaseModel):
|
|
reason: str | None = None
|
|
|
|
|
|
class RegenerateRequest(BaseModel):
|
|
prompt: str | None = None
|
|
dimension_id: int | None = None
|
|
variants: int = 8
|
|
|
|
|
|
class PromptUpdate(BaseModel):
|
|
prompt: str
|
|
|
|
|
|
class PipelineRunRequest(BaseModel):
|
|
scorers: list[str] | None = None
|
|
skip_process: bool = False
|
|
score_only: bool = False
|
|
rescore: bool = False
|
|
|
|
|
|
class PinSeedRequest(BaseModel):
|
|
seed: int
|
|
category: str
|
|
entity_id: str = ""
|
|
variant_id: int | None = None
|
|
|
|
|
|
class GenerateRequest(BaseModel):
|
|
category: str | None = None
|
|
sprite_id: str | None = None
|
|
variants: int = 8
|
|
priority: str = "normal"
|
|
max_sprites: int | None = None
|
|
|
|
|
|
def create_app(
|
|
registry: SpriteRegistry | None = None,
|
|
raw_dir: Path = RAW_DIR,
|
|
variants_dir: Path = VARIANTS_DIR,
|
|
) -> FastAPI:
|
|
if registry is None:
|
|
registry = SpriteRegistry(DB_PATH)
|
|
|
|
app = FastAPI(title="Magic Civilization Sprite Generator", version="1.0.0")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# ── Sprites ───────────────────────────────────────────────────────────
|
|
|
|
@app.get("/api/sprites")
|
|
def list_sprites(
|
|
category: Annotated[str | None, Query()] = None,
|
|
status: Annotated[str | None, Query()] = None,
|
|
search: Annotated[str | None, Query()] = None,
|
|
limit: Annotated[int, Query(ge=1, le=10000)] = 200,
|
|
offset: Annotated[int, Query(ge=0)] = 0,
|
|
) -> list[dict]:
|
|
return registry.get_sprites(
|
|
category=category, status=status, search=search,
|
|
limit=limit, offset=offset,
|
|
)
|
|
|
|
@app.get("/api/sprites/{sprite_id:path}/variants")
|
|
def get_variants(
|
|
sprite_id: str,
|
|
dimension_id: Annotated[int | None, Query()] = None,
|
|
) -> list[dict]:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
return registry.get_variants(sprite_id, dimension_id=dimension_id)
|
|
|
|
@app.get("/api/sprites/{sprite_id:path}")
|
|
def get_sprite(sprite_id: str) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
return sprite
|
|
|
|
@app.post("/api/sprites/{sprite_id:path}/approve")
|
|
def approve_sprite(sprite_id: str, body: ApproveRequest) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.approve_variant(body.variant_id)
|
|
return {"status": "approved", "variant_id": body.variant_id}
|
|
|
|
@app.post("/api/sprites/{sprite_id:path}/reject")
|
|
def reject_sprite(sprite_id: str, body: RejectRequest) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.reject_sprite(sprite_id, dimension_id=body.dimension_id)
|
|
return {"status": "rejected"}
|
|
|
|
@app.post("/api/sprites/{sprite_id:path}/skip")
|
|
def skip_sprite(sprite_id: str) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.update_sprite_status(sprite_id, "skip")
|
|
return {"status": "skip"}
|
|
|
|
@app.get("/api/variants/{variant_id:int}")
|
|
def get_variant(variant_id: int) -> dict:
|
|
variant = registry.get_variant(variant_id)
|
|
if not variant:
|
|
raise HTTPException(404, f"Variant not found: {variant_id}")
|
|
sprite = registry.get_sprite(variant["sprite_id"])
|
|
if sprite:
|
|
variant["category"] = sprite["category"]
|
|
variant["entity_id"] = sprite["entity_id"]
|
|
return variant
|
|
|
|
@app.get("/api/variants/{variant_id:int}/scores")
|
|
def get_variant_scores(variant_id: int) -> list[dict]:
|
|
"""Get all per-scorer evaluations for a variant."""
|
|
return registry.get_scores(variant_id)
|
|
|
|
@app.post("/api/variants/{variant_id:int}/reject")
|
|
def reject_variant(variant_id: int, body: RejectVariantRequest = RejectVariantRequest()) -> dict:
|
|
registry.reject_variant(variant_id, body.reason)
|
|
return {"status": "rejected", "variant_id": variant_id}
|
|
|
|
@app.post("/api/sprites/{sprite_id:path}/regenerate")
|
|
def regenerate_sprite(sprite_id: str, body: RegenerateRequest) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
if body.prompt:
|
|
registry.conn.execute(
|
|
"UPDATE sprites SET prompt = ?, updated_at = datetime('now') WHERE id = ?",
|
|
(body.prompt, sprite_id),
|
|
)
|
|
registry.conn.commit()
|
|
registry.reject_sprite(sprite_id, dimension_id=body.dimension_id)
|
|
return {"status": "needed", "message": "Ready for regeneration"}
|
|
|
|
@app.put("/api/sprites/{sprite_id:path}/prompt")
|
|
def update_prompt(sprite_id: str, body: PromptUpdate) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.conn.execute(
|
|
"UPDATE sprites SET prompt = ?, updated_at = datetime('now') WHERE id = ?",
|
|
(body.prompt, sprite_id),
|
|
)
|
|
registry.conn.commit()
|
|
return {"status": "updated", "prompt": body.prompt}
|
|
|
|
# ── Progress & Queue ─────────────────────────────────────────────────
|
|
|
|
@app.get("/api/stats")
|
|
def get_stats() -> dict:
|
|
return registry.get_stats()
|
|
|
|
@app.get("/api/progress")
|
|
def get_progress() -> dict:
|
|
return registry.get_progress()
|
|
|
|
@app.get("/api/pipeline")
|
|
def get_pipeline() -> dict:
|
|
return registry.get_pipeline_dashboard()
|
|
|
|
# Track the advance.py subprocess
|
|
_pipeline_proc: dict[str, subprocess.Popen | None] = {"proc": None}
|
|
|
|
@app.post("/api/pipeline/run")
|
|
def run_pipeline(body: PipelineRunRequest) -> dict:
|
|
"""Spawn tools/advance.py as a background subprocess."""
|
|
proc = _pipeline_proc["proc"]
|
|
if proc is not None and proc.poll() is None:
|
|
return {"status": "already_running", "pid": proc.pid}
|
|
|
|
cmd = [sys.executable, "-u", str(TOOL_DIR / "tools" / "advance.py")]
|
|
if body.scorers:
|
|
cmd += ["--scorers"] + body.scorers
|
|
if body.rescore:
|
|
cmd.append("--rescore")
|
|
if body.skip_process:
|
|
cmd.append("--skip-process")
|
|
if body.score_only:
|
|
cmd.append("--score-only")
|
|
|
|
log_path = TOOL_DIR / "advance.log"
|
|
log_file = open(log_path, "w")
|
|
proc = subprocess.Popen(
|
|
cmd,
|
|
stdout=log_file,
|
|
stderr=subprocess.STDOUT,
|
|
cwd=str(TOOL_DIR),
|
|
)
|
|
_pipeline_proc["proc"] = proc
|
|
return {"status": "started", "pid": proc.pid}
|
|
|
|
@app.get("/api/pipeline/status")
|
|
def pipeline_run_status() -> dict:
|
|
"""Check if the advance.py process is still running."""
|
|
proc = _pipeline_proc["proc"]
|
|
if proc is None:
|
|
return {"running": False, "pid": None, "returncode": None, "log_tail": ""}
|
|
|
|
log_path = TOOL_DIR / "advance.log"
|
|
log_tail = ""
|
|
if log_path.exists():
|
|
lines = log_path.read_text().splitlines()
|
|
log_tail = "\n".join(lines[-20:])
|
|
|
|
poll = proc.poll()
|
|
return {
|
|
"running": poll is None,
|
|
"pid": proc.pid,
|
|
"returncode": poll,
|
|
"log_tail": log_tail,
|
|
}
|
|
|
|
@app.get("/api/queue")
|
|
def get_review_queue(
|
|
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
|
) -> list[dict]:
|
|
return registry.get_review_queue(limit=limit)
|
|
|
|
@app.post("/api/queue/{sprite_id:path}/approve")
|
|
def approve_from_queue(sprite_id: str, body: ApproveRequest) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.approve_variant(body.variant_id)
|
|
return {"status": "approved", "variant_id": body.variant_id}
|
|
|
|
@app.get("/api/variants/recent")
|
|
def get_recent_variants(
|
|
limit: Annotated[int, Query(ge=1, le=5000)] = 30,
|
|
) -> list[dict]:
|
|
return registry.get_recent_variants(limit=limit)
|
|
|
|
@app.get("/api/variants/review")
|
|
def get_review_variants(
|
|
limit: Annotated[int, Query(ge=1, le=5000)] = 500,
|
|
) -> list[dict]:
|
|
return registry.get_review_variants(limit=limit)
|
|
|
|
@app.get("/api/terrain-grid")
|
|
def terrain_grid(
|
|
elevation: Annotated[str, Query(pattern=r"^(lowland|highland|alpine)$")] = "lowland",
|
|
) -> dict:
|
|
"""Biome grid cells for one elevation tier, with best variant per cell."""
|
|
cells = registry.get_terrain_grid(elevation)
|
|
return {"cells": cells}
|
|
|
|
@app.get("/api/theater")
|
|
def theater_variants(
|
|
mode: Annotated[str, Query(pattern=r"^(all|review)$")] = "all",
|
|
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
|
offset: Annotated[int, Query(ge=0)] = 0,
|
|
) -> dict:
|
|
"""Paginated, server-filtered variants for the Sprite Theater UI."""
|
|
items, total = registry.query_variants(mode=mode, limit=limit, offset=offset)
|
|
return {"items": items, "total": total}
|
|
|
|
@app.get("/api/variants/browse")
|
|
def browse_variants(
|
|
stage: Annotated[str, Query(pattern=r"^(all|completed|review|processed|approved|installed|scored_[a-z0-9_]+)$")] = "completed",
|
|
limit: Annotated[int, Query(ge=1, le=500)] = 120,
|
|
offset: Annotated[int, Query(ge=0)] = 0,
|
|
) -> dict:
|
|
"""Paginated variant browse by funnel stage — used by the dashboard funnel click-throughs."""
|
|
items, total = registry.query_variants(mode=stage, limit=limit, offset=offset)
|
|
return {"items": items, "total": total}
|
|
|
|
@app.get("/api/stream/variants")
|
|
async def stream_variants() -> StreamingResponse:
|
|
async def event_generator():
|
|
# Own connection to avoid SQLite threading collisions with main registry
|
|
sse_registry = SpriteRegistry(DB_PATH)
|
|
current = sse_registry.get_recent_variants(limit=1)
|
|
last_id = current[0]["variant_id"] if current else 0
|
|
last_rating_snapshot = ""
|
|
poll_count = 0
|
|
while True:
|
|
await asyncio.sleep(3)
|
|
new_variants = sse_registry.get_recent_variants(limit=10, since_id=last_id)
|
|
if new_variants:
|
|
last_id = max(v["variant_id"] for v in new_variants)
|
|
yield f"data: {json.dumps(new_variants)}\n\n"
|
|
continue
|
|
|
|
# Check for rating updates every ~15s (5 cycles)
|
|
poll_count += 1
|
|
if poll_count % 5 == 0:
|
|
updated = sse_registry.get_recent_variants(limit=30)
|
|
snapshot = "|".join(
|
|
f"{v['variant_id']}:{v['rating']}" for v in updated
|
|
)
|
|
if snapshot != last_rating_snapshot:
|
|
last_rating_snapshot = snapshot
|
|
yield f"data: {json.dumps(updated)}\n\n"
|
|
continue
|
|
|
|
yield ": keepalive\n\n"
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
@app.get("/api/runs")
|
|
def get_runs() -> list[dict]:
|
|
return registry.get_runs()
|
|
|
|
# ── Analytics + Seed pool ─────────────────────────────────────────────
|
|
|
|
@app.get("/api/analytics/quality")
|
|
def quality_analytics(
|
|
category: Annotated[str | None, Query()] = None,
|
|
scorer: Annotated[str | None, Query()] = None,
|
|
) -> dict:
|
|
"""Quality dimension averages and gate failure rates across scored variants."""
|
|
return registry.get_quality_analytics(category=category, scorer=scorer)
|
|
|
|
@app.get("/api/seeds")
|
|
def list_seeds(
|
|
category: Annotated[str | None, Query()] = None,
|
|
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
|
) -> list[dict]:
|
|
"""Top seeds from the pool for a category."""
|
|
return registry.get_seed_pool_report(category=category, limit=limit)
|
|
|
|
@app.post("/api/seeds/pin")
|
|
def pin_seed(body: PinSeedRequest) -> dict:
|
|
"""Manually pin a seed to the pool with max quality (user override)."""
|
|
registry.pin_seed(
|
|
seed=body.seed,
|
|
category=body.category,
|
|
entity_id=body.entity_id,
|
|
variant_id=body.variant_id,
|
|
)
|
|
return {"status": "pinned", "seed": body.seed, "category": body.category}
|
|
|
|
# ── Generation trigger ────────────────────────────────────────────────
|
|
|
|
@app.post("/api/generate")
|
|
async def trigger_generate(body: GenerateRequest) -> dict:
|
|
from engine.generator import SpriteGenerator
|
|
|
|
config = json.loads((TOOL_DIR / "sprite-config.json").read_text())
|
|
gen = SpriteGenerator(config=config, registry=registry, raw_dir=raw_dir)
|
|
|
|
sprites = registry.get_sprites(
|
|
category=body.category,
|
|
status="needed",
|
|
limit=body.max_sprites or 10000,
|
|
)
|
|
if body.sprite_id:
|
|
sprites = [s for s in sprites if s["id"] == body.sprite_id]
|
|
|
|
if not sprites:
|
|
return {"submitted": 0, "message": "No sprites in 'needed' status"}
|
|
|
|
sprite_ids = [s["id"] for s in sprites]
|
|
submitted = await gen.submit_batch(
|
|
sprite_ids=sprite_ids,
|
|
variants_per=body.variants,
|
|
priority=body.priority,
|
|
)
|
|
return {"submitted": submitted, "sprites": len(sprite_ids)}
|
|
|
|
# ── Image serving ─────────────────────────────────────────────────────
|
|
|
|
@app.get("/images/raw/{file_path:path}")
|
|
def serve_raw(file_path: str) -> FileResponse:
|
|
full = raw_dir / file_path
|
|
if not full.exists():
|
|
raise HTTPException(404, f"Image not found: {file_path}")
|
|
return FileResponse(full, media_type="image/png")
|
|
|
|
@app.get("/images/variants/{file_path:path}")
|
|
def serve_variant(file_path: str) -> FileResponse:
|
|
full = variants_dir / file_path
|
|
if not full.exists():
|
|
raise HTTPException(404, f"Image not found: {file_path}")
|
|
return FileResponse(full, media_type="image/png")
|
|
|
|
# ── Static GUI (SPA fallback) ────────────────────────────────────────
|
|
|
|
if GUI_DIST.exists():
|
|
index_html = GUI_DIST / "index.html"
|
|
|
|
app.mount("/assets", StaticFiles(directory=str(GUI_DIST / "assets")), name="assets")
|
|
|
|
@app.get("/{full_path:path}")
|
|
def spa_fallback(full_path: str) -> FileResponse:
|
|
if full_path.startswith("api/") or full_path.startswith("images/"):
|
|
raise HTTPException(404)
|
|
candidate = GUI_DIST / full_path
|
|
if full_path and candidate.is_file():
|
|
return FileResponse(candidate)
|
|
return FileResponse(index_html)
|
|
|
|
return app
|
|
|
|
|
|
# Direct run support
|
|
app = create_app()
|