IrohXu commited on
Commit
de06036
·
1 Parent(s): 487b94e

Update Weight

Browse files
Files changed (4) hide show
  1. config.json +53 -0
  2. configuration_sapiens.py +127 -0
  3. model.safetensors +3 -0
  4. modeling_sapiens.py +621 -0
config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SapiensGaitModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_sapiens.SapiensGaitConfig",
7
+ "AutoModel": "modeling_sapiens.SapiensGaitModel"
8
+ },
9
+ "model_type": "sapiens_gait",
10
+ "image_size": [
11
+ 1024,
12
+ 768
13
+ ],
14
+ "patch_size": 16,
15
+ "patch_padding": 2,
16
+ "image_mean": [
17
+ 123.675,
18
+ 116.28,
19
+ 103.53
20
+ ],
21
+ "image_std": [
22
+ 58.395,
23
+ 57.12,
24
+ 57.375
25
+ ],
26
+ "in_channels": 3,
27
+ "embed_dims": 1024,
28
+ "num_layers": 24,
29
+ "num_heads": 16,
30
+ "feedforward_channels": 4096,
31
+ "drop_rate": 0.0,
32
+ "drop_path_rate": 0.0,
33
+ "qkv_bias": true,
34
+ "num_keypoints": 133,
35
+ "deconv_out_channels": [
36
+ 768,
37
+ 768
38
+ ],
39
+ "deconv_kernel_sizes": [
40
+ 4,
41
+ 4
42
+ ],
43
+ "conv_out_channels": [
44
+ 768,
45
+ 768
46
+ ],
47
+ "conv_kernel_sizes": [
48
+ 1,
49
+ 1
50
+ ],
51
+ "torch_dtype": "float32",
52
+ "transformers_version": "4.40.0"
53
+ }
configuration_sapiens.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sapiens Gait Model Configuration (Pure HuggingFace, no OpenMMLab dependency)."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+ # Pre-defined architecture variants
6
+ SAPIENS_ARCH_ZOO = {
7
+ "sapiens_0.3b": {
8
+ "embed_dims": 1024,
9
+ "num_layers": 24,
10
+ "num_heads": 16,
11
+ "feedforward_channels": 4096,
12
+ },
13
+ "sapiens_0.6b": {
14
+ "embed_dims": 1280,
15
+ "num_layers": 32,
16
+ "num_heads": 16,
17
+ "feedforward_channels": 5120,
18
+ },
19
+ "sapiens_1b": {
20
+ "embed_dims": 1536,
21
+ "num_layers": 40,
22
+ "num_heads": 24,
23
+ "feedforward_channels": 6144,
24
+ },
25
+ "sapiens_2b": {
26
+ "embed_dims": 1920,
27
+ "num_layers": 48,
28
+ "num_heads": 32,
29
+ "feedforward_channels": 7680,
30
+ },
31
+ }
32
+
33
+
34
+ class SapiensGaitConfig(PretrainedConfig):
35
+ """Configuration class for Sapiens Gait pose estimation model.
36
+
37
+ This configuration stores all architecture parameters needed to build
38
+ the Sapiens model natively in PyTorch/HuggingFace without any OpenMMLab
39
+ dependency.
40
+
41
+ Args:
42
+ arch (str, optional): Architecture variant name. One of
43
+ "sapiens_0.3b", "sapiens_0.6b", "sapiens_1b", "sapiens_2b".
44
+ If provided, overrides embed_dims/num_layers/num_heads/feedforward_channels.
45
+ image_size (list[int]): Input image size as [height, width].
46
+ Defaults to [1024, 768].
47
+ patch_size (int): Patch size for the ViT backbone. Defaults to 16.
48
+ in_channels (int): Number of input image channels. Defaults to 3.
49
+ embed_dims (int): Embedding dimension. Defaults to 1920 (sapiens_2b).
50
+ num_layers (int): Number of transformer layers. Defaults to 48.
51
+ num_heads (int): Number of attention heads. Defaults to 32.
52
+ feedforward_channels (int): Hidden dim of FFN. Defaults to 7680.
53
+ drop_rate (float): Dropout rate. Defaults to 0.0.
54
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.0.
55
+ qkv_bias (bool): Whether to use bias in QKV projection. Defaults to True.
56
+ patch_padding (int): Padding for patch embedding conv. Defaults to 2.
57
+ num_keypoints (int): Number of output keypoints. Defaults to 17.
58
+ deconv_out_channels (list[int]): Output channels for deconv layers.
59
+ deconv_kernel_sizes (list[int]): Kernel sizes for deconv layers.
60
+ conv_out_channels (list[int]): Output channels for conv layers in head.
61
+ conv_kernel_sizes (list[int]): Kernel sizes for conv layers in head.
62
+ image_mean (list[float]): Normalization mean (RGB).
63
+ image_std (list[float]): Normalization std (RGB).
64
+ """
65
+
66
+ model_type = "sapiens_gait"
67
+
68
+ def __init__(
69
+ self,
70
+ arch=None,
71
+ image_size=None,
72
+ patch_size=16,
73
+ in_channels=3,
74
+ embed_dims=1920,
75
+ num_layers=48,
76
+ num_heads=32,
77
+ feedforward_channels=7680,
78
+ drop_rate=0.0,
79
+ drop_path_rate=0.0,
80
+ qkv_bias=True,
81
+ patch_padding=2,
82
+ num_keypoints=17,
83
+ deconv_out_channels=None,
84
+ deconv_kernel_sizes=None,
85
+ conv_out_channels=None,
86
+ conv_kernel_sizes=None,
87
+ image_mean=None,
88
+ image_std=None,
89
+ **kwargs,
90
+ ):
91
+ super().__init__(**kwargs)
92
+
93
+ # Apply architecture preset if specified
94
+ if arch is not None:
95
+ if arch not in SAPIENS_ARCH_ZOO:
96
+ raise ValueError(
97
+ f"Unknown arch '{arch}'. Choose from: {list(SAPIENS_ARCH_ZOO.keys())}"
98
+ )
99
+ preset = SAPIENS_ARCH_ZOO[arch]
100
+ embed_dims = preset["embed_dims"]
101
+ num_layers = preset["num_layers"]
102
+ num_heads = preset["num_heads"]
103
+ feedforward_channels = preset["feedforward_channels"]
104
+
105
+ # Backbone (ViT) parameters
106
+ self.image_size = image_size if image_size is not None else [1024, 768]
107
+ self.patch_size = patch_size
108
+ self.in_channels = in_channels
109
+ self.embed_dims = embed_dims
110
+ self.num_layers = num_layers
111
+ self.num_heads = num_heads
112
+ self.feedforward_channels = feedforward_channels
113
+ self.drop_rate = drop_rate
114
+ self.drop_path_rate = drop_path_rate
115
+ self.qkv_bias = qkv_bias
116
+ self.patch_padding = patch_padding
117
+
118
+ # Head parameters
119
+ self.num_keypoints = num_keypoints
120
+ self.deconv_out_channels = list(deconv_out_channels) if deconv_out_channels is not None else [768, 768]
121
+ self.deconv_kernel_sizes = list(deconv_kernel_sizes) if deconv_kernel_sizes is not None else [4, 4]
122
+ self.conv_out_channels = list(conv_out_channels) if conv_out_channels is not None else [768, 768]
123
+ self.conv_kernel_sizes = list(conv_kernel_sizes) if conv_kernel_sizes is not None else [1, 1]
124
+
125
+ # Preprocessing (for reference; user applies externally)
126
+ self.image_mean = image_mean if image_mean is not None else [123.675, 116.28, 103.53]
127
+ self.image_std = image_std if image_std is not None else [58.395, 57.12, 57.375]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3276793c19e48030c1770ff356e2b2361ec4c10b6e6b4f26d1210ca55dc1044
3
+ size 1318227164
modeling_sapiens.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sapiens Gait Model — Pure HuggingFace Transformers implementation.
3
+
4
+ No dependency on mmengine, mmcv, mmpose, or mmpretrain.
5
+ Weight key names are designed to exactly match the original OpenMMLab
6
+ checkpoint layout so existing safetensors can be loaded directly.
7
+
8
+ Architecture:
9
+ SapiensGaitModel (PreTrainedModel)
10
+ └── backbone (SapiensTopdownPoseEstimator)
11
+ ├── backbone (SapiensVisionTransformer)
12
+ │ ├── patch_embed.projection (Conv2d)
13
+ │ ├── pos_embed (Parameter)
14
+ │ ├── layers[i].ln1 (LayerNorm)
15
+ │ ├── layers[i].attn.qkv (Linear)
16
+ │ ├── layers[i].attn.proj (Linear)
17
+ │ ├── layers[i].ln2 (LayerNorm)
18
+ │ ├── layers[i].ffn.layers (Sequential)
19
+ │ └── ln1 (final LayerNorm)
20
+ └── head (SapiensHeatmapHead)
21
+ ├── deconv_layers (Sequential[ConvTranspose2d, InstanceNorm2d, SiLU, ...])
22
+ ├── conv_layers (Sequential[Conv2d, InstanceNorm2d, SiLU, ...])
23
+ └── final_layer (Conv2d 1x1)
24
+ """
25
+
26
+ import math
27
+ from typing import Dict, List, Optional, Tuple, Union
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ from transformers import PreTrainedModel
34
+ from transformers.modeling_outputs import ModelOutput
35
+
36
+ try:
37
+ from .configuration_sapiens import SapiensGaitConfig
38
+ except ImportError:
39
+ from configuration_sapiens import SapiensGaitConfig
40
+
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Utility: Stochastic Depth (DropPath)
44
+ # ---------------------------------------------------------------------------
45
+ class DropPath(nn.Module):
46
+ """Drop paths (Stochastic Depth) per sample."""
47
+
48
+ def __init__(self, drop_prob: float = 0.0):
49
+ super().__init__()
50
+ self.drop_prob = drop_prob
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ if not self.training or self.drop_prob == 0.0:
54
+ return x
55
+ keep_prob = 1.0 - self.drop_prob
56
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
57
+ random_tensor = torch.rand(shape, dtype=x.dtype, device=x.device)
58
+ random_tensor = torch.floor_(random_tensor + keep_prob)
59
+ return x.div(keep_prob) * random_tensor
60
+
61
+ def extra_repr(self) -> str:
62
+ return f"drop_prob={self.drop_prob}"
63
+
64
+
65
+ # ---------------------------------------------------------------------------
66
+ # Utility: Resize positional embedding via interpolation
67
+ # ---------------------------------------------------------------------------
68
+ def resize_pos_embed(
69
+ pos_embed: torch.Tensor,
70
+ src_shape: Tuple[int, int],
71
+ dst_shape: Tuple[int, int],
72
+ mode: str = "bicubic",
73
+ num_extra_tokens: int = 0,
74
+ ) -> torch.Tensor:
75
+ """Resize positional embedding from *src_shape* to *dst_shape*.
76
+
77
+ Works on a (1, N, C) tensor where N = num_extra_tokens + H*W.
78
+ """
79
+ if src_shape == dst_shape:
80
+ return pos_embed
81
+
82
+ extra_tokens = pos_embed[:, :num_extra_tokens] if num_extra_tokens > 0 else None
83
+ patch_pos_embed = pos_embed[:, num_extra_tokens:] # (1, H_s*W_s, C)
84
+
85
+ src_h, src_w = src_shape
86
+ dst_h, dst_w = dst_shape
87
+ C = patch_pos_embed.shape[-1]
88
+
89
+ # (1, H_s*W_s, C) -> (1, C, H_s, W_s)
90
+ patch_pos_embed = patch_pos_embed.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2).float()
91
+ # Interpolate to (1, C, H_d, W_d)
92
+ patch_pos_embed = F.interpolate(patch_pos_embed, size=(dst_h, dst_w), mode=mode, align_corners=False)
93
+ # (1, C, H_d, W_d) -> (1, H_d*W_d, C)
94
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, C)
95
+
96
+ if extra_tokens is not None:
97
+ patch_pos_embed = torch.cat([extra_tokens.float(), patch_pos_embed], dim=1)
98
+
99
+ return patch_pos_embed.to(pos_embed.dtype)
100
+
101
+
102
+ # ---------------------------------------------------------------------------
103
+ # Patch Embedding
104
+ # ---------------------------------------------------------------------------
105
+ class SapiensPatchEmbed(nn.Module):
106
+ """Image-to-patch embedding using a single Conv2d.
107
+
108
+ Matches the mmcv ``PatchEmbed`` weight layout::
109
+
110
+ patch_embed.projection.weight (embed_dims, in_channels, kH, kW)
111
+ patch_embed.projection.bias (embed_dims,)
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ in_channels: int,
117
+ embed_dims: int,
118
+ kernel_size: int,
119
+ stride: int,
120
+ padding: int,
121
+ input_size: Tuple[int, int],
122
+ ):
123
+ super().__init__()
124
+ self.projection = nn.Conv2d(
125
+ in_channels,
126
+ embed_dims,
127
+ kernel_size=kernel_size,
128
+ stride=stride,
129
+ padding=padding,
130
+ )
131
+ # Pre-compute the initial output resolution (used for pos_embed sizing)
132
+ h, w = input_size
133
+ out_h = (h + 2 * padding - kernel_size) // stride + 1
134
+ out_w = (w + 2 * padding - kernel_size) // stride + 1
135
+ self.init_out_size = (out_h, out_w)
136
+
137
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
138
+ x = self.projection(x) # (B, C, H_out, W_out)
139
+ out_size = (x.shape[2], x.shape[3])
140
+ x = x.flatten(2).transpose(1, 2) # (B, N, C)
141
+ return x, out_size
142
+
143
+
144
+ # ---------------------------------------------------------------------------
145
+ # Multi-Head Self-Attention (with fused QKV)
146
+ # ---------------------------------------------------------------------------
147
+ class SapiensAttention(nn.Module):
148
+ """Multi-head self-attention with fused QKV linear.
149
+
150
+ Weight layout::
151
+
152
+ attn.qkv.weight (3*embed_dims, embed_dims)
153
+ attn.qkv.bias (3*embed_dims,)
154
+ attn.proj.weight (embed_dims, embed_dims)
155
+ attn.proj.bias (embed_dims,)
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ embed_dims: int,
161
+ num_heads: int,
162
+ attn_drop: float = 0.0,
163
+ proj_drop: float = 0.0,
164
+ drop_path_rate: float = 0.0,
165
+ qkv_bias: bool = True,
166
+ ):
167
+ super().__init__()
168
+ self.embed_dims = embed_dims
169
+ self.num_heads = num_heads
170
+ self.head_dims = embed_dims // num_heads
171
+ self.scale = self.head_dims ** -0.5
172
+
173
+ self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
174
+ self.attn_drop = nn.Dropout(attn_drop)
175
+ self.proj = nn.Linear(embed_dims, embed_dims)
176
+ self.proj_drop = nn.Dropout(proj_drop)
177
+ self.out_drop = DropPath(drop_path_rate)
178
+
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ B, N, C = x.shape
181
+ qkv = (
182
+ self.qkv(x)
183
+ .reshape(B, N, 3, self.num_heads, self.head_dims)
184
+ .permute(2, 0, 3, 1, 4)
185
+ )
186
+ q, k, v = qkv.unbind(0) # each (B, heads, N, head_dim)
187
+
188
+ attn = (q @ k.transpose(-2, -1)) * self.scale
189
+ attn = attn.softmax(dim=-1)
190
+ attn = self.attn_drop(attn)
191
+
192
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
193
+ x = self.proj(x)
194
+ x = self.proj_drop(x)
195
+ x = self.out_drop(x)
196
+ return x
197
+
198
+
199
+ # ---------------------------------------------------------------------------
200
+ # Feed-Forward Network (matches mmcv FFN sequential layout)
201
+ # ---------------------------------------------------------------------------
202
+ class SapiensFFN(nn.Module):
203
+ """Two-layer MLP with GELU activation.
204
+
205
+ The internal ``self.layers`` is structured as::
206
+
207
+ Sequential(
208
+ Sequential(Linear, GELU, Dropout), # index 0
209
+ Linear, # index 1
210
+ Dropout, # index 2
211
+ )
212
+
213
+ This ensures the weight keys are::
214
+
215
+ ffn.layers.0.0.weight (fc1 weight)
216
+ ffn.layers.0.0.bias (fc1 bias)
217
+ ffn.layers.1.weight (fc2 weight)
218
+ ffn.layers.1.bias (fc2 bias)
219
+ """
220
+
221
+ def __init__(
222
+ self,
223
+ embed_dims: int,
224
+ feedforward_channels: int,
225
+ ffn_drop: float = 0.0,
226
+ drop_path_rate: float = 0.0,
227
+ ):
228
+ super().__init__()
229
+ self.layers = nn.Sequential(
230
+ nn.Sequential(
231
+ nn.Linear(embed_dims, feedforward_channels),
232
+ nn.GELU(),
233
+ nn.Dropout(ffn_drop),
234
+ ),
235
+ nn.Linear(feedforward_channels, embed_dims),
236
+ nn.Dropout(ffn_drop),
237
+ )
238
+ self.dropout_layer = DropPath(drop_path_rate)
239
+
240
+ def forward(self, x: torch.Tensor, identity: Optional[torch.Tensor] = None) -> torch.Tensor:
241
+ out = self.layers(x)
242
+ if identity is None:
243
+ identity = x
244
+ return identity + self.dropout_layer(out)
245
+
246
+
247
+ # ---------------------------------------------------------------------------
248
+ # Transformer Encoder Layer
249
+ # ---------------------------------------------------------------------------
250
+ class SapiensTransformerLayer(nn.Module):
251
+ """Pre-norm Transformer encoder layer.
252
+
253
+ Architecture::
254
+
255
+ x = x + attn(ln1(x))
256
+ x = ffn(ln2(x), identity=x) # residual handled inside FFN
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ embed_dims: int,
262
+ num_heads: int,
263
+ feedforward_channels: int,
264
+ drop_rate: float = 0.0,
265
+ attn_drop_rate: float = 0.0,
266
+ drop_path_rate: float = 0.0,
267
+ qkv_bias: bool = True,
268
+ ):
269
+ super().__init__()
270
+ self.ln1 = nn.LayerNorm(embed_dims, eps=1e-6)
271
+ self.attn = SapiensAttention(
272
+ embed_dims=embed_dims,
273
+ num_heads=num_heads,
274
+ attn_drop=attn_drop_rate,
275
+ proj_drop=drop_rate,
276
+ drop_path_rate=drop_path_rate,
277
+ qkv_bias=qkv_bias,
278
+ )
279
+ self.ln2 = nn.LayerNorm(embed_dims, eps=1e-6)
280
+ self.ffn = SapiensFFN(
281
+ embed_dims=embed_dims,
282
+ feedforward_channels=feedforward_channels,
283
+ ffn_drop=drop_rate,
284
+ drop_path_rate=drop_path_rate,
285
+ )
286
+
287
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
288
+ x = x + self.attn(self.ln1(x))
289
+ x = self.ffn(self.ln2(x), identity=x)
290
+ return x
291
+
292
+
293
+ # ---------------------------------------------------------------------------
294
+ # Vision Transformer Backbone
295
+ # ---------------------------------------------------------------------------
296
+ class SapiensVisionTransformer(nn.Module):
297
+ """Sapiens Vision Transformer backbone (no CLS token, feature-map output).
298
+
299
+ Key weight names::
300
+
301
+ patch_embed.projection.{weight,bias}
302
+ pos_embed
303
+ layers.{i}.ln1.{weight,bias}
304
+ layers.{i}.attn.qkv.{weight,bias}
305
+ layers.{i}.attn.proj.{weight,bias}
306
+ layers.{i}.ln2.{weight,bias}
307
+ layers.{i}.ffn.layers.0.0.{weight,bias}
308
+ layers.{i}.ffn.layers.1.{weight,bias}
309
+ ln1.{weight,bias}
310
+ """
311
+
312
+ def __init__(self, config: SapiensGaitConfig):
313
+ super().__init__()
314
+ self.embed_dims = config.embed_dims
315
+ self.num_layers = config.num_layers
316
+
317
+ # Patch embedding
318
+ self.patch_embed = SapiensPatchEmbed(
319
+ in_channels=config.in_channels,
320
+ embed_dims=config.embed_dims,
321
+ kernel_size=config.patch_size,
322
+ stride=config.patch_size,
323
+ padding=config.patch_padding,
324
+ input_size=tuple(config.image_size),
325
+ )
326
+ self.patch_resolution = self.patch_embed.init_out_size
327
+ num_patches = self.patch_resolution[0] * self.patch_resolution[1]
328
+
329
+ # Positional embedding (no CLS token)
330
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, config.embed_dims))
331
+ self.drop_after_pos = nn.Dropout(p=config.drop_rate)
332
+
333
+ # Stochastic depth schedule
334
+ dpr = np.linspace(0, config.drop_path_rate, config.num_layers).tolist()
335
+
336
+ # Transformer encoder layers
337
+ self.layers = nn.ModuleList(
338
+ [
339
+ SapiensTransformerLayer(
340
+ embed_dims=config.embed_dims,
341
+ num_heads=config.num_heads,
342
+ feedforward_channels=config.feedforward_channels,
343
+ drop_rate=config.drop_rate,
344
+ attn_drop_rate=0.0,
345
+ drop_path_rate=dpr[i],
346
+ qkv_bias=config.qkv_bias,
347
+ )
348
+ for i in range(config.num_layers)
349
+ ]
350
+ )
351
+
352
+ # Final LayerNorm
353
+ self.ln1 = nn.LayerNorm(config.embed_dims, eps=1e-6)
354
+
355
+ # ---- Load hook: resize / strip CLS from saved pos_embed if needed -----
356
+ def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
357
+ name = prefix + "pos_embed"
358
+ if name not in state_dict:
359
+ return
360
+
361
+ ckpt_pe = state_dict[name]
362
+ model_pe = self.pos_embed
363
+
364
+ # If checkpoint has one extra token (CLS) but model doesn't -> strip it
365
+ if ckpt_pe.shape[1] == model_pe.shape[1] + 1:
366
+ state_dict[name] = ckpt_pe[:, 1:]
367
+ ckpt_pe = state_dict[name]
368
+ elif ckpt_pe.shape[1] != model_pe.shape[1] and ckpt_pe.shape[1] % 2 == 1:
369
+ # Odd number of tokens likely means CLS token is present
370
+ state_dict[name] = ckpt_pe[:, 1:]
371
+ ckpt_pe = state_dict[name]
372
+
373
+ # If spatial resolution differs -> interpolate
374
+ if ckpt_pe.shape != model_pe.shape:
375
+ num_ckpt_patches = ckpt_pe.shape[1]
376
+ ckpt_h = ckpt_w = int(math.sqrt(num_ckpt_patches))
377
+ if ckpt_h * ckpt_w != num_ckpt_patches:
378
+ # Non-square: try to infer from aspect ratio
379
+ # Fallback: assume the same aspect ratio as model
380
+ ratio = self.patch_resolution[0] / self.patch_resolution[1]
381
+ ckpt_h = int(math.sqrt(num_ckpt_patches * ratio))
382
+ ckpt_w = num_ckpt_patches // ckpt_h
383
+ state_dict[name] = resize_pos_embed(
384
+ ckpt_pe,
385
+ (ckpt_h, ckpt_w),
386
+ self.patch_resolution,
387
+ mode="bicubic",
388
+ num_extra_tokens=0,
389
+ )
390
+
391
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
392
+ B = x.shape[0]
393
+
394
+ # Patch embedding
395
+ x, patch_resolution = self.patch_embed(x)
396
+
397
+ # Add (possibly interpolated) positional embedding
398
+ x = x + resize_pos_embed(
399
+ self.pos_embed,
400
+ self.patch_resolution,
401
+ patch_resolution,
402
+ mode="bicubic",
403
+ num_extra_tokens=0,
404
+ )
405
+ x = self.drop_after_pos(x)
406
+
407
+ hidden_states = []
408
+ # Transformer layers
409
+ for layer in self.layers:
410
+ x = layer(x)
411
+ hidden_states.append(x) # save output after each layer for potential MMVision compatibility
412
+
413
+ # Final norm
414
+ x = self.ln1(x)
415
+
416
+ # Reshape to 2-D feature map: (B, N, C) -> (B, C, H, W)
417
+ x_mapping = x.reshape(B, *patch_resolution, -1).permute(0, 3, 1, 2)
418
+
419
+ outputs = {
420
+ "feat_1d_tokens": x, # (B, N, C)
421
+ "feat_2d_tokens": x_mapping, # (B, C, H, W)
422
+ "hidden_states": hidden_states, # list of (B, N, C) after each layer
423
+ }
424
+
425
+ return outputs
426
+
427
+
428
+ # ---------------------------------------------------------------------------
429
+ # Heatmap Head (deconv upsampler)
430
+ # ---------------------------------------------------------------------------
431
+ class SapiensHeatmapHead(nn.Module):
432
+ """Simple-Baselines-style heatmap head with deconv + conv layers.
433
+
434
+ Uses ``InstanceNorm2d + SiLU`` (matching the Sapiens default ``use_silu=True``).
435
+
436
+ Key weight names::
437
+
438
+ deconv_layers.{0,3,...}.weight (ConvTranspose2d, no bias)
439
+ conv_layers.{0,3,...}.{weight,bias} (Conv2d)
440
+ final_layer.{weight,bias} (Conv2d 1×1)
441
+ """
442
+
443
+ def __init__(self, config: SapiensGaitConfig):
444
+ super().__init__()
445
+
446
+ in_channels = config.embed_dims
447
+
448
+ # --- Deconv (transposed-conv) upsampling layers ---
449
+ deconv_layers: List[nn.Module] = []
450
+ for out_ch, ks in zip(config.deconv_out_channels, config.deconv_kernel_sizes):
451
+ if ks == 4:
452
+ pad, opad = 1, 0
453
+ elif ks == 3:
454
+ pad, opad = 1, 1
455
+ elif ks == 2:
456
+ pad, opad = 0, 0
457
+ else:
458
+ raise ValueError(f"Unsupported deconv kernel size {ks}")
459
+ deconv_layers.append(
460
+ nn.ConvTranspose2d(in_channels, out_ch, kernel_size=ks, stride=2, padding=pad, output_padding=opad, bias=False)
461
+ )
462
+ deconv_layers.append(nn.InstanceNorm2d(out_ch))
463
+ deconv_layers.append(nn.SiLU(inplace=True))
464
+ in_channels = out_ch
465
+ self.deconv_layers = nn.Sequential(*deconv_layers)
466
+
467
+ # --- 1×1 (or N×N) conv refinement layers ---
468
+ conv_layers: List[nn.Module] = []
469
+ for out_ch, ks in zip(config.conv_out_channels, config.conv_kernel_sizes):
470
+ pad = (ks - 1) // 2
471
+ conv_layers.append(
472
+ nn.Conv2d(in_channels, out_ch, kernel_size=ks, stride=1, padding=pad)
473
+ )
474
+ conv_layers.append(nn.InstanceNorm2d(out_ch))
475
+ conv_layers.append(nn.SiLU(inplace=True))
476
+ in_channels = out_ch
477
+ self.conv_layers = nn.Sequential(*conv_layers)
478
+
479
+ # --- Final projection to keypoint heatmaps ---
480
+ self.final_layer = nn.Conv2d(in_channels, config.num_keypoints, kernel_size=1)
481
+
482
+ def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
483
+ x = feats[-1] # take the last (only) feature map
484
+ x = self.deconv_layers(x)
485
+ x = self.conv_layers(x)
486
+ x = self.final_layer(x)
487
+ return x
488
+
489
+
490
+ # ---------------------------------------------------------------------------
491
+ # Top-Down Pose Estimator (backbone + head wrapper)
492
+ # ---------------------------------------------------------------------------
493
+ class SapiensTopdownPoseEstimator(nn.Module):
494
+ """Wraps the ViT backbone and heatmap head.
495
+
496
+ Named ``backbone`` and ``head`` to match the original OpenMMLab key prefix:
497
+ backbone.backbone.… -> self.backbone.…
498
+ backbone.head.… -> self.head.…
499
+ """
500
+
501
+ def __init__(self, config: SapiensGaitConfig):
502
+ super().__init__()
503
+ self.backbone = SapiensVisionTransformer(config)
504
+ self.head = SapiensHeatmapHead(config)
505
+
506
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
507
+ outputs = self.backbone(x)
508
+ feats_2d_tokens = outputs["feat_2d_tokens"]
509
+ feats_1d_tokens = outputs["feat_1d_tokens"]
510
+ hidden_states = outputs["hidden_states"]
511
+ heatmaps = self.head(feats_2d_tokens)
512
+
513
+ outputs = {
514
+ "heatmaps": heatmaps,
515
+ "last_hidden_state": feats_1d_tokens,
516
+ "hidden_states": hidden_states,
517
+ }
518
+
519
+ return outputs
520
+
521
+
522
+ # ---------------------------------------------------------------------------
523
+ # HuggingFace PreTrainedModel
524
+ # ---------------------------------------------------------------------------
525
+ class SapiensGaitModel(PreTrainedModel):
526
+ """Sapiens pose-estimation model as a HuggingFace ``PreTrainedModel``.
527
+
528
+ This implementation is **completely independent** of OpenMMLab and only
529
+ requires ``torch`` and ``transformers``.
530
+
531
+ Usage::
532
+
533
+ from transformers import AutoModel, AutoConfig
534
+
535
+ config = AutoConfig.from_pretrained("path/to/sapiens_gait_fixed")
536
+ model = AutoModel.from_pretrained("path/to/sapiens_gait_fixed")
537
+
538
+ # pixel_values: (B, 3, H, W), already normalised with config.image_mean / image_std
539
+ out = model(pixel_values)
540
+ keypoints = out["keypoints"] # (B, K, 2) normalised [0,1]
541
+ scores = out["scores"] # (B, K)
542
+ heatmaps = out["heatmaps"] # (B, K, Hm, Wm)
543
+
544
+ .. note::
545
+ Input images should be RGB and normalised with::
546
+
547
+ pixel = (pixel - mean) / std
548
+
549
+ using ``config.image_mean`` and ``config.image_std`` (in 0-255 scale).
550
+ """
551
+
552
+ config_class = SapiensGaitConfig
553
+
554
+ def __init__(self, config: SapiensGaitConfig):
555
+ super().__init__(config)
556
+ self.backbone = SapiensTopdownPoseEstimator(config)
557
+
558
+ # Register the pos_embed resize hook so from_pretrained handles
559
+ # checkpoints with a different spatial resolution gracefully.
560
+ self.backbone.backbone._register_load_state_dict_pre_hook(
561
+ self.backbone.backbone._prepare_pos_embed
562
+ )
563
+
564
+ # Initialize weights (only for freshly created models; from_pretrained
565
+ # overwrites with checkpoint values).
566
+ self.post_init()
567
+
568
+ def _init_weights(self, module: nn.Module):
569
+ """Initialize weights following the original Sapiens convention."""
570
+ if isinstance(module, nn.Linear):
571
+ nn.init.xavier_uniform_(module.weight)
572
+ if module.bias is not None:
573
+ nn.init.zeros_(module.bias)
574
+ elif isinstance(module, nn.Conv2d):
575
+ nn.init.normal_(module.weight, std=0.001)
576
+ if module.bias is not None:
577
+ nn.init.zeros_(module.bias)
578
+ elif isinstance(module, nn.ConvTranspose2d):
579
+ nn.init.normal_(module.weight, std=0.001)
580
+ elif isinstance(module, nn.LayerNorm):
581
+ nn.init.ones_(module.weight)
582
+ nn.init.zeros_(module.bias)
583
+ elif isinstance(module, nn.InstanceNorm2d):
584
+ if module.weight is not None:
585
+ nn.init.ones_(module.weight)
586
+ if module.bias is not None:
587
+ nn.init.zeros_(module.bias)
588
+
589
+ def forward(
590
+ self,
591
+ pixel_values: torch.Tensor,
592
+ return_heatmaps: bool = True,
593
+ ) -> Dict[str, torch.Tensor]:
594
+ """
595
+ Args:
596
+ pixel_values: (B, 3, H, W) normalised input images.
597
+ return_heatmaps: whether to include raw heatmaps in the output.
598
+
599
+ Returns:
600
+ dict with keys ``keypoints`` (B, K, 2), ``scores`` (B, K),
601
+ and optionally ``heatmaps`` (B, K, Hm, Wm).
602
+ """
603
+ outputs = self.backbone(pixel_values)
604
+ heatmaps = outputs["heatmaps"]
605
+ feats_1d_tokens = outputs["feats"]
606
+
607
+ B, K, H, W = heatmaps.shape
608
+ heatmaps_flat = heatmaps.view(B, K, -1)
609
+ max_scores, idx = torch.max(heatmaps_flat, dim=-1)
610
+
611
+ preds_x = (idx % W).float() / W
612
+ preds_y = (idx // W).float() / H
613
+
614
+ keypoints = torch.stack([preds_x, preds_y], dim=-1)
615
+
616
+ out = {"keypoints": keypoints, "scores": max_scores}
617
+ if return_heatmaps:
618
+ out["heatmaps"] = heatmaps
619
+
620
+ return out
621
+