| import pytorch_lightning as pl |
| import torch.optim as optim |
| from lib.model.v2a import V2A |
| from lib.model.body_model_params import BodyModelParams |
| from lib.model.deformer import SMPLDeformer |
| import cv2 |
| import torch |
| from lib.model.loss import Loss |
| import hydra |
| import os |
| import numpy as np |
| from lib.utils.meshing import generate_mesh |
| from kaolin.ops.mesh import index_vertices_by_faces |
| import trimesh |
| from lib.model.deformer import skinning |
| from lib.utils import utils |
| class V2AModel(pl.LightningModule): |
| def __init__(self, opt) -> None: |
| super().__init__() |
|
|
| self.opt = opt |
| self.opt_0 = opt |
| num_training_frames = opt.dataset.metainfo.end_frame - opt.dataset.metainfo.start_frame |
| self.betas_path = os.path.join(hydra.utils.to_absolute_path('..'), 'data', opt.dataset.metainfo.data_dir, 'mean_shape.npy') |
| self.gender = opt.dataset.metainfo.gender |
| self.model = V2A(opt.model, self.betas_path, self.gender, num_training_frames) |
| self.start_frame = opt.dataset.metainfo.start_frame |
| self.end_frame = opt.dataset.metainfo.end_frame |
| self.training_modules = ["model"] |
|
|
| self.training_indices = list(range(self.start_frame, self.end_frame)) |
| self.body_model_params = BodyModelParams(num_training_frames, model_type='smpl') |
| self.load_body_model_params() |
| optim_params = self.body_model_params.param_names |
| for param_name in optim_params: |
| self.body_model_params.set_requires_grad(param_name, requires_grad=True) |
| self.training_modules += ['body_model_params'] |
| |
| self.loss = Loss(opt.model.loss) |
| |
| def load_body_model_params(self): |
| body_model_params = {param_name: [] for param_name in self.body_model_params.param_names} |
| data_root = os.path.join('../data', self.opt.dataset.metainfo.data_dir) |
| data_root = hydra.utils.to_absolute_path(data_root) |
|
|
| body_model_params['betas'] = torch.tensor(np.load(os.path.join(data_root, 'mean_shape.npy'))[None], dtype=torch.float32) |
| body_model_params['global_orient'] = torch.tensor(np.load(os.path.join(data_root, 'poses.npy'))[self.training_indices][:, :3], dtype=torch.float32) |
| body_model_params['body_pose'] = torch.tensor(np.load(os.path.join(data_root, 'poses.npy'))[self.training_indices] [:, 3:], dtype=torch.float32) |
| body_model_params['transl'] = torch.tensor(np.load(os.path.join(data_root, 'normalize_trans.npy'))[self.training_indices], dtype=torch.float32) |
|
|
| for param_name in body_model_params.keys(): |
| self.body_model_params.init_parameters(param_name, body_model_params[param_name], requires_grad=False) |
|
|
| def configure_optimizers(self): |
| params = [{'params': self.model.parameters(), 'lr':self.opt.model.learning_rate}] |
| params.append({'params': self.body_model_params.parameters(), 'lr':self.opt.model.learning_rate*0.1}) |
| self.optimizer = optim.Adam(params, lr=self.opt.model.learning_rate, eps=1e-8) |
| self.scheduler = optim.lr_scheduler.MultiStepLR( |
| self.optimizer, milestones=self.opt.model.sched_milestones, gamma=self.opt.model.sched_factor) |
| return [self.optimizer], [self.scheduler] |
|
|
| def training_step(self, batch): |
| inputs, targets = batch |
|
|
| batch_idx = inputs["idx"] |
|
|
| body_model_params = self.body_model_params(batch_idx) |
| inputs['smpl_pose'] = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1) |
| inputs['smpl_shape'] = body_model_params['betas'] |
| inputs['smpl_trans'] = body_model_params['transl'] |
|
|
| inputs['current_epoch'] = self.current_epoch |
| model_outputs = self.model(inputs) |
|
|
| loss_output = self.loss(model_outputs, targets) |
| for k, v in loss_output.items(): |
| if k in ["loss"]: |
| self.log(k, v.item(), prog_bar=True, on_step=True) |
| else: |
| self.log(k, v.item(), prog_bar=True, on_step=True) |
| return loss_output["loss"] |
|
|
| def training_epoch_end(self, outputs) -> None: |
| |
| if self.current_epoch != 0 and self.current_epoch % 20 == 0: |
| cond = {'smpl': torch.zeros(1, 69).float().cuda()} |
| mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=2) |
| self.model.mesh_v_cano = torch.tensor(mesh_canonical.vertices[None], device = self.model.smpl_v_cano.device).float() |
| self.model.mesh_f_cano = torch.tensor(mesh_canonical.faces.astype(np.int64), device=self.model.smpl_v_cano.device) |
| self.model.mesh_face_vertices = index_vertices_by_faces(self.model.mesh_v_cano, self.model.mesh_f_cano) |
|
|
| |
| |
| |
| |
| |
|
|
| return super().training_epoch_end(outputs) |
|
|
| def query_oc(self, x, cond): |
| |
| x = x.reshape(-1, 3) |
| mnfld_pred = self.model.implicit_network(x, cond)[:,:,0].reshape(-1,1) |
| return {'sdf':mnfld_pred} |
|
|
| def query_wc(self, x): |
| |
| x = x.reshape(-1, 3) |
| w = self.model.deformer.query_weights(x) |
| |
| return w |
|
|
| def query_od(self, x, cond, smpl_tfs, smpl_verts): |
| |
| x = x.reshape(-1, 3) |
| x_c, _ = self.model.deformer.forward(x, smpl_tfs, return_weights=False, inverse=True, smpl_verts=smpl_verts) |
| output = self.model.implicit_network(x_c, cond)[0] |
| sdf = output[:, 0:1] |
| |
| return {'sdf': sdf} |
|
|
| def get_deformed_mesh_fast_mode(self, verts, smpl_tfs): |
| verts = torch.tensor(verts).cuda().float() |
| weights = self.model.deformer.query_weights(verts) |
| verts_deformed = skinning(verts.unsqueeze(0), weights, smpl_tfs).data.cpu().numpy()[0] |
| return verts_deformed |
|
|
| def validation_step(self, batch, *args, **kwargs): |
|
|
| output = {} |
| inputs, targets = batch |
| inputs['current_epoch'] = self.current_epoch |
| self.model.eval() |
|
|
| body_model_params = self.body_model_params(inputs['image_id']) |
| inputs['smpl_pose'] = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1) |
| inputs['smpl_shape'] = body_model_params['betas'] |
| inputs['smpl_trans'] = body_model_params['transl'] |
|
|
| cond = {'smpl': inputs["smpl_pose"][:, 3:]/np.pi} |
| mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=3) |
| |
| mesh_canonical = trimesh.Trimesh(mesh_canonical.vertices, mesh_canonical.faces) |
| |
| output.update({ |
| 'canonical_mesh':mesh_canonical |
| }) |
|
|
| split = utils.split_input(inputs, targets["total_pixels"][0], n_pixels=min(targets['pixel_per_batch'], targets["img_size"][0] * targets["img_size"][1])) |
|
|
| res = [] |
| for s in split: |
|
|
| out = self.model(s) |
|
|
| for k, v in out.items(): |
| try: |
| out[k] = v.detach() |
| except: |
| out[k] = v |
|
|
| res.append({ |
| 'rgb_values': out['rgb_values'].detach(), |
| 'normal_values': out['normal_values'].detach(), |
| 'fg_rgb_values': out['fg_rgb_values'].detach(), |
| }) |
| batch_size = targets['rgb'].shape[0] |
|
|
| model_outputs = utils.merge_output(res, targets["total_pixels"][0], batch_size) |
|
|
| output.update({ |
| "rgb_values": model_outputs["rgb_values"].detach().clone(), |
| "normal_values": model_outputs["normal_values"].detach().clone(), |
| "fg_rgb_values": model_outputs["fg_rgb_values"].detach().clone(), |
| **targets, |
| }) |
| |
| return output |
|
|
| def validation_step_end(self, batch_parts): |
| return batch_parts |
|
|
| def validation_epoch_end(self, outputs) -> None: |
| img_size = outputs[0]["img_size"] |
|
|
| rgb_pred = torch.cat([output["rgb_values"] for output in outputs], dim=0) |
| rgb_pred = rgb_pred.reshape(*img_size, -1) |
|
|
| fg_rgb_pred = torch.cat([output["fg_rgb_values"] for output in outputs], dim=0) |
| fg_rgb_pred = fg_rgb_pred.reshape(*img_size, -1) |
|
|
| normal_pred = torch.cat([output["normal_values"] for output in outputs], dim=0) |
| normal_pred = (normal_pred.reshape(*img_size, -1) + 1) / 2 |
|
|
| rgb_gt = torch.cat([output["rgb"] for output in outputs], dim=1).squeeze(0) |
| rgb_gt = rgb_gt.reshape(*img_size, -1) |
| if 'normal' in outputs[0].keys(): |
| normal_gt = torch.cat([output["normal"] for output in outputs], dim=1).squeeze(0) |
| normal_gt = (normal_gt.reshape(*img_size, -1) + 1) / 2 |
| normal = torch.cat([normal_gt, normal_pred], dim=0).cpu().numpy() |
| else: |
| normal = torch.cat([normal_pred], dim=0).cpu().numpy() |
|
|
| rgb = torch.cat([rgb_gt, rgb_pred], dim=0).cpu().numpy() |
| rgb = (rgb * 255).astype(np.uint8) |
|
|
| fg_rgb = torch.cat([fg_rgb_pred], dim=0).cpu().numpy() |
| fg_rgb = (fg_rgb * 255).astype(np.uint8) |
|
|
| normal = (normal * 255).astype(np.uint8) |
|
|
| os.makedirs("rendering", exist_ok=True) |
| os.makedirs("normal", exist_ok=True) |
| os.makedirs('fg_rendering', exist_ok=True) |
|
|
| canonical_mesh = outputs[0]['canonical_mesh'] |
| canonical_mesh.export(f"rendering/{self.current_epoch}.ply") |
|
|
| cv2.imwrite(f"rendering/{self.current_epoch}.png", rgb[:, :, ::-1]) |
| cv2.imwrite(f"normal/{self.current_epoch}.png", normal[:, :, ::-1]) |
| cv2.imwrite(f"fg_rendering/{self.current_epoch}.png", fg_rgb[:, :, ::-1]) |
| |
| def test_step(self, batch, *args, **kwargs): |
| inputs, targets, pixel_per_batch, total_pixels, idx = batch |
| num_splits = (total_pixels + pixel_per_batch - |
| 1) // pixel_per_batch |
| results = [] |
|
|
| scale, smpl_trans, smpl_pose, smpl_shape = torch.split(inputs["smpl_params"], [1, 3, 72, 10], dim=1) |
|
|
| body_model_params = self.body_model_params(inputs['idx']) |
| smpl_shape = body_model_params['betas'] if body_model_params['betas'].dim() == 2 else body_model_params['betas'].unsqueeze(0) |
| smpl_trans = body_model_params['transl'] |
| smpl_pose = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1) |
|
|
| smpl_outputs = self.model.smpl_server(scale, smpl_trans, smpl_pose, smpl_shape) |
| smpl_tfs = smpl_outputs['smpl_tfs'] |
| cond = {'smpl': smpl_pose[:, 3:]/np.pi} |
|
|
| mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=4) |
| self.model.deformer = SMPLDeformer(betas=np.load(self.betas_path), gender=self.gender, K=7) |
| verts_deformed = self.get_deformed_mesh_fast_mode(mesh_canonical.vertices, smpl_tfs) |
| mesh_deformed = trimesh.Trimesh(vertices=verts_deformed, faces=mesh_canonical.faces, process=False) |
|
|
| os.makedirs("test_mask", exist_ok=True) |
| os.makedirs("test_rendering", exist_ok=True) |
| os.makedirs("test_fg_rendering", exist_ok=True) |
| os.makedirs("test_normal", exist_ok=True) |
| os.makedirs("test_mesh", exist_ok=True) |
| |
| mesh_canonical.export(f"test_mesh/{int(idx.cpu().numpy()):04d}_canonical.ply") |
| mesh_deformed.export(f"test_mesh/{int(idx.cpu().numpy()):04d}_deformed.ply") |
| self.model.deformer = SMPLDeformer(betas=np.load(self.betas_path), gender=self.gender) |
| for i in range(num_splits): |
| indices = list(range(i * pixel_per_batch, |
| min((i + 1) * pixel_per_batch, total_pixels))) |
| batch_inputs = {"uv": inputs["uv"][:, indices], |
| "intrinsics": inputs['intrinsics'], |
| "pose": inputs['pose'], |
| "smpl_params": inputs["smpl_params"], |
| "smpl_pose": inputs["smpl_params"][:, 4:76], |
| "smpl_shape": inputs["smpl_params"][:, 76:], |
| "smpl_trans": inputs["smpl_params"][:, 1:4], |
| "idx": inputs["idx"] if 'idx' in inputs.keys() else None} |
|
|
| body_model_params = self.body_model_params(inputs['idx']) |
|
|
| batch_inputs.update({'smpl_pose': torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1)}) |
| batch_inputs.update({'smpl_shape': body_model_params['betas']}) |
| batch_inputs.update({'smpl_trans': body_model_params['transl']}) |
|
|
| batch_targets = {"rgb": targets["rgb"][:, indices].detach().clone() if 'rgb' in targets.keys() else None, |
| "img_size": targets["img_size"]} |
|
|
| with torch.no_grad(): |
| model_outputs = self.model(batch_inputs) |
| results.append({"rgb_values":model_outputs["rgb_values"].detach().clone(), |
| "fg_rgb_values":model_outputs["fg_rgb_values"].detach().clone(), |
| "normal_values": model_outputs["normal_values"].detach().clone(), |
| "acc_map": model_outputs["acc_map"].detach().clone(), |
| **batch_targets}) |
|
|
| img_size = results[0]["img_size"] |
| rgb_pred = torch.cat([result["rgb_values"] for result in results], dim=0) |
| rgb_pred = rgb_pred.reshape(*img_size, -1) |
|
|
| fg_rgb_pred = torch.cat([result["fg_rgb_values"] for result in results], dim=0) |
| fg_rgb_pred = fg_rgb_pred.reshape(*img_size, -1) |
|
|
| normal_pred = torch.cat([result["normal_values"] for result in results], dim=0) |
| normal_pred = (normal_pred.reshape(*img_size, -1) + 1) / 2 |
|
|
| pred_mask = torch.cat([result["acc_map"] for result in results], dim=0) |
| pred_mask = pred_mask.reshape(*img_size, -1) |
|
|
| if results[0]['rgb'] is not None: |
| rgb_gt = torch.cat([result["rgb"] for result in results], dim=1).squeeze(0) |
| rgb_gt = rgb_gt.reshape(*img_size, -1) |
| rgb = torch.cat([rgb_gt, rgb_pred], dim=0).cpu().numpy() |
| else: |
| rgb = torch.cat([rgb_pred], dim=0).cpu().numpy() |
| if 'normal' in results[0].keys(): |
| normal_gt = torch.cat([result["normal"] for result in results], dim=1).squeeze(0) |
| normal_gt = (normal_gt.reshape(*img_size, -1) + 1) / 2 |
| normal = torch.cat([normal_gt, normal_pred], dim=0).cpu().numpy() |
| else: |
| normal = torch.cat([normal_pred], dim=0).cpu().numpy() |
| |
| rgb = (rgb * 255).astype(np.uint8) |
|
|
| fg_rgb = torch.cat([fg_rgb_pred], dim=0).cpu().numpy() |
| fg_rgb = (fg_rgb * 255).astype(np.uint8) |
|
|
| normal = (normal * 255).astype(np.uint8) |
|
|
| cv2.imwrite(f"test_mask/{int(idx.cpu().numpy()):04d}.png", pred_mask.cpu().numpy() * 255) |
| cv2.imwrite(f"test_rendering/{int(idx.cpu().numpy()):04d}.png", rgb[:, :, ::-1]) |
| cv2.imwrite(f"test_normal/{int(idx.cpu().numpy()):04d}.png", normal[:, :, ::-1]) |
| cv2.imwrite(f"test_fg_rendering/{int(idx.cpu().numpy()):04d}.png", fg_rgb[:, :, ::-1]) |