import torch import torch.nn as nn import torch.nn.functional as F import timm # ========================= # Simple HRNet baseline # ========================= class SimpleHRNet(nn.Module): def __init__(self, num_landmarks=29, in_chans=3): super().__init__() self.stem = nn.Sequential( nn.Conv2d(in_chans, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ) self.block1 = self._make_block(64, 64) self.block2 = self._make_block(64, 64) self.block3 = self._make_block(64, 64) self.head = nn.Conv2d(64, num_landmarks, kernel_size=1) def _make_block(self, in_ch, out_ch): return nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): x = self.stem(x) x = self.block1(x) x = self.block2(x) x = self.block3(x) return self.head(x) # ========================= # ViT + Heatmap Head # ========================= class ViTHeatmap(nn.Module): def __init__( self, num_landmarks=29, model_name="vit_base_patch16_224", pretrained=True, img_size=(512, 512), ): super().__init__() self.backbone = timm.create_model( model_name, pretrained=pretrained, img_size=img_size, dynamic_img_size=True, num_classes=0, global_pool="", ) embed_dim = self.backbone.num_features self.conv_proj = nn.Conv2d(embed_dim, 256, kernel_size=1) self.head = nn.Sequential( nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(inplace=True), nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), nn.Conv2d(256, 128, 3, padding=1), nn.ReLU(inplace=True), nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, num_landmarks, kernel_size=1), ) def forward(self, x): B = x.shape[0] tokens = self.backbone.forward_features(x) if isinstance(tokens, (list, tuple)): tokens = tokens[-1] tokens = tokens[:, 1:, :] # drop CLS token num_patches = tokens.shape[1] h = x.shape[2] // 16 w = x.shape[3] // 16 if h * w != num_patches: raise ValueError( f"Patch grid mismatch: input {(x.shape[2], x.shape[3])}, " f"expected {h}x{w}={h*w} patches, got {num_patches}" ) feat = tokens.transpose(1, 2).reshape(B, -1, h, w) feat = self.conv_proj(feat) return self.head(feat) # ========================= # model test # ========================= if __name__ == "__main__": x = torch.randn(2, 3, 224, 224) model1 = SimpleHRNet(num_landmarks=29) out1 = model1(x) print("HRNet output:", out1.shape) model2 = ViTHeatmap(num_landmarks=29) out2 = model2(x) print("ViT output:", out2.shape)