| import torch |
| import hydra |
| import numpy as np |
| from ..smpl.body_models import SMPL |
|
|
| class SMPLServer(torch.nn.Module): |
|
|
| def __init__(self, gender='neutral', betas=None, v_template=None): |
| super().__init__() |
|
|
|
|
| self.smpl = SMPL(model_path=hydra.utils.to_absolute_path('lib/smpl/smpl_model'), |
| gender=gender, |
| batch_size=1, |
| use_hands=False, |
| use_feet_keypoints=False, |
| dtype=torch.float32).cuda() |
|
|
| self.bone_parents = self.smpl.bone_parents.astype(int) |
| self.bone_parents[0] = -1 |
| self.bone_ids = [] |
| self.faces = self.smpl.faces |
| for i in range(24): self.bone_ids.append([self.bone_parents[i], i]) |
|
|
| if v_template is not None: |
| self.v_template = torch.tensor(v_template).float().cuda() |
| else: |
| self.v_template = None |
|
|
| if betas is not None: |
| self.betas = torch.tensor(betas).float().cuda() |
| else: |
| self.betas = None |
|
|
| |
| param_canonical = torch.zeros((1, 86),dtype=torch.float32).cuda() |
| param_canonical[0, 0] = 1 |
| param_canonical[0, 9] = np.pi / 6 |
| param_canonical[0, 12] = -np.pi / 6 |
| if self.betas is not None and self.v_template is None: |
| param_canonical[0,-10:] = self.betas |
| self.param_canonical = param_canonical |
|
|
| output = self.forward(*torch.split(self.param_canonical, [1, 3, 72, 10], dim=1), absolute=True) |
| self.verts_c = output['smpl_verts'] |
| self.joints_c = output['smpl_jnts'] |
| self.tfs_c_inv = output['smpl_tfs'].squeeze(0).inverse() |
|
|
|
|
| def forward(self, scale, transl, thetas, betas, absolute=False): |
| """return SMPL output from params |
| Args: |
| scale : scale factor. shape: [B, 1] |
| transl: translation. shape: [B, 3] |
| thetas: pose. shape: [B, 72] |
| betas: shape. shape: [B, 10] |
| absolute (bool): if true return smpl_tfs wrt thetas=0. else wrt thetas=thetas_canonical. |
| Returns: |
| smpl_verts: vertices. shape: [B, 6893. 3] |
| smpl_tfs: bone transformations. shape: [B, 24, 4, 4] |
| smpl_jnts: joint positions. shape: [B, 25, 3] |
| """ |
|
|
| output = {} |
|
|
| |
| if self.v_template is not None: |
| betas = torch.zeros_like(betas) |
|
|
| |
| smpl_output = self.smpl.forward(betas=betas, |
| transl=torch.zeros_like(transl), |
| body_pose=thetas[:, 3:], |
| global_orient=thetas[:, :3], |
| return_verts=True, |
| return_full_pose=True, |
| v_template=self.v_template) |
|
|
| verts = smpl_output.vertices.clone() |
| output['smpl_verts'] = verts * scale.unsqueeze(1) + transl.unsqueeze(1) * scale.unsqueeze(1) |
|
|
| joints = smpl_output.joints.clone() |
| output['smpl_jnts'] = joints * scale.unsqueeze(1) + transl.unsqueeze(1) * scale.unsqueeze(1) |
|
|
| tf_mats = smpl_output.T.clone() |
| tf_mats[:, :, :3, :] = tf_mats[:, :, :3, :] * scale.unsqueeze(1).unsqueeze(1) |
| tf_mats[:, :, :3, 3] = tf_mats[:, :, :3, 3] + transl.unsqueeze(1) * scale.unsqueeze(1) |
|
|
| if not absolute: |
| tf_mats = torch.einsum('bnij,njk->bnik', tf_mats, self.tfs_c_inv) |
| |
| output['smpl_tfs'] = tf_mats |
| output['smpl_weights'] = smpl_output.weights |
| return output |