from io import BytesIO from typing import Any, Callable, cast from warnings import warn, catch_warnings, filterwarnings import numpy as np from torch import Tensor from einops import rearrange import PIL.Image as image import PIL.ImageCms as image_cms from PIL.Image import Image, Resampling from PIL.ImageCms import ( Direction, Intent, ImageCmsProfile, PyCMSError, createProfile, getDefaultIntent, isIntentSupported, profileToProfile ) from PIL.ImageOps import exif_transpose try: import pillow_jxl except ImportError: pass image.MAX_IMAGE_PIXELS = None _SRGB = createProfile(colorSpace='sRGB') _INTENT_FLAGS = { Intent.PERCEPTUAL: image_cms.FLAGS["HIGHRESPRECALC"], Intent.RELATIVE_COLORIMETRIC: ( image_cms.FLAGS["HIGHRESPRECALC"] | image_cms.FLAGS["BLACKPOINTCOMPENSATION"] ), Intent.ABSOLUTE_COLORIMETRIC: image_cms.FLAGS["HIGHRESPRECALC"] } class CMSWarning(UserWarning): def __init__( self, message: str, *, path: str | None = None, cms_info: dict[str, Any] | None = None, cause: Exception | None = None, ): super().__init__(message) self.__cause__ = cause self.path = path self.cms_info = cms_info self.add_note(f"path: {path}") self.add_note(f"info: {cms_info}") def _coalesce_intent(intent: Intent | int) -> Intent: if isinstance(intent, Intent): return intent match intent: case 0: return Intent.PERCEPTUAL case 1: return Intent.RELATIVE_COLORIMETRIC case 2: return Intent.SATURATION case 3: return Intent.ABSOLUTE_COLORIMETRIC case _: raise ValueError("invalid intent") def _add_info(info: dict[str, Any], source: object, key: str) -> None: try: if (value := getattr(source, key, None)) is not None: info[key] = value except Exception: pass def open_srgb( path: str, *, resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None, crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None, expect: tuple[int, int] | None = None, ) -> Image: with open(path, "rb", buffering=(1024 * 1024)) as file: img: Image = image.open(file) try: out = process_srgb(img, resize=resize, crop=crop, expect=expect) except: img.close() raise if img is not out: img.close() return out def process_srgb( img: Image, *, resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None, crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None, expect: tuple[int, int] | None = None, ) -> Image: img.load() try: exif_transpose(img, in_place=True) except Exception: pass # corrupt EXIF metadata is fine size = (img.width, img.height) if expect is not None and size != expect: raise RuntimeError( f"Image is {size[0]}x{size[1]}, " f"but expected {expect[0]}x{expect[1]}." ) if (icc_raw := img.info.get("icc_profile")) is not None: cms_info: dict[str, Any] = { "native_mode": img.mode, "transparency": img.has_transparency_data, } try: profile = ImageCmsProfile(BytesIO(icc_raw)) _add_info(cms_info, profile.profile, "profile_description") _add_info(cms_info, profile.profile, "target") _add_info(cms_info, profile.profile, "xcolor_space") _add_info(cms_info, profile.profile, "connection_space") _add_info(cms_info, profile.profile, "colorimetric_intent") _add_info(cms_info, profile.profile, "rendering_intent") working_mode = img.mode if img.mode.startswith(("RGB", "BGR", "P")): working_mode = "RGBA" if img.has_transparency_data else "RGB" elif img.mode.startswith(("L", "I", "F")) or img.mode == "1": working_mode = "LA" if img.has_transparency_data else "L" if img.mode != working_mode: cms_info["working_mode"] = working_mode img = img.convert(working_mode) mode = "RGBA" if img.has_transparency_data else "RGB" intent = Intent.RELATIVE_COLORIMETRIC if isIntentSupported(profile, intent, Direction.INPUT) != 1: intent = _coalesce_intent(getDefaultIntent(profile)) cms_info["conversion_intent"] = intent if (flags := _INTENT_FLAGS.get(intent)) is None: raise RuntimeError("Unsupported intent") if img.mode == mode: profileToProfile( img, profile, _SRGB, renderingIntent=intent, inPlace=True, flags=flags ) else: img = cast(Image, profileToProfile( img, profile, _SRGB, renderingIntent=intent, outputMode=mode, flags=flags )) except Exception as ex: pass if img.has_transparency_data: if img.mode != "RGBa": try: img = img.convert("RGBa") except ValueError: img = img.convert("RGBA").convert("RGBa") elif img.mode != "RGB": img = img.convert("RGB") if crop is not None and not isinstance(crop, tuple): crop = crop(size) if crop is not None: left, top, right, bottom = crop size = (right - left, top - bottom) if resize is not None and not isinstance(resize, tuple): resize = resize(size) if resize is not None and size != resize: img = img.resize( resize, Resampling.LANCZOS, box=crop, reducing_gap=3.0 ) crop = None if crop is not None: img = img.crop(crop) return img def put_srgb(img: Image, tensor: Tensor) -> None: if img.mode not in ("RGB", "RGBA", "RGBa"): raise ValueError(f"Image has non-RGB mode {img.mode}.") np.copyto(tensor.numpy(), np.asarray(img)[:, :, :3], casting="no") def put_srgb_patch( img: Image, patch_data: Tensor, patch_coord: Tensor, patch_valid: Tensor, patch_size: int ) -> None: if img.mode not in ("RGB", "RGBA", "RGBa"): raise ValueError(f"Image has non-RGB mode {img.mode}.") patches = rearrange( np.asarray(img)[:, :, :3], "(h p1) (w p2) c -> h w (p1 p2 c)", p1=patch_size, p2=patch_size ) coords = np.stack(np.meshgrid( np.arange(patches.shape[0], dtype=np.int16), np.arange(patches.shape[1], dtype=np.int16), indexing="ij" ), axis=-1) coords = rearrange(coords, "h w c -> (h w) c") patches = rearrange(patches, "h w p -> (h w) p") n = patches.shape[0] np.copyto(patch_data[:n].numpy(), patches, casting="no") np.copyto(patch_coord[:n].numpy(), coords, casting="no") patch_valid[:n] = True def unpatchify(input: Tensor, coords: Tensor, valid: Tensor) -> Tensor: """ Scatter valid patches from (seqlen, ...) to (H, W, ...), using coords and valid mask. Args: input: Tensor of shape (seqlen, ...), patch data. coords: Tensor of shape (seqlen, 2), spatial coordinates [y, x] for each patch. valid: Tensor of shape (seqlen,), boolean mask for valid patches. Returns: Tensor of shape (H, W, ...), with valid patches scattered to their spatial locations. """ valid_coords = coords[0, valid[0]] # (n_valid, 2) valid_patches = input[valid[0]] # (n_valid, ...) h = int(valid_coords[:, 0].max().item()) + 1 w = int(valid_coords[:, 1].max().item()) + 1 output_shape = (h, w) + input.shape[1:] output = input.new_zeros(output_shape) output[valid_coords[:, 0], valid_coords[:, 1]] = valid_patches return output