2026-03-25 22:48:52 -07:00
|
|
|
"""Image post-processing for generated sprites — raw model output to game-ready assets."""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2026-03-28 21:31:40 -07:00
|
|
|
import logging
|
2026-03-25 22:48:52 -07:00
|
|
|
from pathlib import Path
|
2026-03-28 21:31:40 -07:00
|
|
|
from typing import Optional
|
2026-03-25 22:48:52 -07:00
|
|
|
|
2026-03-28 21:31:40 -07:00
|
|
|
from PIL import Image, ImageDraw
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
2026-03-25 22:48:52 -07:00
|
|
|
|
|
|
|
|
_CATEGORY_PROCESSORS = {
|
|
|
|
|
"terrain": "_process_terrain",
|
|
|
|
|
"biome_grid": "_process_biome_grid",
|
|
|
|
|
"edges": "_process_edges",
|
|
|
|
|
"units": "_process_units",
|
|
|
|
|
"buildings": "_process_buildings",
|
|
|
|
|
"resources": "_process_resources",
|
|
|
|
|
"improvements": "_process_improvements",
|
|
|
|
|
"spells": "_process_spells",
|
|
|
|
|
"ui": "_process_ui",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpriteProcessor:
|
|
|
|
|
def __init__(self, hex_mask_path: Path | None = None) -> None:
|
|
|
|
|
"""Load hex mask if provided. Mask is a 384x332 flat-top hex polygon."""
|
|
|
|
|
self.hex_mask: Image.Image | None = None
|
|
|
|
|
if hex_mask_path and hex_mask_path.exists():
|
|
|
|
|
self.hex_mask = Image.open(hex_mask_path).convert("L")
|
|
|
|
|
else:
|
|
|
|
|
self.hex_mask = self.generate_hex_mask()
|
|
|
|
|
|
2026-03-28 21:31:40 -07:00
|
|
|
# Lazy-loaded rembg session for background removal
|
|
|
|
|
self._rembg_session: Optional[object] = None
|
|
|
|
|
|
2026-03-25 22:48:52 -07:00
|
|
|
def process(self, raw_path: Path, category: str, output_path: Path) -> bool:
|
|
|
|
|
"""Open raw image, apply category processing, save to output_path."""
|
|
|
|
|
method_name = _CATEGORY_PROCESSORS.get(category)
|
|
|
|
|
if method_name is None:
|
|
|
|
|
print(f"Unknown category '{category}', skipping {raw_path.name}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
img = Image.open(raw_path).convert("RGBA")
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
print(f"Failed to open {raw_path}: {exc}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
method = getattr(self, method_name)
|
|
|
|
|
result: Image.Image = method(img)
|
|
|
|
|
|
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
result.save(output_path)
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
# -- category processors ---------------------------------------------------
|
|
|
|
|
|
|
|
|
|
def _process_terrain(self, img: Image.Image) -> Image.Image:
|
|
|
|
|
"""Center crop to square, resize to 384x332, apply hex mask."""
|
|
|
|
|
return self._apply_hex_mask(self._center_crop_square(img).resize((384, 332), Image.LANCZOS))
|
|
|
|
|
|
|
|
|
|
def _process_biome_grid(self, img: Image.Image) -> Image.Image:
|
|
|
|
|
"""Same as terrain — 384x332 hex-masked."""
|
|
|
|
|
return self._process_terrain(img)
|
|
|
|
|
|
|
|
|
|
def _process_edges(self, img: Image.Image) -> Image.Image:
|
|
|
|
|
"""384x332 with hex mask."""
|
|
|
|
|
return self._process_terrain(img)
|
|
|
|
|
|
|
|
|
|
def _process_units(self, img: Image.Image) -> Image.Image:
|
2026-03-28 21:31:40 -07:00
|
|
|
"""Remove background, center crop, resize to 256x256."""
|
|
|
|
|
img = self._remove_background(img)
|
2026-03-25 22:48:52 -07:00
|
|
|
return self._center_crop_square(img).resize((256, 256), Image.LANCZOS)
|
|
|
|
|
|
|
|
|
|
def _process_buildings(self, img: Image.Image) -> Image.Image:
|
2026-03-28 21:31:40 -07:00
|
|
|
"""Remove background, center crop, resize to 128x128."""
|
|
|
|
|
img = self._remove_background(img)
|
2026-03-25 22:48:52 -07:00
|
|
|
return self._center_crop_square(img).resize((128, 128), Image.LANCZOS)
|
|
|
|
|
|
|
|
|
|
def _process_resources(self, img: Image.Image) -> Image.Image:
|
2026-03-28 21:31:40 -07:00
|
|
|
"""Remove background, center crop, resize to 64x64."""
|
|
|
|
|
img = self._remove_background(img)
|
2026-03-25 22:48:52 -07:00
|
|
|
return self._center_crop_square(img).resize((64, 64), Image.LANCZOS)
|
|
|
|
|
|
|
|
|
|
def _process_improvements(self, img: Image.Image) -> Image.Image:
|
2026-03-28 21:31:40 -07:00
|
|
|
"""Remove background, center crop, resize to 64x64."""
|
|
|
|
|
img = self._remove_background(img)
|
2026-03-25 22:48:52 -07:00
|
|
|
return self._center_crop_square(img).resize((64, 64), Image.LANCZOS)
|
|
|
|
|
|
|
|
|
|
def _process_spells(self, img: Image.Image) -> Image.Image:
|
2026-03-28 21:31:40 -07:00
|
|
|
"""Remove background, center crop, resize to 128x128."""
|
|
|
|
|
img = self._remove_background(img)
|
2026-03-25 22:48:52 -07:00
|
|
|
return self._center_crop_square(img).resize((128, 128), Image.LANCZOS)
|
|
|
|
|
|
|
|
|
|
def _process_ui(self, img: Image.Image) -> Image.Image:
|
2026-03-28 21:31:40 -07:00
|
|
|
"""Remove background, center crop, resize to 64x64."""
|
|
|
|
|
img = self._remove_background(img)
|
2026-03-25 22:48:52 -07:00
|
|
|
return self._center_crop_square(img).resize((64, 64), Image.LANCZOS)
|
|
|
|
|
|
|
|
|
|
# -- hex mask --------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
def _apply_hex_mask(self, img: Image.Image) -> Image.Image:
|
|
|
|
|
"""Apply the hex mask as the alpha channel of the image."""
|
|
|
|
|
mask = self.hex_mask
|
|
|
|
|
if mask is None:
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
if mask.size != img.size:
|
|
|
|
|
mask = mask.resize(img.size, Image.LANCZOS)
|
|
|
|
|
|
|
|
|
|
if img.mode != "RGBA":
|
|
|
|
|
img = img.convert("RGBA")
|
|
|
|
|
|
|
|
|
|
r, g, b, a = img.split()
|
|
|
|
|
# Combine existing alpha with hex mask
|
|
|
|
|
combined_alpha = Image.composite(a, Image.new("L", img.size, 0), mask)
|
|
|
|
|
return Image.merge("RGBA", (r, g, b, combined_alpha))
|
|
|
|
|
|
2026-03-28 21:31:40 -07:00
|
|
|
# -- background removal (rembg / U2Net) ------------------------------------
|
2026-03-26 11:38:32 -07:00
|
|
|
|
2026-03-28 21:31:40 -07:00
|
|
|
def _get_rembg_session(self) -> object:
|
|
|
|
|
"""Lazy-load rembg session on first use.
|
|
|
|
|
|
|
|
|
|
Uses U2Net with CPU provider to avoid CUDA dtype conflicts — the
|
|
|
|
|
diffusion pipeline uses fp16 but rembg/onnxruntime expects fp32.
|
|
|
|
|
CPU is fast enough for the small images we process (~1024x1024).
|
2026-03-26 11:38:32 -07:00
|
|
|
"""
|
2026-03-28 21:31:40 -07:00
|
|
|
if self._rembg_session is not None:
|
|
|
|
|
return self._rembg_session
|
|
|
|
|
|
|
|
|
|
from rembg import new_session
|
|
|
|
|
|
|
|
|
|
self._rembg_session = new_session(
|
|
|
|
|
"u2net",
|
|
|
|
|
providers=["CPUExecutionProvider"],
|
|
|
|
|
)
|
|
|
|
|
logger.info("rembg session initialized (u2net, CPU)")
|
|
|
|
|
return self._rembg_session
|
2026-03-26 11:38:32 -07:00
|
|
|
|
2026-03-28 21:31:40 -07:00
|
|
|
def _remove_background(self, img: Image.Image) -> Image.Image:
|
|
|
|
|
"""Remove background using U2Net neural segmentation via rembg.
|
|
|
|
|
|
|
|
|
|
Unlike chroma keying, this works regardless of background or subject
|
|
|
|
|
color — it segments the salient foreground object from everything else.
|
|
|
|
|
Alpha matting is enabled for cleaner edges on small game sprites.
|
|
|
|
|
"""
|
|
|
|
|
from rembg import remove
|
|
|
|
|
|
|
|
|
|
if img.mode != "RGB":
|
|
|
|
|
img = img.convert("RGB")
|
|
|
|
|
|
|
|
|
|
return remove(
|
|
|
|
|
img,
|
|
|
|
|
session=self._get_rembg_session(),
|
|
|
|
|
alpha_matting=True,
|
|
|
|
|
alpha_matting_foreground_threshold=240,
|
|
|
|
|
alpha_matting_background_threshold=10,
|
|
|
|
|
)
|
2026-03-26 11:38:32 -07:00
|
|
|
|
2026-03-25 22:48:52 -07:00
|
|
|
# -- utilities -------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _center_crop_square(img: Image.Image) -> Image.Image:
|
|
|
|
|
"""Crop the center of the image to a square."""
|
|
|
|
|
w, h = img.size
|
|
|
|
|
side = min(w, h)
|
|
|
|
|
left = (w - side) // 2
|
|
|
|
|
top = (h - side) // 2
|
|
|
|
|
return img.crop((left, top, left + side, top + side))
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def generate_hex_mask(width: int = 384, height: int = 332) -> Image.Image:
|
|
|
|
|
"""Generate a flat-top hexagon mask programmatically.
|
|
|
|
|
|
|
|
|
|
Flat-top hex vertices for a hex of width W and height H,
|
|
|
|
|
centered at (W/2, H/2):
|
|
|
|
|
(W/4, 0), (3*W/4, 0), (W, H/2), (3*W/4, H), (W/4, H), (0, H/2)
|
|
|
|
|
|
|
|
|
|
Returns a grayscale Image where white=inside, black=outside.
|
|
|
|
|
"""
|
|
|
|
|
mask = Image.new("L", (width, height), 0)
|
|
|
|
|
draw = ImageDraw.Draw(mask)
|
|
|
|
|
w4 = width / 4
|
|
|
|
|
h2 = height / 2
|
|
|
|
|
vertices = [
|
|
|
|
|
(w4, 0),
|
|
|
|
|
(3 * w4, 0),
|
|
|
|
|
(width, h2),
|
|
|
|
|
(3 * w4, height),
|
|
|
|
|
(w4, height),
|
|
|
|
|
(0, h2),
|
|
|
|
|
]
|
|
|
|
|
draw.polygon(vertices, fill=255)
|
|
|
|
|
return mask
|