| import torch |
| import torch.nn as nn |
| import numpy as np |
| from functools import partial |
| from tqdm import tqdm |
|
|
| from mae import MAE |
|
|
| |
| POSITIONS = { |
| "Fp1": (-3.09, 11.46, 2.79), |
| "Fp2": (2.84, 11.53, 2.77), |
| "F3": (-5.18, 8.67, 7.87), |
| "F4": (5.03, 8.74, 7.73), |
| "F7": (-7.19, 7.31, 2.58), |
| "F8": (7.14, 7.45, 2.51), |
| "T3": (-8.60, 1.49, 3.12), |
| "T4": (8.33, 1.53, 3.10), |
| "C3": (-6.71, 2.34, 10.45), |
| "C4": (6.53, 2.36, 10.37), |
| "T5": (-8.77, 1.29, -0.77), |
| "T6": (8.37, 1.17, -0.77), |
| "P3": (-5.50, -4.42, 9.99), |
| "P4": (5.36, -4.43, 10.05), |
| "O1": (-3.16, -8.06, 5.48), |
| "O2": (2.77, -8.05, 5.47), |
| "Fz": (-0.12, 9.33, 10.26), |
| "Cz": (-0.14, 2.76, 14.02), |
| "Pz": (-0.17, -4.52, 12.67), |
| "A2": (8.39, 0.20, -2.69), |
| } |
|
|
|
|
| class MANAS1(nn.Module): |
| def __init__(self, checkpoint_path, num_classes=2, flat_dim=512): |
| super().__init__() |
|
|
| print(f"Loading checkpoint from {checkpoint_path}...") |
| ckpt = torch.load(checkpoint_path, map_location="cpu") |
|
|
| self.mae = MAE(fs=200, embed_dim=512, encoder_depth=12, encoder_heads=8, decoder_depth=4, decoder_heads=8, mask_ratio=0.55) |
| self.mae.load_state_dict(ckpt["model_state_dict"]) |
|
|
| self.patch_embed = self.mae.patch_embed |
| self.pos_enc = self.mae.pos_enc |
| self.encoder = self.mae.encoder |
| self.patch_size = self.mae.patch_size |
| self.step = self.mae.step |
|
|
| self.flat_dim = flat_dim |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| def prepare_coords(self, xyz, num_patches): |
| B, C, _ = xyz.shape |
| device = xyz.device |
| time_idx = torch.arange(num_patches, device=device).float() |
| spat = xyz.unsqueeze(2).expand(-1, -1, num_patches, -1) |
| time = time_idx.view(1, 1, num_patches, 1).expand(B, C, -1, -1) |
| return torch.cat([spat, time], dim=-1).flatten(1, 2) |
|
|
| def forward(self, x, pos): |
| patches = x.unfold(-1, self.patch_size, self.step) |
| num_patches = patches.shape[2] |
|
|
| tokens = self.patch_embed.linear(patches).flatten(1, 2) |
|
|
| coords = self.prepare_coords(pos, num_patches) |
| pe = self.pos_enc(coords) |
|
|
| x_enc = tokens + pe |
| latents, _ = self.encoder(x_enc) |
|
|
| |
| return latents |