| | #ifdef TEST_ON_CUDA |
| | #include <mma.h> |
| |
|
| | #include <cuda_fp16.h> |
| | #include <cuda_fp8.h> |
| |
|
| | namespace wmma = nvcuda::wmma; |
| |
|
| | #define LIB_CALL(call) \ |
| | do { \ |
| | cudaError_t err = call; \ |
| | if (err != cudaSuccess) { \ |
| | abort(); \ |
| | } \ |
| | } while (0) |
| |
|
| | #define HOST_TYPE(x) cuda##x |
| |
|
| | #else |
| |
|
| | #ifndef HIP_HEADERS__ |
| | #include <hip/hip_runtime.h> |
| | #include <hip/hip_fp8.h> |
| | #include <hip/hip_fp16.h> |
| | #include <rocwmma/rocwmma.hpp> |
| | #define HIP_HEADERS__ |
| | #endif |
| |
|
| | namespace wmma = rocwmma; |
| |
|
| | #define LIB_CALL(call) \ |
| | do { \ |
| | hipError_t err = call; \ |
| | if (err != hipSuccess) { \ |
| | abort(); \ |
| | } \ |
| | } while (0) |
| |
|
| | #define HOST_TYPE(x) hip##x |
| |
|
| | #endif |
| |
|
| |
|