| |
|
|
| import argparse |
| import torch |
| from safetensors import safe_open |
|
|
|
|
| def compare_safetensors(filepath1: str, filepath2: str): |
| """ |
| Compares two .safetensors files, ignoring a specific prefix on layer names, |
| and prints a summary of the differences. |
| |
| Args: |
| filepath1 (str): Path to the first .safetensors file. |
| filepath2 (str): Path to the second .safetensors file. |
| """ |
| |
| prefix_to_ignore = "model.diffusion_model." |
|
|
| |
| results = { |
| "only_in_file1": [], |
| "only_in_file2": [], |
| "different_content": [], |
| } |
|
|
| print("\nLoading files and preparing for comparison...") |
| print(f"Ignoring prefix: '{prefix_to_ignore}'") |
|
|
| try: |
| |
| with safe_open(filepath1, framework="pt", device="cpu") as f1, \ |
| safe_open(filepath2, framework="pt", device="cpu") as f2: |
|
|
| |
| map1 = {key.removeprefix(prefix_to_ignore): key for key in f1.keys()} |
| map2 = {key.removeprefix(prefix_to_ignore): key for key in f2.keys()} |
|
|
| |
| normalized_keys1 = set(map1.keys()) |
| normalized_keys2 = set(map2.keys()) |
|
|
| |
| results["only_in_file1"] = sorted(list(normalized_keys1 - normalized_keys2)) |
| results["only_in_file2"] = sorted(list(normalized_keys2 - normalized_keys1)) |
|
|
| |
| common_normalized_keys = normalized_keys1.intersection(normalized_keys2) |
| print(f"Comparing {len(common_normalized_keys)} common tensors...") |
|
|
| for norm_key in sorted(list(common_normalized_keys)): |
| |
| original_key1 = map1[norm_key] |
| original_key2 = map2[norm_key] |
|
|
| |
| tensor1 = f1.get_tensor(original_key1) |
| tensor2 = f2.get_tensor(original_key2) |
|
|
| |
| if not torch.equal(tensor1, tensor2): |
| |
| results["different_content"].append(norm_key) |
|
|
| |
| print("\n" + "=" * 60) |
| print("🔍 Safetensor Comparison Results") |
| print("=" * 60) |
| print(f"File 1: {filepath1}") |
| print(f"File 2: {filepath2}") |
| print("-" * 60) |
|
|
| |
| total_diffs = len(results["only_in_file1"]) + len(results["only_in_file2"]) + len(results["different_content"]) |
| if total_diffs == 0: |
| print("\n✅ The files are identical after normalization. No differences found.") |
| print("=" * 60 + "\n") |
| return |
|
|
| |
| if results["different_content"]: |
| print(f"\n↔️ Tensors with Different Content ({len(results['different_content'])}):") |
| for norm_key in results["different_content"]: |
| print(f" - Normalized Key: {norm_key}") |
| print(f" (File 1 Original: {map1[norm_key]})") |
| print(f" (File 2 Original: {map2[norm_key]})") |
|
|
| |
| if results["only_in_file1"]: |
| print(f"\n→ Tensors Only in File 1 ({len(results['only_in_file1'])}):") |
| for norm_key in results["only_in_file1"]: |
| print(f" - Normalized Key: {norm_key} (Original: {map1[norm_key]})") |
|
|
| |
| if results["only_in_file2"]: |
| print(f"\n← Tensors Only in File 2 ({len(results['only_in_file2'])}):") |
| for norm_key in results["only_in_file2"]: |
| print(f" - Normalized Key: {norm_key} (Original: {map2[norm_key]})") |
|
|
| print("\n" + "=" * 60 + "\n") |
|
|
| except FileNotFoundError as e: |
| print(f"❌ Error: Could not find a file. Details: {e}") |
| except Exception as e: |
| print(f"❌ An error occurred: {e}") |
| print("Please ensure both files are valid .safetensors files.") |
|
|
|
|
| if __name__ == "__main__": |
| |
| parser = argparse.ArgumentParser( |
| description="Compares two .safetensors files and lists the differences in their layers (tensors), ignoring a specific prefix.", |
| formatter_class=argparse.RawTextHelpFormatter |
| ) |
|
|
| parser.add_argument( |
| "file1", |
| type=str, |
| help="Path to the first .safetensors file." |
| ) |
| parser.add_argument( |
| "file2", |
| type=str, |
| help="Path to the second .safetensors file." |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| compare_safetensors(args.file1, args.file2) |