#!/usr/bin/env python3 """Convert GLAP checkpoint to HuggingFace safetensors format. Usage: python convert_checkpoint.py [output_dir] The output_dir defaults to the current directory. Produces: model.safetensors + config.json """ import argparse import json import sys from pathlib import Path import torch def convert_state_dict(old_state_dict: dict) -> dict: """Map original GLAP state dict keys to HuggingFace format. Original: audio_encoder.model.* (DashengWrapper wrapping dasheng_base) HuggingFace: audio_encoder.* (DashengAudioEncoder directly) Original: text_encoder.model.* (TextEncoderSonarWrapper wrapping SonarTextEncoder) HuggingFace: text_encoder.* (SonarTextEncoder directly) """ new_state_dict = {} for key, value in old_state_dict.items(): # Skip outputlayer (Identity layer, no learnable params) if "outputlayer" in key: continue # audio_encoder.model.X -> audio_encoder.X if key.startswith("audio_encoder.model."): new_key = "audio_encoder." + key[len("audio_encoder.model."):] new_state_dict[new_key] = value # text_encoder.model.X -> text_encoder.X elif key.startswith("text_encoder.model."): new_key = "text_encoder." + key[len("text_encoder.model."):] new_state_dict[new_key] = value # audio_proj.X -> audio_proj.X (unchanged) elif key.startswith("audio_proj."): new_state_dict[key] = value # text_proj.X -> text_proj.X (unchanged) elif key.startswith("text_proj."): new_state_dict[key] = value else: # Unknown key, keep as-is with warning print(f" Warning: unrecognized key: {key}", file=sys.stderr) new_state_dict[key] = value return new_state_dict def extract_config(old_config: dict) -> dict: """Extract HuggingFace config from original GLAP training config.""" model_args = old_config.get("model_args", {}) # Default values matching the pretrained model config = { "architectures": ["GlapModel"], "auto_map": { "AutoConfig": "configuration_glap.GlapConfig", "AutoModel": "modeling_glap.GlapModel", }, "model_type": "glap", "audio_embed_dim": 768, "audio_depth": 12, "audio_num_heads": 12, "patch_size": [64, 4], "patch_stride": [64, 4], "target_length": 1008, "sample_rate": old_config.get("sample_rate", 16000), "text_vocab_size": 256206, "text_model_dim": 1024, "text_num_layers": 24, "text_num_heads": 16, "text_ffn_inner_dim": 8192, "text_max_seq_len": 514, "text_pad_idx": 0, "text_dropout_p": 0.1, "embed_size": model_args.get("embed_size", 1024), } return config def main(): parser = argparse.ArgumentParser(description="Convert GLAP checkpoint to HuggingFace format") parser.add_argument("input", help="Path to original glap_checkpoint.pt") parser.add_argument( "-o", "--output-dir", default=".", help="Output directory (default: current dir)" ) args = parser.parse_args() input_path = Path(args.input) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) print(f"Loading checkpoint from {input_path}...") checkpoint = torch.load(str(input_path), map_location="cpu", weights_only=False) if "model" not in checkpoint: print("Error: checkpoint does not contain 'model' key", file=sys.stderr) sys.exit(1) print("Converting state dict...") old_state_dict = checkpoint["model"] new_state_dict = convert_state_dict(old_state_dict) print(f" Original keys: {len(old_state_dict)}") print(f" Converted keys: {len(new_state_dict)}") # Save as safetensors try: from safetensors.torch import save_file safetensors_path = output_dir / "model.safetensors" print(f"Saving safetensors to {safetensors_path}...") save_file(new_state_dict, str(safetensors_path)) print(" Done.") except ImportError: # Fall back to pytorch format pt_path = output_dir / "pytorch_model.bin" print(f"safetensors not installed, saving as {pt_path}...") torch.save(new_state_dict, str(pt_path)) print(" Done. Install safetensors for HuggingFace compatibility: pip install safetensors") # Save config if "config" in checkpoint: config = extract_config(checkpoint["config"]) else: print("Warning: no config in checkpoint, using defaults", file=sys.stderr) config = extract_config({}) config_path = output_dir / "config.json" print(f"Saving config to {config_path}...") with open(config_path, "w") as f: json.dump(config, f, indent=2) print("Conversion complete!") print(f"Files in {output_dir}:") for p in sorted(output_dir.iterdir()): if p.suffix in (".safetensors", ".bin", ".json"): size = p.stat().st_size print(f" {p.name}: {size / 1024 / 1024:.1f} MB") if __name__ == "__main__": main()