GLAP / convert_checkpoint.py
Heinrich Dinkel
Updated GLAP, removed all dependencies to sonar.
a09cac7
#!/usr/bin/env python3
"""Convert GLAP checkpoint to HuggingFace safetensors format.
Usage:
python convert_checkpoint.py <input_checkpoint.pt> [output_dir]
The output_dir defaults to the current directory.
Produces: model.safetensors + config.json
"""
import argparse
import json
import sys
from pathlib import Path
import torch
def convert_state_dict(old_state_dict: dict) -> dict:
"""Map original GLAP state dict keys to HuggingFace format.
Original: audio_encoder.model.* (DashengWrapper wrapping dasheng_base)
HuggingFace: audio_encoder.* (DashengAudioEncoder directly)
Original: text_encoder.model.* (TextEncoderSonarWrapper wrapping SonarTextEncoder)
HuggingFace: text_encoder.* (SonarTextEncoder directly)
"""
new_state_dict = {}
for key, value in old_state_dict.items():
# Skip outputlayer (Identity layer, no learnable params)
if "outputlayer" in key:
continue
# audio_encoder.model.X -> audio_encoder.X
if key.startswith("audio_encoder.model."):
new_key = "audio_encoder." + key[len("audio_encoder.model."):]
new_state_dict[new_key] = value
# text_encoder.model.X -> text_encoder.X
elif key.startswith("text_encoder.model."):
new_key = "text_encoder." + key[len("text_encoder.model."):]
new_state_dict[new_key] = value
# audio_proj.X -> audio_proj.X (unchanged)
elif key.startswith("audio_proj."):
new_state_dict[key] = value
# text_proj.X -> text_proj.X (unchanged)
elif key.startswith("text_proj."):
new_state_dict[key] = value
else:
# Unknown key, keep as-is with warning
print(f" Warning: unrecognized key: {key}", file=sys.stderr)
new_state_dict[key] = value
return new_state_dict
def extract_config(old_config: dict) -> dict:
"""Extract HuggingFace config from original GLAP training config."""
model_args = old_config.get("model_args", {})
# Default values matching the pretrained model
config = {
"architectures": ["GlapModel"],
"auto_map": {
"AutoConfig": "configuration_glap.GlapConfig",
"AutoModel": "modeling_glap.GlapModel",
},
"model_type": "glap",
"audio_embed_dim": 768,
"audio_depth": 12,
"audio_num_heads": 12,
"patch_size": [64, 4],
"patch_stride": [64, 4],
"target_length": 1008,
"sample_rate": old_config.get("sample_rate", 16000),
"text_vocab_size": 256206,
"text_model_dim": 1024,
"text_num_layers": 24,
"text_num_heads": 16,
"text_ffn_inner_dim": 8192,
"text_max_seq_len": 514,
"text_pad_idx": 0,
"text_dropout_p": 0.1,
"embed_size": model_args.get("embed_size", 1024),
}
return config
def main():
parser = argparse.ArgumentParser(description="Convert GLAP checkpoint to HuggingFace format")
parser.add_argument("input", help="Path to original glap_checkpoint.pt")
parser.add_argument(
"-o", "--output-dir", default=".", help="Output directory (default: current dir)"
)
args = parser.parse_args()
input_path = Path(args.input)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Loading checkpoint from {input_path}...")
checkpoint = torch.load(str(input_path), map_location="cpu", weights_only=False)
if "model" not in checkpoint:
print("Error: checkpoint does not contain 'model' key", file=sys.stderr)
sys.exit(1)
print("Converting state dict...")
old_state_dict = checkpoint["model"]
new_state_dict = convert_state_dict(old_state_dict)
print(f" Original keys: {len(old_state_dict)}")
print(f" Converted keys: {len(new_state_dict)}")
# Save as safetensors
try:
from safetensors.torch import save_file
safetensors_path = output_dir / "model.safetensors"
print(f"Saving safetensors to {safetensors_path}...")
save_file(new_state_dict, str(safetensors_path))
print(" Done.")
except ImportError:
# Fall back to pytorch format
pt_path = output_dir / "pytorch_model.bin"
print(f"safetensors not installed, saving as {pt_path}...")
torch.save(new_state_dict, str(pt_path))
print(" Done. Install safetensors for HuggingFace compatibility: pip install safetensors")
# Save config
if "config" in checkpoint:
config = extract_config(checkpoint["config"])
else:
print("Warning: no config in checkpoint, using defaults", file=sys.stderr)
config = extract_config({})
config_path = output_dir / "config.json"
print(f"Saving config to {config_path}...")
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
print("Conversion complete!")
print(f"Files in {output_dir}:")
for p in sorted(output_dir.iterdir()):
if p.suffix in (".safetensors", ".bin", ".json"):
size = p.stat().st_size
print(f" {p.name}: {size / 1024 / 1024:.1f} MB")
if __name__ == "__main__":
main()