| import torch |
| import numpy as np |
| from .processors import Processor_id |
|
|
|
|
| class ControlNetConfigUnit: |
| def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False): |
| self.processor_id = processor_id |
| self.model_path = model_path |
| self.scale = scale |
| self.skip_processor = skip_processor |
|
|
|
|
| class ControlNetUnit: |
| def __init__(self, processor, model, scale=1.0): |
| self.processor = processor |
| self.model = model |
| self.scale = scale |
|
|
|
|
| class MultiControlNetManager: |
| def __init__(self, controlnet_units=[]): |
| self.processors = [unit.processor for unit in controlnet_units] |
| self.models = [unit.model for unit in controlnet_units] |
| self.scales = [unit.scale for unit in controlnet_units] |
|
|
| def cpu(self): |
| for model in self.models: |
| model.cpu() |
|
|
| def to(self, device): |
| for model in self.models: |
| model.to(device) |
| for processor in self.processors: |
| processor.to(device) |
| |
| def process_image(self, image, processor_id=None): |
| if processor_id is None: |
| processed_image = [processor(image) for processor in self.processors] |
| else: |
| processed_image = [self.processors[processor_id](image)] |
| processed_image = torch.concat([ |
| torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0) |
| for image_ in processed_image |
| ], dim=0) |
| return processed_image |
| |
| def __call__( |
| self, |
| sample, timestep, encoder_hidden_states, conditionings, |
| tiled=False, tile_size=64, tile_stride=32, **kwargs |
| ): |
| res_stack = None |
| for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales): |
| res_stack_ = model( |
| sample, timestep, encoder_hidden_states, conditioning, **kwargs, |
| tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, |
| processor_id=processor.processor_id |
| ) |
| res_stack_ = [res * scale for res in res_stack_] |
| if res_stack is None: |
| res_stack = res_stack_ |
| else: |
| res_stack = [i + j for i, j in zip(res_stack, res_stack_)] |
| return res_stack |
|
|
|
|
| class FluxMultiControlNetManager(MultiControlNetManager): |
| def __init__(self, controlnet_units=[]): |
| super().__init__(controlnet_units=controlnet_units) |
|
|
| def process_image(self, image, processor_id=None): |
| if processor_id is None: |
| processed_image = [processor(image) for processor in self.processors] |
| else: |
| processed_image = [self.processors[processor_id](image)] |
| return processed_image |
|
|
| def __call__(self, conditionings, **kwargs): |
| res_stack, single_res_stack = None, None |
| for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales): |
| res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs) |
| res_stack_ = [res * scale for res in res_stack_] |
| single_res_stack_ = [res * scale for res in single_res_stack_] |
| if res_stack is None: |
| res_stack = res_stack_ |
| single_res_stack = single_res_stack_ |
| else: |
| res_stack = [i + j for i, j in zip(res_stack, res_stack_)] |
| single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)] |
| return res_stack, single_res_stack |
|
|