feat(sprite-generation): Add new sprite generation algorithms and dynamic prompt templates, refactor registry for better extensibility

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
Claude Code 2026-03-26 01:06:57 -07:00
parent a2e6d594d1
commit 4e1dad2807
3 changed files with 216 additions and 284 deletions

View file

@ -1,110 +1,45 @@
"""Unified model-boss async job client for sprite generation.
"""Sprite generation via model-boss InferenceClient.
Submits diffusion jobs, polls for results, and downloads completed images.
Uses only stdlib (urllib, json, base64) -- no requests dependency.
Uses the same pattern as auto-commit-service: InferenceClient submits work
to model-boss coordinator via HTTP, waits for results via Redis pubsub.
No polling. model-boss queues internally and processes when GPU is available.
"""
from __future__ import annotations
import asyncio
import base64
import json
import logging
import random
import time
from pathlib import Path
from urllib.error import URLError
from urllib.request import Request, urlopen
from model_boss import InferenceClient
from engine.prompts import get_generation_size, get_negative, get_variant_modifier
from engine.registry import SpriteRegistry
logger = logging.getLogger(__name__)
# Max concurrent generation requests (model-boss queues them, but this
# controls how many we're awaiting pubsub results for simultaneously)
MAX_CONCURRENT = 8
class SpriteGenerator:
def __init__(self, config: dict, registry: SpriteRegistry, raw_dir: Path):
self.api_base = config["api_base"].rstrip("/")
self.model = config["model"]
self.defaults = config["defaults"]
self.registry = registry
self.raw_dir = raw_dir
self.raw_dir.mkdir(parents=True, exist_ok=True)
self._client = InferenceClient(
client_id="sprite-generator",
default_priority="normal",
timeout=600.0,
)
self._semaphore = asyncio.Semaphore(MAX_CONCURRENT)
# -- public API ------------------------------------------------------------
def generate_sprites(
self,
sprite_ids: list[str],
variants_per: int = 8,
priority: str = "normal",
dimension_ids: list[int] | None = None,
) -> int:
"""Submit generation jobs for a list of sprites.
If dimension_ids is provided, generates variants for those specific
dimensions instead of the base sprite. Otherwise generates base variants.
Returns total jobs submitted.
"""
total_submitted = 0
run_id = self.registry.start_run(variants_per=variants_per)
for sprite_id in sprite_ids:
sprite = self.registry.get_sprite(sprite_id)
if sprite is None:
print(f"[skip] Sprite not found: {sprite_id}")
continue
prompt = sprite["prompt"] or ""
negative = sprite["negative_prompt"] or get_negative(sprite["category"])
width = sprite["gen_width"] or 1024
height = sprite["gen_height"] or 512
targets: list[tuple[int | None, str]] = []
if dimension_ids is not None:
for dim in sprite.get("dimensions_list", []):
if dim["id"] in dimension_ids:
dim_modifier = dim.get("prompt_modifier") or ""
targets.append((dim["id"], dim_modifier))
else:
targets.append((None, ""))
for dim_id, dim_modifier in targets:
for i in range(variants_per):
seed = random.randint(0, 2**32 - 1)
variant_modifier = get_variant_modifier(i)
combined_modifier = ", ".join(
p for p in [dim_modifier, variant_modifier] if p
)
variant_id = self.submit_one(
sprite_id=sprite_id,
prompt=prompt,
negative=negative,
width=width,
height=height,
seed=seed,
prompt_modifier=combined_modifier,
priority=priority,
dimension_id=dim_id,
)
if variant_id is not None:
total_submitted += 1
tag = f"{sprite_id} v{i}"
if dim_id is not None:
tag += f" dim={dim_id}"
print(
f"[{total_submitted}] Submitted {tag} (seed={seed})"
)
with self.registry.conn:
self.registry.conn.execute(
"UPDATE generation_runs SET total_jobs=? WHERE id=?",
(total_submitted, run_id),
)
print(f"\nSubmitted {total_submitted} jobs (run {run_id})")
return total_submitted
def submit_one(
async def generate_one(
self,
sprite_id: str,
prompt: str,
@ -116,185 +51,170 @@ class SpriteGenerator:
priority: str = "normal",
dimension_id: int | None = None,
) -> int | None:
"""Queue a single generation job via model-boss async jobs API.
"""Generate a single sprite image via model-boss.
Submits to the queue and returns immediately. Does NOT wait for
the image that's handled by poll_pending().
Returns variant_id or None on submission failure.
Submits to coordinator, waits for result via Redis pubsub.
Saves raw image to disk, creates variant in DB.
Returns variant_id or None on failure.
"""
full_prompt = ", ".join(p for p in [prompt, prompt_modifier] if p)
job_id = self._submit_job(
prompt=full_prompt,
negative=negative,
width=width,
height=height,
seed=seed,
priority=priority,
)
if job_id is None:
async with self._semaphore:
try:
result = await self._client.generate_image(
model=self.model,
prompt=full_prompt,
negative_prompt=negative,
width=width,
height=height,
steps=self.defaults.get("steps", 25),
guidance_scale=self.defaults.get("guidance_scale", 7.5),
seed=seed,
priority=priority,
keep_alive=300,
)
except Exception as exc:
logger.warning("Generation failed for %s: %s", sprite_id, exc)
print(f"[error] {sprite_id}: {exc}")
return None
# Extract image data — format varies:
# model-boss InferenceClient: {"data": [{"b64_json": "..."}], "model": "..."}
# model-boss HTTP API: {"images": ["..."], "durationMs": N}
image_b64 = None
data_list = result.get("data", [])
if data_list and isinstance(data_list[0], dict):
image_b64 = data_list[0].get("b64_json")
if not image_b64:
images = result.get("images", [])
if images:
image_b64 = images[0]
if not image_b64:
print(f"[error] {sprite_id}: no image data in result (keys: {list(result.keys())})")
return None
image_bytes = base64.b64decode(image_b64) if isinstance(image_b64, str) else image_b64
duration = result.get("durationMs") or result.get("duration_ms")
# Create variant in DB
variant_id = self.registry.add_variant(
sprite_id=sprite_id,
seed=seed,
dimension_id=dimension_id,
prompt_modifier=prompt_modifier,
job_id=job_id,
job_id=None,
)
# Save raw image
safe_name = sprite_id.replace("/", "_")
filename = f"{safe_name}_{variant_id}.png"
raw_path = self.raw_dir / filename
raw_path.write_bytes(image_bytes)
# Update variant as completed
self.registry.update_variant_status(
variant_id, "completed",
raw_path=str(raw_path),
generation_ms=duration,
)
return variant_id
def poll_pending(self, max_polls: int = 1000) -> dict:
"""Poll all pending/running variants for completion.
Single pass over pending variants. Downloads and saves completed images.
Returns {"completed": N, "failed": N, "pending": N}.
"""
pending_variants = self.registry.get_pending_variants()
if not pending_variants:
return {"completed": 0, "failed": 0, "pending": 0}
completed = 0
failed = 0
still_pending = 0
count = min(len(pending_variants), max_polls)
for variant in pending_variants[:count]:
job_id = variant["job_id"]
if not job_id:
failed += 1
self.registry.update_variant_status(variant["id"], "failed")
continue
status_data = self._check_job(job_id)
if status_data is None:
still_pending += 1
continue
job_status = status_data.get("status", "pending")
if job_status == "completed":
dl = self._download_result(job_id)
if dl is not None:
image_bytes, duration = dl
safe_name = variant["sprite_id"].replace("/", "_")
filename = f"{safe_name}_{variant['id']}.png"
raw_path = self.raw_dir / filename
raw_path.write_bytes(image_bytes)
self.registry.update_variant_status(
variant["id"],
"completed",
raw_path=str(raw_path),
generation_ms=duration,
)
completed += 1
else:
self.registry.update_variant_status(variant["id"], "failed")
failed += 1
elif job_status == "failed":
self.registry.update_variant_status(variant["id"], "failed")
failed += 1
elif job_status == "running":
self.registry.update_variant_status(variant["id"], "running")
still_pending += 1
else:
still_pending += 1
report = {"completed": completed, "failed": failed, "pending": still_pending}
print(
f"Poll: {completed} completed, {failed} failed, {still_pending} pending"
)
return report
def poll_and_wait(self, interval: int = 10, timeout: int = 0) -> dict:
"""Poll continuously until all pending jobs are done.
Prints progress every `interval` seconds.
If timeout > 0, stop after that many seconds.
Returns final poll report.
"""
start = time.monotonic()
report = {"completed": 0, "failed": 0, "pending": 0}
while True:
report = self.poll_pending()
if report["pending"] == 0:
print("All jobs finished.")
return report
if timeout > 0 and (time.monotonic() - start) >= timeout:
print(f"Timeout after {timeout}s with {report['pending']} pending.")
return report
elapsed = int(time.monotonic() - start)
print(f" Waiting {interval}s... (elapsed {elapsed}s)")
time.sleep(interval)
# -- internal HTTP ---------------------------------------------------------
def _submit_job(
async def generate_batch(
self,
sprite_ids: list[str],
variants_per: int = 4,
priority: str = "normal",
) -> int:
"""Generate variants for a list of sprites concurrently.
Each request goes through model-boss Redis pubsub.
Concurrency controlled by semaphore (MAX_CONCURRENT awaiting at once).
model-boss handles its own internal queue for GPU scheduling.
Returns total variants generated.
"""
run_id = self.registry.start_run(variants_per=variants_per)
tasks: list[asyncio.Task] = []
total_planned = 0
for sprite_id in sprite_ids:
sprite = self.registry.get_sprite(sprite_id)
if not sprite:
continue
prompt = sprite["prompt"]
negative = sprite["negative_prompt"] or get_negative(sprite["category"])
gen_w, gen_h = get_generation_size(sprite["category"])
width = sprite.get("gen_width") or gen_w
height = sprite.get("gen_height") or gen_h
for i in range(variants_per):
seed = random.randint(0, 2**32 - 1)
modifier = get_variant_modifier(i)
total_planned += 1
task = asyncio.create_task(
self._generate_and_report(
sprite_id=sprite_id,
prompt=prompt,
negative=negative,
width=width,
height=height,
seed=seed,
prompt_modifier=modifier,
priority=priority,
index=total_planned,
)
)
tasks.append(task)
print(f"Queued {total_planned} generation requests ({len(sprite_ids)} sprites × {variants_per} variants)")
# Update run with total
with self.registry.conn:
self.registry.conn.execute(
"UPDATE generation_runs SET total_jobs=? WHERE id=?",
(total_planned, run_id),
)
# Await all — results stream in as model-boss completes them
results = await asyncio.gather(*tasks, return_exceptions=True)
completed = sum(1 for r in results if isinstance(r, int) and r is not None)
failed = total_planned - completed
self.registry.update_run(run_id, completed_delta=completed, failed_delta=failed, finished=True)
print(f"\nCompleted: {completed}/{total_planned} ({failed} failed)")
return completed
async def _generate_and_report(
self,
sprite_id: str,
prompt: str,
negative: str,
width: int,
height: int,
seed: int,
prompt_modifier: str,
priority: str,
) -> str | None:
"""Submit a job to model-boss async queue. Returns job_id or None."""
body = {
"model": self.model,
"prompt": prompt,
"negativePrompt": negative,
"width": width,
"height": height,
"steps": self.defaults.get("steps", 25),
"guidanceScale": self.defaults.get("guidance_scale", 7.5),
"seed": seed,
"xPriority": priority,
"xClientId": "sprite-generator",
}
data = json.dumps(body).encode()
req = Request(
f"{self.api_base}/api/v1/diffusion/jobs",
data=data,
headers={"Content-Type": "application/json"},
method="POST",
index: int,
) -> int | None:
"""Generate one variant and print progress."""
variant_id = await self.generate_one(
sprite_id=sprite_id,
prompt=prompt,
negative=negative,
width=width,
height=height,
seed=seed,
prompt_modifier=prompt_modifier,
priority=priority,
)
try:
with urlopen(req, timeout=30) as resp:
result = json.loads(resp.read())
return result.get("jobId")
except (URLError, OSError, json.JSONDecodeError, KeyError) as exc:
print(f"[error] Submit failed: {exc}")
return None
if variant_id is not None:
print(f"[{index}] Done {sprite_id} (seed={seed}) → variant {variant_id}")
return variant_id
def _check_job(self, job_id: str) -> dict | None:
"""Check job status. Returns status dict or None on error."""
req = Request(f"{self.api_base}/api/v1/diffusion/jobs/{job_id}")
try:
with urlopen(req, timeout=15) as resp:
return json.loads(resp.read())
except (URLError, OSError, json.JSONDecodeError) as exc:
print(f"[error] Status check failed for {job_id}: {exc}")
return None
def _download_result(self, job_id: str) -> tuple[bytes, int | None] | None:
"""Download completed job result. Returns (image_bytes, duration_ms) or None."""
req = Request(f"{self.api_base}/api/v1/diffusion/jobs/{job_id}/result")
try:
with urlopen(req, timeout=60) as resp:
result = json.loads(resp.read())
images = result.get("images", [])
if not images:
print(f"[error] No images in result for {job_id}")
return None
duration = result.get("durationMs")
return base64.b64decode(images[0]), duration
except (URLError, OSError, json.JSONDecodeError) as exc:
print(f"[error] Download failed for {job_id}: {exc}")
return None
async def close(self) -> None:
"""Clean up client resources."""
await self._client.close()

View file

@ -28,46 +28,52 @@ STYLE_PREFIXES: dict[str, str] = {
"masterpiece, best quality, game asset"
),
"units": (
"isometric game character sprite viewed from above at 45-degree angle, "
"NOT front-facing, NOT side view, seen from elevated camera looking down, "
"single character standing on flat grey background, "
"full body visible head to toe, small figure centered in frame, "
"painted digital art style, strategy game unit sprite like Age of Wonders or Warcraft III, "
"one character only, clean edges, no scenery, no ground texture, "
"masterpiece, best quality, game sprite on neutral background"
"isometric game character sprite on transparent background, "
"camera looking down from above at 60 degree angle, "
"you see the TOP OF THE HEAD and shoulders from above, "
"character appears short and foreshortened from the high camera, "
"like an Age of Empires II unit or Diablo II character sprite, "
"single small figure, full body visible, painted digital art, "
"clean edges ready to layer over terrain, "
"NO front-facing portrait, NO eye contact with camera, "
"transparent PNG background, masterpiece, best quality"
),
"buildings": (
"isometric building viewed from above at 45-degree angle, "
"isometric building on transparent background, "
"camera looking down from above at 45-degree angle, "
"you can see the ROOF and TWO WALLS of the building, "
"NOT front-facing, NOT a facade, camera is elevated looking down, "
"single small building on flat ground, painted digital art, "
"strategy game building like Age of Empires or Civilization V, "
"one structure only, clean edges, no other buildings, no characters, "
"masterpiece, best quality, game building sprite"
"single small building, painted digital art, "
"like Age of Empires II or Civilization V building, "
"clean edges ready to layer over terrain, "
"transparent PNG background, masterpiece, best quality"
),
"resources": (
"single natural resource object centered on transparent background, "
"painted fantasy game icon style, like a Civilization V map resource marker, "
"ONE distinct recognizable object, NOT a texture, NOT a pattern, "
"clean simple icon composition with clear silhouette, "
"small game sprite overlay for a hex tile, "
"masterpiece, best quality, game resource icon"
"top-down overhead view of a natural resource deposit on the ground, "
"terrain overlay sprite for a hex-grid strategy game, "
"the resource is visible as a natural feature embedded in the earth, "
"like Civilization V resource icons visible on hex tiles when revealed, "
"painted fantasy game art, transparent background around the deposit, "
"small concentrated feature, NOT a full terrain texture, "
"masterpiece, best quality, game map overlay sprite"
),
"improvements": (
"single tile improvement viewed from above, isometric perspective, "
"small cultivated area or construction on natural ground, "
"painted fantasy game art, strategy game tile overlay, "
"like a Civilization V tile improvement, simple clean composition, "
"ONE feature centered, recognizable at small scale, "
"masterpiece, best quality, game tile improvement"
"single tile improvement on transparent background, "
"isometric view from above, small area of cultivated or developed land, "
"ONE simple feature like a small farm field or mine entrance, "
"like a Civilization V tile improvement overlaid on terrain, "
"painted fantasy game art, clean simple composition, "
"NOT a complex scene with multiple buildings, just ONE improvement, "
"transparent PNG background, masterpiece, best quality"
),
"spells": (
"magical spell effect, dramatic magical energy, "
"painterly fantasy illustration suitable for a spellbook icon, "
"vivid glowing magical energy as the central focus, "
"dark background, single spell effect, clean composition, "
"like a Magic: The Gathering card art or spell icon, "
"masterpiece, best quality, game icon"
"magical spell effect icon on dark background, "
"abstract magical energy as the subject NOT a character, "
"glowing energy, particles, runes, or elemental force, "
"like a spell icon from Diablo or World of Warcraft ability bar, "
"circular or square icon composition, centered, "
"NO person, NO character, NO angel, just pure magical energy, "
"vivid colors on black background, masterpiece, best quality"
),
"edges": (
"seamless terrain edge transition, top-down bird's-eye view, "
@ -102,12 +108,14 @@ NEGATIVES: dict[str, str] = {
),
"units": (
"text, watermark, blurry, low quality, anime, photo, "
"background, scenery, terrain, landscape, sky, horizon, ground, floor, "
"white background, solid background, background scenery, terrain, landscape, sky, horizon, "
"front-facing, looking at camera, portrait, headshot, "
"multiple characters, crowd, group, army, "
"hexagon, geometric, border, frame, UI, card frame"
),
"buildings": (
"text, watermark, blurry, low quality, anime, photo, "
"white background, solid background, "
"front elevation, straight-on view, architectural drawing, blueprint, "
"terrain, landscape, sky, horizon, person, character, "
"multiple buildings, city, town, village, street, "
@ -115,19 +123,23 @@ NEGATIVES: dict[str, str] = {
),
"resources": (
"text, watermark, blurry, low quality, anime, photo, "
"perspective view, horizon, sky, person, character, building, "
"dark background, black background, solid background, border, frame, "
"texture, pattern, seamless, tileable, "
"person, character, building, landscape, horizon, sky, "
"multiple objects, collage, grid, collection, "
"hexagon, geometric, abstract, border, frame, UI"
"hexagon, geometric, abstract, UI"
),
"improvements": (
"text, watermark, blurry, low quality, anime, photo, "
"white background, "
"perspective view, horizon, sky, person, character, "
"multiple structures, city, village, "
"multiple buildings, city, village, complex scene, "
"hexagon, geometric, abstract, border, frame, UI"
),
"spells": (
"text, watermark, blurry, low quality, anime, photo, "
"person, character, terrain, landscape, building, "
"person, character, human, angel, figure, face, body, wings, "
"terrain, landscape, building, scenery, "
"hexagon, geometric, border, frame, UI, card frame"
),
"edges": (

View file

@ -426,13 +426,13 @@ class SpriteRegistry:
).fetchall()
return [dict(r) for r in rows]
def get_recent_variants(self, limit: int = 30, since: str | None = None) -> list[dict]:
def get_recent_variants(self, limit: int = 30, since_id: int | None = None) -> list[dict]:
"""Recently completed variants with sprite metadata for the stream ticker."""
clauses = ["v.job_status = 'completed'", "v.raw_path IS NOT NULL"]
params: list[str | int] = []
if since:
clauses.append("v.created_at > ?")
params.append(since)
if since_id is not None:
clauses.append("v.id > ?")
params.append(since_id)
where = " WHERE " + " AND ".join(clauses)
params.append(limit)
rows = self.conn.execute(