| import torch |
| import megablocks |
|
|
| import unittest |
| from absl.testing import parameterized |
|
|
| |
| |
|
|
|
|
| def allclose(x, y, pct=2.0): |
| mask = torch.isclose(x, y, rtol=1e-5) |
| pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 |
| if pct_diff > pct: |
| print(x[torch.logical_not(mask)], y[torch.logical_not(mask)]) |
| print("{:.2f}% of values not close.".format(pct_diff)) |
| return False |
| return True |
|
|
|
|
| def add_flags(x): |
| out = [] |
| for y in x: |
| for trans_b in (False, True): |
| out.append(y + (trans_b, False)) |
|
|
| |
| |
| |
| return out |
|
|
|
|
| _TEST_PROBLEMS = add_flags(( |
| (1, 128, 128, 128), |
| (8, 128, 128, 128), |
| (16, 128, 128, 128), |
| (1, 128, 256, 512), |
| (8, 128, 256, 512), |
| (16, 128, 256, 512), |
| )) |
|
|
|
|
| def randn(bs, x, y): |
| out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x) |
| return out.cuda().to(torch.bfloat16) |
|
|
|
|
| def gmm(a, b, batch_sizes, trans_b=False): |
| batch_sizes = batch_sizes.cpu().numpy() |
|
|
| out = [] |
| start = 0 |
| for i, size in enumerate(batch_sizes): |
| rhs = b[i, :, :].t() if trans_b else b[i, :, :] |
| out.append(a[start:start + size, :] @ rhs) |
| start += size |
| return torch.cat(out) |
|
|
|
|
| @parameterized.parameters(*_TEST_PROBLEMS) |
| class OpsTest(parameterized.TestCase): |
|
|
| def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b, batch_sizes_on_device): |
| torch.manual_seed(0) |
| a = randn(z, m, k).view(-1, k) |
| b = randn(z, n, k) if trans_b else randn(z, k, n) |
| batch_sizes = torch.tensor([m] * z) |
| if batch_sizes_on_device: |
| batch_sizes = batch_sizes.cuda() |
|
|
| a.requires_grad_(True) |
| b.requires_grad_(True) |
| a_ref = a.detach().clone().requires_grad_(True) |
| b_ref = b.detach().clone().requires_grad_(True) |
|
|
| |
| out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b) |
| |
| expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) |
| self.assertTrue(allclose(out, expected_out)) |
|
|
| |
| out.sum().backward() |
| expected_out.sum().backward() |
| self.assertTrue(allclose(a.grad, a_ref.grad)) |
| self.assertTrue(allclose(b.grad, b_ref.grad)) |
|
|
| def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b, batch_sizes_on_device): |
| torch.manual_seed(0) |
| a = randn(z, m, k).view(-1, k) |
| b = randn(z, n, k) if trans_b else randn(z, k, n) |
|
|
| dist = torch.rand(z, ) |
| dist /= dist.sum() |
| batch_sizes = (dist * m).to(torch.long) |
| error = m * z - batch_sizes.sum() |
| batch_sizes[-1] += error |
| assert batch_sizes.sum() == (m * z) |
| if batch_sizes_on_device: |
| batch_sizes = batch_sizes.cuda() |
|
|
| a.requires_grad_(True) |
| b.requires_grad_(True) |
| a_ref = a.detach().clone().requires_grad_(True) |
| b_ref = b.detach().clone().requires_grad_(True) |
|
|
| out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b) |
| expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) |
| self.assertTrue(allclose(out, expected_out)) |
|
|
| |
| out.sum().backward() |
| expected_out.sum().backward() |
| self.assertTrue(allclose(a.grad, a_ref.grad)) |
|
|
| |
| |
|
|
|
|
| |
| @parameterized.parameters(False, False) |
| class EdgeCasesTest(unittest.TestCase): |
|
|
| def testGroupedGemm_ZeroSize(self, batch_sizes_on_device): |
| torch.manual_seed(0) |
| m = 16384 |
| k = 4096 |
| n = 14336 |
| num_experts = 8 |
|
|
| a = randn(num_experts, m // num_experts, k).view(-1, k) |
| b = randn(num_experts, k, n) |
| batch_sizes = torch.tensor([219, 2246, 5, 8103, 1, 1117, 4693, 0]).to(torch.long) |
| if batch_sizes_on_device: |
| batch_sizes = batch_sizes.cuda() |
|
|
| a.requires_grad_(True) |
| b.requires_grad_(True) |
| a_ref = a.detach().clone().requires_grad_(True) |
| b_ref = b.detach().clone().requires_grad_(True) |
|
|
| out = megablocks.gg_ops.gmm(a, b, batch_sizes) |
| expected_out = gmm(a_ref, b_ref, batch_sizes) |
| self.assertTrue(allclose(out, expected_out)) |
|
|
| |
| out.sum().backward() |
| expected_out.sum().backward() |
| self.assertTrue(allclose(a.grad, a_ref.grad)) |
| self.assertTrue(allclose(b.grad, b_ref.grad)) |
|
|
| def testGroupedGemm_ZeroK(self, batch_sizes_on_device): |
| sz = 128 |
| total_tokens = 192 |
|
|
| a = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16) |
| b = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16) |
| c = torch.ones(4, sz, sz).cuda().to(torch.bfloat16) |
| batch_sizes = torch.tensor([0, 128, 0, 64]).to(torch.long) |
| if batch_sizes_on_device: |
| batch_sizes = batch_sizes.cuda() |
|
|
| megablocks.gg_backend.gmm(a, b, batch_sizes, trans_a=True, c=c) |
| self.assertTrue((c[0] == 0).all()) |
| self.assertTrue((c[1] == 128).all()) |
| self.assertTrue((c[2] == 0).all()) |
| self.assertTrue((c[3] == 64).all()) |
|
|
|
|
| if __name__ == '__main__': |
| unittest.main() |
|
|