CephVIT / model.py
farrell236's picture
Upload model.py
c320b82 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
# =========================
# Simple HRNet baseline
# =========================
class SimpleHRNet(nn.Module):
def __init__(self, num_landmarks=29, in_chans=3):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(in_chans, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
self.block1 = self._make_block(64, 64)
self.block2 = self._make_block(64, 64)
self.block3 = self._make_block(64, 64)
self.head = nn.Conv2d(64, num_landmarks, kernel_size=1)
def _make_block(self, in_ch, out_ch):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.stem(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
return self.head(x)
# =========================
# ViT + Heatmap Head
# =========================
class ViTHeatmap(nn.Module):
def __init__(
self,
num_landmarks=29,
model_name="vit_base_patch16_224",
pretrained=True,
img_size=(512, 512),
):
super().__init__()
self.backbone = timm.create_model(
model_name,
pretrained=pretrained,
img_size=img_size,
dynamic_img_size=True,
num_classes=0,
global_pool="",
)
embed_dim = self.backbone.num_features
self.conv_proj = nn.Conv2d(embed_dim, 256, kernel_size=1)
self.head = nn.Sequential(
nn.Conv2d(256, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(256, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(128, 64, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, num_landmarks, kernel_size=1),
)
def forward(self, x):
B = x.shape[0]
tokens = self.backbone.forward_features(x)
if isinstance(tokens, (list, tuple)):
tokens = tokens[-1]
tokens = tokens[:, 1:, :] # drop CLS token
num_patches = tokens.shape[1]
h = x.shape[2] // 16
w = x.shape[3] // 16
if h * w != num_patches:
raise ValueError(
f"Patch grid mismatch: input {(x.shape[2], x.shape[3])}, "
f"expected {h}x{w}={h*w} patches, got {num_patches}"
)
feat = tokens.transpose(1, 2).reshape(B, -1, h, w)
feat = self.conv_proj(feat)
return self.head(feat)
# =========================
# model test
# =========================
if __name__ == "__main__":
x = torch.randn(2, 3, 224, 224)
model1 = SimpleHRNet(num_landmarks=29)
out1 = model1(x)
print("HRNet output:", out1.shape)
model2 = ViTHeatmap(num_landmarks=29)
out2 = model2(x)
print("ViT output:", out2.shape)