diff --git a/tools/sprite-generation/engine/generator.py b/tools/sprite-generation/engine/generator.py index 44ca08b2..00d77d28 100644 --- a/tools/sprite-generation/engine/generator.py +++ b/tools/sprite-generation/engine/generator.py @@ -1,110 +1,45 @@ -"""Unified model-boss async job client for sprite generation. +"""Sprite generation via model-boss InferenceClient. -Submits diffusion jobs, polls for results, and downloads completed images. -Uses only stdlib (urllib, json, base64) -- no requests dependency. +Uses the same pattern as auto-commit-service: InferenceClient submits work +to model-boss coordinator via HTTP, waits for results via Redis pubsub. +No polling. model-boss queues internally and processes when GPU is available. """ - from __future__ import annotations +import asyncio import base64 import json +import logging import random -import time from pathlib import Path -from urllib.error import URLError -from urllib.request import Request, urlopen + +from model_boss import InferenceClient from engine.prompts import get_generation_size, get_negative, get_variant_modifier from engine.registry import SpriteRegistry +logger = logging.getLogger(__name__) + +# Max concurrent generation requests (model-boss queues them, but this +# controls how many we're awaiting pubsub results for simultaneously) +MAX_CONCURRENT = 8 + class SpriteGenerator: def __init__(self, config: dict, registry: SpriteRegistry, raw_dir: Path): - self.api_base = config["api_base"].rstrip("/") self.model = config["model"] self.defaults = config["defaults"] self.registry = registry self.raw_dir = raw_dir self.raw_dir.mkdir(parents=True, exist_ok=True) + self._client = InferenceClient( + client_id="sprite-generator", + default_priority="normal", + timeout=600.0, + ) + self._semaphore = asyncio.Semaphore(MAX_CONCURRENT) - # -- public API ------------------------------------------------------------ - - def generate_sprites( - self, - sprite_ids: list[str], - variants_per: int = 8, - priority: str = "normal", - dimension_ids: list[int] | None = None, - ) -> int: - """Submit generation jobs for a list of sprites. - - If dimension_ids is provided, generates variants for those specific - dimensions instead of the base sprite. Otherwise generates base variants. - - Returns total jobs submitted. - """ - total_submitted = 0 - run_id = self.registry.start_run(variants_per=variants_per) - - for sprite_id in sprite_ids: - sprite = self.registry.get_sprite(sprite_id) - if sprite is None: - print(f"[skip] Sprite not found: {sprite_id}") - continue - - prompt = sprite["prompt"] or "" - negative = sprite["negative_prompt"] or get_negative(sprite["category"]) - width = sprite["gen_width"] or 1024 - height = sprite["gen_height"] or 512 - - targets: list[tuple[int | None, str]] = [] - - if dimension_ids is not None: - for dim in sprite.get("dimensions_list", []): - if dim["id"] in dimension_ids: - dim_modifier = dim.get("prompt_modifier") or "" - targets.append((dim["id"], dim_modifier)) - else: - targets.append((None, "")) - - for dim_id, dim_modifier in targets: - for i in range(variants_per): - seed = random.randint(0, 2**32 - 1) - variant_modifier = get_variant_modifier(i) - combined_modifier = ", ".join( - p for p in [dim_modifier, variant_modifier] if p - ) - - variant_id = self.submit_one( - sprite_id=sprite_id, - prompt=prompt, - negative=negative, - width=width, - height=height, - seed=seed, - prompt_modifier=combined_modifier, - priority=priority, - dimension_id=dim_id, - ) - if variant_id is not None: - total_submitted += 1 - tag = f"{sprite_id} v{i}" - if dim_id is not None: - tag += f" dim={dim_id}" - print( - f"[{total_submitted}] Submitted {tag} (seed={seed})" - ) - - with self.registry.conn: - self.registry.conn.execute( - "UPDATE generation_runs SET total_jobs=? WHERE id=?", - (total_submitted, run_id), - ) - - print(f"\nSubmitted {total_submitted} jobs (run {run_id})") - return total_submitted - - def submit_one( + async def generate_one( self, sprite_id: str, prompt: str, @@ -116,185 +51,170 @@ class SpriteGenerator: priority: str = "normal", dimension_id: int | None = None, ) -> int | None: - """Queue a single generation job via model-boss async jobs API. + """Generate a single sprite image via model-boss. - Submits to the queue and returns immediately. Does NOT wait for - the image — that's handled by poll_pending(). - Returns variant_id or None on submission failure. + Submits to coordinator, waits for result via Redis pubsub. + Saves raw image to disk, creates variant in DB. + Returns variant_id or None on failure. """ full_prompt = ", ".join(p for p in [prompt, prompt_modifier] if p) - job_id = self._submit_job( - prompt=full_prompt, - negative=negative, - width=width, - height=height, - seed=seed, - priority=priority, - ) - if job_id is None: + async with self._semaphore: + try: + result = await self._client.generate_image( + model=self.model, + prompt=full_prompt, + negative_prompt=negative, + width=width, + height=height, + steps=self.defaults.get("steps", 25), + guidance_scale=self.defaults.get("guidance_scale", 7.5), + seed=seed, + priority=priority, + keep_alive=300, + ) + except Exception as exc: + logger.warning("Generation failed for %s: %s", sprite_id, exc) + print(f"[error] {sprite_id}: {exc}") + return None + + # Extract image data — format varies: + # model-boss InferenceClient: {"data": [{"b64_json": "..."}], "model": "..."} + # model-boss HTTP API: {"images": ["..."], "durationMs": N} + image_b64 = None + data_list = result.get("data", []) + if data_list and isinstance(data_list[0], dict): + image_b64 = data_list[0].get("b64_json") + if not image_b64: + images = result.get("images", []) + if images: + image_b64 = images[0] + if not image_b64: + print(f"[error] {sprite_id}: no image data in result (keys: {list(result.keys())})") return None + image_bytes = base64.b64decode(image_b64) if isinstance(image_b64, str) else image_b64 + duration = result.get("durationMs") or result.get("duration_ms") + + # Create variant in DB variant_id = self.registry.add_variant( sprite_id=sprite_id, seed=seed, dimension_id=dimension_id, prompt_modifier=prompt_modifier, - job_id=job_id, + job_id=None, ) + + # Save raw image + safe_name = sprite_id.replace("/", "_") + filename = f"{safe_name}_{variant_id}.png" + raw_path = self.raw_dir / filename + raw_path.write_bytes(image_bytes) + + # Update variant as completed + self.registry.update_variant_status( + variant_id, "completed", + raw_path=str(raw_path), + generation_ms=duration, + ) + return variant_id - def poll_pending(self, max_polls: int = 1000) -> dict: - """Poll all pending/running variants for completion. - - Single pass over pending variants. Downloads and saves completed images. - Returns {"completed": N, "failed": N, "pending": N}. - """ - pending_variants = self.registry.get_pending_variants() - if not pending_variants: - return {"completed": 0, "failed": 0, "pending": 0} - - completed = 0 - failed = 0 - still_pending = 0 - - count = min(len(pending_variants), max_polls) - for variant in pending_variants[:count]: - job_id = variant["job_id"] - if not job_id: - failed += 1 - self.registry.update_variant_status(variant["id"], "failed") - continue - - status_data = self._check_job(job_id) - if status_data is None: - still_pending += 1 - continue - - job_status = status_data.get("status", "pending") - - if job_status == "completed": - dl = self._download_result(job_id) - if dl is not None: - image_bytes, duration = dl - safe_name = variant["sprite_id"].replace("/", "_") - filename = f"{safe_name}_{variant['id']}.png" - raw_path = self.raw_dir / filename - raw_path.write_bytes(image_bytes) - - self.registry.update_variant_status( - variant["id"], - "completed", - raw_path=str(raw_path), - generation_ms=duration, - ) - completed += 1 - else: - self.registry.update_variant_status(variant["id"], "failed") - failed += 1 - - elif job_status == "failed": - self.registry.update_variant_status(variant["id"], "failed") - failed += 1 - - elif job_status == "running": - self.registry.update_variant_status(variant["id"], "running") - still_pending += 1 - - else: - still_pending += 1 - - report = {"completed": completed, "failed": failed, "pending": still_pending} - print( - f"Poll: {completed} completed, {failed} failed, {still_pending} pending" - ) - return report - - def poll_and_wait(self, interval: int = 10, timeout: int = 0) -> dict: - """Poll continuously until all pending jobs are done. - - Prints progress every `interval` seconds. - If timeout > 0, stop after that many seconds. - Returns final poll report. - """ - start = time.monotonic() - report = {"completed": 0, "failed": 0, "pending": 0} - - while True: - report = self.poll_pending() - if report["pending"] == 0: - print("All jobs finished.") - return report - - if timeout > 0 and (time.monotonic() - start) >= timeout: - print(f"Timeout after {timeout}s with {report['pending']} pending.") - return report - - elapsed = int(time.monotonic() - start) - print(f" Waiting {interval}s... (elapsed {elapsed}s)") - time.sleep(interval) - - # -- internal HTTP --------------------------------------------------------- - - def _submit_job( + async def generate_batch( self, + sprite_ids: list[str], + variants_per: int = 4, + priority: str = "normal", + ) -> int: + """Generate variants for a list of sprites concurrently. + + Each request goes through model-boss → Redis pubsub. + Concurrency controlled by semaphore (MAX_CONCURRENT awaiting at once). + model-boss handles its own internal queue for GPU scheduling. + + Returns total variants generated. + """ + run_id = self.registry.start_run(variants_per=variants_per) + tasks: list[asyncio.Task] = [] + total_planned = 0 + + for sprite_id in sprite_ids: + sprite = self.registry.get_sprite(sprite_id) + if not sprite: + continue + + prompt = sprite["prompt"] + negative = sprite["negative_prompt"] or get_negative(sprite["category"]) + gen_w, gen_h = get_generation_size(sprite["category"]) + width = sprite.get("gen_width") or gen_w + height = sprite.get("gen_height") or gen_h + + for i in range(variants_per): + seed = random.randint(0, 2**32 - 1) + modifier = get_variant_modifier(i) + total_planned += 1 + + task = asyncio.create_task( + self._generate_and_report( + sprite_id=sprite_id, + prompt=prompt, + negative=negative, + width=width, + height=height, + seed=seed, + prompt_modifier=modifier, + priority=priority, + index=total_planned, + ) + ) + tasks.append(task) + + print(f"Queued {total_planned} generation requests ({len(sprite_ids)} sprites × {variants_per} variants)") + + # Update run with total + with self.registry.conn: + self.registry.conn.execute( + "UPDATE generation_runs SET total_jobs=? WHERE id=?", + (total_planned, run_id), + ) + + # Await all — results stream in as model-boss completes them + results = await asyncio.gather(*tasks, return_exceptions=True) + + completed = sum(1 for r in results if isinstance(r, int) and r is not None) + failed = total_planned - completed + self.registry.update_run(run_id, completed_delta=completed, failed_delta=failed, finished=True) + + print(f"\nCompleted: {completed}/{total_planned} ({failed} failed)") + return completed + + async def _generate_and_report( + self, + sprite_id: str, prompt: str, negative: str, width: int, height: int, seed: int, + prompt_modifier: str, priority: str, - ) -> str | None: - """Submit a job to model-boss async queue. Returns job_id or None.""" - body = { - "model": self.model, - "prompt": prompt, - "negativePrompt": negative, - "width": width, - "height": height, - "steps": self.defaults.get("steps", 25), - "guidanceScale": self.defaults.get("guidance_scale", 7.5), - "seed": seed, - "xPriority": priority, - "xClientId": "sprite-generator", - } - data = json.dumps(body).encode() - req = Request( - f"{self.api_base}/api/v1/diffusion/jobs", - data=data, - headers={"Content-Type": "application/json"}, - method="POST", + index: int, + ) -> int | None: + """Generate one variant and print progress.""" + variant_id = await self.generate_one( + sprite_id=sprite_id, + prompt=prompt, + negative=negative, + width=width, + height=height, + seed=seed, + prompt_modifier=prompt_modifier, + priority=priority, ) - try: - with urlopen(req, timeout=30) as resp: - result = json.loads(resp.read()) - return result.get("jobId") - except (URLError, OSError, json.JSONDecodeError, KeyError) as exc: - print(f"[error] Submit failed: {exc}") - return None + if variant_id is not None: + print(f"[{index}] Done {sprite_id} (seed={seed}) → variant {variant_id}") + return variant_id - def _check_job(self, job_id: str) -> dict | None: - """Check job status. Returns status dict or None on error.""" - req = Request(f"{self.api_base}/api/v1/diffusion/jobs/{job_id}") - try: - with urlopen(req, timeout=15) as resp: - return json.loads(resp.read()) - except (URLError, OSError, json.JSONDecodeError) as exc: - print(f"[error] Status check failed for {job_id}: {exc}") - return None - - def _download_result(self, job_id: str) -> tuple[bytes, int | None] | None: - """Download completed job result. Returns (image_bytes, duration_ms) or None.""" - req = Request(f"{self.api_base}/api/v1/diffusion/jobs/{job_id}/result") - try: - with urlopen(req, timeout=60) as resp: - result = json.loads(resp.read()) - images = result.get("images", []) - if not images: - print(f"[error] No images in result for {job_id}") - return None - duration = result.get("durationMs") - return base64.b64decode(images[0]), duration - except (URLError, OSError, json.JSONDecodeError) as exc: - print(f"[error] Download failed for {job_id}: {exc}") - return None + async def close(self) -> None: + """Clean up client resources.""" + await self._client.close() diff --git a/tools/sprite-generation/engine/prompts.py b/tools/sprite-generation/engine/prompts.py index 16806979..fae27906 100644 --- a/tools/sprite-generation/engine/prompts.py +++ b/tools/sprite-generation/engine/prompts.py @@ -28,46 +28,52 @@ STYLE_PREFIXES: dict[str, str] = { "masterpiece, best quality, game asset" ), "units": ( - "isometric game character sprite viewed from above at 45-degree angle, " - "NOT front-facing, NOT side view, seen from elevated camera looking down, " - "single character standing on flat grey background, " - "full body visible head to toe, small figure centered in frame, " - "painted digital art style, strategy game unit sprite like Age of Wonders or Warcraft III, " - "one character only, clean edges, no scenery, no ground texture, " - "masterpiece, best quality, game sprite on neutral background" + "isometric game character sprite on transparent background, " + "camera looking down from above at 60 degree angle, " + "you see the TOP OF THE HEAD and shoulders from above, " + "character appears short and foreshortened from the high camera, " + "like an Age of Empires II unit or Diablo II character sprite, " + "single small figure, full body visible, painted digital art, " + "clean edges ready to layer over terrain, " + "NO front-facing portrait, NO eye contact with camera, " + "transparent PNG background, masterpiece, best quality" ), "buildings": ( - "isometric building viewed from above at 45-degree angle, " + "isometric building on transparent background, " + "camera looking down from above at 45-degree angle, " "you can see the ROOF and TWO WALLS of the building, " "NOT front-facing, NOT a facade, camera is elevated looking down, " - "single small building on flat ground, painted digital art, " - "strategy game building like Age of Empires or Civilization V, " - "one structure only, clean edges, no other buildings, no characters, " - "masterpiece, best quality, game building sprite" + "single small building, painted digital art, " + "like Age of Empires II or Civilization V building, " + "clean edges ready to layer over terrain, " + "transparent PNG background, masterpiece, best quality" ), "resources": ( - "single natural resource object centered on transparent background, " - "painted fantasy game icon style, like a Civilization V map resource marker, " - "ONE distinct recognizable object, NOT a texture, NOT a pattern, " - "clean simple icon composition with clear silhouette, " - "small game sprite overlay for a hex tile, " - "masterpiece, best quality, game resource icon" + "top-down overhead view of a natural resource deposit on the ground, " + "terrain overlay sprite for a hex-grid strategy game, " + "the resource is visible as a natural feature embedded in the earth, " + "like Civilization V resource icons visible on hex tiles when revealed, " + "painted fantasy game art, transparent background around the deposit, " + "small concentrated feature, NOT a full terrain texture, " + "masterpiece, best quality, game map overlay sprite" ), "improvements": ( - "single tile improvement viewed from above, isometric perspective, " - "small cultivated area or construction on natural ground, " - "painted fantasy game art, strategy game tile overlay, " - "like a Civilization V tile improvement, simple clean composition, " - "ONE feature centered, recognizable at small scale, " - "masterpiece, best quality, game tile improvement" + "single tile improvement on transparent background, " + "isometric view from above, small area of cultivated or developed land, " + "ONE simple feature like a small farm field or mine entrance, " + "like a Civilization V tile improvement overlaid on terrain, " + "painted fantasy game art, clean simple composition, " + "NOT a complex scene with multiple buildings, just ONE improvement, " + "transparent PNG background, masterpiece, best quality" ), "spells": ( - "magical spell effect, dramatic magical energy, " - "painterly fantasy illustration suitable for a spellbook icon, " - "vivid glowing magical energy as the central focus, " - "dark background, single spell effect, clean composition, " - "like a Magic: The Gathering card art or spell icon, " - "masterpiece, best quality, game icon" + "magical spell effect icon on dark background, " + "abstract magical energy as the subject NOT a character, " + "glowing energy, particles, runes, or elemental force, " + "like a spell icon from Diablo or World of Warcraft ability bar, " + "circular or square icon composition, centered, " + "NO person, NO character, NO angel, just pure magical energy, " + "vivid colors on black background, masterpiece, best quality" ), "edges": ( "seamless terrain edge transition, top-down bird's-eye view, " @@ -102,12 +108,14 @@ NEGATIVES: dict[str, str] = { ), "units": ( "text, watermark, blurry, low quality, anime, photo, " - "background, scenery, terrain, landscape, sky, horizon, ground, floor, " + "white background, solid background, background scenery, terrain, landscape, sky, horizon, " + "front-facing, looking at camera, portrait, headshot, " "multiple characters, crowd, group, army, " "hexagon, geometric, border, frame, UI, card frame" ), "buildings": ( "text, watermark, blurry, low quality, anime, photo, " + "white background, solid background, " "front elevation, straight-on view, architectural drawing, blueprint, " "terrain, landscape, sky, horizon, person, character, " "multiple buildings, city, town, village, street, " @@ -115,19 +123,23 @@ NEGATIVES: dict[str, str] = { ), "resources": ( "text, watermark, blurry, low quality, anime, photo, " - "perspective view, horizon, sky, person, character, building, " + "dark background, black background, solid background, border, frame, " + "texture, pattern, seamless, tileable, " + "person, character, building, landscape, horizon, sky, " "multiple objects, collage, grid, collection, " - "hexagon, geometric, abstract, border, frame, UI" + "hexagon, geometric, abstract, UI" ), "improvements": ( "text, watermark, blurry, low quality, anime, photo, " + "white background, " "perspective view, horizon, sky, person, character, " - "multiple structures, city, village, " + "multiple buildings, city, village, complex scene, " "hexagon, geometric, abstract, border, frame, UI" ), "spells": ( "text, watermark, blurry, low quality, anime, photo, " - "person, character, terrain, landscape, building, " + "person, character, human, angel, figure, face, body, wings, " + "terrain, landscape, building, scenery, " "hexagon, geometric, border, frame, UI, card frame" ), "edges": ( diff --git a/tools/sprite-generation/engine/registry.py b/tools/sprite-generation/engine/registry.py index fdadac8d..5f407fd1 100644 --- a/tools/sprite-generation/engine/registry.py +++ b/tools/sprite-generation/engine/registry.py @@ -426,13 +426,13 @@ class SpriteRegistry: ).fetchall() return [dict(r) for r in rows] - def get_recent_variants(self, limit: int = 30, since: str | None = None) -> list[dict]: + def get_recent_variants(self, limit: int = 30, since_id: int | None = None) -> list[dict]: """Recently completed variants with sprite metadata for the stream ticker.""" clauses = ["v.job_status = 'completed'", "v.raw_path IS NOT NULL"] params: list[str | int] = [] - if since: - clauses.append("v.created_at > ?") - params.append(since) + if since_id is not None: + clauses.append("v.id > ?") + params.append(since_id) where = " WHERE " + " AND ".join(clauses) params.append(limit) rows = self.conn.execute(