| import torch |
| import torch.nn as nn |
| from omini.rotation.layer import Linear, Rotation |
|
|
| def test_rotation_merge(): |
| """ |
| Test that merging rotation adapter produces the same output as the unmerged version. |
| """ |
| print("="*60) |
| print("Testing Rotation Layer Merge") |
| print("="*60) |
| |
| |
| torch.manual_seed(42) |
| |
| |
| in_features = 512 |
| out_features = 1024 |
| r = 4 |
| num_rotations = 4 |
| T = 1.0 |
| batch_size = 8 |
| seq_len = 16 |
| |
| |
| base_layer = nn.Linear(in_features, out_features, bias=True) |
| |
| |
| rotation_layer = Linear( |
| base_layer=base_layer, |
| adapter_name="default", |
| r=r, |
| T=T, |
| num_rotations=num_rotations |
| ) |
| |
| |
| x = torch.randn(batch_size, seq_len, in_features) |
| |
| |
| print("\n" + "-"*60) |
| print("Test 1: Computing output BEFORE merge") |
| print("-"*60) |
| rotation_layer.eval() |
| with torch.no_grad(): |
| output_before = rotation_layer(x) |
| |
| print(f"Output shape: {output_before.shape}") |
| print(f"Output mean: {output_before.mean().item():.6f}") |
| print(f"Output std: {output_before.std().item():.6f}") |
| print(f"Output min: {output_before.min().item():.6f}") |
| print(f"Output max: {output_before.max().item():.6f}") |
| |
| |
| original_weight = base_layer.weight.data.clone() |
| |
| |
| print("\n" + "-"*60) |
| print("Test 2: Merging adapter") |
| print("-"*60) |
| rotation_layer.merge(safe_merge=True, adapter_names=["default"]) |
| print(f"β Adapter merged successfully") |
| print(f"β Merged adapters: {rotation_layer.merged_adapters}") |
| |
| |
| weight_diff = (base_layer.weight.data - original_weight).abs().max().item() |
| print(f"Max weight change: {weight_diff:.6e}") |
| |
| |
| print("\n" + "-"*60) |
| print("Test 3: Computing output AFTER merge") |
| print("-"*60) |
| with torch.no_grad(): |
| output_after = rotation_layer(x) |
| |
| print(f"Output shape: {output_after.shape}") |
| print(f"Output mean: {output_after.mean().item():.6f}") |
| print(f"Output std: {output_after.std().item():.6f}") |
| print(f"Output min: {output_after.min().item():.6f}") |
| print(f"Output max: {output_after.max().item():.6f}") |
| |
| |
| print("\n" + "-"*60) |
| print("Test 4: Comparing outputs") |
| print("-"*60) |
| |
| |
| abs_diff = (output_after - output_before).abs() |
| rel_diff = abs_diff / (output_before.abs() + 1e-8) |
| |
| max_abs_diff = abs_diff.max().item() |
| mean_abs_diff = abs_diff.mean().item() |
| max_rel_diff = rel_diff.max().item() |
| mean_rel_diff = rel_diff.mean().item() |
| |
| print(f"Max absolute difference: {max_abs_diff:.6e}") |
| print(f"Mean absolute difference: {mean_abs_diff:.6e}") |
| print(f"Max relative difference: {max_rel_diff:.6e}") |
| print(f"Mean relative difference: {mean_rel_diff:.6e}") |
| |
| |
| atol = 1e-4 |
| rtol = 1e-3 |
| |
| are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol) |
| |
| if are_close: |
| print(f"\nβ
PASS: Outputs are identical (within atol={atol}, rtol={rtol})") |
| else: |
| print(f"\nβ FAIL: Outputs differ significantly") |
| print(f" Expected: atol < {atol}, rtol < {rtol}") |
| print(f" Got: max_abs_diff = {max_abs_diff:.6e}, max_rel_diff = {max_rel_diff:.6e}") |
| |
| |
| print("\n" + "-"*60) |
| print("Test 5: Testing unmerge") |
| print("-"*60) |
| rotation_layer.unmerge() |
| print(f"β Adapter unmerged") |
| print(f"β Merged adapters: {rotation_layer.merged_adapters}") |
| |
| with torch.no_grad(): |
| output_unmerged = rotation_layer(x) |
| |
| unmerge_diff = (output_unmerged - output_before).abs().max().item() |
| print(f"Max difference after unmerge: {unmerge_diff:.6e}") |
| |
| unmerge_close = torch.allclose(output_before, output_unmerged, atol=atol, rtol=rtol) |
| if unmerge_close: |
| print(f"β
PASS: Unmerge restored original behavior") |
| else: |
| print(f"β FAIL: Unmerge did not restore original behavior") |
| |
| |
| weight_restored_diff = (base_layer.weight.data - original_weight).abs().max().item() |
| print(f"Max weight difference after unmerge: {weight_restored_diff:.6e}") |
| |
| weight_restored = torch.allclose(base_layer.weight.data, original_weight, atol=1e-5) |
| if weight_restored: |
| print(f"β
PASS: Original weights restored") |
| else: |
| print(f"β FAIL: Original weights not fully restored") |
| |
| print("\n" + "="*60) |
| print("Test Summary") |
| print("="*60) |
| return are_close and unmerge_close and weight_restored |
|
|
|
|
| def test_multiple_merges(): |
| """ |
| Test merging and unmerging multiple times. |
| """ |
| print("\n" + "="*60) |
| print("Testing Multiple Merge/Unmerge Cycles") |
| print("="*60) |
| |
| torch.manual_seed(42) |
| |
| in_features = 256 |
| out_features = 512 |
| r = 4 |
| num_rotations = 4 |
| |
| base_layer = nn.Linear(in_features, out_features, bias=True) |
| rotation_layer = Linear( |
| base_layer=base_layer, |
| adapter_name="default", |
| r=r, |
| T=1.0, |
| num_rotations=num_rotations |
| ) |
| |
| x = torch.randn(4, 8, in_features) |
| rotation_layer.eval() |
| |
| |
| with torch.no_grad(): |
| original_output = rotation_layer(x) |
| |
| |
| all_passed = True |
| for cycle in range(3): |
| print(f"\nCycle {cycle + 1}:") |
| |
| |
| rotation_layer.merge(safe_merge=True) |
| with torch.no_grad(): |
| merged_output = rotation_layer(x) |
| |
| merge_close = torch.allclose(original_output, merged_output, atol=1e-4, rtol=1e-3) |
| print(f" Merge: {'β
PASS' if merge_close else 'β FAIL'}") |
| |
| |
| rotation_layer.unmerge() |
| with torch.no_grad(): |
| unmerged_output = rotation_layer(x) |
| |
| unmerge_close = torch.allclose(original_output, unmerged_output, atol=1e-4, rtol=1e-3) |
| print(f" Unmerge: {'β
PASS' if unmerge_close else 'β FAIL'}") |
| |
| all_passed = all_passed and merge_close and unmerge_close |
| |
| return all_passed |
|
|
|
|
| def test_with_different_dtypes(): |
| """ |
| Test merging with different data types. |
| """ |
| print("\n" + "="*60) |
| print("Testing Different Data Types") |
| print("="*60) |
| |
| torch.manual_seed(42) |
| |
| dtypes = [torch.float32, torch.float16, torch.bfloat16] |
| all_passed = True |
| |
| for dtype in dtypes: |
| print(f"\nTesting with dtype: {dtype}") |
| |
| in_features = 256 |
| out_features = 512 |
| r = 4 |
| num_rotations = 4 |
| |
| base_layer = nn.Linear(in_features, out_features, bias=True) |
| base_layer = base_layer.to(dtype) |
| |
| rotation_layer = Linear( |
| base_layer=base_layer, |
| adapter_name="default", |
| r=r, |
| T=1.0, |
| num_rotations=num_rotations |
| ) |
| rotation_layer = rotation_layer.to(dtype) |
| |
| x = torch.randn(4, 8, in_features, dtype=dtype) |
| rotation_layer.eval() |
| |
| with torch.no_grad(): |
| output_before = rotation_layer(x) |
| rotation_layer.merge(safe_merge=True) |
| output_after = rotation_layer(x) |
| |
| |
| if dtype == torch.float32: |
| atol, rtol = 1e-5, 1e-4 |
| elif dtype == torch.float16: |
| atol, rtol = 1e-2, 1e-2 |
| else: |
| atol, rtol = 1e-2, 1e-2 |
| |
| are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol) |
| |
| if are_close: |
| print(f" β
PASS") |
| else: |
| max_diff = (output_after - output_before).abs().max().item() |
| print(f" β FAIL (max diff: {max_diff:.6e})") |
| |
| all_passed = all_passed and are_close |
| |
| return all_passed |
|
|
|
|
| if __name__ == "__main__": |
| print("\n" + "="*60) |
| print("ROTATION LAYER MERGE TEST SUITE") |
| print("="*60) |
| |
| results = {} |
| |
| |
| results["basic_merge"] = test_rotation_merge() |
| results["multiple_cycles"] = test_multiple_merges() |
| results["different_dtypes"] = test_with_different_dtypes() |
| |
| |
| print("\n" + "="*60) |
| print("FINAL SUMMARY") |
| print("="*60) |
| |
| for test_name, passed in results.items(): |
| status = "β
PASS" if passed else "β FAIL" |
| print(f"{test_name}: {status}") |
| |
| all_passed = all(results.values()) |
| print("\n" + "="*60) |
| if all_passed: |
| print("π ALL TESTS PASSED!") |
| else: |
| print("β οΈ SOME TESTS FAILED") |
| print("="*60) |