| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import einops |
| import math |
|
|
| class FuzzyEmbedding(nn.Module): |
| def __init__(self, grid_size, scale, width, apply_fuzzy=False): |
| super(FuzzyEmbedding, self).__init__() |
| assert grid_size == 1024, "grid_size must be 1024 for now" |
| |
| self.grid_size = grid_size |
| self.scale = scale |
| self.width = width |
| self.apply_fuzzy = apply_fuzzy |
| |
| |
| self.positional_embedding = nn.Parameter( |
| scale * torch.randn(grid_size, width)) |
| |
| self.class_positional_embedding = nn.Parameter( |
| scale * torch.randn(1, width)) |
|
|
| @torch.cuda.amp.autocast(enabled=False) |
| def forward(self, grid_height, grid_width, train=True, dtype=torch.float32): |
| meshx, meshy = torch.meshgrid( |
| torch.tensor(list(range(grid_height)), device=self.positional_embedding.device), |
| torch.tensor(list(range(grid_width)), device=self.positional_embedding.device) |
| ) |
| meshx = meshx.to(dtype) |
| meshy = meshy.to(dtype) |
|
|
| |
| meshx = 2 * (meshx / (grid_height - 1)) - 1 |
| meshy = 2 * (meshy / (grid_width - 1)) - 1 |
| |
| if self.apply_fuzzy: |
| |
| if train: |
| noise_x = torch.rand_like(meshx) * 0.0008 - 0.0004 |
| noise_y = torch.rand_like(meshy) * 0.0008 - 0.0004 |
| else: |
| noise_x = torch.zeros_like(meshx) |
| noise_y = torch.zeros_like(meshy) |
|
|
| |
| meshx = meshx + noise_x |
| meshy = meshy + noise_y |
| |
| grid = torch.stack((meshy, meshx), 2).to(self.positional_embedding.device) |
| grid = grid.unsqueeze(0) |
| |
| positional_embedding = einops.rearrange(self.positional_embedding, "(h w) d -> d h w", h=int(math.sqrt(self.grid_size)), w=int(math.sqrt(self.grid_size))) |
| positional_embedding = positional_embedding.to(dtype) |
| positional_embedding = positional_embedding.unsqueeze(0) |
|
|
| fuzzy_embedding = F.grid_sample(positional_embedding, grid, align_corners=False) |
| fuzzy_embedding = fuzzy_embedding.to(dtype) |
| fuzzy_embedding = einops.rearrange(fuzzy_embedding, "b d h w -> b (h w) d").squeeze(0) |
|
|
| final_embedding = torch.cat([self.class_positional_embedding, fuzzy_embedding], dim=0) |
| return final_embedding |
|
|
|
|
| if __name__ == "__main__": |
| fuzzy_embedding = FuzzyEmbedding(256, 1.0, 1024) |
| grid_height = 16 |
| grid_width = 32 |
| fuzzy_embedding = fuzzy_embedding(grid_height, grid_width, dtype=torch.bfloat16) |
| print(fuzzy_embedding.shape) |