| import torch |
| import numpy as np |
| from .processors import Processor_id |
|
|
|
|
| class ControlNetConfigUnit: |
| def __init__(self, processor_id: Processor_id, model_path, scale=1.0): |
| self.processor_id = processor_id |
| self.model_path = model_path |
| self.scale = scale |
|
|
|
|
| class ControlNetUnit: |
| def __init__(self, processor, model, scale=1.0): |
| self.processor = processor |
| self.model = model |
| self.scale = scale |
|
|
|
|
| class MultiControlNetManager: |
| def __init__(self, controlnet_units=[]): |
| self.processors = [unit.processor for unit in controlnet_units] |
| self.models = [unit.model for unit in controlnet_units] |
| self.scales = [unit.scale for unit in controlnet_units] |
|
|
| def process_image(self, image, processor_id=None): |
| if processor_id is None: |
| processed_image = [processor(image) for processor in self.processors] |
| else: |
| processed_image = [self.processors[processor_id](image)] |
| processed_image = torch.concat([ |
| torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0) |
| for image_ in processed_image |
| ], dim=0) |
| return processed_image |
| |
| def __call__( |
| self, |
| sample, timestep, encoder_hidden_states, conditionings, |
| tiled=False, tile_size=64, tile_stride=32 |
| ): |
| res_stack = None |
| for conditioning, model, scale in zip(conditionings, self.models, self.scales): |
| res_stack_ = model( |
| sample, timestep, encoder_hidden_states, conditioning, |
| tiled=tiled, tile_size=tile_size, tile_stride=tile_stride |
| ) |
| res_stack_ = [res * scale for res in res_stack_] |
| if res_stack is None: |
| res_stack = res_stack_ |
| else: |
| res_stack = [i + j for i, j in zip(res_stack, res_stack_)] |
| return res_stack |
|
|