| | import torch
|
| | from .tools import VariantSupport
|
| | from comfy_execution.graph_utils import GraphBuilder
|
| |
|
| | class TestLazyMixImages:
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "image1": ("IMAGE",{"lazy": True}),
|
| | "image2": ("IMAGE",{"lazy": True}),
|
| | "mask": ("MASK",),
|
| | },
|
| | }
|
| |
|
| | RETURN_TYPES = ("IMAGE",)
|
| | FUNCTION = "mix"
|
| |
|
| | CATEGORY = "Testing/Nodes"
|
| |
|
| | def check_lazy_status(self, mask, image1, image2):
|
| | mask_min = mask.min()
|
| | mask_max = mask.max()
|
| | needed = []
|
| | if image1 is None and (mask_min != 1.0 or mask_max != 1.0):
|
| | needed.append("image1")
|
| | if image2 is None and (mask_min != 0.0 or mask_max != 0.0):
|
| | needed.append("image2")
|
| | return needed
|
| |
|
| |
|
| | def mix(self, mask, image1, image2):
|
| | mask_min = mask.min()
|
| | mask_max = mask.max()
|
| | if mask_min == 0.0 and mask_max == 0.0:
|
| | return (image1,)
|
| | elif mask_min == 1.0 and mask_max == 1.0:
|
| | return (image2,)
|
| |
|
| | if len(mask.shape) == 2:
|
| | mask = mask.unsqueeze(0)
|
| | if len(mask.shape) == 3:
|
| | mask = mask.unsqueeze(3)
|
| | if mask.shape[3] < image1.shape[3]:
|
| | mask = mask.repeat(1, 1, 1, image1.shape[3])
|
| |
|
| | result = image1 * (1. - mask) + image2 * mask,
|
| | return (result[0],)
|
| |
|
| | class TestVariadicAverage:
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "input1": ("IMAGE",),
|
| | },
|
| | }
|
| |
|
| | RETURN_TYPES = ("IMAGE",)
|
| | FUNCTION = "variadic_average"
|
| |
|
| | CATEGORY = "Testing/Nodes"
|
| |
|
| | def variadic_average(self, input1, **kwargs):
|
| | inputs = [input1]
|
| | while 'input' + str(len(inputs) + 1) in kwargs:
|
| | inputs.append(kwargs['input' + str(len(inputs) + 1)])
|
| | return (torch.stack(inputs).mean(dim=0),)
|
| |
|
| |
|
| | class TestCustomIsChanged:
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "image": ("IMAGE",),
|
| | },
|
| | "optional": {
|
| | "should_change": ("BOOL", {"default": False}),
|
| | },
|
| | }
|
| |
|
| | RETURN_TYPES = ("IMAGE",)
|
| | FUNCTION = "custom_is_changed"
|
| |
|
| | CATEGORY = "Testing/Nodes"
|
| |
|
| | def custom_is_changed(self, image, should_change=False):
|
| | return (image,)
|
| |
|
| | @classmethod
|
| | def IS_CHANGED(cls, should_change=False, *args, **kwargs):
|
| | if should_change:
|
| | return float("NaN")
|
| | else:
|
| | return False
|
| |
|
| | class TestIsChangedWithConstants:
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "image": ("IMAGE",),
|
| | "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
| | },
|
| | }
|
| |
|
| | RETURN_TYPES = ("IMAGE",)
|
| | FUNCTION = "custom_is_changed"
|
| |
|
| | CATEGORY = "Testing/Nodes"
|
| |
|
| | def custom_is_changed(self, image, value):
|
| | return (image * value,)
|
| |
|
| | @classmethod
|
| | def IS_CHANGED(cls, image, value):
|
| | if image is None:
|
| | return value
|
| | else:
|
| | return image.mean().item() * value
|
| |
|
| | class TestCustomValidation1:
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "input1": ("IMAGE,FLOAT",),
|
| | "input2": ("IMAGE,FLOAT",),
|
| | },
|
| | }
|
| |
|
| | RETURN_TYPES = ("IMAGE",)
|
| | FUNCTION = "custom_validation1"
|
| |
|
| | CATEGORY = "Testing/Nodes"
|
| |
|
| | def custom_validation1(self, input1, input2):
|
| | if isinstance(input1, float) and isinstance(input2, float):
|
| | result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
| | else:
|
| | result = input1 * input2
|
| | return (result,)
|
| |
|
| | @classmethod
|
| | def VALIDATE_INPUTS(cls, input1=None, input2=None):
|
| | if input1 is not None:
|
| | if not isinstance(input1, (torch.Tensor, float)):
|
| | return f"Invalid type of input1: {type(input1)}"
|
| | if input2 is not None:
|
| | if not isinstance(input2, (torch.Tensor, float)):
|
| | return f"Invalid type of input2: {type(input2)}"
|
| |
|
| | return True
|
| |
|
| | class TestCustomValidation2:
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "input1": ("IMAGE,FLOAT",),
|
| | "input2": ("IMAGE,FLOAT",),
|
| | },
|
| | }
|
| |
|
| | RETURN_TYPES = ("IMAGE",)
|
| | FUNCTION = "custom_validation2"
|
| |
|
| | CATEGORY = "Testing/Nodes"
|
| |
|
| | def custom_validation2(self, input1, input2):
|
| | if isinstance(input1, float) and isinstance(input2, float):
|
| | result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
| | else:
|
| | result = input1 * input2
|
| | return (result,)
|
| |
|
| | @classmethod
|
| | def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None):
|
| | if input1 is not None:
|
| | if not isinstance(input1, (torch.Tensor, float)):
|
| | return f"Invalid type of input1: {type(input1)}"
|
| | if input2 is not None:
|
| | if not isinstance(input2, (torch.Tensor, float)):
|
| | return f"Invalid type of input2: {type(input2)}"
|
| |
|
| | if 'input1' in input_types:
|
| | if input_types['input1'] not in ["IMAGE", "FLOAT"]:
|
| | return f"Invalid type of input1: {input_types['input1']}"
|
| | if 'input2' in input_types:
|
| | if input_types['input2'] not in ["IMAGE", "FLOAT"]:
|
| | return f"Invalid type of input2: {input_types['input2']}"
|
| |
|
| | return True
|
| |
|
| | @VariantSupport()
|
| | class TestCustomValidation3:
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "input1": ("IMAGE,FLOAT",),
|
| | "input2": ("IMAGE,FLOAT",),
|
| | },
|
| | }
|
| |
|
| | RETURN_TYPES = ("IMAGE",)
|
| | FUNCTION = "custom_validation3"
|
| |
|
| | CATEGORY = "Testing/Nodes"
|
| |
|
| | def custom_validation3(self, input1, input2):
|
| | if isinstance(input1, float) and isinstance(input2, float):
|
| | result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
| | else:
|
| | result = input1 * input2
|
| | return (result,)
|
| |
|
| | class TestCustomValidation4:
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "input1": ("FLOAT",),
|
| | "input2": ("FLOAT",),
|
| | },
|
| | }
|
| |
|
| | RETURN_TYPES = ("IMAGE",)
|
| | FUNCTION = "custom_validation4"
|
| |
|
| | CATEGORY = "Testing/Nodes"
|
| |
|
| | def custom_validation4(self, input1, input2):
|
| | result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
| | return (result,)
|
| |
|
| | @classmethod
|
| | def VALIDATE_INPUTS(cls, input1, input2):
|
| | if input1 is not None:
|
| | if not isinstance(input1, float):
|
| | return f"Invalid type of input1: {type(input1)}"
|
| | if input2 is not None:
|
| | if not isinstance(input2, float):
|
| | return f"Invalid type of input2: {type(input2)}"
|
| |
|
| | return True
|
| |
|
| | class TestCustomValidation5:
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "input1": ("FLOAT", {"min": 0.0, "max": 1.0}),
|
| | "input2": ("FLOAT", {"min": 0.0, "max": 1.0}),
|
| | },
|
| | }
|
| |
|
| | RETURN_TYPES = ("IMAGE",)
|
| | FUNCTION = "custom_validation5"
|
| |
|
| | CATEGORY = "Testing/Nodes"
|
| |
|
| | def custom_validation5(self, input1, input2):
|
| | value = input1 * input2
|
| | return (torch.ones([1, 512, 512, 3]) * value,)
|
| |
|
| | @classmethod
|
| | def VALIDATE_INPUTS(cls, **kwargs):
|
| | if kwargs['input2'] == 7.0:
|
| | return "7s are not allowed. I've never liked 7s."
|
| | return True
|
| |
|
| | class TestDynamicDependencyCycle:
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "input1": ("IMAGE",),
|
| | "input2": ("IMAGE",),
|
| | },
|
| | }
|
| |
|
| | RETURN_TYPES = ("IMAGE",)
|
| | FUNCTION = "dynamic_dependency_cycle"
|
| |
|
| | CATEGORY = "Testing/Nodes"
|
| |
|
| | def dynamic_dependency_cycle(self, input1, input2):
|
| | g = GraphBuilder()
|
| | mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
| | mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0))
|
| | mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0))
|
| |
|
| |
|
| | mix1.set_input("image2", mix2.out(0))
|
| |
|
| | return {
|
| | "result": (mix2.out(0),),
|
| | "expand": g.finalize(),
|
| | }
|
| |
|
| | class TestMixedExpansionReturns:
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "input1": ("FLOAT",),
|
| | },
|
| | }
|
| |
|
| | RETURN_TYPES = ("IMAGE","IMAGE")
|
| | FUNCTION = "mixed_expansion_returns"
|
| |
|
| | CATEGORY = "Testing/Nodes"
|
| |
|
| | def mixed_expansion_returns(self, input1):
|
| | white_image = torch.ones([1, 512, 512, 3])
|
| | if input1 <= 0.1:
|
| | return (torch.ones([1, 512, 512, 3]) * 0.1, white_image)
|
| | elif input1 <= 0.2:
|
| | return {
|
| | "result": (torch.ones([1, 512, 512, 3]) * 0.2, white_image),
|
| | }
|
| | else:
|
| | g = GraphBuilder()
|
| | mask = g.node("StubMask", value=0.3, height=512, width=512, batch_size=1)
|
| | black = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
| | white = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
| | mix = g.node("TestLazyMixImages", image1=black.out(0), image2=white.out(0), mask=mask.out(0))
|
| | return {
|
| | "result": (mix.out(0), white_image),
|
| | "expand": g.finalize(),
|
| | }
|
| |
|
| | TEST_NODE_CLASS_MAPPINGS = {
|
| | "TestLazyMixImages": TestLazyMixImages,
|
| | "TestVariadicAverage": TestVariadicAverage,
|
| | "TestCustomIsChanged": TestCustomIsChanged,
|
| | "TestIsChangedWithConstants": TestIsChangedWithConstants,
|
| | "TestCustomValidation1": TestCustomValidation1,
|
| | "TestCustomValidation2": TestCustomValidation2,
|
| | "TestCustomValidation3": TestCustomValidation3,
|
| | "TestCustomValidation4": TestCustomValidation4,
|
| | "TestCustomValidation5": TestCustomValidation5,
|
| | "TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
| | "TestMixedExpansionReturns": TestMixedExpansionReturns,
|
| | }
|
| |
|
| | TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
| | "TestLazyMixImages": "Lazy Mix Images",
|
| | "TestVariadicAverage": "Variadic Average",
|
| | "TestCustomIsChanged": "Custom IsChanged",
|
| | "TestIsChangedWithConstants": "IsChanged With Constants",
|
| | "TestCustomValidation1": "Custom Validation 1",
|
| | "TestCustomValidation2": "Custom Validation 2",
|
| | "TestCustomValidation3": "Custom Validation 3",
|
| | "TestCustomValidation4": "Custom Validation 4",
|
| | "TestCustomValidation5": "Custom Validation 5",
|
| | "TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
| | "TestMixedExpansionReturns": "Mixed Expansion Returns",
|
| | }
|
| |
|