"""FastAPI server for the sprite generation review GUI. Serves the REST API for browsing sprites, reviewing variants, and approving/rejecting. Also serves raw/variant images and the static GUI build. Usage: uvicorn server:app --port 5801 --reload # or via CLI: python3 cli.py review --port 5801 """ from __future__ import annotations import asyncio import json import subprocess import sys from pathlib import Path from typing import Annotated from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from engine.registry import SpriteRegistry TOOL_DIR = Path(__file__).resolve().parent DB_PATH = TOOL_DIR / "spritegen.db" RAW_DIR = TOOL_DIR / "raw" VARIANTS_DIR = TOOL_DIR / "variants" GUI_DIST = TOOL_DIR / "gui" / "dist" class ApproveRequest(BaseModel): variant_id: int dimension_id: int | None = None class RejectRequest(BaseModel): dimension_id: int | None = None class RejectVariantRequest(BaseModel): reason: str | None = None class RegenerateRequest(BaseModel): prompt: str | None = None dimension_id: int | None = None variants: int = 8 class PromptUpdate(BaseModel): prompt: str class PipelineRunRequest(BaseModel): scorers: list[str] | None = None skip_process: bool = False score_only: bool = False rescore: bool = False class PinSeedRequest(BaseModel): seed: int category: str entity_id: str = "" variant_id: int | None = None class GenerateRequest(BaseModel): category: str | None = None sprite_id: str | None = None variants: int = 8 priority: str = "normal" max_sprites: int | None = None def create_app( registry: SpriteRegistry | None = None, raw_dir: Path = RAW_DIR, variants_dir: Path = VARIANTS_DIR, ) -> FastAPI: if registry is None: registry = SpriteRegistry(DB_PATH) app = FastAPI(title="Magic Civilization Sprite Generator", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── Sprites ─────────────────────────────────────────────────────────── @app.get("/api/sprites") def list_sprites( category: Annotated[str | None, Query()] = None, status: Annotated[str | None, Query()] = None, search: Annotated[str | None, Query()] = None, limit: Annotated[int, Query(ge=1, le=10000)] = 200, offset: Annotated[int, Query(ge=0)] = 0, ) -> list[dict]: return registry.get_sprites( category=category, status=status, search=search, limit=limit, offset=offset, ) @app.get("/api/sprites/{sprite_id:path}/variants") def get_variants( sprite_id: str, dimension_id: Annotated[int | None, Query()] = None, ) -> list[dict]: sprite = registry.get_sprite(sprite_id) if not sprite: raise HTTPException(404, f"Sprite not found: {sprite_id}") return registry.get_variants(sprite_id, dimension_id=dimension_id) @app.get("/api/sprites/{sprite_id:path}") def get_sprite(sprite_id: str) -> dict: sprite = registry.get_sprite(sprite_id) if not sprite: raise HTTPException(404, f"Sprite not found: {sprite_id}") return sprite @app.post("/api/sprites/{sprite_id:path}/approve") def approve_sprite(sprite_id: str, body: ApproveRequest) -> dict: sprite = registry.get_sprite(sprite_id) if not sprite: raise HTTPException(404, f"Sprite not found: {sprite_id}") registry.approve_variant(body.variant_id) return {"status": "approved", "variant_id": body.variant_id} @app.post("/api/sprites/{sprite_id:path}/reject") def reject_sprite(sprite_id: str, body: RejectRequest) -> dict: sprite = registry.get_sprite(sprite_id) if not sprite: raise HTTPException(404, f"Sprite not found: {sprite_id}") registry.reject_sprite(sprite_id, dimension_id=body.dimension_id) return {"status": "rejected"} @app.post("/api/sprites/{sprite_id:path}/skip") def skip_sprite(sprite_id: str) -> dict: sprite = registry.get_sprite(sprite_id) if not sprite: raise HTTPException(404, f"Sprite not found: {sprite_id}") registry.update_sprite_status(sprite_id, "skip") return {"status": "skip"} @app.get("/api/variants/{variant_id:int}") def get_variant(variant_id: int) -> dict: variant = registry.get_variant(variant_id) if not variant: raise HTTPException(404, f"Variant not found: {variant_id}") sprite = registry.get_sprite(variant["sprite_id"]) if sprite: variant["category"] = sprite["category"] variant["entity_id"] = sprite["entity_id"] return variant @app.get("/api/variants/{variant_id:int}/scores") def get_variant_scores(variant_id: int) -> list[dict]: """Get all per-scorer evaluations for a variant.""" return registry.get_scores(variant_id) @app.post("/api/variants/{variant_id:int}/reject") def reject_variant(variant_id: int, body: RejectVariantRequest = RejectVariantRequest()) -> dict: registry.reject_variant(variant_id, body.reason) return {"status": "rejected", "variant_id": variant_id} @app.post("/api/sprites/{sprite_id:path}/regenerate") def regenerate_sprite(sprite_id: str, body: RegenerateRequest) -> dict: sprite = registry.get_sprite(sprite_id) if not sprite: raise HTTPException(404, f"Sprite not found: {sprite_id}") if body.prompt: registry.conn.execute( "UPDATE sprites SET prompt = ?, updated_at = datetime('now') WHERE id = ?", (body.prompt, sprite_id), ) registry.conn.commit() registry.reject_sprite(sprite_id, dimension_id=body.dimension_id) return {"status": "needed", "message": "Ready for regeneration"} @app.put("/api/sprites/{sprite_id:path}/prompt") def update_prompt(sprite_id: str, body: PromptUpdate) -> dict: sprite = registry.get_sprite(sprite_id) if not sprite: raise HTTPException(404, f"Sprite not found: {sprite_id}") registry.conn.execute( "UPDATE sprites SET prompt = ?, updated_at = datetime('now') WHERE id = ?", (body.prompt, sprite_id), ) registry.conn.commit() return {"status": "updated", "prompt": body.prompt} # ── Progress & Queue ───────────────────────────────────────────────── @app.get("/api/stats") def get_stats() -> dict: return registry.get_stats() @app.get("/api/progress") def get_progress() -> dict: return registry.get_progress() @app.get("/api/pipeline") def get_pipeline() -> dict: return registry.get_pipeline_dashboard() # Track the advance.py subprocess _pipeline_proc: dict[str, subprocess.Popen | None] = {"proc": None} @app.post("/api/pipeline/run") def run_pipeline(body: PipelineRunRequest) -> dict: """Spawn tools/advance.py as a background subprocess.""" proc = _pipeline_proc["proc"] if proc is not None and proc.poll() is None: return {"status": "already_running", "pid": proc.pid} cmd = [sys.executable, "-u", str(TOOL_DIR / "tools" / "advance.py")] if body.scorers: cmd += ["--scorers"] + body.scorers if body.rescore: cmd.append("--rescore") if body.skip_process: cmd.append("--skip-process") if body.score_only: cmd.append("--score-only") log_path = TOOL_DIR / "advance.log" log_file = open(log_path, "w") proc = subprocess.Popen( cmd, stdout=log_file, stderr=subprocess.STDOUT, cwd=str(TOOL_DIR), ) _pipeline_proc["proc"] = proc return {"status": "started", "pid": proc.pid} @app.get("/api/pipeline/status") def pipeline_run_status() -> dict: """Check if the advance.py process is still running.""" proc = _pipeline_proc["proc"] if proc is None: return {"running": False, "pid": None, "returncode": None, "log_tail": ""} log_path = TOOL_DIR / "advance.log" log_tail = "" if log_path.exists(): lines = log_path.read_text().splitlines() log_tail = "\n".join(lines[-20:]) poll = proc.poll() return { "running": poll is None, "pid": proc.pid, "returncode": poll, "log_tail": log_tail, } @app.get("/api/queue") def get_review_queue( limit: Annotated[int, Query(ge=1, le=200)] = 50, ) -> list[dict]: return registry.get_review_queue(limit=limit) @app.post("/api/queue/{sprite_id:path}/approve") def approve_from_queue(sprite_id: str, body: ApproveRequest) -> dict: sprite = registry.get_sprite(sprite_id) if not sprite: raise HTTPException(404, f"Sprite not found: {sprite_id}") registry.approve_variant(body.variant_id) return {"status": "approved", "variant_id": body.variant_id} @app.get("/api/variants/recent") def get_recent_variants( limit: Annotated[int, Query(ge=1, le=5000)] = 30, ) -> list[dict]: return registry.get_recent_variants(limit=limit) @app.get("/api/variants/review") def get_review_variants( limit: Annotated[int, Query(ge=1, le=5000)] = 500, ) -> list[dict]: return registry.get_review_variants(limit=limit) @app.get("/api/terrain-grid") def terrain_grid( elevation: Annotated[str, Query(pattern=r"^(lowland|highland|alpine)$")] = "lowland", ) -> dict: """Biome grid cells for one elevation tier, with best variant per cell.""" cells = registry.get_terrain_grid(elevation) return {"cells": cells} @app.get("/api/theater") def theater_variants( mode: Annotated[str, Query(pattern=r"^(all|review)$")] = "all", limit: Annotated[int, Query(ge=1, le=200)] = 50, offset: Annotated[int, Query(ge=0)] = 0, ) -> dict: """Paginated, server-filtered variants for the Sprite Theater UI.""" items, total = registry.query_variants(mode=mode, limit=limit, offset=offset) return {"items": items, "total": total} @app.get("/api/variants/browse") def browse_variants( stage: Annotated[str, Query(pattern=r"^(all|completed|review|processed|approved|installed|scored_[a-z0-9_]+)$")] = "completed", limit: Annotated[int, Query(ge=1, le=500)] = 120, offset: Annotated[int, Query(ge=0)] = 0, ) -> dict: """Paginated variant browse by funnel stage — used by the dashboard funnel click-throughs.""" items, total = registry.query_variants(mode=stage, limit=limit, offset=offset) return {"items": items, "total": total} @app.get("/api/stream/variants") async def stream_variants() -> StreamingResponse: async def event_generator(): # Own connection to avoid SQLite threading collisions with main registry sse_registry = SpriteRegistry(DB_PATH) current = sse_registry.get_recent_variants(limit=1) last_id = current[0]["variant_id"] if current else 0 last_rating_snapshot = "" poll_count = 0 while True: await asyncio.sleep(3) new_variants = sse_registry.get_recent_variants(limit=10, since_id=last_id) if new_variants: last_id = max(v["variant_id"] for v in new_variants) yield f"data: {json.dumps(new_variants)}\n\n" continue # Check for rating updates every ~15s (5 cycles) poll_count += 1 if poll_count % 5 == 0: updated = sse_registry.get_recent_variants(limit=30) snapshot = "|".join( f"{v['variant_id']}:{v['rating']}" for v in updated ) if snapshot != last_rating_snapshot: last_rating_snapshot = snapshot yield f"data: {json.dumps(updated)}\n\n" continue yield ": keepalive\n\n" return StreamingResponse( event_generator(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @app.get("/api/runs") def get_runs() -> list[dict]: return registry.get_runs() # ── Analytics + Seed pool ───────────────────────────────────────────── @app.get("/api/analytics/quality") def quality_analytics( category: Annotated[str | None, Query()] = None, scorer: Annotated[str | None, Query()] = None, ) -> dict: """Quality dimension averages and gate failure rates across scored variants.""" return registry.get_quality_analytics(category=category, scorer=scorer) @app.get("/api/seeds") def list_seeds( category: Annotated[str | None, Query()] = None, limit: Annotated[int, Query(ge=1, le=200)] = 50, ) -> list[dict]: """Top seeds from the pool for a category.""" return registry.get_seed_pool_report(category=category, limit=limit) @app.post("/api/seeds/pin") def pin_seed(body: PinSeedRequest) -> dict: """Manually pin a seed to the pool with max quality (user override).""" registry.pin_seed( seed=body.seed, category=body.category, entity_id=body.entity_id, variant_id=body.variant_id, ) return {"status": "pinned", "seed": body.seed, "category": body.category} # ── Generation trigger ──────────────────────────────────────────────── @app.post("/api/generate") async def trigger_generate(body: GenerateRequest) -> dict: from engine.generator import SpriteGenerator config = json.loads((TOOL_DIR / "sprite-config.json").read_text()) gen = SpriteGenerator(config=config, registry=registry, raw_dir=raw_dir) sprites = registry.get_sprites( category=body.category, status="needed", limit=body.max_sprites or 10000, ) if body.sprite_id: sprites = [s for s in sprites if s["id"] == body.sprite_id] if not sprites: return {"submitted": 0, "message": "No sprites in 'needed' status"} sprite_ids = [s["id"] for s in sprites] submitted = await gen.submit_batch( sprite_ids=sprite_ids, variants_per=body.variants, priority=body.priority, ) return {"submitted": submitted, "sprites": len(sprite_ids)} # ── Image serving ───────────────────────────────────────────────────── @app.get("/images/raw/{file_path:path}") def serve_raw(file_path: str) -> FileResponse: full = raw_dir / file_path if not full.exists(): raise HTTPException(404, f"Image not found: {file_path}") return FileResponse(full, media_type="image/png") @app.get("/images/variants/{file_path:path}") def serve_variant(file_path: str) -> FileResponse: full = variants_dir / file_path if not full.exists(): raise HTTPException(404, f"Image not found: {file_path}") return FileResponse(full, media_type="image/png") # ── Static GUI (SPA fallback) ──────────────────────────────────────── if GUI_DIST.exists(): index_html = GUI_DIST / "index.html" app.mount("/assets", StaticFiles(directory=str(GUI_DIST / "assets")), name="assets") @app.get("/{full_path:path}") def spa_fallback(full_path: str) -> FileResponse: if full_path.startswith("api/") or full_path.startswith("images/"): raise HTTPException(404) candidate = GUI_DIST / full_path if full_path and candidate.is_file(): return FileResponse(candidate) return FileResponse(index_html) return app # Direct run support app = create_app()