File size: 7,882 Bytes
4db9215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""
UNI feature processors: transform UNI pathology features into multi-scale spatial maps.

- UNIFeatureProcessor: for CLS-token features (4x4 = 16 tokens)
- UNIFeatureProcessorHighRes: for patch-token features (32x32 = 1024 tokens)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class UNIFeatureProcessor(nn.Module):
    """Process UNI features [B, 16, 1024] β†’ multi-scale spatial feature maps.

    UNI produces 16 spatial tokens (4x4 grid) of 1024-dim. We project to
    generator channel dim and upsample to match each decoder layer resolution.
    """

    def __init__(self, uni_dim=1024, base_channels=512):
        super().__init__()
        self.base_channels = base_channels

        # Project UNI features to generator channel dim
        self.proj = nn.Sequential(
            nn.Linear(uni_dim, base_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Multi-scale upsamplers: 4Γ—4 β†’ {16, 32, 64, 128, 256}
        # Each stage doubles spatial resolution
        ch = base_channels

        # 4β†’8β†’16
        self.up_16 = nn.Sequential(
            nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        )
        # 16β†’32
        self.up_32 = nn.Sequential(
            nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        )
        # 32β†’64
        ch_64 = base_channels // 2  # 256
        self.up_64 = nn.Sequential(
            nn.ConvTranspose2d(ch, ch_64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        )
        # 64β†’128
        ch_128 = base_channels // 4  # 128
        self.up_128 = nn.Sequential(
            nn.ConvTranspose2d(ch_64, ch_128, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        )
        # 128β†’256
        ch_256 = base_channels // 8  # 64
        self.up_256 = nn.Sequential(
            nn.ConvTranspose2d(ch_128, ch_256, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, uni_features):
        """
        Args:
            uni_features: [B, 16, 1024]

        Returns:
            dict of spatial feature maps at each resolution
        """
        B = uni_features.shape[0]

        # Project and reshape to spatial
        x = self.proj(uni_features)  # [B, 16, 512]
        x = x.permute(0, 2, 1).reshape(B, self.base_channels, 4, 4)  # [B, 512, 4, 4]

        # Multi-scale upsampling
        feat_16 = self.up_16(x)      # [B, 512, 16, 16]
        feat_32 = self.up_32(feat_16)   # [B, 512, 32, 32]
        feat_64 = self.up_64(feat_32)   # [B, 256, 64, 64]
        feat_128 = self.up_128(feat_64)  # [B, 128, 128, 128]
        feat_256 = self.up_256(feat_128)  # [B, 64, 256, 256]

        return {
            16: feat_16,
            32: feat_32,
            64: feat_64,
            128: feat_128,
            256: feat_256,
        }


class UNIFeatureProcessorHighRes(nn.Module):
    """Process high-res UNI features [B, 1024, 1024] β†’ multi-scale spatial maps.

    With patch-token extraction, UNI produces 1024 tokens (32x32 spatial grid)
    of 1024-dim β€” 64x more spatial resolution than the CLS-only 4x4 grid.

    Since we START at 32x32, we process features with Conv2d (no hallucinated
    upsampling). Every spatial feature is backed by real UNI patch tokens.

    Architecture:
        32x32 input β†’ conv process β†’ feat_32 (512ch)
        32β†’64 upsample β†’ conv β†’ feat_64 (256ch)
        64β†’128 upsample β†’ conv β†’ feat_128 (128ch)
        128β†’256 upsample β†’ conv β†’ feat_256 (64ch)
        Also: 32β†’16 downsample β†’ feat_16 (512ch, for bottleneck)
    """

    def __init__(self, uni_dim=1024, base_channels=512, spatial_size=32,
                 output_512=False):
        super().__init__()
        self.base_channels = base_channels
        self.spatial_size = spatial_size
        self.output_512 = output_512
        ch = base_channels

        # Project UNI 1024-dim β†’ 512-dim per token
        self.proj = nn.Sequential(
            nn.Linear(uni_dim, ch),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Process at 32x32 (native resolution) β€” refine projected features
        self.proc_32 = nn.Sequential(
            nn.Conv2d(ch, ch, 3, padding=1),
            nn.InstanceNorm2d(ch),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch, ch, 3, padding=1),
            nn.InstanceNorm2d(ch),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # 32β†’16 downsample (for bottleneck conditioning)
        self.down_16 = nn.Sequential(
            nn.Conv2d(ch, ch, 4, stride=2, padding=1),
            nn.InstanceNorm2d(ch),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # 32β†’64 upsample + refine
        ch_64 = ch // 2  # 256
        self.up_64 = nn.Sequential(
            nn.ConvTranspose2d(ch, ch_64, 4, stride=2, padding=1),
            nn.InstanceNorm2d(ch_64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch_64, ch_64, 3, padding=1),
            nn.InstanceNorm2d(ch_64),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # 64β†’128 upsample + refine
        ch_128 = ch // 4  # 128
        self.up_128 = nn.Sequential(
            nn.ConvTranspose2d(ch_64, ch_128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(ch_128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch_128, ch_128, 3, padding=1),
            nn.InstanceNorm2d(ch_128),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # 128β†’256 upsample + refine
        ch_256 = ch // 8  # 64
        self.up_256 = nn.Sequential(
            nn.ConvTranspose2d(ch_128, ch_256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(ch_256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch_256, ch_256, 3, padding=1),
            nn.InstanceNorm2d(ch_256),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # 256β†’512 upsample (for 1024 models with SPADE at dec1)
        if output_512:
            ch_512 = ch // 16  # 32
            self.up_512 = nn.Sequential(
                nn.ConvTranspose2d(ch_256, ch_512, 4, stride=2, padding=1),
                nn.InstanceNorm2d(ch_512),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(ch_512, ch_512, 3, padding=1),
                nn.InstanceNorm2d(ch_512),
                nn.LeakyReLU(0.2, inplace=True),
            )

    def forward(self, uni_features):
        """
        Args:
            uni_features: [B, S*S, 1024] where S = spatial_size (default 32)

        Returns:
            dict of spatial feature maps: {16, 32, 64, 128, 256}
        """
        B = uni_features.shape[0]
        S = self.spatial_size

        # Project and reshape to spatial grid
        x = self.proj(uni_features)  # [B, S*S, 512]
        x = x.permute(0, 2, 1).reshape(B, self.base_channels, S, S)  # [B, 512, 32, 32]

        # Process at native 32x32
        feat_32 = self.proc_32(x) + x  # residual connection

        # Downsample for bottleneck
        feat_16 = self.down_16(feat_32)  # [B, 512, 16, 16]

        # Upsample path β€” each level adds spatial detail from real UNI tokens
        feat_64 = self.up_64(feat_32)    # [B, 256, 64, 64]
        feat_128 = self.up_128(feat_64)  # [B, 128, 128, 128]
        feat_256 = self.up_256(feat_128) # [B, 64, 256, 256]

        out = {
            16: feat_16,
            32: feat_32,
            64: feat_64,
            128: feat_128,
            256: feat_256,
        }

        if self.output_512:
            feat_512 = self.up_512(feat_256)  # [B, 32, 512, 512]
            out[512] = feat_512

        return out