| import numpy as np |
| import torch |
|
|
|
|
| def get_agent_id_feature(agent_id, agent_num): |
| agent_id_feature = torch.zeros(agent_num) |
| agent_id_feature[agent_id] = 1 |
| return agent_id_feature |
|
|
|
|
| def get_movement_feature(): |
| |
| movement_feature = torch.randint(0, 2, (8, )) |
| return movement_feature |
|
|
|
|
| def get_own_feature(): |
| |
| return torch.randn(10) |
|
|
|
|
| def get_ally_visible_feature(): |
| |
| |
| if np.random.random() > 0.5: |
| ally_visible_feature = torch.randn(4) |
| else: |
| ally_visible_feature = torch.zeros(4) |
| return ally_visible_feature |
|
|
|
|
| def get_enemy_visible_feature(): |
| |
| |
| if np.random.random() > 0.8: |
| enemy_visible_feature = torch.randn(4) |
| else: |
| enemy_visible_feature = torch.zeros(4) |
| return enemy_visible_feature |
|
|
|
|
| def get_ind_global_state(agent_id, ally_agent_num, enemy_agent_num): |
| |
| raise NotImplementedError |
|
|
|
|
| def get_ep_global_state(agent_id, ally_agent_num, enemy_agent_num): |
| |
| |
| |
| ally_center_feature = torch.randn(8) |
| enemy_center_feature = torch.randn(8) |
| return torch.cat([ally_center_feature, enemy_center_feature]) |
|
|
|
|
| def get_as_global_state(agent_id, ally_agent_num, enemy_agent_num): |
| |
| raise NotImplementedError |
|
|
|
|
| def test_global_state(): |
| ally_agent_num = 3 |
| enemy_agent_num = 5 |
| |
| for agent_id in range(ally_agent_num): |
| ind_global_state = get_ind_global_state(agent_id, ally_agent_num, enemy_agent_num) |
| assert isinstance(ind_global_state, torch.Tensor) |
| |
| for agent_id in range(ally_agent_num): |
| ep_global_state = get_ep_global_state(agent_id, ally_agent_num, enemy_agent_num) |
| assert isinstance(ep_global_state, torch.Tensor) |
| |
| for agent_id in range(ally_agent_num): |
| as_global_state = get_as_global_state(agent_id, ally_agent_num, enemy_agent_num) |
| assert isinstance(as_global_state, torch.Tensor) |
|
|
|
|
| if __name__ == "__main__": |
| test_global_state() |
|
|