| | from typing import Optional, Tuple, List |
| | import torch |
| | import torch.nn as nn |
| | import treetensor.torch as ttorch |
| |
|
| |
|
| | class PPOFModel(nn.Module): |
| | mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] |
| |
|
| | def __init__( |
| | self, |
| | obs_shape: Tuple[int], |
| | action_shape: int, |
| | encoder_hidden_size_list: List = [128, 128, 64], |
| | actor_head_hidden_size: int = 64, |
| | actor_head_layer_num: int = 1, |
| | critic_head_hidden_size: int = 64, |
| | critic_head_layer_num: int = 1, |
| | activation: Optional[nn.Module] = nn.ReLU(), |
| | ) -> None: |
| | super(PPOFModel, self).__init__() |
| | self.obs_shape, self.action_shape = obs_shape, action_shape |
| |
|
| | |
| | layers = [] |
| | input_size = obs_shape[0] |
| | kernel_size_list = [8, 4, 3] |
| | stride_list = [4, 2, 1] |
| | for i in range(len(encoder_hidden_size_list)): |
| | output_size = encoder_hidden_size_list[i] |
| | layers.append(nn.Conv2d(input_size, output_size, kernel_size_list[i], stride_list[i])) |
| | layers.append(activation) |
| | input_size = output_size |
| | layers.append(nn.Flatten()) |
| | self.encoder = nn.Sequential(*layers) |
| |
|
| | flatten_size = input_size = self.get_flatten_size() |
| | |
| | layers = [] |
| | for i in range(critic_head_layer_num): |
| | layers.append(nn.Linear(input_size, critic_head_hidden_size)) |
| | layers.append(activation) |
| | input_size = critic_head_hidden_size |
| | layers.append(nn.Linear(critic_head_hidden_size, 1)) |
| | self.critic = nn.Sequential(*layers) |
| | |
| | layers = [] |
| | input_size = flatten_size |
| | for i in range(actor_head_layer_num): |
| | layers.append(nn.Linear(input_size, actor_head_hidden_size)) |
| | layers.append(activation) |
| | input_size = actor_head_hidden_size |
| | self.actor = nn.Sequential(*layers) |
| | self.mu = nn.Linear(actor_head_hidden_size, action_shape) |
| | self.log_sigma = nn.Parameter(torch.zeros(1, action_shape)) |
| |
|
| | |
| | self.init_weights() |
| |
|
| | def init_weights(self) -> None: |
| | |
| | raise NotImplementedError |
| |
|
| | def get_flatten_size(self) -> int: |
| | test_data = torch.randn(1, *self.obs_shape) |
| | with torch.no_grad(): |
| | output = self.encoder(test_data) |
| | return output.shape[1] |
| |
|
| | def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor: |
| | assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) |
| | return getattr(self, mode)(inputs) |
| |
|
| | def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor: |
| | x = self.encoder(x) |
| | x = self.actor(x) |
| | mu = self.mu(x) |
| | log_sigma = self.log_sigma + torch.zeros_like(mu) |
| | sigma = torch.exp(log_sigma) |
| | return ttorch.as_tensor({'mu': mu, 'sigma': sigma}) |
| |
|
| | def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: |
| | x = self.encoder(x) |
| | value = self.critic(x) |
| | return value |
| |
|
| | def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: |
| | x = self.encoder(x) |
| | value = self.critic(x) |
| | x = self.actor(x) |
| | mu = self.mu(x) |
| | log_sigma = self.log_sigma + torch.zeros_like(mu) |
| | sigma = torch.exp(log_sigma) |
| | return ttorch.as_tensor({'logit': {'mu': mu, 'sigma': sigma}, 'value': value}) |
| |
|
| |
|
| | def test_ppof_model() -> None: |
| | model = PPOFModel((4, 84, 84), 5) |
| | print(model) |
| | data = torch.randn(3, 4, 84, 84) |
| | output = model(data, mode='compute_critic') |
| | assert output.shape == (3, 1) |
| | output = model(data, mode='compute_actor') |
| | assert output.mu.shape == (3, 5) |
| | assert output.sigma.shape == (3, 5) |
| | output = model(data, mode='compute_actor_critic') |
| | assert output.value.shape == (3, 1) |
| | assert output.logit.mu.shape == (3, 5) |
| | assert output.logit.sigma.shape == (3, 5) |
| | print('End...') |
| |
|
| |
|
| | if __name__ == "__main__": |
| | test_ppof_model() |
| |
|