| import torch |
| from .tools import VariantSupport |
| from comfy_execution.graph_utils import GraphBuilder |
|
|
| class TestLazyMixImages: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "image1": ("IMAGE",{"lazy": True}), |
| "image2": ("IMAGE",{"lazy": True}), |
| "mask": ("MASK",), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "mix" |
|
|
| CATEGORY = "Testing/Nodes" |
|
|
| def check_lazy_status(self, mask, image1, image2): |
| mask_min = mask.min() |
| mask_max = mask.max() |
| needed = [] |
| if image1 is None and (mask_min != 1.0 or mask_max != 1.0): |
| needed.append("image1") |
| if image2 is None and (mask_min != 0.0 or mask_max != 0.0): |
| needed.append("image2") |
| return needed |
|
|
| |
| def mix(self, mask, image1, image2): |
| mask_min = mask.min() |
| mask_max = mask.max() |
| if mask_min == 0.0 and mask_max == 0.0: |
| return (image1,) |
| elif mask_min == 1.0 and mask_max == 1.0: |
| return (image2,) |
|
|
| if len(mask.shape) == 2: |
| mask = mask.unsqueeze(0) |
| if len(mask.shape) == 3: |
| mask = mask.unsqueeze(3) |
| if mask.shape[3] < image1.shape[3]: |
| mask = mask.repeat(1, 1, 1, image1.shape[3]) |
|
|
| result = image1 * (1. - mask) + image2 * mask, |
| return (result[0],) |
|
|
| class TestVariadicAverage: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "input1": ("IMAGE",), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "variadic_average" |
|
|
| CATEGORY = "Testing/Nodes" |
|
|
| def variadic_average(self, input1, **kwargs): |
| inputs = [input1] |
| while 'input' + str(len(inputs) + 1) in kwargs: |
| inputs.append(kwargs['input' + str(len(inputs) + 1)]) |
| return (torch.stack(inputs).mean(dim=0),) |
|
|
|
|
| class TestCustomIsChanged: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "image": ("IMAGE",), |
| }, |
| "optional": { |
| "should_change": ("BOOL", {"default": False}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "custom_is_changed" |
|
|
| CATEGORY = "Testing/Nodes" |
|
|
| def custom_is_changed(self, image, should_change=False): |
| return (image,) |
| |
| @classmethod |
| def IS_CHANGED(cls, should_change=False, *args, **kwargs): |
| if should_change: |
| return float("NaN") |
| else: |
| return False |
|
|
| class TestIsChangedWithConstants: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "image": ("IMAGE",), |
| "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "custom_is_changed" |
|
|
| CATEGORY = "Testing/Nodes" |
|
|
| def custom_is_changed(self, image, value): |
| return (image * value,) |
| |
| @classmethod |
| def IS_CHANGED(cls, image, value): |
| if image is None: |
| return value |
| else: |
| return image.mean().item() * value |
|
|
| class TestCustomValidation1: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "input1": ("IMAGE,FLOAT",), |
| "input2": ("IMAGE,FLOAT",), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "custom_validation1" |
|
|
| CATEGORY = "Testing/Nodes" |
|
|
| def custom_validation1(self, input1, input2): |
| if isinstance(input1, float) and isinstance(input2, float): |
| result = torch.ones([1, 512, 512, 3]) * input1 * input2 |
| else: |
| result = input1 * input2 |
| return (result,) |
|
|
| @classmethod |
| def VALIDATE_INPUTS(cls, input1=None, input2=None): |
| if input1 is not None: |
| if not isinstance(input1, (torch.Tensor, float)): |
| return f"Invalid type of input1: {type(input1)}" |
| if input2 is not None: |
| if not isinstance(input2, (torch.Tensor, float)): |
| return f"Invalid type of input2: {type(input2)}" |
|
|
| return True |
|
|
| class TestCustomValidation2: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "input1": ("IMAGE,FLOAT",), |
| "input2": ("IMAGE,FLOAT",), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "custom_validation2" |
|
|
| CATEGORY = "Testing/Nodes" |
|
|
| def custom_validation2(self, input1, input2): |
| if isinstance(input1, float) and isinstance(input2, float): |
| result = torch.ones([1, 512, 512, 3]) * input1 * input2 |
| else: |
| result = input1 * input2 |
| return (result,) |
|
|
| @classmethod |
| def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None): |
| if input1 is not None: |
| if not isinstance(input1, (torch.Tensor, float)): |
| return f"Invalid type of input1: {type(input1)}" |
| if input2 is not None: |
| if not isinstance(input2, (torch.Tensor, float)): |
| return f"Invalid type of input2: {type(input2)}" |
|
|
| if 'input1' in input_types: |
| if input_types['input1'] not in ["IMAGE", "FLOAT"]: |
| return f"Invalid type of input1: {input_types['input1']}" |
| if 'input2' in input_types: |
| if input_types['input2'] not in ["IMAGE", "FLOAT"]: |
| return f"Invalid type of input2: {input_types['input2']}" |
|
|
| return True |
|
|
| @VariantSupport() |
| class TestCustomValidation3: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "input1": ("IMAGE,FLOAT",), |
| "input2": ("IMAGE,FLOAT",), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "custom_validation3" |
|
|
| CATEGORY = "Testing/Nodes" |
|
|
| def custom_validation3(self, input1, input2): |
| if isinstance(input1, float) and isinstance(input2, float): |
| result = torch.ones([1, 512, 512, 3]) * input1 * input2 |
| else: |
| result = input1 * input2 |
| return (result,) |
|
|
| class TestCustomValidation4: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "input1": ("FLOAT",), |
| "input2": ("FLOAT",), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "custom_validation4" |
|
|
| CATEGORY = "Testing/Nodes" |
|
|
| def custom_validation4(self, input1, input2): |
| result = torch.ones([1, 512, 512, 3]) * input1 * input2 |
| return (result,) |
|
|
| @classmethod |
| def VALIDATE_INPUTS(cls, input1, input2): |
| if input1 is not None: |
| if not isinstance(input1, float): |
| return f"Invalid type of input1: {type(input1)}" |
| if input2 is not None: |
| if not isinstance(input2, float): |
| return f"Invalid type of input2: {type(input2)}" |
|
|
| return True |
|
|
| class TestCustomValidation5: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "input1": ("FLOAT", {"min": 0.0, "max": 1.0}), |
| "input2": ("FLOAT", {"min": 0.0, "max": 1.0}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "custom_validation5" |
|
|
| CATEGORY = "Testing/Nodes" |
|
|
| def custom_validation5(self, input1, input2): |
| value = input1 * input2 |
| return (torch.ones([1, 512, 512, 3]) * value,) |
|
|
| @classmethod |
| def VALIDATE_INPUTS(cls, **kwargs): |
| if kwargs['input2'] == 7.0: |
| return "7s are not allowed. I've never liked 7s." |
| return True |
|
|
| class TestDynamicDependencyCycle: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "input1": ("IMAGE",), |
| "input2": ("IMAGE",), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "dynamic_dependency_cycle" |
|
|
| CATEGORY = "Testing/Nodes" |
|
|
| def dynamic_dependency_cycle(self, input1, input2): |
| g = GraphBuilder() |
| mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) |
| mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0)) |
| mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0)) |
|
|
| |
| mix1.set_input("image2", mix2.out(0)) |
|
|
| return { |
| "result": (mix2.out(0),), |
| "expand": g.finalize(), |
| } |
|
|
| class TestMixedExpansionReturns: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "input1": ("FLOAT",), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE","IMAGE") |
| FUNCTION = "mixed_expansion_returns" |
|
|
| CATEGORY = "Testing/Nodes" |
|
|
| def mixed_expansion_returns(self, input1): |
| white_image = torch.ones([1, 512, 512, 3]) |
| if input1 <= 0.1: |
| return (torch.ones([1, 512, 512, 3]) * 0.1, white_image) |
| elif input1 <= 0.2: |
| return { |
| "result": (torch.ones([1, 512, 512, 3]) * 0.2, white_image), |
| } |
| else: |
| g = GraphBuilder() |
| mask = g.node("StubMask", value=0.3, height=512, width=512, batch_size=1) |
| black = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
| white = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) |
| mix = g.node("TestLazyMixImages", image1=black.out(0), image2=white.out(0), mask=mask.out(0)) |
| return { |
| "result": (mix.out(0), white_image), |
| "expand": g.finalize(), |
| } |
|
|
| TEST_NODE_CLASS_MAPPINGS = { |
| "TestLazyMixImages": TestLazyMixImages, |
| "TestVariadicAverage": TestVariadicAverage, |
| "TestCustomIsChanged": TestCustomIsChanged, |
| "TestIsChangedWithConstants": TestIsChangedWithConstants, |
| "TestCustomValidation1": TestCustomValidation1, |
| "TestCustomValidation2": TestCustomValidation2, |
| "TestCustomValidation3": TestCustomValidation3, |
| "TestCustomValidation4": TestCustomValidation4, |
| "TestCustomValidation5": TestCustomValidation5, |
| "TestDynamicDependencyCycle": TestDynamicDependencyCycle, |
| "TestMixedExpansionReturns": TestMixedExpansionReturns, |
| } |
|
|
| TEST_NODE_DISPLAY_NAME_MAPPINGS = { |
| "TestLazyMixImages": "Lazy Mix Images", |
| "TestVariadicAverage": "Variadic Average", |
| "TestCustomIsChanged": "Custom IsChanged", |
| "TestIsChangedWithConstants": "IsChanged With Constants", |
| "TestCustomValidation1": "Custom Validation 1", |
| "TestCustomValidation2": "Custom Validation 2", |
| "TestCustomValidation3": "Custom Validation 3", |
| "TestCustomValidation4": "Custom Validation 4", |
| "TestCustomValidation5": "Custom Validation 5", |
| "TestDynamicDependencyCycle": "Dynamic Dependency Cycle", |
| "TestMixedExpansionReturns": "Mixed Expansion Returns", |
| } |
|
|