AlgoX commited on
Commit
fd1ab91
·
1 Parent(s): 73cac5c

feat : add hawk model

Browse files
Files changed (1) hide show
  1. model/hawk.py +106 -0
model/hawk.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def get_model_device(model):
7
+ return next(iter(model.parameters())).device
8
+
9
+
10
+ class RGLRU(nn.Module):
11
+ def __init__(self, hidden_size: int, c: float = 8.0):
12
+ super().__init__()
13
+ self.hidden_size = hidden_size
14
+ self.c = c
15
+
16
+ self.input_gate = nn.Linear(hidden_size, hidden_size, bias=False)
17
+ self.recurrence_gate = nn.Linear(hidden_size, hidden_size, bias=False)
18
+ self.a = nn.Parameter(torch.empty(hidden_size))
19
+
20
+ def forward(self, x_t: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
21
+ batch_size, hidden_size = x_t.shape
22
+ assert hidden_size == self.hidden_size
23
+ assert state.shape[0] == batch_size
24
+
25
+ i_t = torch.sigmoid(self.input_gate(x_t))
26
+ r_t = torch.sigmoid(self.recurrence_gate(x_t))
27
+
28
+ # Compute recurrence
29
+ a_t = self.a ** (self.c * r_t)
30
+ multiplier = torch.sqrt(1 - a_t**2)
31
+ new_state = (state * a_t) + (multiplier * (i_t * x_t))
32
+
33
+ return new_state
34
+
35
+ def init_state(self, batch_size: int, device: torch.device | None = None):
36
+ if device is None:
37
+ device = get_model_device(self)
38
+ return torch.zeros(batch_size, self.hidden_size, device=device)
39
+
40
+
41
+ class CausalConv1d(nn.Module):
42
+
43
+ def __init__(self, hidden_size, kernel_size):
44
+ super().__init__()
45
+ self.hidden_size = hidden_size
46
+ self.kernel_size = kernel_size
47
+ self.conv = nn.Conv1d(
48
+ hidden_size, hidden_size, kernel_size, groups=hidden_size, bias=True
49
+ )
50
+
51
+ def init_state(self, batch_size: int, device: torch.device | None = None):
52
+ if device is None:
53
+ device = get_model_device(self)
54
+ return torch.zeros(
55
+ batch_size, self.hidden_size, self.kernel_size - 1, device=device
56
+ )
57
+
58
+ def forward(self, x: torch.Tensor, state: torch.Tensor):
59
+ x_with_state = torch.concat([state, x[:, :, None]], dim=-1)
60
+ out = self.conv(x_with_state)
61
+ new_state = x_with_state[:, :, 1:]
62
+ return out.squeeze(-1), new_state
63
+
64
+
65
+ class Hawk(nn.Module):
66
+ def __init__(self, hidden_size: int, conv_kernel_size: int = 4):
67
+ super().__init__()
68
+
69
+ self.conv_kernel_size = conv_kernel_size
70
+ self.hidden_size = hidden_size
71
+
72
+ self.gate_proj = nn.Linear(hidden_size, hidden_size, bias=False)
73
+ self.recurrent_proj = nn.Linear(hidden_size, hidden_size, bias=False)
74
+ self.conv = CausalConv1d(hidden_size, conv_kernel_size)
75
+ self.rglru = RGLRU(hidden_size)
76
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=False)
77
+
78
+ def forward(
79
+ self, x: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]
80
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
81
+ conv_state, rglru_state = state
82
+
83
+ batch_size, hidden_size = x.shape
84
+
85
+ assert batch_size == conv_state.shape[0] == rglru_state.shape[0]
86
+ assert self.hidden_size == hidden_size == rglru_state.shape[1]
87
+
88
+ gate = F.gelu(self.gate_proj(x))
89
+ x = self.recurrent_proj(x)
90
+
91
+ x, new_conv_state = self.conv(x, conv_state)
92
+ new_rglru_state = self.rglru(x, rglru_state)
93
+
94
+ gated = gate * new_rglru_state
95
+ out = self.out_proj(gated)
96
+
97
+ new_state = [new_conv_state, new_rglru_state]
98
+ return out, new_state
99
+
100
+ def init_state(
101
+ self, batch_size: int, device: torch.device | None = None
102
+ ) -> list[torch.Tensor]:
103
+ return [
104
+ self.conv.init_state(batch_size, device),
105
+ self.rglru.init_state(batch_size, device),
106
+ ]