| import warnings |
| from typing import Optional |
|
|
| import torch |
| from jaxtyping import Float |
| from lxml import etree |
|
|
|
|
| def load_asc_cdl(cdl_path: str, device: torch.device = torch.device("cpu")) -> dict: |
| """ |
| Loads ASC CDL parameters from an XML file. |
| |
| Parameters: |
| cdl_path (str): Path to the ASC CDL XML file |
| |
| Returns: |
| Dict: |
| slope, offset, power, and saturation values as torch tensors |
| """ |
| try: |
| tree = etree.parse(cdl_path) |
| root = tree.getroot() |
| except Exception as e: |
| raise ValueError(f"Error loading ASC CDL from {cdl_path}: {e}") |
|
|
| |
| sop_node = root.find(".//SOPNode") |
| slope = torch.tensor( |
| [float(x) for x in sop_node.find("Slope").text.split()], device=device |
| ) |
| offset = torch.tensor( |
| [float(x) for x in sop_node.find("Offset").text.split()], device=device |
| ) |
| power = torch.tensor( |
| [float(x) for x in sop_node.find("Power").text.split()], device=device |
| ) |
|
|
| |
| sat_node = root.find(".//SatNode") |
| saturation = torch.tensor(float(sat_node.find("Saturation").text), device=device) |
|
|
| return {"slope": slope, "offset": offset, "power": power, "saturation": saturation} |
|
|
|
|
| def save_asc_cdl(cdl_dict: dict, cdl_path: Optional[str]): |
| """ |
| Saves ASC CDL parameters to an XML file. |
| |
| Parameters: |
| cdl_dict (dict): Dictionary containing slope, offset, power, and |
| saturation values |
| """ |
| root = etree.Element("ASC_CDL") |
| sop_node = etree.SubElement(root, "SOPNode") |
| etree.SubElement(sop_node, "Slope").text = " ".join( |
| str(x) for x in cdl_dict["slope"].detach().cpu().numpy() |
| ) |
| etree.SubElement(sop_node, "Offset").text = " ".join( |
| str(x) for x in cdl_dict["offset"].detach().cpu().numpy() |
| ) |
| etree.SubElement(sop_node, "Power").text = " ".join( |
| str(x) for x in cdl_dict["power"].detach().cpu().numpy() |
| ) |
| sat_node = etree.SubElement(root, "SatNode") |
| etree.SubElement(sat_node, "Saturation").text = str( |
| cdl_dict["saturation"].detach().cpu().numpy() |
| ) |
|
|
| tree = etree.ElementTree(root) |
| if cdl_path is not None: |
| try: |
| tree.write( |
| cdl_path, pretty_print=True, xml_declaration=True, encoding="utf-8" |
| ) |
| except Exception as e: |
| raise ValueError(f"Error saving ASC CDL to {cdl_path}: {e}") |
| else: |
| return etree.tostring( |
| root, pretty_print=True, xml_declaration=True, encoding="utf-8" |
| ).decode("utf-8") |
|
|
|
|
| def apply_sop( |
| img: Float[torch.Tensor, "*B C H W"], |
| slope: Float[torch.Tensor, "*B C"], |
| offset: Float[torch.Tensor, "*B C"], |
| power: Float[torch.Tensor, "*B C"], |
| clamp: bool = True, |
| ) -> Float[torch.Tensor, "*B C H W"]: |
| """ |
| Applies Slope, Offset, and Power adjustments. |
| |
| Parameters: |
| img (torch.Tensor): Input image tensor (*B, C, H, W) |
| slope (torch.Tensor): Slope per channel (*B, C) |
| offset (torch.Tensor): Offset per channel (*B, C) |
| power (torch.Tensor): Power per channel (*B, C) |
| |
| Returns: |
| torch.Tensor: Image after SOP adjustments. |
| """ |
| so = img * slope.unsqueeze(-1).unsqueeze(-1) + offset.unsqueeze(-1).unsqueeze(-1) |
| if clamp: |
| so = torch.clamp(so, min=0.0, max=1.0) |
| return torch.where( |
| so > 1e-7, torch.pow(so.clamp(min=1e-7), power.unsqueeze(-1).unsqueeze(-1)), so |
| ) |
|
|
|
|
| def apply_saturation( |
| img: Float[torch.Tensor, "*B C H W"], |
| saturation: Float[torch.Tensor, "*B"], |
| ) -> Float[torch.Tensor, "*B C H W"]: |
| """ |
| Applies saturation adjustment. |
| |
| Parameters: |
| img (torch.Tensor): Image tensor (*B, C, H, W) |
| saturation (torch.Tensor): Saturation factor (*B) |
| |
| Returns: |
| torch.Tensor: Image after saturation adjustment. |
| """ |
| |
| lum = ( |
| 0.2126 * img[..., 0, :, :] |
| + 0.7152 * img[..., 1, :, :] |
| + 0.0722 * img[..., 2, :, :] |
| ) |
| lum = lum.unsqueeze(-3) |
| return lum + (img - lum) * saturation.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
| def asc_cdl_forward( |
| img: Float[torch.Tensor, "*B C H W"], |
| slope: Float[torch.Tensor, "*B C"], |
| offset: Float[torch.Tensor, "*B C"], |
| power: Float[torch.Tensor, "*B C"], |
| saturation: Float[torch.Tensor, "*B"], |
| clamp: bool = True, |
| ) -> Float[torch.Tensor, "*B C H W"]: |
| """ |
| Applies ASC CDL transformation in Fwd or FwdNoClamp mode. |
| |
| Parameters: |
| img (torch.Tensor): Input image tensor (*B, C, H, W) |
| slope (torch.Tensor): Slope per channel (*B, C) |
| offset (torch.Tensor): Offset per channel (*B, C) |
| power (torch.Tensor): Power per channel (*B, C) |
| saturation (torch.Tensor): Saturation factor (*B) |
| clamp (bool): If True, clamps output to [0, 1] (Fwd mode). |
| If False, no clamping (FwdNoClamp mode). |
| |
| Returns: |
| torch.Tensor: Transformed image tensor. |
| """ |
| |
| if (saturation < 0).any(): |
| warnings.warn("Saturation is below 0, this will result in a color shift.") |
| if (slope < 0).any(): |
| warnings.warn("Slope is below 0, this will result in a color shift.") |
| if (power < 0).any(): |
| warnings.warn("Power is below 0, this will result in a color shift.") |
|
|
| img_batch_dim = img.shape[:-3] |
| |
| |
| if slope.ndim == 1: |
| slope = slope.view(*[1] * len(img_batch_dim), *slope.shape) |
| if offset.ndim == 1: |
| offset = offset.view(*[1] * len(img_batch_dim), *offset.shape) |
| if power.ndim == 1: |
| power = power.view(*[1] * len(img_batch_dim), *power.shape) |
| if saturation.ndim == 0: |
| saturation = saturation.view(*[1] * len(img_batch_dim), *saturation.shape) |
|
|
| |
| assert slope.ndim == len(img_batch_dim) + 1 |
| assert offset.ndim == len(img_batch_dim) + 1 |
| assert power.ndim == len(img_batch_dim) + 1 |
| assert saturation.ndim == len(img_batch_dim) |
|
|
| |
| img = apply_sop(img, slope, offset, power, clamp=clamp) |
| |
| |
| img = apply_saturation(img, saturation) |
| |
| |
| if clamp: |
| img = torch.clamp(img, 0.0, 1.0) |
| return img |
|
|
|
|
| def inverse_saturation( |
| img: Float[torch.Tensor, "*B C H W"], |
| saturation: Float[torch.Tensor, "*B"], |
| ) -> Float[torch.Tensor, "*B C H W"]: |
| """ |
| Reverts saturation adjustment. |
| |
| Parameters: |
| img (torch.Tensor): Image tensor (*B, C, H, W) |
| saturation (torch.Tensor): Saturation factor (*B) |
| |
| Returns: |
| torch.Tensor: Image after reversing saturation adjustment. |
| """ |
| |
| lum = ( |
| 0.2126 * img[..., 0, :, :] |
| + 0.7152 * img[..., 1, :, :] |
| + 0.0722 * img[..., 2, :, :] |
| ) |
| lum = lum.unsqueeze(-3) |
| return lum + (img - lum) / saturation.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
| def asc_cdl_reverse( |
| img: Float[torch.Tensor, "*B C H W"], |
| slope: Float[torch.Tensor, "*B C"], |
| offset: Float[torch.Tensor, "*B C"], |
| power: Float[torch.Tensor, "*B C"], |
| saturation: Float[torch.Tensor, "*B"], |
| clamp: bool = True, |
| ) -> Float[torch.Tensor, "*B C H W"]: |
| """ |
| Applies reverse ASC CDL transformation. |
| |
| Parameters: |
| img (torch.Tensor): Transformed image tensor (*B, C, H, W) |
| slope (torch.Tensor): Slope per channel (*B, C) |
| offset (torch.Tensor): Offset per channel (*B, C) |
| power (torch.Tensor): Power per channel (*B, C) |
| saturation (torch.Tensor): Saturation factor (*B) |
| clamp (bool): If True, clamps output to [0, 1]. |
| |
| Returns: |
| torch.Tensor: Recovered input image tensor. |
| """ |
| |
| if (saturation < 0).any(): |
| warnings.warn("Saturation is below 0, this will result in a color shift.") |
| if (slope < 0).any(): |
| warnings.warn("Slope is below 0, this will result in a color shift.") |
| if (power < 0).any(): |
| warnings.warn("Power is below 0, this will result in a color shift.") |
|
|
| img_batch_dim = img.shape[:-3] |
| |
| |
| if slope.ndim == 1: |
| slope = slope.view(*[1] * len(img_batch_dim), *slope.shape) |
| if offset.ndim == 1: |
| offset = offset.view(*[1] * len(img_batch_dim), *offset.shape) |
| if power.ndim == 1: |
| power = power.view(*[1] * len(img_batch_dim), *power.shape) |
| if saturation.ndim == 0: |
| saturation = saturation.view(*[1] * len(img_batch_dim), *saturation.shape) |
|
|
| |
| assert slope.ndim == len(img_batch_dim) + 1 |
| assert offset.ndim == len(img_batch_dim) + 1 |
| assert power.ndim == len(img_batch_dim) + 1 |
| assert saturation.ndim == len(img_batch_dim) |
|
|
| |
| img = inverse_saturation(img, saturation) |
| |
| if clamp: |
| img = torch.clamp(img, 0.0, 1.0) |
| img = torch.where( |
| img > 1e-7, torch.pow(img, 1 / power.unsqueeze(-1).unsqueeze(-1)), img |
| ) |
| img = (img - offset.unsqueeze(-1).unsqueeze(-1)) / slope.unsqueeze(-1).unsqueeze(-1) |
| |
| if clamp: |
| img = torch.clamp(img, 0.0, 1.0) |
| return img |
|
|