Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| # ========================= | |
| # Simple HRNet baseline | |
| # ========================= | |
| class SimpleHRNet(nn.Module): | |
| def __init__(self, num_landmarks=29, in_chans=3): | |
| super().__init__() | |
| self.stem = nn.Sequential( | |
| nn.Conv2d(in_chans, 64, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| ) | |
| self.block1 = self._make_block(64, 64) | |
| self.block2 = self._make_block(64, 64) | |
| self.block3 = self._make_block(64, 64) | |
| self.head = nn.Conv2d(64, num_landmarks, kernel_size=1) | |
| def _make_block(self, in_ch, out_ch): | |
| return nn.Sequential( | |
| nn.Conv2d(in_ch, out_ch, 3, padding=1), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_ch, out_ch, 3, padding=1), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| x = self.stem(x) | |
| x = self.block1(x) | |
| x = self.block2(x) | |
| x = self.block3(x) | |
| return self.head(x) | |
| # ========================= | |
| # ViT + Heatmap Head | |
| # ========================= | |
| class ViTHeatmap(nn.Module): | |
| def __init__( | |
| self, | |
| num_landmarks=29, | |
| model_name="vit_base_patch16_224", | |
| pretrained=True, | |
| img_size=(512, 512), | |
| ): | |
| super().__init__() | |
| self.backbone = timm.create_model( | |
| model_name, | |
| pretrained=pretrained, | |
| img_size=img_size, | |
| dynamic_img_size=True, | |
| num_classes=0, | |
| global_pool="", | |
| ) | |
| embed_dim = self.backbone.num_features | |
| self.conv_proj = nn.Conv2d(embed_dim, 256, kernel_size=1) | |
| self.head = nn.Sequential( | |
| nn.Conv2d(256, 256, 3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), | |
| nn.Conv2d(256, 128, 3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), | |
| nn.Conv2d(128, 64, 3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, num_landmarks, kernel_size=1), | |
| ) | |
| def forward(self, x): | |
| B = x.shape[0] | |
| tokens = self.backbone.forward_features(x) | |
| if isinstance(tokens, (list, tuple)): | |
| tokens = tokens[-1] | |
| tokens = tokens[:, 1:, :] # drop CLS token | |
| num_patches = tokens.shape[1] | |
| h = x.shape[2] // 16 | |
| w = x.shape[3] // 16 | |
| if h * w != num_patches: | |
| raise ValueError( | |
| f"Patch grid mismatch: input {(x.shape[2], x.shape[3])}, " | |
| f"expected {h}x{w}={h*w} patches, got {num_patches}" | |
| ) | |
| feat = tokens.transpose(1, 2).reshape(B, -1, h, w) | |
| feat = self.conv_proj(feat) | |
| return self.head(feat) | |
| # ========================= | |
| # model test | |
| # ========================= | |
| if __name__ == "__main__": | |
| x = torch.randn(2, 3, 224, 224) | |
| model1 = SimpleHRNet(num_landmarks=29) | |
| out1 = model1(x) | |
| print("HRNet output:", out1.shape) | |
| model2 = ViTHeatmap(num_landmarks=29) | |
| out2 = model2(x) | |
| print("ViT output:", out2.shape) |