| from typing import List |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| from controlnet_aux import OpenposeDetector |
| from diffusers import (ControlNetModel, StableDiffusionControlNetPipeline, |
| UniPCMultistepScheduler) |
| from PIL import Image |
| from util.cache import clear_cuda_and_gc |
| from util.commons import disable_safety_checker, download_image |
|
|
|
|
| class ControlNet: |
| __current_task_name = "" |
|
|
| def load(self, model_dir: str): |
| |
| self.load_canny() |
|
|
| pipe = StableDiffusionControlNetPipeline.from_pretrained( |
| model_dir, controlnet=self.controlnet, torch_dtype=torch.float16 |
| ) |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
| pipe.enable_model_cpu_offload() |
| pipe.enable_xformers_memory_efficient_attention() |
| disable_safety_checker(pipe) |
| self.pipe = pipe |
|
|
| def load_canny(self): |
| if self.__current_task_name == "canny": |
| return |
| canny = ControlNetModel.from_pretrained( |
| "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16 |
| ).to("cuda") |
| self.__current_task_name = "canny" |
| self.controlnet = canny |
| if hasattr(self, "pipe"): |
| self.pipe.controlnet = canny |
| clear_cuda_and_gc() |
|
|
| def load_pose(self): |
| if self.__current_task_name == "pose": |
| return |
| pose = ControlNetModel.from_pretrained( |
| "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16 |
| ).to("cuda") |
| self.__current_task_name = "pose" |
| self.controlnet = pose |
| if hasattr(self, "pipe"): |
| self.pipe.controlnet = pose |
| clear_cuda_and_gc() |
|
|
| def cleanup(self): |
| self.pipe.controlnet = None |
| self.controlnet = None |
| self.__current_task_name = "" |
|
|
| clear_cuda_and_gc() |
|
|
| @torch.inference_mode() |
| def process_canny( |
| self, |
| prompt: List[str], |
| imageUrl: str, |
| seed: int, |
| steps: int, |
| negative_prompt: List[str], |
| height: int, |
| width: int, |
| ): |
| if self.__current_task_name != "canny": |
| raise Exception("ControlNet is not loaded with canny model") |
|
|
| torch.manual_seed(seed) |
|
|
| init_image = download_image(imageUrl) |
| init_image = self.__canny_detect_edge(init_image) |
|
|
| return self.pipe.__call__( |
| prompt=prompt, |
| image=init_image, |
| guidance_scale=9, |
| num_images_per_prompt=1, |
| negative_prompt=negative_prompt, |
| num_inference_steps=steps, |
| height=height, |
| width=width, |
| ).images |
|
|
| @torch.inference_mode() |
| def process_pose( |
| self, |
| prompt: List[str], |
| image: List[Image.Image], |
| seed: int, |
| steps: int, |
| negative_prompt: List[str], |
| height: int, |
| width: int, |
| ): |
| if self.__current_task_name != "pose": |
| raise Exception("ControlNet is not loaded with pose model") |
|
|
| torch.manual_seed(seed) |
|
|
| return self.pipe.__call__( |
| prompt=prompt, |
| image=image, |
| num_images_per_prompt=1, |
| num_inference_steps=steps, |
| negative_prompt=negative_prompt, |
| height=height, |
| width=width, |
| ).images |
|
|
| def detect_pose(self, imageUrl: str) -> Image.Image: |
| detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") |
| image = download_image(imageUrl) |
| image = detector.__call__(image) |
| return image |
|
|
| def __canny_detect_edge(self, image: Image.Image) -> Image.Image: |
| image_array = np.array(image) |
|
|
| low_threshold = 100 |
| high_threshold = 200 |
|
|
| image_array = cv2.Canny(image_array, low_threshold, high_threshold) |
| image_array = image_array[:, :, None] |
| image_array = np.concatenate([image_array, image_array, image_array], axis=2) |
| canny_image = Image.fromarray(image_array) |
| return canny_image |
|
|