|
|
import warnings |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from jaxtyping import Float |
|
|
from lxml import etree |
|
|
|
|
|
|
|
|
def load_asc_cdl(cdl_path: str, device: torch.device = torch.device("cpu")) -> dict: |
|
|
""" |
|
|
Loads ASC CDL parameters from an XML file. |
|
|
|
|
|
Parameters: |
|
|
cdl_path (str): Path to the ASC CDL XML file |
|
|
|
|
|
Returns: |
|
|
Dict: |
|
|
slope, offset, power, and saturation values as torch tensors |
|
|
""" |
|
|
try: |
|
|
tree = etree.parse(cdl_path) |
|
|
root = tree.getroot() |
|
|
except Exception as e: |
|
|
raise ValueError(f"Error loading ASC CDL from {cdl_path}: {e}") |
|
|
|
|
|
|
|
|
sop_node = root.find(".//SOPNode") |
|
|
slope = torch.tensor( |
|
|
[float(x) for x in sop_node.find("Slope").text.split()], device=device |
|
|
) |
|
|
offset = torch.tensor( |
|
|
[float(x) for x in sop_node.find("Offset").text.split()], device=device |
|
|
) |
|
|
power = torch.tensor( |
|
|
[float(x) for x in sop_node.find("Power").text.split()], device=device |
|
|
) |
|
|
|
|
|
|
|
|
sat_node = root.find(".//SatNode") |
|
|
saturation = torch.tensor(float(sat_node.find("Saturation").text), device=device) |
|
|
|
|
|
return {"slope": slope, "offset": offset, "power": power, "saturation": saturation} |
|
|
|
|
|
|
|
|
def save_asc_cdl(cdl_dict: dict, cdl_path: Optional[str]): |
|
|
""" |
|
|
Saves ASC CDL parameters to an XML file. |
|
|
|
|
|
Parameters: |
|
|
cdl_dict (dict): Dictionary containing slope, offset, power, and |
|
|
saturation values |
|
|
""" |
|
|
root = etree.Element("ASC_CDL") |
|
|
sop_node = etree.SubElement(root, "SOPNode") |
|
|
etree.SubElement(sop_node, "Slope").text = " ".join( |
|
|
str(x) for x in cdl_dict["slope"].detach().cpu().numpy() |
|
|
) |
|
|
etree.SubElement(sop_node, "Offset").text = " ".join( |
|
|
str(x) for x in cdl_dict["offset"].detach().cpu().numpy() |
|
|
) |
|
|
etree.SubElement(sop_node, "Power").text = " ".join( |
|
|
str(x) for x in cdl_dict["power"].detach().cpu().numpy() |
|
|
) |
|
|
sat_node = etree.SubElement(root, "SatNode") |
|
|
etree.SubElement(sat_node, "Saturation").text = str( |
|
|
cdl_dict["saturation"].detach().cpu().numpy() |
|
|
) |
|
|
|
|
|
tree = etree.ElementTree(root) |
|
|
if cdl_path is not None: |
|
|
try: |
|
|
tree.write( |
|
|
cdl_path, pretty_print=True, xml_declaration=True, encoding="utf-8" |
|
|
) |
|
|
except Exception as e: |
|
|
raise ValueError(f"Error saving ASC CDL to {cdl_path}: {e}") |
|
|
else: |
|
|
return etree.tostring( |
|
|
root, pretty_print=True, xml_declaration=True, encoding="utf-8" |
|
|
).decode("utf-8") |
|
|
|
|
|
|
|
|
def apply_sop( |
|
|
img: Float[torch.Tensor, "*B C H W"], |
|
|
slope: Float[torch.Tensor, "*B C"], |
|
|
offset: Float[torch.Tensor, "*B C"], |
|
|
power: Float[torch.Tensor, "*B C"], |
|
|
clamp: bool = True, |
|
|
) -> Float[torch.Tensor, "*B C H W"]: |
|
|
""" |
|
|
Applies Slope, Offset, and Power adjustments. |
|
|
|
|
|
Parameters: |
|
|
img (torch.Tensor): Input image tensor (*B, C, H, W) |
|
|
slope (torch.Tensor): Slope per channel (*B, C) |
|
|
offset (torch.Tensor): Offset per channel (*B, C) |
|
|
power (torch.Tensor): Power per channel (*B, C) |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Image after SOP adjustments. |
|
|
""" |
|
|
so = img * slope.unsqueeze(-1).unsqueeze(-1) + offset.unsqueeze(-1).unsqueeze(-1) |
|
|
if clamp: |
|
|
so = torch.clamp(so, min=0.0, max=1.0) |
|
|
return torch.where( |
|
|
so > 1e-7, torch.pow(so.clamp(min=1e-7), power.unsqueeze(-1).unsqueeze(-1)), so |
|
|
) |
|
|
|
|
|
|
|
|
def apply_saturation( |
|
|
img: Float[torch.Tensor, "*B C H W"], |
|
|
saturation: Float[torch.Tensor, "*B"], |
|
|
) -> Float[torch.Tensor, "*B C H W"]: |
|
|
""" |
|
|
Applies saturation adjustment. |
|
|
|
|
|
Parameters: |
|
|
img (torch.Tensor): Image tensor (*B, C, H, W) |
|
|
saturation (torch.Tensor): Saturation factor (*B) |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Image after saturation adjustment. |
|
|
""" |
|
|
|
|
|
lum = ( |
|
|
0.2126 * img[..., 0, :, :] |
|
|
+ 0.7152 * img[..., 1, :, :] |
|
|
+ 0.0722 * img[..., 2, :, :] |
|
|
) |
|
|
lum = lum.unsqueeze(-3) |
|
|
return lum + (img - lum) * saturation.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
|
|
|
|
def asc_cdl_forward( |
|
|
img: Float[torch.Tensor, "*B C H W"], |
|
|
slope: Float[torch.Tensor, "*B C"], |
|
|
offset: Float[torch.Tensor, "*B C"], |
|
|
power: Float[torch.Tensor, "*B C"], |
|
|
saturation: Float[torch.Tensor, "*B"], |
|
|
clamp: bool = True, |
|
|
) -> Float[torch.Tensor, "*B C H W"]: |
|
|
""" |
|
|
Applies ASC CDL transformation in Fwd or FwdNoClamp mode. |
|
|
|
|
|
Parameters: |
|
|
img (torch.Tensor): Input image tensor (*B, C, H, W) |
|
|
slope (torch.Tensor): Slope per channel (*B, C) |
|
|
offset (torch.Tensor): Offset per channel (*B, C) |
|
|
power (torch.Tensor): Power per channel (*B, C) |
|
|
saturation (torch.Tensor): Saturation factor (*B) |
|
|
clamp (bool): If True, clamps output to [0, 1] (Fwd mode). |
|
|
If False, no clamping (FwdNoClamp mode). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Transformed image tensor. |
|
|
""" |
|
|
|
|
|
if (saturation < 0).any(): |
|
|
warnings.warn("Saturation is below 0, this will result in a color shift.") |
|
|
if (slope < 0).any(): |
|
|
warnings.warn("Slope is below 0, this will result in a color shift.") |
|
|
if (power < 0).any(): |
|
|
warnings.warn("Power is below 0, this will result in a color shift.") |
|
|
|
|
|
img_batch_dim = img.shape[:-3] |
|
|
|
|
|
|
|
|
if slope.ndim == 1: |
|
|
slope = slope.view(*[1] * len(img_batch_dim), *slope.shape) |
|
|
if offset.ndim == 1: |
|
|
offset = offset.view(*[1] * len(img_batch_dim), *offset.shape) |
|
|
if power.ndim == 1: |
|
|
power = power.view(*[1] * len(img_batch_dim), *power.shape) |
|
|
if saturation.ndim == 0: |
|
|
saturation = saturation.view(*[1] * len(img_batch_dim), *saturation.shape) |
|
|
|
|
|
|
|
|
assert slope.ndim == len(img_batch_dim) + 1 |
|
|
assert offset.ndim == len(img_batch_dim) + 1 |
|
|
assert power.ndim == len(img_batch_dim) + 1 |
|
|
assert saturation.ndim == len(img_batch_dim) |
|
|
|
|
|
|
|
|
img = apply_sop(img, slope, offset, power, clamp=clamp) |
|
|
|
|
|
|
|
|
img = apply_saturation(img, saturation) |
|
|
|
|
|
|
|
|
if clamp: |
|
|
img = torch.clamp(img, 0.0, 1.0) |
|
|
return img |
|
|
|
|
|
|
|
|
def inverse_saturation( |
|
|
img: Float[torch.Tensor, "*B C H W"], |
|
|
saturation: Float[torch.Tensor, "*B"], |
|
|
) -> Float[torch.Tensor, "*B C H W"]: |
|
|
""" |
|
|
Reverts saturation adjustment. |
|
|
|
|
|
Parameters: |
|
|
img (torch.Tensor): Image tensor (*B, C, H, W) |
|
|
saturation (torch.Tensor): Saturation factor (*B) |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Image after reversing saturation adjustment. |
|
|
""" |
|
|
|
|
|
lum = ( |
|
|
0.2126 * img[..., 0, :, :] |
|
|
+ 0.7152 * img[..., 1, :, :] |
|
|
+ 0.0722 * img[..., 2, :, :] |
|
|
) |
|
|
lum = lum.unsqueeze(-3) |
|
|
return lum + (img - lum) / saturation.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
|
|
|
|
def asc_cdl_reverse( |
|
|
img: Float[torch.Tensor, "*B C H W"], |
|
|
slope: Float[torch.Tensor, "*B C"], |
|
|
offset: Float[torch.Tensor, "*B C"], |
|
|
power: Float[torch.Tensor, "*B C"], |
|
|
saturation: Float[torch.Tensor, "*B"], |
|
|
clamp: bool = True, |
|
|
) -> Float[torch.Tensor, "*B C H W"]: |
|
|
""" |
|
|
Applies reverse ASC CDL transformation. |
|
|
|
|
|
Parameters: |
|
|
img (torch.Tensor): Transformed image tensor (*B, C, H, W) |
|
|
slope (torch.Tensor): Slope per channel (*B, C) |
|
|
offset (torch.Tensor): Offset per channel (*B, C) |
|
|
power (torch.Tensor): Power per channel (*B, C) |
|
|
saturation (torch.Tensor): Saturation factor (*B) |
|
|
clamp (bool): If True, clamps output to [0, 1]. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Recovered input image tensor. |
|
|
""" |
|
|
|
|
|
if (saturation < 0).any(): |
|
|
warnings.warn("Saturation is below 0, this will result in a color shift.") |
|
|
if (slope < 0).any(): |
|
|
warnings.warn("Slope is below 0, this will result in a color shift.") |
|
|
if (power < 0).any(): |
|
|
warnings.warn("Power is below 0, this will result in a color shift.") |
|
|
|
|
|
img_batch_dim = img.shape[:-3] |
|
|
|
|
|
|
|
|
if slope.ndim == 1: |
|
|
slope = slope.view(*[1] * len(img_batch_dim), *slope.shape) |
|
|
if offset.ndim == 1: |
|
|
offset = offset.view(*[1] * len(img_batch_dim), *offset.shape) |
|
|
if power.ndim == 1: |
|
|
power = power.view(*[1] * len(img_batch_dim), *power.shape) |
|
|
if saturation.ndim == 0: |
|
|
saturation = saturation.view(*[1] * len(img_batch_dim), *saturation.shape) |
|
|
|
|
|
|
|
|
assert slope.ndim == len(img_batch_dim) + 1 |
|
|
assert offset.ndim == len(img_batch_dim) + 1 |
|
|
assert power.ndim == len(img_batch_dim) + 1 |
|
|
assert saturation.ndim == len(img_batch_dim) |
|
|
|
|
|
|
|
|
img = inverse_saturation(img, saturation) |
|
|
|
|
|
if clamp: |
|
|
img = torch.clamp(img, 0.0, 1.0) |
|
|
img = torch.where( |
|
|
img > 1e-7, torch.pow(img, 1 / power.unsqueeze(-1).unsqueeze(-1)), img |
|
|
) |
|
|
img = (img - offset.unsqueeze(-1).unsqueeze(-1)) / slope.unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
|
if clamp: |
|
|
img = torch.clamp(img, 0.0, 1.0) |
|
|
return img |
|
|
|