"""Image post-processing for generated sprites — raw model output to game-ready assets.""" from __future__ import annotations from pathlib import Path import numpy as np from PIL import Image, ImageDraw, ImageFilter _CATEGORY_PROCESSORS = { "terrain": "_process_terrain", "biome_grid": "_process_biome_grid", "edges": "_process_edges", "units": "_process_units", "buildings": "_process_buildings", "resources": "_process_resources", "improvements": "_process_improvements", "spells": "_process_spells", "ui": "_process_ui", } class SpriteProcessor: def __init__(self, hex_mask_path: Path | None = None) -> None: """Load hex mask if provided. Mask is a 384x332 flat-top hex polygon.""" self.hex_mask: Image.Image | None = None if hex_mask_path and hex_mask_path.exists(): self.hex_mask = Image.open(hex_mask_path).convert("L") else: self.hex_mask = self.generate_hex_mask() def process(self, raw_path: Path, category: str, output_path: Path) -> bool: """Open raw image, apply category processing, save to output_path.""" method_name = _CATEGORY_PROCESSORS.get(category) if method_name is None: print(f"Unknown category '{category}', skipping {raw_path.name}") return False try: img = Image.open(raw_path).convert("RGBA") except Exception as exc: print(f"Failed to open {raw_path}: {exc}") return False method = getattr(self, method_name) result: Image.Image = method(img) output_path.parent.mkdir(parents=True, exist_ok=True) result.save(output_path) return True # -- category processors --------------------------------------------------- def _process_terrain(self, img: Image.Image) -> Image.Image: """Center crop to square, resize to 384x332, apply hex mask.""" return self._apply_hex_mask(self._center_crop_square(img).resize((384, 332), Image.LANCZOS)) def _process_biome_grid(self, img: Image.Image) -> Image.Image: """Same as terrain — 384x332 hex-masked.""" return self._process_terrain(img) def _process_edges(self, img: Image.Image) -> Image.Image: """384x332 with hex mask.""" return self._process_terrain(img) def _process_units(self, img: Image.Image) -> Image.Image: """Remove green chroma key, center crop, resize to 256x256.""" img = self._remove_chroma_key(img) return self._center_crop_square(img).resize((256, 256), Image.LANCZOS) def _process_buildings(self, img: Image.Image) -> Image.Image: """Remove green chroma key, center crop, resize to 128x128.""" img = self._remove_chroma_key(img) return self._center_crop_square(img).resize((128, 128), Image.LANCZOS) def _process_resources(self, img: Image.Image) -> Image.Image: """Remove green chroma key, center crop, resize to 64x64.""" img = self._remove_chroma_key(img) return self._center_crop_square(img).resize((64, 64), Image.LANCZOS) def _process_improvements(self, img: Image.Image) -> Image.Image: """Remove green chroma key, center crop, resize to 64x64.""" img = self._remove_chroma_key(img) return self._center_crop_square(img).resize((64, 64), Image.LANCZOS) def _process_spells(self, img: Image.Image) -> Image.Image: """Center crop to square, resize to 128x128. No chroma key — spells use black bg.""" return self._center_crop_square(img).resize((128, 128), Image.LANCZOS) def _process_ui(self, img: Image.Image) -> Image.Image: """Remove green chroma key, center crop, resize to 64x64.""" img = self._remove_chroma_key(img) return self._center_crop_square(img).resize((64, 64), Image.LANCZOS) # -- hex mask -------------------------------------------------------------- def _apply_hex_mask(self, img: Image.Image) -> Image.Image: """Apply the hex mask as the alpha channel of the image.""" mask = self.hex_mask if mask is None: return img if mask.size != img.size: mask = mask.resize(img.size, Image.LANCZOS) if img.mode != "RGBA": img = img.convert("RGBA") r, g, b, a = img.split() # Combine existing alpha with hex mask combined_alpha = Image.composite(a, Image.new("L", img.size, 0), mask) return Image.merge("RGBA", (r, g, b, combined_alpha)) # -- chroma key removal ---------------------------------------------------- @staticmethod def _remove_chroma_key( img: Image.Image, green_dominance: int = 30, min_green: int = 80, edge_feather: int = 3, ) -> Image.Image: """Remove green chroma key background, replacing with transparency. Detects pixels where green channel is bright AND dominates R and B. The two thresholds work together: - green_dominance: G must exceed max(R,B) by this amount (filters out dark greens, olive, etc. that appear in clothing/foliage subjects) - min_green: G channel must be at least this bright (filters out dark shadows that happen to be greenish) Edge pixels are feathered with gaussian blur for anti-aliased edges. """ if img.mode != "RGBA": img = img.convert("RGBA") arr = np.array(img, dtype=np.float32) r, g, b = arr[:, :, 0], arr[:, :, 1], arr[:, :, 2] # Chroma key detection: bright green that dominates other channels green_score = g - np.maximum(r, b) is_chroma = (green_score > green_dominance) & (g > min_green) # Create alpha mask: 0 for chroma pixels, 255 for subject alpha = np.where(is_chroma, 0, 255).astype(np.uint8) alpha_img = Image.fromarray(alpha, mode="L") # Feather edges for smooth anti-aliasing if edge_feather > 0: alpha_img = alpha_img.filter(ImageFilter.GaussianBlur(radius=edge_feather)) # Boost contrast to keep hard interior while softening edges alpha_arr = np.array(alpha_img, dtype=np.float32) alpha_arr = np.clip(alpha_arr * 1.5, 0, 255).astype(np.uint8) alpha_img = Image.fromarray(alpha_arr, mode="L") # Combine with existing alpha channel result = img.copy() orig_alpha = np.array(img.split()[3]) new_alpha = np.minimum(orig_alpha, np.array(alpha_img)) result.putalpha(Image.fromarray(new_alpha, mode="L")) return result # -- utilities ------------------------------------------------------------- @staticmethod def _center_crop_square(img: Image.Image) -> Image.Image: """Crop the center of the image to a square.""" w, h = img.size side = min(w, h) left = (w - side) // 2 top = (h - side) // 2 return img.crop((left, top, left + side, top + side)) @staticmethod def generate_hex_mask(width: int = 384, height: int = 332) -> Image.Image: """Generate a flat-top hexagon mask programmatically. Flat-top hex vertices for a hex of width W and height H, centered at (W/2, H/2): (W/4, 0), (3*W/4, 0), (W, H/2), (3*W/4, H), (W/4, H), (0, H/2) Returns a grayscale Image where white=inside, black=outside. """ mask = Image.new("L", (width, height), 0) draw = ImageDraw.Draw(mask) w4 = width / 4 h2 = height / 2 vertices = [ (w4, 0), (3 * w4, 0), (width, h2), (3 * w4, height), (w4, height), (0, h2), ] draw.polygon(vertices, fill=255) return mask