| import torch |
| import torch.nn as nn |
|
|
| class BodyModelParams(nn.Module): |
| def __init__(self, num_frames, model_type='smpl'): |
| super(BodyModelParams, self).__init__() |
| self.num_frames = num_frames |
| self.model_type = model_type |
| self.params_dim = { |
| 'betas': 10, |
| 'global_orient': 3, |
| 'transl': 3, |
| } |
| if model_type == 'smpl': |
| self.params_dim.update({ |
| 'body_pose': 69, |
| }) |
| else: |
| assert ValueError(f'Unknown model type {model_type}, exiting!') |
| |
| self.param_names = self.params_dim.keys() |
| |
| for param_name in self.param_names: |
| if param_name == 'betas': |
| param = nn.Embedding(1, self.params_dim[param_name]) |
| param.weight.data.fill_(0) |
| param.weight.requires_grad = False |
| setattr(self, param_name, param) |
| else: |
| param = nn.Embedding(num_frames, self.params_dim[param_name]) |
| param.weight.data.fill_(0) |
| param.weight.requires_grad = False |
| setattr(self, param_name, param) |
| |
| def init_parameters(self, param_name, data, requires_grad=False): |
| getattr(self, param_name).weight.data = data[..., :self.params_dim[param_name]] |
| getattr(self, param_name).weight.requires_grad = requires_grad |
|
|
| def set_requires_grad(self, param_name, requires_grad=True): |
| getattr(self, param_name).weight.requires_grad = requires_grad |
|
|
| def forward(self, frame_ids): |
| params = {} |
| for param_name in self.param_names: |
| if param_name == 'betas': |
| params[param_name] = getattr(self, param_name)(torch.zeros_like(frame_ids)) |
| else: |
| params[param_name] = getattr(self, param_name)(frame_ids) |
| return params |