| | """ |
| | Author: Minh Pham-Dinh |
| | Created: Jan 26th, 2024 |
| | Last Modified: Feb 5th, 2024 |
| | Email: mhpham26@colby.edu |
| | |
| | Description: |
| | File containing the ReplayBuffer that will be used in Dreamer. |
| | |
| | The implementation is based on: |
| | Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019. |
| | [Online]. Available: https://arxiv.org/abs/1912.01603 |
| | """ |
| |
|
| | import numpy as np |
| | from gymnasium import Env |
| | import torch |
| | from addict import Dict |
| |
|
| | class ReplayBuffer: |
| | def __init__(self, capacity, obs_size, action_size): |
| | |
| | |
| | self.obs_size = obs_size |
| | self.action_size = action_size |
| |
|
| | |
| | state_type = np.uint8 if len(self.obs_size) < 3 else np.float32 |
| | |
| | self.observation = np.zeros((capacity, ) + self.obs_size, dtype=state_type) |
| | |
| | self.actions = np.zeros((capacity, ) + self.action_size, dtype=np.float32) |
| | self.rewards = np.zeros((capacity, 1), dtype=np.float32) |
| | self.dones = np.zeros((capacity, 1), dtype=np.float32) |
| |
|
| | self.pointer = 0 |
| | self.full = False |
| | |
| | print(f''' |
| | -----------initialized memory---------- |
| | |
| | obs_buffer_shape: {self.observation.shape} |
| | actions_buffer_shape: {self.actions.shape} |
| | rewards_buffer_shape: {self.rewards.shape} |
| | dones_buffer_shape: {self.dones.shape} |
| | |
| | ---------------------------------------- |
| | ''') |
| |
|
| | def add(self, obs, action, reward, done): |
| | """Add method for buffer |
| | |
| | Args: |
| | obs (np.array): current observation |
| | action (np.array): action taken |
| | reward (float): reward received after action |
| | next_obs (np.array): next observation |
| | done (bool): boolean value of termination or truncation |
| | """ |
| | self.observation[self.pointer] = obs |
| | self.actions[self.pointer] = action |
| | self.rewards[self.pointer] = reward |
| | self.dones[self.pointer] = done |
| | self.pointer = (self.pointer + 1) % self.observation.shape[0] |
| | if self.pointer == 0: |
| | self.full = True |
| |
|
| | def sample(self, batch_size, seq_len, device): |
| | """ |
| | Samples batches of experiences of fixed sequence length from the replay buffer, |
| | taking into account the circular nature of the buffer to avoid crossing the |
| | "end" of the buffer when it is full. |
| | |
| | This method ensures that sampled sequences are continuous and do not wrap around |
| | the end of the buffer, maintaining the temporal integrity of experiences. This is |
| | particularly important when the buffer is full, and the pointer marks the boundary |
| | between the newest and oldest data in the buffer. |
| | |
| | Args: |
| | batch_size (int): The number of sequences to sample. |
| | seq_len (int): The length of each sequence to sample. |
| | device (torch.device): The device on which the sampled data will be loaded. |
| | |
| | Raises: |
| | Exception: If there is not enough data in the buffer to sample a full sequence. |
| | |
| | Returns: |
| | Dict: A dictionary containing the sampled sequences of observations, actions, |
| | rewards, and dones. Each item in the dictionary is a tensor of shape |
| | (batch_size, seq_len, feature_dimension), except for 'dones' which is of shape |
| | (batch_size, seq_len, 1). |
| | |
| | Notes: |
| | - The method handles different scenarios based on the buffer's state (full or not) |
| | and the pointer's position to ensure valid sequence sampling without wrapping. |
| | - When the buffer is not full, sequences can start from index 0 up to the |
| | index where `seq_len` sequences can fit without surpassing the current pointer. |
| | - When the buffer is full, the method ensures sequences do not start in a way |
| | that would cause them to wrap around past the pointer, effectively crossing |
| | the boundary between the newest and oldest data. |
| | - This approach guarantees the sampled sequences respect the temporal order |
| | and continuity necessary for algorithms that rely on sequences of experiences. |
| | """ |
| | |
| | |
| | if self.pointer < seq_len and not self.full: |
| | raise Exception('not enough data to sample') |
| |
|
| | |
| | if self.full: |
| | if self.pointer - seq_len < 0: |
| | valid_range = np.arange(self.pointer, self.observation.shape[0] - (self.pointer - seq_len) + 1) |
| | else: |
| | range_1 = np.arange(0, self.pointer - seq_len + 1) |
| | range_2 = np.arange(self.pointer, self.observation.shape[0]) |
| | valid_range = np.concatenate((range_1, range_2), -1) |
| | else: |
| | valid_range = np.arange(0, self.pointer-seq_len+1) |
| |
|
| | start_index = np.random.choice(valid_range, (batch_size, 1)) |
| | |
| | seq_len = np.arange(seq_len) |
| | sample_idcs = (start_index + seq_len) % self.observation.shape[0] |
| | |
| | batch = Dict() |
| | |
| | batch.obs = torch.from_numpy(self.observation[sample_idcs]).to(device) |
| | batch.actions = torch.from_numpy(self.actions[sample_idcs]).to(device) |
| | batch.rewards = torch.from_numpy(self.rewards[sample_idcs]).to(device) |
| | batch.dones = torch.from_numpy(self.dones[sample_idcs]).to(device) |
| | |
| | return batch |
| | |
| | def clear(self, ): |
| | self.pointer = 0 |
| | self.full = False |
| |
|
| | def __len__(self, ): |
| | return self.pointer |