UNIStainNet / src /models /uni_processor.py
faceless-void's picture
Upload folder using huggingface_hub
4db9215 verified
"""
UNI feature processors: transform UNI pathology features into multi-scale spatial maps.
- UNIFeatureProcessor: for CLS-token features (4x4 = 16 tokens)
- UNIFeatureProcessorHighRes: for patch-token features (32x32 = 1024 tokens)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNIFeatureProcessor(nn.Module):
"""Process UNI features [B, 16, 1024] β†’ multi-scale spatial feature maps.
UNI produces 16 spatial tokens (4x4 grid) of 1024-dim. We project to
generator channel dim and upsample to match each decoder layer resolution.
"""
def __init__(self, uni_dim=1024, base_channels=512):
super().__init__()
self.base_channels = base_channels
# Project UNI features to generator channel dim
self.proj = nn.Sequential(
nn.Linear(uni_dim, base_channels),
nn.LeakyReLU(0.2, inplace=True),
)
# Multi-scale upsamplers: 4Γ—4 β†’ {16, 32, 64, 128, 256}
# Each stage doubles spatial resolution
ch = base_channels
# 4β†’8β†’16
self.up_16 = nn.Sequential(
nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# 16β†’32
self.up_32 = nn.Sequential(
nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# 32β†’64
ch_64 = base_channels // 2 # 256
self.up_64 = nn.Sequential(
nn.ConvTranspose2d(ch, ch_64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# 64β†’128
ch_128 = base_channels // 4 # 128
self.up_128 = nn.Sequential(
nn.ConvTranspose2d(ch_64, ch_128, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# 128β†’256
ch_256 = base_channels // 8 # 64
self.up_256 = nn.Sequential(
nn.ConvTranspose2d(ch_128, ch_256, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, uni_features):
"""
Args:
uni_features: [B, 16, 1024]
Returns:
dict of spatial feature maps at each resolution
"""
B = uni_features.shape[0]
# Project and reshape to spatial
x = self.proj(uni_features) # [B, 16, 512]
x = x.permute(0, 2, 1).reshape(B, self.base_channels, 4, 4) # [B, 512, 4, 4]
# Multi-scale upsampling
feat_16 = self.up_16(x) # [B, 512, 16, 16]
feat_32 = self.up_32(feat_16) # [B, 512, 32, 32]
feat_64 = self.up_64(feat_32) # [B, 256, 64, 64]
feat_128 = self.up_128(feat_64) # [B, 128, 128, 128]
feat_256 = self.up_256(feat_128) # [B, 64, 256, 256]
return {
16: feat_16,
32: feat_32,
64: feat_64,
128: feat_128,
256: feat_256,
}
class UNIFeatureProcessorHighRes(nn.Module):
"""Process high-res UNI features [B, 1024, 1024] β†’ multi-scale spatial maps.
With patch-token extraction, UNI produces 1024 tokens (32x32 spatial grid)
of 1024-dim β€” 64x more spatial resolution than the CLS-only 4x4 grid.
Since we START at 32x32, we process features with Conv2d (no hallucinated
upsampling). Every spatial feature is backed by real UNI patch tokens.
Architecture:
32x32 input β†’ conv process β†’ feat_32 (512ch)
32β†’64 upsample β†’ conv β†’ feat_64 (256ch)
64β†’128 upsample β†’ conv β†’ feat_128 (128ch)
128β†’256 upsample β†’ conv β†’ feat_256 (64ch)
Also: 32β†’16 downsample β†’ feat_16 (512ch, for bottleneck)
"""
def __init__(self, uni_dim=1024, base_channels=512, spatial_size=32,
output_512=False):
super().__init__()
self.base_channels = base_channels
self.spatial_size = spatial_size
self.output_512 = output_512
ch = base_channels
# Project UNI 1024-dim β†’ 512-dim per token
self.proj = nn.Sequential(
nn.Linear(uni_dim, ch),
nn.LeakyReLU(0.2, inplace=True),
)
# Process at 32x32 (native resolution) β€” refine projected features
self.proc_32 = nn.Sequential(
nn.Conv2d(ch, ch, 3, padding=1),
nn.InstanceNorm2d(ch),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ch, ch, 3, padding=1),
nn.InstanceNorm2d(ch),
nn.LeakyReLU(0.2, inplace=True),
)
# 32β†’16 downsample (for bottleneck conditioning)
self.down_16 = nn.Sequential(
nn.Conv2d(ch, ch, 4, stride=2, padding=1),
nn.InstanceNorm2d(ch),
nn.LeakyReLU(0.2, inplace=True),
)
# 32β†’64 upsample + refine
ch_64 = ch // 2 # 256
self.up_64 = nn.Sequential(
nn.ConvTranspose2d(ch, ch_64, 4, stride=2, padding=1),
nn.InstanceNorm2d(ch_64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ch_64, ch_64, 3, padding=1),
nn.InstanceNorm2d(ch_64),
nn.LeakyReLU(0.2, inplace=True),
)
# 64β†’128 upsample + refine
ch_128 = ch // 4 # 128
self.up_128 = nn.Sequential(
nn.ConvTranspose2d(ch_64, ch_128, 4, stride=2, padding=1),
nn.InstanceNorm2d(ch_128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ch_128, ch_128, 3, padding=1),
nn.InstanceNorm2d(ch_128),
nn.LeakyReLU(0.2, inplace=True),
)
# 128β†’256 upsample + refine
ch_256 = ch // 8 # 64
self.up_256 = nn.Sequential(
nn.ConvTranspose2d(ch_128, ch_256, 4, stride=2, padding=1),
nn.InstanceNorm2d(ch_256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ch_256, ch_256, 3, padding=1),
nn.InstanceNorm2d(ch_256),
nn.LeakyReLU(0.2, inplace=True),
)
# 256β†’512 upsample (for 1024 models with SPADE at dec1)
if output_512:
ch_512 = ch // 16 # 32
self.up_512 = nn.Sequential(
nn.ConvTranspose2d(ch_256, ch_512, 4, stride=2, padding=1),
nn.InstanceNorm2d(ch_512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ch_512, ch_512, 3, padding=1),
nn.InstanceNorm2d(ch_512),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, uni_features):
"""
Args:
uni_features: [B, S*S, 1024] where S = spatial_size (default 32)
Returns:
dict of spatial feature maps: {16, 32, 64, 128, 256}
"""
B = uni_features.shape[0]
S = self.spatial_size
# Project and reshape to spatial grid
x = self.proj(uni_features) # [B, S*S, 512]
x = x.permute(0, 2, 1).reshape(B, self.base_channels, S, S) # [B, 512, 32, 32]
# Process at native 32x32
feat_32 = self.proc_32(x) + x # residual connection
# Downsample for bottleneck
feat_16 = self.down_16(feat_32) # [B, 512, 16, 16]
# Upsample path β€” each level adds spatial detail from real UNI tokens
feat_64 = self.up_64(feat_32) # [B, 256, 64, 64]
feat_128 = self.up_128(feat_64) # [B, 128, 128, 128]
feat_256 = self.up_256(feat_128) # [B, 64, 256, 256]
out = {
16: feat_16,
32: feat_32,
64: feat_64,
128: feat_128,
256: feat_256,
}
if self.output_512:
feat_512 = self.up_512(feat_256) # [B, 32, 512, 512]
out[512] = feat_512
return out