Spaces:
Build error
Build error
| import os | |
| import random | |
| import unittest | |
| import numpy as np | |
| import torch | |
| from torch.nn.functional import scaled_dot_product_attention | |
| from finetrainers.models.attention_dispatch import ( | |
| AttentionProvider, | |
| _AttentionProviderRegistry, | |
| _set_context_parallel_options, | |
| attention_dispatch, | |
| attention_provider, | |
| flash_attn_flash_attention, | |
| native_cudnn_attention, | |
| native_efficient_attention, | |
| native_flash_attention, | |
| ) | |
| from finetrainers.parallel.ptd import _EquipartitionSharder | |
| def set_seed(seed): | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| def get_world_size(): | |
| if torch.distributed.is_initialized(): | |
| return torch.distributed.get_world_size() | |
| return int(os.environ.get("WORLD_SIZE", 1)) | |
| class AttentionDispatchTest(unittest.TestCase): | |
| def setUpClass(cls): | |
| set_seed(0) | |
| def test_forward(self): | |
| if not torch.cuda.is_available(): | |
| self.skipTest("CUDA is not available") | |
| cuda_capability = torch.cuda.get_device_capability() | |
| query, key, value = self._create_dummy_inputs() | |
| all_providers = [ | |
| (AttentionProvider._NATIVE_MATH, 0), | |
| (AttentionProvider.NATIVE, 5e-3), | |
| (AttentionProvider.FLASH, 5e-3), | |
| (AttentionProvider.FLASH_VARLEN, 5e-3), | |
| (AttentionProvider.FLEX, 2e-2), | |
| (AttentionProvider._NATIVE_CUDNN, 5e-3), | |
| (AttentionProvider._NATIVE_EFFICIENT, 5e-3), | |
| (AttentionProvider._NATIVE_FLASH, 5e-3), | |
| (AttentionProvider.SAGE, 1e-1), | |
| (AttentionProvider.SAGE_VARLEN, 2e-0), | |
| (AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA, 2e-0), # TODO: look into the high difference threshold | |
| (AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON, 2e-0), | |
| (AttentionProvider.XFORMERS, 5e-3), | |
| ] | |
| if cuda_capability >= (8, 9): | |
| all_providers.append((AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA, 2e-0)) | |
| if cuda_capability >= (9, 0): | |
| all_providers.append((AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA_SM90, 2e-0)) | |
| ref_output = None | |
| for i, (provider, threshold) in enumerate(all_providers): | |
| try: | |
| output = self._check_forward_pass(provider, query, key, value) | |
| if i == 0: | |
| ref_output = output.detach().clone() | |
| else: | |
| self.assertTrue( | |
| torch.allclose(output, ref_output, atol=threshold), f"Forward pass mismatch for {provider}" | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Forward pass test failed for {provider} with error: {e}") | |
| def test_backward(self): | |
| if not torch.cuda.is_available(): | |
| self.skipTest("CUDA is not available") | |
| query, key, value = self._create_dummy_inputs() | |
| selected_providers = [ | |
| AttentionProvider.FLASH, | |
| AttentionProvider.FLASH_VARLEN, | |
| AttentionProvider.FLEX, | |
| AttentionProvider.NATIVE, | |
| AttentionProvider.XFORMERS, | |
| ] | |
| ref_output = None | |
| for i, provider in enumerate(selected_providers): | |
| try: | |
| output = self._check_backward_pass(provider, query, key, value) | |
| if i == 0: | |
| ref_output = output.detach().clone() | |
| else: | |
| if provider == AttentionProvider.FLEX: | |
| threshold = 1e-2 | |
| else: | |
| threshold = 1e-3 | |
| self.assertTrue( | |
| torch.allclose(output, ref_output, atol=threshold), f"Backward pass mismatch for {provider}" | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Backward pass test failed for {provider} with error: {e}") | |
| def _create_dummy_inputs( | |
| self, batch_size=2, num_heads=8, seq_len=256, head_dim=64, dtype=torch.bfloat16, device="cuda" | |
| ): | |
| torch.manual_seed(0) | |
| query = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) | |
| key = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) | |
| value = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) | |
| return query, key, value | |
| def _check_forward_pass(self, provider: AttentionProvider, query, key, value): | |
| kwargs = {} | |
| if provider == AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA: | |
| kwargs["pv_accum_dtype"] = "fp32" | |
| with attention_provider(provider): | |
| output = attention_dispatch(query, key, value, attention_kwargs=kwargs) | |
| self.assertIsNotNone(output) | |
| self.assertEqual(output.shape, query.shape) | |
| return output | |
| def _check_backward_pass(self, provider: AttentionProvider, query, key, value): | |
| query.requires_grad_(True) | |
| key.requires_grad_(True) | |
| value.requires_grad_(True) | |
| with attention_provider(provider): | |
| output = attention_dispatch(query, key, value) | |
| loss = output.mean() | |
| loss.backward() | |
| self.assertTrue(query.grad is not None) | |
| self.assertTrue(key.grad is not None) | |
| self.assertTrue(value.grad is not None) | |
| query.grad.zero_() | |
| key.grad.zero_() | |
| value.grad.zero_() | |
| return output | |
| class RingAttentionTest(unittest.TestCase): | |
| def setUpClass(cls): | |
| torch.distributed.init_process_group(backend="nccl") | |
| rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size() | |
| cls.rank = rank | |
| cls.world_size = world_size | |
| torch.cuda.set_device(rank) | |
| cls.mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,)) | |
| set_seed(0) | |
| cls.batch_size = 2 | |
| cls.num_heads = 8 | |
| cls.seq_len = 256 | |
| cls.head_dim = 64 | |
| cls.dtype = torch.bfloat16 | |
| cls.device = "cuda" | |
| _AttentionProviderRegistry._set_context_parallel( | |
| mesh=cls.mesh, convert_to_fp32=True, rotate_method="allgather" | |
| ) | |
| _set_context_parallel_options(is_causal=False) | |
| cls.full_query = torch.randn( | |
| cls.batch_size, | |
| cls.num_heads, | |
| cls.seq_len * cls.world_size, | |
| cls.head_dim, | |
| dtype=cls.dtype, | |
| device=cls.device, | |
| requires_grad=True, | |
| ) | |
| cls.full_key = torch.randn( | |
| cls.batch_size, | |
| cls.num_heads, | |
| cls.seq_len * cls.world_size, | |
| cls.head_dim, | |
| dtype=cls.dtype, | |
| device=cls.device, | |
| requires_grad=True, | |
| ) | |
| cls.full_value = torch.randn( | |
| cls.batch_size, | |
| cls.num_heads, | |
| cls.seq_len * cls.world_size, | |
| cls.head_dim, | |
| dtype=cls.dtype, | |
| device=cls.device, | |
| requires_grad=True, | |
| ) | |
| # Ensure all ranks have the same data | |
| with torch.no_grad(): | |
| torch.distributed.broadcast(cls.full_query, src=0) | |
| torch.distributed.broadcast(cls.full_key, src=0) | |
| torch.distributed.broadcast(cls.full_value, src=0) | |
| with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): | |
| reference_output = scaled_dot_product_attention(cls.full_query, cls.full_key, cls.full_value) | |
| cls.reference_output = reference_output.detach().clone() | |
| reference_output.sum().backward() | |
| cls.query, cls.key, cls.value = ( | |
| _EquipartitionSharder.shard(x, dim=2, mesh=cls.mesh).detach().clone() | |
| for x in (cls.full_query, cls.full_key, cls.full_value) | |
| ) | |
| def tearDownClass(cls): | |
| torch.distributed.destroy_process_group() | |
| def _test_forward_native_cudnn_attention(self, atol: float = 1e-3): | |
| output = native_cudnn_attention(self.query, self.key, self.value) | |
| output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) | |
| self.assertEqual(output.shape, self.reference_output.shape) | |
| self.assertTrue(torch.allclose(output, self.reference_output, atol=atol)) | |
| def _test_forward_native_efficient_attention(self, atol: float = 1e-3): | |
| output = native_efficient_attention(self.query, self.key, self.value) | |
| output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) | |
| self.assertEqual(output.shape, self.reference_output.shape) | |
| self.assertTrue(torch.allclose(output, self.reference_output, atol=atol)) | |
| def _test_forward_native_flash_attention(self, atol: float = 1e-3): | |
| output = native_flash_attention(self.query, self.key, self.value) | |
| output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) | |
| self.assertEqual(output.shape, self.reference_output.shape) | |
| self.assertTrue(torch.allclose(output, self.reference_output, atol=atol)) | |
| def _test_forward_flash_attn_flash_attention(self, atol: float = 1e-3): | |
| output = flash_attn_flash_attention(self.query, self.key, self.value) | |
| output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) | |
| self.assertEqual(output.shape, self.reference_output.shape) | |
| self.assertTrue(torch.allclose(output, self.reference_output, atol=atol)) | |
| def _test_backward_native_cudnn_attention(self, atol: float = 1e-3): | |
| query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value)) | |
| query.requires_grad = True | |
| key.requires_grad = True | |
| value.requires_grad = True | |
| output = native_cudnn_attention(query, key, value) | |
| output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) | |
| output.sum().backward() | |
| with torch.no_grad(): | |
| q_g, k_g, v_g = ( | |
| _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh) | |
| for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad) | |
| ) | |
| self.assertTrue(torch.allclose(query.grad, q_g, atol=atol)) | |
| self.assertTrue(torch.allclose(key.grad, k_g, atol=atol)) | |
| self.assertTrue(torch.allclose(value.grad, v_g, atol=atol)) | |
| def _test_backward_native_efficient_attention(self, atol: float = 1e-3): | |
| query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value)) | |
| query.requires_grad = True | |
| key.requires_grad = True | |
| value.requires_grad = True | |
| output = native_efficient_attention(query, key, value) | |
| output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) | |
| output.sum().backward() | |
| with torch.no_grad(): | |
| q_g, k_g, v_g = ( | |
| _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh) | |
| for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad) | |
| ) | |
| self.assertTrue(torch.allclose(query.grad, q_g, atol=atol)) | |
| self.assertTrue(torch.allclose(key.grad, k_g, atol=atol)) | |
| self.assertTrue(torch.allclose(value.grad, v_g, atol=atol)) | |
| def _test_backward_native_flash_attention(self, atol: float = 1e-3): | |
| query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value)) | |
| query.requires_grad = True | |
| key.requires_grad = True | |
| value.requires_grad = True | |
| output = native_flash_attention(query, key, value) | |
| output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) | |
| output.sum().backward() | |
| with torch.no_grad(): | |
| q_g, k_g, v_g = ( | |
| _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh) | |
| for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad) | |
| ) | |
| self.assertTrue(torch.allclose(query.grad, q_g, atol=atol)) | |
| self.assertTrue(torch.allclose(key.grad, k_g, atol=atol)) | |
| self.assertTrue(torch.allclose(value.grad, v_g, atol=atol)) | |
| def _test_backward_flash_attn_flash_attention(self, atol: float = 1e-3): | |
| query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value)) | |
| query.requires_grad = True | |
| key.requires_grad = True | |
| value.requires_grad = True | |
| output = flash_attn_flash_attention(query, key, value) | |
| output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) | |
| output.sum().backward() | |
| with torch.no_grad(): | |
| q_g, k_g, v_g = ( | |
| _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh) | |
| for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad) | |
| ) | |
| self.assertTrue(torch.allclose(query.grad, q_g, atol=atol)) | |
| self.assertTrue(torch.allclose(key.grad, k_g, atol=atol)) | |
| self.assertTrue(torch.allclose(value.grad, v_g, atol=atol)) | |
| class RingAttentionCPTesterMixin: | |
| def test_forward_native_cudnn_attention(self): | |
| self._test_forward_native_cudnn_attention(atol=1e-2) | |
| def test_forward_native_efficient_attention(self): | |
| self._test_forward_native_efficient_attention(atol=1e-2) | |
| def test_forward_native_flash_attention(self): | |
| self._test_forward_native_flash_attention(atol=1e-2) | |
| def test_forward_flash_attn_flash_attention(self): | |
| self._test_forward_flash_attn_flash_attention(atol=1e-2) | |
| def test_backward_native_cudnn_attention(self): | |
| atol = 1e-2 * self.world_size # TODO: make bounds more strict | |
| self._test_backward_native_cudnn_attention(atol=atol) | |
| def test_backward_native_efficient_attention(self): | |
| atol = 1e-2 * self.world_size # TODO: make bounds more strict | |
| self._test_backward_native_efficient_attention(atol=atol) | |
| def test_backward_native_flash_attention(self): | |
| atol = 1e-2 * self.world_size # TODO: make bounds more strict | |
| self._test_backward_native_flash_attention(atol=atol) | |
| def test_backward_flash_attn_flash_attention(self): | |
| # Seems to require much higher bound for some reason | |
| atol = 1.5e-1 * self.world_size # TODO: make bounds more strict | |
| self._test_backward_flash_attn_flash_attention(atol=atol) | |
| class RingAttentionCP2Test(RingAttentionTest, RingAttentionCPTesterMixin): | |
| pass | |
| class RingAttentionCP4Test(RingAttentionTest, RingAttentionCPTesterMixin): | |
| pass | |