File size: 5,209 Bytes
a09cac7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | #!/usr/bin/env python3
"""Convert GLAP checkpoint to HuggingFace safetensors format.
Usage:
python convert_checkpoint.py <input_checkpoint.pt> [output_dir]
The output_dir defaults to the current directory.
Produces: model.safetensors + config.json
"""
import argparse
import json
import sys
from pathlib import Path
import torch
def convert_state_dict(old_state_dict: dict) -> dict:
"""Map original GLAP state dict keys to HuggingFace format.
Original: audio_encoder.model.* (DashengWrapper wrapping dasheng_base)
HuggingFace: audio_encoder.* (DashengAudioEncoder directly)
Original: text_encoder.model.* (TextEncoderSonarWrapper wrapping SonarTextEncoder)
HuggingFace: text_encoder.* (SonarTextEncoder directly)
"""
new_state_dict = {}
for key, value in old_state_dict.items():
# Skip outputlayer (Identity layer, no learnable params)
if "outputlayer" in key:
continue
# audio_encoder.model.X -> audio_encoder.X
if key.startswith("audio_encoder.model."):
new_key = "audio_encoder." + key[len("audio_encoder.model."):]
new_state_dict[new_key] = value
# text_encoder.model.X -> text_encoder.X
elif key.startswith("text_encoder.model."):
new_key = "text_encoder." + key[len("text_encoder.model."):]
new_state_dict[new_key] = value
# audio_proj.X -> audio_proj.X (unchanged)
elif key.startswith("audio_proj."):
new_state_dict[key] = value
# text_proj.X -> text_proj.X (unchanged)
elif key.startswith("text_proj."):
new_state_dict[key] = value
else:
# Unknown key, keep as-is with warning
print(f" Warning: unrecognized key: {key}", file=sys.stderr)
new_state_dict[key] = value
return new_state_dict
def extract_config(old_config: dict) -> dict:
"""Extract HuggingFace config from original GLAP training config."""
model_args = old_config.get("model_args", {})
# Default values matching the pretrained model
config = {
"architectures": ["GlapModel"],
"auto_map": {
"AutoConfig": "configuration_glap.GlapConfig",
"AutoModel": "modeling_glap.GlapModel",
},
"model_type": "glap",
"audio_embed_dim": 768,
"audio_depth": 12,
"audio_num_heads": 12,
"patch_size": [64, 4],
"patch_stride": [64, 4],
"target_length": 1008,
"sample_rate": old_config.get("sample_rate", 16000),
"text_vocab_size": 256206,
"text_model_dim": 1024,
"text_num_layers": 24,
"text_num_heads": 16,
"text_ffn_inner_dim": 8192,
"text_max_seq_len": 514,
"text_pad_idx": 0,
"text_dropout_p": 0.1,
"embed_size": model_args.get("embed_size", 1024),
}
return config
def main():
parser = argparse.ArgumentParser(description="Convert GLAP checkpoint to HuggingFace format")
parser.add_argument("input", help="Path to original glap_checkpoint.pt")
parser.add_argument(
"-o", "--output-dir", default=".", help="Output directory (default: current dir)"
)
args = parser.parse_args()
input_path = Path(args.input)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Loading checkpoint from {input_path}...")
checkpoint = torch.load(str(input_path), map_location="cpu", weights_only=False)
if "model" not in checkpoint:
print("Error: checkpoint does not contain 'model' key", file=sys.stderr)
sys.exit(1)
print("Converting state dict...")
old_state_dict = checkpoint["model"]
new_state_dict = convert_state_dict(old_state_dict)
print(f" Original keys: {len(old_state_dict)}")
print(f" Converted keys: {len(new_state_dict)}")
# Save as safetensors
try:
from safetensors.torch import save_file
safetensors_path = output_dir / "model.safetensors"
print(f"Saving safetensors to {safetensors_path}...")
save_file(new_state_dict, str(safetensors_path))
print(" Done.")
except ImportError:
# Fall back to pytorch format
pt_path = output_dir / "pytorch_model.bin"
print(f"safetensors not installed, saving as {pt_path}...")
torch.save(new_state_dict, str(pt_path))
print(" Done. Install safetensors for HuggingFace compatibility: pip install safetensors")
# Save config
if "config" in checkpoint:
config = extract_config(checkpoint["config"])
else:
print("Warning: no config in checkpoint, using defaults", file=sys.stderr)
config = extract_config({})
config_path = output_dir / "config.json"
print(f"Saving config to {config_path}...")
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
print("Conversion complete!")
print(f"Files in {output_dir}:")
for p in sorted(output_dir.iterdir()):
if p.suffix in (".safetensors", ".bin", ".json"):
size = p.stat().st_size
print(f" {p.name}: {size / 1024 / 1024:.1f} MB")
if __name__ == "__main__":
main()
|