| |
| |
| |
| |
| |
| |
| |
| from typing import Optional |
|
|
| import torch |
| from torch import nn |
|
|
|
|
| class ClsToken(nn.Module): |
| def __init__(self, ndim: int, |
| num_tokens: int = 1, |
| enabled: bool = True, |
| register_multiple: Optional[int] = None, |
| num_registers: Optional[int] = None, |
| ): |
| super().__init__() |
|
|
| self.ndim = ndim |
| self.enabled = enabled |
| self.num_registers = 0 |
| self.num_tokens = num_tokens |
| if enabled: |
| if num_registers: |
| self.num_registers = num_registers |
| elif register_multiple: |
| self.num_registers = register_multiple - (num_tokens % register_multiple) |
|
|
| scale = ndim ** -0.5 |
| self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale) |
| else: |
| self.token = None |
|
|
| self.num_patches = self.num_tokens + self.num_registers |
|
|
| def disable(self): |
| self.token = None |
| self.enabled = False |
|
|
| def forward(self, x: torch.Tensor): |
| if self.token is None: |
| return x |
|
|
| token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1) |
| x = torch.cat([ |
| token, |
| x, |
| ], dim=1) |
|
|
| return x |
|
|
| def no_weight_decay(self): |
| return [ |
| 'token', |
| ] |
|
|