| | import torchvision |
| | import torch |
| | from torch import nn |
| | import numpy as np |
| | import kornia |
| | import cv2 |
| | from core.utils import load_or_fail |
| | |
| | from .effnet import EfficientNetEncoder |
| | from .cnet_modules.pidinet import PidiNetDetector |
| | from .cnet_modules.inpainting.saliency_model import MicroResNet |
| | |
| | from .common import LayerNorm2d |
| |
|
| |
|
| | class CNetResBlock(nn.Module): |
| | def __init__(self, c): |
| | super().__init__() |
| | self.blocks = nn.Sequential( |
| | LayerNorm2d(c), |
| | nn.GELU(), |
| | nn.Conv2d(c, c, kernel_size=3, padding=1), |
| | LayerNorm2d(c), |
| | nn.GELU(), |
| | nn.Conv2d(c, c, kernel_size=3, padding=1), |
| | ) |
| |
|
| | def forward(self, x): |
| | return x + self.blocks(x) |
| |
|
| |
|
| | class ControlNet(nn.Module): |
| | def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None): |
| | super().__init__() |
| | if bottleneck_mode is None: |
| | bottleneck_mode = 'effnet' |
| | self.proj_blocks = proj_blocks |
| | if bottleneck_mode == 'effnet': |
| | embd_channels = 1280 |
| | |
| | self.backbone = torchvision.models.efficientnet_v2_s().features.eval() |
| | if c_in != 3: |
| | in_weights = self.backbone[0][0].weight.data |
| | self.backbone[0][0] = nn.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False) |
| | if c_in > 3: |
| | nn.init.constant_(self.backbone[0][0].weight, 0) |
| | self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone() |
| | else: |
| | self.backbone[0][0].weight.data = in_weights[:, :c_in].clone() |
| | elif bottleneck_mode == 'simple': |
| | embd_channels = c_in |
| | self.backbone = nn.Sequential( |
| | nn.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | nn.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1), |
| | ) |
| | elif bottleneck_mode == 'large': |
| | self.backbone = nn.Sequential( |
| | nn.Conv2d(c_in, 4096 * 4, kernel_size=1), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | nn.Conv2d(4096 * 4, 1024, kernel_size=1), |
| | *[CNetResBlock(1024) for _ in range(8)], |
| | nn.Conv2d(1024, 1280, kernel_size=1), |
| | ) |
| | embd_channels = 1280 |
| | else: |
| | raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}') |
| | self.projections = nn.ModuleList() |
| | for _ in range(len(proj_blocks)): |
| | self.projections.append(nn.Sequential( |
| | nn.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | nn.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False), |
| | )) |
| | nn.init.constant_(self.projections[-1][-1].weight, 0) |
| |
|
| | def forward(self, x): |
| | x = self.backbone(x) |
| | proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)] |
| | for i, idx in enumerate(self.proj_blocks): |
| | proj_outputs[idx] = self.projections[i](x) |
| | return proj_outputs |
| |
|
| |
|
| | class ControlNetDeliverer(): |
| | def __init__(self, controlnet_projections): |
| | self.controlnet_projections = controlnet_projections |
| | self.restart() |
| |
|
| | def restart(self): |
| | self.idx = 0 |
| | return self |
| |
|
| | def __call__(self): |
| | if self.idx < len(self.controlnet_projections): |
| | output = self.controlnet_projections[self.idx] |
| | else: |
| | output = None |
| | self.idx += 1 |
| | return output |
| |
|
| |
|
| | |
| |
|
| | class BaseFilter(): |
| | def __init__(self, device): |
| | self.device = device |
| |
|
| | def num_channels(self): |
| | return 3 |
| |
|
| | def __call__(self, x): |
| | return x |
| |
|
| |
|
| | class CannyFilter(BaseFilter): |
| | def __init__(self, device, resize=224): |
| | super().__init__(device) |
| | self.resize = resize |
| |
|
| | def num_channels(self): |
| | return 1 |
| |
|
| | def __call__(self, x): |
| | orig_size = x.shape[-2:] |
| | if self.resize is not None: |
| | x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') |
| | edges = [cv2.Canny(x[i].mul(255).permute(1, 2, 0).cpu().numpy().astype(np.uint8), 100, 200) for i in range(len(x))] |
| | edges = torch.stack([torch.tensor(e).div(255).unsqueeze(0) for e in edges], dim=0) |
| | if self.resize is not None: |
| | edges = nn.functional.interpolate(edges, size=orig_size, mode='bilinear') |
| | return edges |
| |
|
| |
|
| | class QRFilter(BaseFilter): |
| | def __init__(self, device, resize=224, blobify=True, dilation_kernels=[3, 5, 7], blur_kernels=[15]): |
| | super().__init__(device) |
| | self.resize = resize |
| | self.blobify = blobify |
| | self.dilation_kernels = dilation_kernels |
| | self.blur_kernels = blur_kernels |
| |
|
| | def num_channels(self): |
| | return 1 |
| |
|
| | def __call__(self, x): |
| | x = x.to(self.device) |
| | orig_size = x.shape[-2:] |
| | if self.resize is not None: |
| | x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') |
| |
|
| | x = kornia.color.rgb_to_hsv(x)[:, -1:] |
| | |
| | if self.blobify: |
| | d_kernel = np.random.choice(self.dilation_kernels) |
| | d_blur = np.random.choice(self.blur_kernels) |
| | if d_blur > 0: |
| | x = torchvision.transforms.GaussianBlur(d_blur)(x) |
| | if d_kernel > 0: |
| | blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5, |
| | d_kernel).pow(2)[:, |
| | None]) < 0.3).float().to(self.device) |
| | x = kornia.morphology.dilation(x, blob_mask) |
| | x = kornia.morphology.erosion(x, blob_mask) |
| | |
| | vmax, vmin = x.amax(dim=[2, 3], keepdim=True)[0], x.amin(dim=[2, 3], keepdim=True)[0] |
| | th = (vmax - vmin) * 0.33 |
| | high_brightness, low_brightness = (x > (vmax - th)).float(), (x < (vmin + th)).float() |
| | mask = (torch.ones_like(x) - low_brightness + high_brightness) * 0.5 |
| |
|
| | if self.resize is not None: |
| | mask = nn.functional.interpolate(mask, size=orig_size, mode='bilinear') |
| | return mask.cpu() |
| |
|
| |
|
| | class PidiFilter(BaseFilter): |
| | def __init__(self, device, resize=224, dilation_kernels=[0, 3, 5, 7, 9], binarize=True): |
| | super().__init__(device) |
| | self.resize = resize |
| | self.model = PidiNetDetector(device) |
| | self.dilation_kernels = dilation_kernels |
| | self.binarize = binarize |
| |
|
| | def num_channels(self): |
| | return 1 |
| |
|
| | def __call__(self, x): |
| | x = x.to(self.device) |
| | orig_size = x.shape[-2:] |
| | if self.resize is not None: |
| | x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') |
| |
|
| | x = self.model(x) |
| | d_kernel = np.random.choice(self.dilation_kernels) |
| | if d_kernel > 0: |
| | blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5, d_kernel).pow(2)[ |
| | :, None]) < 0.3).float().to(self.device) |
| | x = kornia.morphology.dilation(x, blob_mask) |
| | if self.binarize: |
| | th = np.random.uniform(0.05, 0.7) |
| | x = (x > th).float() |
| |
|
| | if self.resize is not None: |
| | x = nn.functional.interpolate(x, size=orig_size, mode='bilinear') |
| | return x.cpu() |
| |
|
| |
|
| | class SRFilter(BaseFilter): |
| | def __init__(self, device, scale_factor=1 / 4): |
| | super().__init__(device) |
| | self.scale_factor = scale_factor |
| |
|
| | def num_channels(self): |
| | return 3 |
| |
|
| | def __call__(self, x): |
| | x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest") |
| | return torch.nn.functional.interpolate(x, scale_factor=1 / self.scale_factor, mode="nearest") |
| |
|
| |
|
| | class SREffnetFilter(BaseFilter): |
| | def __init__(self, device, scale_factor=1/2): |
| | super().__init__(device) |
| | self.scale_factor = scale_factor |
| |
|
| | self.effnet_preprocess = torchvision.transforms.Compose([ |
| | torchvision.transforms.Normalize( |
| | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) |
| | ) |
| | ]) |
| |
|
| | self.effnet = EfficientNetEncoder().to(self.device) |
| | effnet_checkpoint = load_or_fail("models/effnet_encoder.safetensors") |
| | self.effnet.load_state_dict(effnet_checkpoint) |
| | self.effnet.eval().requires_grad_(False) |
| |
|
| | def num_channels(self): |
| | return 16 |
| |
|
| | def __call__(self, x): |
| | x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest") |
| | with torch.no_grad(): |
| | effnet_embedding = self.effnet(self.effnet_preprocess(x.to(self.device))).cpu() |
| | effnet_embedding = torch.nn.functional.interpolate(effnet_embedding, scale_factor=1/self.scale_factor, mode="nearest") |
| | upscaled_image = torch.nn.functional.interpolate(x, scale_factor=1/self.scale_factor, mode="nearest") |
| | return effnet_embedding, upscaled_image |
| |
|
| |
|
| | class InpaintFilter(BaseFilter): |
| | def __init__(self, device, thresold=[0.04, 0.4], p_outpaint=0.4): |
| | super().__init__(device) |
| | self.saliency_model = MicroResNet().eval().requires_grad_(False).to(device) |
| | self.saliency_model.load_state_dict(load_or_fail("modules/cnet_modules/inpainting/saliency_model.pt")) |
| | self.thresold = thresold |
| | self.p_outpaint = p_outpaint |
| |
|
| | def num_channels(self): |
| | return 4 |
| |
|
| | def __call__(self, x, mask=None, threshold=None, outpaint=None): |
| | x = x.to(self.device) |
| | resized_x = torchvision.transforms.functional.resize(x, 240, antialias=True) |
| | if threshold is None: |
| | threshold = np.random.uniform(self.thresold[0], self.thresold[1]) |
| | if mask is None: |
| | saliency_map = self.saliency_model(resized_x) > threshold |
| | if outpaint is None: |
| | if np.random.rand() < self.p_outpaint: |
| | saliency_map = ~saliency_map |
| | else: |
| | if outpaint: |
| | saliency_map = ~saliency_map |
| | interpolated_saliency_map = torch.nn.functional.interpolate(saliency_map.float(), size=x.shape[2:], mode="nearest") |
| | saliency_map = torchvision.transforms.functional.gaussian_blur(interpolated_saliency_map, 141) > 0.5 |
| | inpainted_images = torch.where(saliency_map, torch.ones_like(x), x) |
| | mask = torch.nn.functional.interpolate(saliency_map.float(), size=inpainted_images.shape[2:], mode="nearest") |
| | else: |
| | mask = mask.to(self.device) |
| | inpainted_images = torch.where(mask, torch.ones_like(x), x) |
| | c_inpaint = torch.cat([inpainted_images, mask], dim=1) |
| | return c_inpaint.cpu() |
| |
|
| |
|
| | |
| | ''' |
| | class IdentityFilter(BaseFilter): |
| | def __init__(self, device, max_faces=4, p_drop=0.05, p_full=0.3): |
| | detector_path = 'modules/cnet_modules/face_id/models/buffalo_l/det_10g.onnx' |
| | recognizer_path = 'modules/cnet_modules/face_id/models/buffalo_l/w600k_r50.onnx' |
| | |
| | super().__init__(device) |
| | self.max_faces = max_faces |
| | self.p_drop = p_drop |
| | self.p_full = p_full |
| | |
| | self.detector = FaceDetector(detector_path, device=device) |
| | self.recognizer = ArcFaceRecognizer(recognizer_path, device=device) |
| | |
| | self.id_colors = torch.tensor([ |
| | [1.0, 0.0, 0.0], # RED |
| | [0.0, 1.0, 0.0], # GREEN |
| | [0.0, 0.0, 1.0], # BLUE |
| | [1.0, 0.0, 1.0], # PURPLE |
| | [0.0, 1.0, 1.0], # CYAN |
| | [1.0, 1.0, 0.0], # YELLOW |
| | [0.5, 0.0, 0.0], # DARK RED |
| | [0.0, 0.5, 0.0], # DARK GREEN |
| | [0.0, 0.0, 0.5], # DARK BLUE |
| | [0.5, 0.0, 0.5], # DARK PURPLE |
| | [0.0, 0.5, 0.5], # DARK CYAN |
| | [0.5, 0.5, 0.0], # DARK YELLOW |
| | ]) |
| | |
| | def num_channels(self): |
| | return 512 |
| | |
| | def get_faces(self, image): |
| | npimg = image.permute(1, 2, 0).mul(255).to(device="cpu", dtype=torch.uint8).cpu().numpy() |
| | bgr = cv2.cvtColor(npimg, cv2.COLOR_RGB2BGR) |
| | bboxes, kpss = self.detector.detect(bgr, max_num=self.max_faces) |
| | N = len(bboxes) |
| | ids = torch.zeros((N, 512), dtype=torch.float32) |
| | for i in range(N): |
| | face = Face(bbox=bboxes[i, :4], kps=kpss[i], det_score=bboxes[i, 4]) |
| | ids[i, :] = self.recognizer.get(bgr, face) |
| | tbboxes = torch.tensor(bboxes[:, :4], dtype=torch.int) |
| | |
| | ids = ids / torch.linalg.norm(ids, dim=1, keepdim=True) |
| | return tbboxes, ids # returns bounding boxes (N x 4) and ID vectors (N x 512) |
| | |
| | def __call__(self, x): |
| | visual_aid = x.clone().cpu() |
| | face_mtx = torch.zeros(x.size(0), 512, x.size(-2) // 32, x.size(-1) // 32) |
| | |
| | for i in range(x.size(0)): |
| | bounding_boxes, ids = self.get_faces(x[i]) |
| | for j in range(bounding_boxes.size(0)): |
| | if np.random.rand() > self.p_drop: |
| | sx, sy, ex, ey = (bounding_boxes[j] / 32).clamp(min=0).round().int().tolist() |
| | ex, ey = max(ex, sx + 1), max(ey, sy + 1) |
| | if bounding_boxes.size(0) == 1 and np.random.rand() < self.p_full: |
| | sx, sy, ex, ey = 0, 0, x.size(-1) // 32, x.size(-2) // 32 |
| | face_mtx[i, :, sy:ey, sx:ex] = ids[j:j + 1, :, None, None] |
| | visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] += self.id_colors[j % 13, :, |
| | None, None] |
| | visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] *= 0.5 |
| | |
| | return face_mtx.to(x.device), visual_aid.to(x.device) |
| | ''' |
| |
|