File size: 5,209 Bytes
a09cac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python3
"""Convert GLAP checkpoint to HuggingFace safetensors format.

Usage:
    python convert_checkpoint.py <input_checkpoint.pt> [output_dir]

The output_dir defaults to the current directory.
Produces: model.safetensors + config.json
"""

import argparse
import json
import sys
from pathlib import Path

import torch


def convert_state_dict(old_state_dict: dict) -> dict:
    """Map original GLAP state dict keys to HuggingFace format.

    Original: audio_encoder.model.* (DashengWrapper wrapping dasheng_base)
    HuggingFace: audio_encoder.* (DashengAudioEncoder directly)

    Original: text_encoder.model.* (TextEncoderSonarWrapper wrapping SonarTextEncoder)
    HuggingFace: text_encoder.* (SonarTextEncoder directly)
    """
    new_state_dict = {}
    for key, value in old_state_dict.items():
        # Skip outputlayer (Identity layer, no learnable params)
        if "outputlayer" in key:
            continue

        # audio_encoder.model.X -> audio_encoder.X
        if key.startswith("audio_encoder.model."):
            new_key = "audio_encoder." + key[len("audio_encoder.model."):]
            new_state_dict[new_key] = value
        # text_encoder.model.X -> text_encoder.X
        elif key.startswith("text_encoder.model."):
            new_key = "text_encoder." + key[len("text_encoder.model."):]
            new_state_dict[new_key] = value
        # audio_proj.X -> audio_proj.X (unchanged)
        elif key.startswith("audio_proj."):
            new_state_dict[key] = value
        # text_proj.X -> text_proj.X (unchanged)
        elif key.startswith("text_proj."):
            new_state_dict[key] = value
        else:
            # Unknown key, keep as-is with warning
            print(f"  Warning: unrecognized key: {key}", file=sys.stderr)
            new_state_dict[key] = value

    return new_state_dict


def extract_config(old_config: dict) -> dict:
    """Extract HuggingFace config from original GLAP training config."""
    model_args = old_config.get("model_args", {})

    # Default values matching the pretrained model
    config = {
        "architectures": ["GlapModel"],
        "auto_map": {
            "AutoConfig": "configuration_glap.GlapConfig",
            "AutoModel": "modeling_glap.GlapModel",
        },
        "model_type": "glap",
        "audio_embed_dim": 768,
        "audio_depth": 12,
        "audio_num_heads": 12,
        "patch_size": [64, 4],
        "patch_stride": [64, 4],
        "target_length": 1008,
        "sample_rate": old_config.get("sample_rate", 16000),
        "text_vocab_size": 256206,
        "text_model_dim": 1024,
        "text_num_layers": 24,
        "text_num_heads": 16,
        "text_ffn_inner_dim": 8192,
        "text_max_seq_len": 514,
        "text_pad_idx": 0,
        "text_dropout_p": 0.1,
        "embed_size": model_args.get("embed_size", 1024),
    }
    return config


def main():
    parser = argparse.ArgumentParser(description="Convert GLAP checkpoint to HuggingFace format")
    parser.add_argument("input", help="Path to original glap_checkpoint.pt")
    parser.add_argument(
        "-o", "--output-dir", default=".", help="Output directory (default: current dir)"
    )
    args = parser.parse_args()

    input_path = Path(args.input)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"Loading checkpoint from {input_path}...")
    checkpoint = torch.load(str(input_path), map_location="cpu", weights_only=False)

    if "model" not in checkpoint:
        print("Error: checkpoint does not contain 'model' key", file=sys.stderr)
        sys.exit(1)

    print("Converting state dict...")
    old_state_dict = checkpoint["model"]
    new_state_dict = convert_state_dict(old_state_dict)

    print(f"  Original keys: {len(old_state_dict)}")
    print(f"  Converted keys: {len(new_state_dict)}")

    # Save as safetensors
    try:
        from safetensors.torch import save_file

        safetensors_path = output_dir / "model.safetensors"
        print(f"Saving safetensors to {safetensors_path}...")
        save_file(new_state_dict, str(safetensors_path))
        print("  Done.")
    except ImportError:
        # Fall back to pytorch format
        pt_path = output_dir / "pytorch_model.bin"
        print(f"safetensors not installed, saving as {pt_path}...")
        torch.save(new_state_dict, str(pt_path))
        print("  Done. Install safetensors for HuggingFace compatibility: pip install safetensors")

    # Save config
    if "config" in checkpoint:
        config = extract_config(checkpoint["config"])
    else:
        print("Warning: no config in checkpoint, using defaults", file=sys.stderr)
        config = extract_config({})

    config_path = output_dir / "config.json"
    print(f"Saving config to {config_path}...")
    with open(config_path, "w") as f:
        json.dump(config, f, indent=2)

    print("Conversion complete!")
    print(f"Files in {output_dir}:")
    for p in sorted(output_dir.iterdir()):
        if p.suffix in (".safetensors", ".bin", ".json"):
            size = p.stat().st_size
            print(f"  {p.name}: {size / 1024 / 1024:.1f} MB")


if __name__ == "__main__":
    main()