| """ |
| Long Short Term Memory (LSTM) <link https://ieeexplore.ieee.org/abstract/document/6795963 link> is a kind of recurrent neural network that can capture long-short term information. |
| This document mainly includes: |
| - Pytorch implementation for LSTM. |
| - An example to test LSTM. |
| For beginners, you can refer to <link https://zhuanlan.zhihu.com/p/32085405 link> to learn the basics about how LSTM works. |
| """ |
| from typing import Optional, Union, Tuple, List, Dict |
| import math |
| import torch |
| import torch.nn as nn |
| from ding.torch_utils import build_normalization |
|
|
|
|
| class LSTM(nn.Module): |
| """ |
| **Overview:** |
| Implementation of LSTM cell with layer norm. |
| """ |
|
|
| def __init__( |
| self, |
| input_size: int, |
| hidden_size: int, |
| num_layers: int, |
| norm_type: Optional[str] = 'LN', |
| dropout: float = 0. |
| ) -> None: |
| |
| super(LSTM, self).__init__() |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.num_layers = num_layers |
| |
| norm_func = build_normalization(norm_type) |
| self.norm = nn.ModuleList([norm_func(hidden_size * 4) for _ in range(2 * num_layers)]) |
| |
| self.wx = nn.ParameterList() |
| self.wh = nn.ParameterList() |
| dims = [input_size] + [hidden_size] * num_layers |
| for l in range(num_layers): |
| self.wx.append(nn.Parameter(torch.zeros(dims[l], dims[l + 1] * 4))) |
| self.wh.append(nn.Parameter(torch.zeros(hidden_size, hidden_size * 4))) |
| self.bias = nn.Parameter(torch.zeros(num_layers, hidden_size * 4)) |
| |
| self.use_dropout = dropout > 0. |
| if self.use_dropout: |
| self.dropout = nn.Dropout(dropout) |
| self._init() |
|
|
| |
| def _before_forward(self, inputs: torch.Tensor, prev_state: Union[None, List[Dict]]) -> torch.Tensor: |
| seq_len, batch_size = inputs.shape[:2] |
| |
| if prev_state is None: |
| zeros = torch.zeros(self.num_layers, batch_size, self.hidden_size, dtype=inputs.dtype, device=inputs.device) |
| prev_state = (zeros, zeros) |
| |
| else: |
| assert len(prev_state) == batch_size |
| state = [[v for v in prev.values()] for prev in prev_state] |
| state = list(zip(*state)) |
| prev_state = [torch.cat(t, dim=1) for t in state] |
|
|
| return prev_state |
|
|
| def _init(self): |
| |
| gain = math.sqrt(1. / self.hidden_size) |
| for l in range(self.num_layers): |
| torch.nn.init.uniform_(self.wx[l], -gain, gain) |
| torch.nn.init.uniform_(self.wh[l], -gain, gain) |
| if self.bias is not None: |
| torch.nn.init.uniform_(self.bias[l], -gain, gain) |
|
|
| def forward( |
| self, |
| inputs: torch.Tensor, |
| prev_state: torch.Tensor, |
| ) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]: |
| |
| seq_len, batch_size = inputs.shape[:2] |
| prev_state = self._before_forward(inputs, prev_state) |
|
|
| H, C = prev_state |
| x = inputs |
| next_state = [] |
| for l in range(self.num_layers): |
| h, c = H[l], C[l] |
| new_x = [] |
| for s in range(seq_len): |
| |
| gate = self.norm[l * 2](torch.matmul(x[s], self.wx[l]) |
| ) + self.norm[l * 2 + 1](torch.matmul(h, self.wh[l])) |
| if self.bias is not None: |
| gate += self.bias[l] |
| gate = list(torch.chunk(gate, 4, dim=1)) |
| i, f, o, z = gate |
| |
| i = torch.sigmoid(i) |
| |
| f = torch.sigmoid(f) |
| |
| o = torch.sigmoid(o) |
| |
| z = torch.tanh(z) |
| |
| c = f * c + i * z |
| |
| h = o * torch.tanh(c) |
| new_x.append(h) |
| next_state.append((h, c)) |
| x = torch.stack(new_x, dim=0) |
| |
| if self.use_dropout and l != self.num_layers - 1: |
| x = self.dropout(x) |
| next_state = [torch.stack(t, dim=0) for t in zip(*next_state)] |
| |
| h, c = next_state |
| batch_size = h.shape[1] |
| |
| next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] |
| next_state = list(zip(*next_state)) |
| next_state = [{k: v for k, v in zip(['h', 'c'], item)} for item in next_state] |
| return x, next_state |
|
|
|
|
| def pack_data(data: List[torch.Tensor], traj_len: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Overview: |
| You need to pack variable-length data to regular tensor, return tensor and corresponding mask. |
| If len(data_i) < traj_len, use `null_padding`, |
| else split the whole sequences info different trajectories. |
| Returns: |
| - tensor (:obj:`torch.Tensor`): dtype (torch.float32), shape (traj_len, B, N) |
| - mask (:obj:`torch.Tensor`): dtype (torch.float32), shape (traj_len, B) |
| """ |
| new_data = [] |
| mask = [] |
| for item in data: |
| D, N = item.shape |
| if D < traj_len: |
| null_padding = torch.zeros(traj_len - D, N) |
| new_item = torch.cat([item, null_padding]) |
| new_data.append(new_item) |
| item_mask = torch.ones(traj_len) |
| item_mask[D:].zero_() |
| mask.append(item_mask) |
| else: |
| for i in range(0, D, traj_len): |
| item_mask = torch.ones(traj_len) |
| new_item = item[i:i + traj_len] |
| if new_item.shape[0] < traj_len: |
| new_item = item[-traj_len:] |
| new_data.append(new_item) |
| mask.append(torch.ones(traj_len)) |
| new_data = torch.stack(new_data, dim=1) |
| mask = torch.stack(mask, dim=1) |
|
|
| return new_data, mask |
|
|
|
|
| def test_lstm(): |
| seq_len_list = [32, 49, 24, 78, 45] |
| traj_len = 32 |
| N = 10 |
| hidden_size = 32 |
| num_layers = 2 |
|
|
| variable_len_data = [torch.rand(s, N) for s in seq_len_list] |
| input_, mask = pack_data(variable_len_data, traj_len) |
| assert isinstance(input_, torch.Tensor), type(input_) |
| batch_size = input_.shape[1] |
| assert batch_size == 9, "packed data must have 9 trajectories" |
| lstm = LSTM(N, hidden_size=hidden_size, num_layers=num_layers, norm_type='LN', dropout=0.1) |
|
|
| prev_state = None |
| for s in range(traj_len): |
| input_step = input_[s:s + 1] |
| output, prev_state = lstm(input_step, prev_state) |
|
|
| assert output.shape == (1, batch_size, hidden_size) |
| assert len(prev_state) == batch_size |
| assert prev_state[0]['h'].shape == (num_layers, 1, hidden_size) |
| loss = (output * mask.unsqueeze(-1)).mean() |
| loss.backward() |
| for _, m in lstm.named_parameters(): |
| assert isinstance(m.grad, torch.Tensor) |
| print('finished') |
|
|
|
|
| if __name__ == '__main__': |
| test_lstm() |
|
|