| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import absolute_import |
| from __future__ import print_function |
| from __future__ import division |
|
|
| import os |
| import os.path as osp |
|
|
|
|
| import pickle |
|
|
| import numpy as np |
|
|
| from collections import namedtuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .lbs import ( |
| lbs, vertices2joints, blend_shapes) |
|
|
| from .vertex_ids import vertex_ids as VERTEX_IDS |
| from .utils import Struct, to_np, to_tensor |
| from .vertex_joint_selector import VertexJointSelector |
|
|
|
|
| ModelOutput = namedtuple('ModelOutput', |
| ['vertices','faces', 'joints', 'full_pose', 'betas', |
| 'global_orient', |
| 'body_pose', 'expression', |
| 'left_hand_pose', 'right_hand_pose', |
| 'jaw_pose', 'T', 'T_weighted', 'weights']) |
| ModelOutput.__new__.__defaults__ = (None,) * len(ModelOutput._fields) |
|
|
| class SMPL(nn.Module): |
|
|
| NUM_JOINTS = 23 |
| NUM_BODY_JOINTS = 23 |
| NUM_BETAS = 10 |
|
|
| def __init__(self, model_path, data_struct=None, |
| create_betas=True, |
| betas=None, |
| create_global_orient=True, |
| global_orient=None, |
| create_body_pose=True, |
| body_pose=None, |
| create_transl=True, |
| transl=None, |
| dtype=torch.float32, |
| batch_size=1, |
| joint_mapper=None, gender='neutral', |
| vertex_ids=None, |
| pose_blend=True, |
| **kwargs): |
| ''' SMPL model constructor |
| |
| Parameters |
| ---------- |
| model_path: str |
| The path to the folder or to the file where the model |
| parameters are stored |
| data_struct: Strct |
| A struct object. If given, then the parameters of the model are |
| read from the object. Otherwise, the model tries to read the |
| parameters from the given `model_path`. (default = None) |
| create_global_orient: bool, optional |
| Flag for creating a member variable for the global orientation |
| of the body. (default = True) |
| global_orient: torch.tensor, optional, Bx3 |
| The default value for the global orientation variable. |
| (default = None) |
| create_body_pose: bool, optional |
| Flag for creating a member variable for the pose of the body. |
| (default = True) |
| body_pose: torch.tensor, optional, Bx(Body Joints * 3) |
| The default value for the body pose variable. |
| (default = None) |
| create_betas: bool, optional |
| Flag for creating a member variable for the shape space |
| (default = True). |
| betas: torch.tensor, optional, Bx10 |
| The default value for the shape member variable. |
| (default = None) |
| create_transl: bool, optional |
| Flag for creating a member variable for the translation |
| of the body. (default = True) |
| transl: torch.tensor, optional, Bx3 |
| The default value for the transl variable. |
| (default = None) |
| dtype: torch.dtype, optional |
| The data type for the created variables |
| batch_size: int, optional |
| The batch size used for creating the member variables |
| joint_mapper: object, optional |
| An object that re-maps the joints. Useful if one wants to |
| re-order the SMPL joints to some other convention (e.g. MSCOCO) |
| (default = None) |
| gender: str, optional |
| Which gender to load |
| vertex_ids: dict, optional |
| A dictionary containing the indices of the extra vertices that |
| will be selected |
| ''' |
|
|
| self.gender = gender |
| self.pose_blend = pose_blend |
|
|
| if data_struct is None: |
| if osp.isdir(model_path): |
| model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl') |
| smpl_path = os.path.join(model_path, model_fn) |
| else: |
| smpl_path = model_path |
| assert osp.exists(smpl_path), 'Path {} does not exist!'.format( |
| smpl_path) |
|
|
| with open(smpl_path, 'rb') as smpl_file: |
| data_struct = Struct(**pickle.load(smpl_file,encoding='latin1')) |
| super(SMPL, self).__init__() |
| self.batch_size = batch_size |
|
|
| if vertex_ids is None: |
| |
| |
| vertex_ids = VERTEX_IDS['smplh'] |
|
|
| self.dtype = dtype |
|
|
| self.joint_mapper = joint_mapper |
|
|
| self.vertex_joint_selector = VertexJointSelector( |
| vertex_ids=vertex_ids, **kwargs) |
|
|
| self.faces = data_struct.f |
| self.register_buffer('faces_tensor', |
| to_tensor(to_np(self.faces, dtype=np.int64), |
| dtype=torch.long)) |
|
|
| if create_betas: |
| if betas is None: |
| default_betas = torch.zeros([batch_size, self.NUM_BETAS], |
| dtype=dtype) |
| else: |
| if 'torch.Tensor' in str(type(betas)): |
| default_betas = betas.clone().detach() |
| else: |
| default_betas = torch.tensor(betas, |
| dtype=dtype) |
|
|
| self.register_parameter('betas', nn.Parameter(default_betas, |
| requires_grad=True)) |
|
|
| |
| |
| |
| if create_global_orient: |
| if global_orient is None: |
| default_global_orient = torch.zeros([batch_size, 3], |
| dtype=dtype) |
| else: |
| if 'torch.Tensor' in str(type(global_orient)): |
| default_global_orient = global_orient.clone().detach() |
| else: |
| default_global_orient = torch.tensor(global_orient, |
| dtype=dtype) |
|
|
| global_orient = nn.Parameter(default_global_orient, |
| requires_grad=True) |
| self.register_parameter('global_orient', global_orient) |
|
|
| if create_body_pose: |
| if body_pose is None: |
| default_body_pose = torch.zeros( |
| [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype) |
| else: |
| if 'torch.Tensor' in str(type(body_pose)): |
| default_body_pose = body_pose.clone().detach() |
| else: |
| default_body_pose = torch.tensor(body_pose, |
| dtype=dtype) |
| self.register_parameter( |
| 'body_pose', |
| nn.Parameter(default_body_pose, requires_grad=True)) |
|
|
| if create_transl: |
| if transl is None: |
| default_transl = torch.zeros([batch_size, 3], |
| dtype=dtype, |
| requires_grad=True) |
| else: |
| default_transl = torch.tensor(transl, dtype=dtype) |
| self.register_parameter( |
| 'transl', |
| nn.Parameter(default_transl, requires_grad=True)) |
|
|
| |
| self.register_buffer('v_template', |
| to_tensor(to_np(data_struct.v_template), |
| dtype=dtype)) |
|
|
| |
| shapedirs = data_struct.shapedirs[:, :, :self.NUM_BETAS] |
| |
| self.register_buffer( |
| 'shapedirs', |
| to_tensor(to_np(shapedirs), dtype=dtype)) |
|
|
|
|
| j_regressor = to_tensor(to_np( |
| data_struct.J_regressor), dtype=dtype) |
| self.register_buffer('J_regressor', j_regressor) |
|
|
| |
| |
| |
| |
|
|
| |
| num_pose_basis = data_struct.posedirs.shape[-1] |
| |
| posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T |
| self.register_buffer('posedirs', |
| to_tensor(to_np(posedirs), dtype=dtype)) |
|
|
| |
| parents = to_tensor(to_np(data_struct.kintree_table[0])).long() |
| parents[0] = -1 |
| self.register_buffer('parents', parents) |
|
|
| self.bone_parents = to_np(data_struct.kintree_table[0]) |
|
|
| self.register_buffer('lbs_weights', |
| to_tensor(to_np(data_struct.weights), dtype=dtype)) |
|
|
| def create_mean_pose(self, data_struct): |
| pass |
|
|
| @torch.no_grad() |
| def reset_params(self, **params_dict): |
| for param_name, param in self.named_parameters(): |
| if param_name in params_dict: |
| param[:] = torch.tensor(params_dict[param_name]) |
| else: |
| param.fill_(0) |
|
|
| def get_T_hip(self, betas=None): |
| v_shaped = self.v_template + blend_shapes(betas, self.shapedirs) |
| J = vertices2joints(self.J_regressor, v_shaped) |
| T_hip = J[0,0] |
| return T_hip |
|
|
| def get_num_verts(self): |
| return self.v_template.shape[0] |
|
|
| def get_num_faces(self): |
| return self.faces.shape[0] |
|
|
| def extra_repr(self): |
| return 'Number of betas: {}'.format(self.NUM_BETAS) |
|
|
| def forward(self, betas=None, body_pose=None, global_orient=None, |
| transl=None, return_verts=True, return_full_pose=False,displacement=None,v_template=None, |
| **kwargs): |
| ''' Forward pass for the SMPL model |
| |
| Parameters |
| ---------- |
| global_orient: torch.tensor, optional, shape Bx3 |
| If given, ignore the member variable and use it as the global |
| rotation of the body. Useful if someone wishes to predicts this |
| with an external model. (default=None) |
| betas: torch.tensor, optional, shape Bx10 |
| If given, ignore the member variable `betas` and use it |
| instead. For example, it can used if shape parameters |
| `betas` are predicted from some external model. |
| (default=None) |
| body_pose: torch.tensor, optional, shape Bx(J*3) |
| If given, ignore the member variable `body_pose` and use it |
| instead. For example, it can used if someone predicts the |
| pose of the body joints are predicted from some external model. |
| It should be a tensor that contains joint rotations in |
| axis-angle format. (default=None) |
| transl: torch.tensor, optional, shape Bx3 |
| If given, ignore the member variable `transl` and use it |
| instead. For example, it can used if the translation |
| `transl` is predicted from some external model. |
| (default=None) |
| return_verts: bool, optional |
| Return the vertices. (default=True) |
| return_full_pose: bool, optional |
| Returns the full axis-angle pose vector (default=False) |
| |
| Returns |
| ------- |
| ''' |
| |
| |
| global_orient = (global_orient if global_orient is not None else |
| self.global_orient) |
| body_pose = body_pose if body_pose is not None else self.body_pose |
| betas = betas if betas is not None else self.betas |
|
|
| apply_trans = transl is not None or hasattr(self, 'transl') |
| if transl is None and hasattr(self, 'transl'): |
| transl = self.transl |
|
|
| full_pose = torch.cat([global_orient, body_pose], dim=1) |
|
|
| |
| |
| |
|
|
| if v_template is None: |
| v_template = self.v_template |
|
|
| if displacement is not None: |
| vertices, joints_smpl, T_weighted, W, T = lbs(betas, full_pose, v_template+displacement, |
| self.shapedirs, self.posedirs, |
| self.J_regressor, self.parents, |
| self.lbs_weights, dtype=self.dtype,pose_blend=self.pose_blend) |
| else: |
| vertices, joints_smpl,T_weighted, W, T = lbs(betas, full_pose, v_template, |
| self.shapedirs, self.posedirs, |
| self.J_regressor, self.parents, |
| self.lbs_weights, dtype=self.dtype,pose_blend=self.pose_blend) |
|
|
| |
| joints = self.vertex_joint_selector(vertices, joints_smpl) |
| |
| |
| |
| if self.joint_mapper is not None: |
| joints = self.joint_mapper(joints) |
|
|
| if apply_trans: |
| joints_smpl = joints_smpl + transl.unsqueeze(dim=1) |
| joints = joints + transl.unsqueeze(dim=1) |
| vertices = vertices + transl.unsqueeze(dim=1) |
|
|
| output = ModelOutput(vertices=vertices if return_verts else None, |
| faces=self.faces, |
| global_orient=global_orient, |
| body_pose=body_pose, |
| joints=joints_smpl, |
| betas=self.betas, |
| full_pose=full_pose if return_full_pose else None, |
| T=T, T_weighted=T_weighted, weights=W) |
| return output |