RAE repo fails when using google/siglip2-so400m-patch14-224 as encoder

#2
by szlgallen - opened

Hello, I’m trying to use this model in the original RAE repo, but I run into an error when I simply swap the encoder to google/siglip2-so400m-patch14-224.

My current config looks like this:
stage_1:
target: stage1.RAE
params:
encoder_cls: 'SigLIP2wNorm'
encoder_config_path: 'google/siglip2-so400m-patch14-224'
encoder_input_size: 224
encoder_params: {'model_name': 'google/siglip2-so400m-patch14-224'}
decoder_config_path: 'configs/decoder/ViTXL'
pretrained_decoder_path: '/root/paddlejob/workspace/shenzhelun/nyu-visionx/siglip2_decoder/model.pt'
noise_tau: 0.
reshape_to_2d: True
normalization_stat_path: 'models/stats/siglip2/base_p16_i256/ImageNet1k/stat.pt'

From my understanding, if I switch to a different SigLIP2 encoder, I probably also need to update decoder_config_path to a decoder config that matches the new encoder’s architecture (e.g., hidden size / patch size). Is that correct?

The reported error is :
timeError: Error(s) in loading state_dict for GeneralDecoder:
size mismatch for decoder_pred.weight: copying a param with shape torch.Size([588, 1152]) from checkpoint, the shape in current model is torch.Size([768, 1152]).
size mismatch for decoder_pred.bias: copying a param with shape torch.Size([588]) from checkpoint, the shape in current model is torch.Size([768]).

VISIONx @ NYU org

Hi,

Yes, we provide the config.json for decoder under this repo.

Hello, i try to update configs to the provided config in this repo like following:
stage_1:
target: stage1.RAE
params:
encoder_cls: 'SigLIP2wNorm'
encoder_config_path: 'google/siglip2-so400m-patch14-224'
encoder_input_size: 224
encoder_params: {'model_name': 'google/siglip2-so400m-patch14-224'}
decoder_config_path: '/root/paddlejob/workspace/shenzhelun/nyu-visionx/siglip2_decoder/config.json'
pretrained_decoder_path: '/root/paddlejob/workspace/shenzhelun/nyu-visionx/siglip2_decoder/model.pt'
noise_tau: 0.
reshape_to_2d: True
normalization_stat_path: 'models/stats/siglip2/base_p16_i256/ImageNet1k/stat.pt'

but I still meet unmatched model parameters error:
File "/root/miniconda3/envs/vace_jt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2624, in load_state_dict
raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for GeneralDecoder:
size mismatch for decoder_pred.weight: copying a param with shape torch.Size([588, 1152]) from checkpoint, the shape in current model is torch.Size([768, 1152]).
size mismatch for decoder_pred.bias: copying a param with shape torch.Size([588]) from checkpoint, the shape in current model is torch.Size([768]).

is there anything i need to further revise?

VISIONx @ NYU org

Hi @szlgallen ,
The issue is that you're still using a normalization_stat_path that doesn't match the SigLIP2-so400m model.
Looking at the error:
decoder_pred.weight: copying a param with shape torch.Size([588, 1152]) from checkpoint, the shape in current model is torch.Size([768, 1152])
The decoder in our checkpoint outputs 588 dimensions because it's configured for patch_size=14:
14 × 14 × 3 = 588
But your model is expecting 768, which corresponds to patch_size=16:
16 × 16 × 3 = 768
This suggests that somewhere in your config, the decoder is being initialized with patch_size=16 instead of patch_size=14.
Your normalization_stat_path points to:
models/stats/siglip2/base_p16_i256/ImageNet1k/stat.pt
This path contains p16 (patch size 16) and i256 (image size 256). You need to use the correct normalization stats for the SigLIP2-so400m model.
If you're using our Scale-RAE repo, you can download the decoder directly from HuggingFace using the CLI:
python cli.py t2i --prompt "a cat" --decoder-repo nyu-visionx/siglip2_decoder
This will automatically download and configure the decoder correctly.

Sign up or log in to comment