| import subprocess |
| import importlib |
| import sys |
| import logging |
| from transformers import BaseImageProcessorFast |
| import torch |
| import numpy as np |
| from rembg import remove, new_session |
| from functools import partial |
| from torchvision.utils import save_image |
| from PIL import Image |
| from kiui.op import recenter |
| import kiui |
|
|
|
|
| |
| |
|
|
|
|
| class LRMImageProcessor(BaseImageProcessorFast): |
| def __init__(self, source_size=512, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.source_size = source_size |
| self.session = None |
| self.rembg_remove = None |
|
|
| |
| def _initialize_session(self): |
| if self.session is None: |
| self.session = new_session("isnet-general-use") |
| self.rembg_remove = partial(remove, session=self.session) |
|
|
| def preprocess_image(self, image): |
| self._initialize_session() |
| image = np.array(image) |
| image = self.rembg_remove(image) |
| mask = self.rembg_remove(image, only_mask=True) |
| image = recenter(image, mask, border_ratio=0.20) |
| image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255.0 |
| if image.shape[1] == 4: |
| image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...]) |
| image = torch.nn.functional.interpolate(image, size=(self.source_size, self.source_size), mode='bicubic', align_corners=True) |
| image = torch.clamp(image, 0, 1) |
| return image |
|
|
| def get_normalized_camera_intrinsics(self, intrinsics: torch.Tensor): |
| fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1] |
| cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1] |
| width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1] |
| fx, fy = fx / width, fy / height |
| cx, cy = cx / width, cy / height |
| return fx, fy, cx, cy |
|
|
| def build_camera_principle(self, RT: torch.Tensor, intrinsics: torch.Tensor): |
| fx, fy, cx, cy = self.get_normalized_camera_intrinsics(intrinsics) |
| return torch.cat([ |
| RT.reshape(-1, 12), |
| fx.unsqueeze(-1), |
| fy.unsqueeze(-1), |
| cx.unsqueeze(-1), |
| cy.unsqueeze(-1), |
| ], dim=-1) |
|
|
| def _default_intrinsics(self): |
| fx = fy = 384 |
| cx = cy = 256 |
| w = h = 512 |
| intrinsics = torch.tensor([ |
| [fx, fy], |
| [cx, cy], |
| [w, h], |
| ], dtype=torch.float32) |
| return intrinsics |
|
|
| def _default_source_camera(self, batch_size: int = 1): |
| dist_to_center = 1.5 |
| canonical_camera_extrinsics = torch.tensor([[ |
| [0, 0, 1, 1], |
| [1, 0, 0, 0], |
| [0, 1, 0, 0], |
| ]], dtype=torch.float32) |
| canonical_camera_intrinsics = self._default_intrinsics().unsqueeze(0) |
| source_camera = self.build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics) |
| return source_camera.repeat(batch_size, 1) |
|
|
| def __call__(self, image, *args, **kwargs): |
| processed_image = self.preprocess_image(image) |
| source_camera = self._default_source_camera(batch_size=1) |
| return processed_image, source_camera |