| | import torch |
| | import asyncio |
| | from typing import Dict |
| | from comfy.utils import ProgressBar |
| | from comfy_execution.graph_utils import GraphBuilder |
| | from comfy.comfy_types.node_typing import ComfyNodeABC |
| | from comfy.comfy_types import IO |
| |
|
| |
|
| | class TestAsyncValidation(ComfyNodeABC): |
| | """Test node with async VALIDATE_INPUTS.""" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "value": ("FLOAT", {"default": 5.0}), |
| | "threshold": ("FLOAT", {"default": 10.0}), |
| | }, |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE",) |
| | FUNCTION = "process" |
| | CATEGORY = "_for_testing/async" |
| |
|
| | @classmethod |
| | async def VALIDATE_INPUTS(cls, value, threshold): |
| | |
| | await asyncio.sleep(0.05) |
| |
|
| | if value > threshold: |
| | return f"Value {value} exceeds threshold {threshold}" |
| | return True |
| |
|
| | def process(self, value, threshold): |
| | |
| | intensity = value / 10.0 |
| | image = torch.ones([1, 512, 512, 3]) * intensity |
| | return (image,) |
| |
|
| |
|
| | class TestAsyncError(ComfyNodeABC): |
| | """Test node that errors during async execution.""" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "value": (IO.ANY, {}), |
| | "error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}), |
| | }, |
| | } |
| |
|
| | RETURN_TYPES = (IO.ANY,) |
| | FUNCTION = "error_execution" |
| | CATEGORY = "_for_testing/async" |
| |
|
| | async def error_execution(self, value, error_after): |
| | await asyncio.sleep(error_after) |
| | raise RuntimeError("Intentional async execution error for testing") |
| |
|
| |
|
| | class TestAsyncValidationError(ComfyNodeABC): |
| | """Test node with async validation that always fails.""" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "value": ("FLOAT", {"default": 5.0}), |
| | "max_value": ("FLOAT", {"default": 10.0}), |
| | }, |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE",) |
| | FUNCTION = "process" |
| | CATEGORY = "_for_testing/async" |
| |
|
| | @classmethod |
| | async def VALIDATE_INPUTS(cls, value, max_value): |
| | await asyncio.sleep(0.05) |
| | |
| | if value > max_value: |
| | return f"Async validation failed: {value} > {max_value}" |
| | return True |
| |
|
| | def process(self, value, max_value): |
| | |
| | image = torch.ones([1, 512, 512, 3]) * (value / max_value) |
| | return (image,) |
| |
|
| |
|
| | class TestAsyncTimeout(ComfyNodeABC): |
| | """Test node that simulates timeout scenarios.""" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "value": (IO.ANY, {}), |
| | "timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}), |
| | "operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}), |
| | }, |
| | } |
| |
|
| | RETURN_TYPES = (IO.ANY,) |
| | FUNCTION = "timeout_execution" |
| | CATEGORY = "_for_testing/async" |
| |
|
| | async def timeout_execution(self, value, timeout, operation_time): |
| | try: |
| | |
| | await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout) |
| | return (value,) |
| | except asyncio.TimeoutError: |
| | raise RuntimeError(f"Operation timed out after {timeout} seconds") |
| |
|
| |
|
| | class TestSyncError(ComfyNodeABC): |
| | """Test node that errors synchronously (for mixed sync/async testing).""" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "value": (IO.ANY, {}), |
| | }, |
| | } |
| |
|
| | RETURN_TYPES = (IO.ANY,) |
| | FUNCTION = "sync_error" |
| | CATEGORY = "_for_testing/async" |
| |
|
| | def sync_error(self, value): |
| | raise RuntimeError("Intentional sync execution error for testing") |
| |
|
| |
|
| | class TestAsyncLazyCheck(ComfyNodeABC): |
| | """Test node with async check_lazy_status.""" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "input1": (IO.ANY, {"lazy": True}), |
| | "input2": (IO.ANY, {"lazy": True}), |
| | "condition": ("BOOLEAN", {"default": True}), |
| | }, |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE",) |
| | FUNCTION = "process" |
| | CATEGORY = "_for_testing/async" |
| |
|
| | async def check_lazy_status(self, condition, input1, input2): |
| | |
| | await asyncio.sleep(0.05) |
| |
|
| | needed = [] |
| | if condition and input1 is None: |
| | needed.append("input1") |
| | if not condition and input2 is None: |
| | needed.append("input2") |
| | return needed |
| |
|
| | def process(self, input1, input2, condition): |
| | |
| | return (torch.ones([1, 512, 512, 3]),) |
| |
|
| |
|
| | class TestDynamicAsyncGeneration(ComfyNodeABC): |
| | """Test node that dynamically generates async nodes.""" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "image1": ("IMAGE",), |
| | "image2": ("IMAGE",), |
| | "num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}), |
| | "sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}), |
| | }, |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE",) |
| | FUNCTION = "generate_async_workflow" |
| | CATEGORY = "_for_testing/async" |
| |
|
| | def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration): |
| | g = GraphBuilder() |
| |
|
| | |
| | sleep_nodes = [] |
| | for i in range(num_async_nodes): |
| | image = image1 if i % 2 == 0 else image2 |
| | sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration) |
| | sleep_nodes.append(sleep_node) |
| |
|
| | |
| | if len(sleep_nodes) == 1: |
| | final_node = sleep_nodes[0] |
| | else: |
| | avg_inputs = {"input1": sleep_nodes[0].out(0)} |
| | for i, node in enumerate(sleep_nodes[1:], 2): |
| | avg_inputs[f"input{i}"] = node.out(0) |
| | final_node = g.node("TestVariadicAverage", **avg_inputs) |
| |
|
| | return { |
| | "result": (final_node.out(0),), |
| | "expand": g.finalize(), |
| | } |
| |
|
| |
|
| | class TestAsyncResourceUser(ComfyNodeABC): |
| | """Test node that uses resources during async execution.""" |
| |
|
| | |
| | _active_resources: Dict[str, bool] = {} |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "value": (IO.ANY, {}), |
| | "resource_id": ("STRING", {"default": "resource_0"}), |
| | "duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}), |
| | }, |
| | } |
| |
|
| | RETURN_TYPES = (IO.ANY,) |
| | FUNCTION = "use_resource" |
| | CATEGORY = "_for_testing/async" |
| |
|
| | async def use_resource(self, value, resource_id, duration): |
| | |
| | if self._active_resources.get(resource_id, False): |
| | raise RuntimeError(f"Resource {resource_id} is already in use!") |
| |
|
| | |
| | self._active_resources[resource_id] = True |
| |
|
| | try: |
| | |
| | await asyncio.sleep(duration) |
| | return (value,) |
| | finally: |
| | |
| | self._active_resources[resource_id] = False |
| |
|
| |
|
| | class TestAsyncBatchProcessing(ComfyNodeABC): |
| | """Test async processing of batched inputs.""" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "images": ("IMAGE",), |
| | "process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}), |
| | }, |
| | "hidden": { |
| | "unique_id": "UNIQUE_ID", |
| | }, |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE",) |
| | FUNCTION = "process_batch" |
| | CATEGORY = "_for_testing/async" |
| |
|
| | async def process_batch(self, images, process_time_per_item, unique_id): |
| | batch_size = images.shape[0] |
| | pbar = ProgressBar(batch_size, node_id=unique_id) |
| |
|
| | |
| | processed = [] |
| | for i in range(batch_size): |
| | |
| | await asyncio.sleep(process_time_per_item) |
| |
|
| | |
| | processed_image = 1.0 - images[i:i+1] |
| | processed.append(processed_image) |
| |
|
| | pbar.update(1) |
| |
|
| | |
| | result = torch.cat(processed, dim=0) |
| | return (result,) |
| |
|
| |
|
| | class TestAsyncConcurrentLimit(ComfyNodeABC): |
| | """Test concurrent execution limits for async nodes.""" |
| |
|
| | _semaphore = asyncio.Semaphore(2) |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "value": (IO.ANY, {}), |
| | "duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}), |
| | "node_id": ("INT", {"default": 0}), |
| | }, |
| | } |
| |
|
| | RETURN_TYPES = (IO.ANY,) |
| | FUNCTION = "limited_execution" |
| | CATEGORY = "_for_testing/async" |
| |
|
| | async def limited_execution(self, value, duration, node_id): |
| | async with self._semaphore: |
| | |
| | await asyncio.sleep(duration) |
| | |
| | return (value,) |
| |
|
| |
|
| | |
| | ASYNC_TEST_NODE_CLASS_MAPPINGS = { |
| | "TestAsyncValidation": TestAsyncValidation, |
| | "TestAsyncError": TestAsyncError, |
| | "TestAsyncValidationError": TestAsyncValidationError, |
| | "TestAsyncTimeout": TestAsyncTimeout, |
| | "TestSyncError": TestSyncError, |
| | "TestAsyncLazyCheck": TestAsyncLazyCheck, |
| | "TestDynamicAsyncGeneration": TestDynamicAsyncGeneration, |
| | "TestAsyncResourceUser": TestAsyncResourceUser, |
| | "TestAsyncBatchProcessing": TestAsyncBatchProcessing, |
| | "TestAsyncConcurrentLimit": TestAsyncConcurrentLimit, |
| | } |
| |
|
| | ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = { |
| | "TestAsyncValidation": "Test Async Validation", |
| | "TestAsyncError": "Test Async Error", |
| | "TestAsyncValidationError": "Test Async Validation Error", |
| | "TestAsyncTimeout": "Test Async Timeout", |
| | "TestSyncError": "Test Sync Error", |
| | "TestAsyncLazyCheck": "Test Async Lazy Check", |
| | "TestDynamicAsyncGeneration": "Test Dynamic Async Generation", |
| | "TestAsyncResourceUser": "Test Async Resource User", |
| | "TestAsyncBatchProcessing": "Test Async Batch Processing", |
| | "TestAsyncConcurrentLimit": "Test Async Concurrent Limit", |
| | } |
| |
|