| import os |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
| import torch |
|
|
| def merge_safetensor_files(sftsr_files, output_file="model.safetensors"): |
| slices_dict = {} |
| metadata = {} |
|
|
| for idx, file in enumerate(sftsr_files): |
| with safe_open(file, framework="pt") as sf_tsr: |
| if idx == 0: |
| metadata = sf_tsr.metadata() |
| for key in sf_tsr.keys(): |
| tensor = sf_tsr.get_tensor(key) |
| if key not in slices_dict: |
| slices_dict[key] = [] |
| slices_dict[key].append(tensor) |
| |
| merged_tensors = {} |
| for key, slices in slices_dict.items(): |
| if len(slices) == 1: |
| merged_tensors[key] = slices[0] |
| else: |
| |
| ref_shape = slices[0].shape |
| concat_dim = None |
| for dim in range(len(ref_shape)): |
| dim_sizes = [s.shape[dim] for s in slices] |
| if len(set(dim_sizes)) > 1: |
| concat_dim = dim |
| break |
| if concat_dim is None: |
| concat_dim = 0 |
| merged_tensors[key] = torch.cat(slices, dim=concat_dim) |
| print(f"Merged key '{key}' from {len(slices)} slices along dim {concat_dim}") |
|
|
| os.makedirs(os.path.dirname(output_file), exist_ok=True) |
| save_file(merged_tensors, output_file, metadata) |
| print(f"Merged {len(sftsr_files)} shards into {output_file}") |
|
|
| def get_safetensor_files(directory): |
| safetensors_files = [] |
| for root, _, files in os.walk(directory): |
| for file in files: |
| if file.endswith(".safetensors"): |
| safetensors_files.append(os.path.join(root, file)) |
| return safetensors_files |
|
|
| if __name__ == "__main__": |
| safetensor_files = get_safetensor_files("./shards") |
| print(f"The following shards/chunks will be merged: {safetensor_files}") |
|
|
| default_output = "./output/merged_model.safetensors" |
| user_output = input(f"Enter output file path [{default_output}]: ").strip() |
| output_file = user_output if user_output else default_output |
|
|
| merge_safetensor_files(safetensor_files, output_file=output_file) |
|
|