Thanks for using Compiler Explorer
Sponsors
Jakt
C++
Ada
Analysis
Android Java
Android Kotlin
Assembly
C
C3
Carbon
C++ (Circle)
CIRCT
Clean
CMake
CMakeScript
COBOL
C++ for OpenCL
MLIR
Cppx
Cppx-Blue
Cppx-Gold
Cpp2-cppfront
Crystal
C#
CUDA C++
D
Dart
Elixir
Erlang
Fortran
F#
GLSL
Go
Haskell
HLSL
Hook
Hylo
IL
ispc
Java
Julia
Kotlin
LLVM IR
LLVM MIR
Modula-2
Nim
Objective-C
Objective-C++
OCaml
OpenCL C
Pascal
Pony
Python
Racket
Ruby
Rust
Snowball
Scala
Solidity
Spice
SPIR-V
Swift
LLVM TableGen
Toit
TypeScript Native
V
Vala
Visual Basic
WASM
Zig
Javascript
GIMPLE
Ygen
cuda source #1
Output
Compile to binary object
Link to binary
Execute the code
Intel asm syntax
Demangle identifiers
Verbose demangling
Filters
Unused labels
Library functions
Directives
Comments
Horizontal whitespace
Debug intrinsics
Compiler
10.0.0 sm_75 CUDA-10.2
10.0.1 sm_75 CUDA-10.2
11.0.0 sm_75 CUDA-10.2
16.0.0 sm_90 CUDA-11.8
NVCC 10.0.130
NVCC 10.1.105
NVCC 10.1.168
NVCC 10.1.243
NVCC 10.2.89
NVCC 11.0.2
NVCC 11.0.3
NVCC 11.1.0
NVCC 11.1.1
NVCC 11.2.0
NVCC 11.2.1
NVCC 11.2.2
NVCC 11.3.0
NVCC 11.3.1
NVCC 11.4.0
NVCC 11.4.1
NVCC 11.4.2
NVCC 11.4.3
NVCC 11.4.4
NVCC 11.5.0
NVCC 11.5.1
NVCC 11.5.2
NVCC 11.6.0
NVCC 11.6.1
NVCC 11.6.2
NVCC 11.7.0
NVCC 11.7.1
NVCC 11.8.0
NVCC 12.0.0
NVCC 12.0.1
NVCC 12.1.0
NVCC 12.2.1
NVCC 12.3.1
NVCC 12.4.1
NVCC 12.5.1
NVCC 12.6.1
NVCC 9.1.85
NVCC 9.2.88
NVRTC 11.0.2
NVRTC 11.0.3
NVRTC 11.1.0
NVRTC 11.1.1
NVRTC 11.2.0
NVRTC 11.2.1
NVRTC 11.2.2
NVRTC 11.3.0
NVRTC 11.3.1
NVRTC 11.4.0
NVRTC 11.4.1
NVRTC 11.5.0
NVRTC 11.5.1
NVRTC 11.5.2
NVRTC 11.6.0
NVRTC 11.6.1
NVRTC 11.6.2
NVRTC 11.7.0
NVRTC 11.7.1
NVRTC 11.8.0
NVRTC 12.0.0
NVRTC 12.0.1
NVRTC 12.1.0
clang 7.0.0 sm_70 CUDA-9.1
clang 8.0.0 sm_75 CUDA-10.0
clang 9.0.0 sm_75 CUDA-10.1
clang rocm-4.5.2
clang rocm-5.0.2
clang rocm-5.1.3
clang rocm-5.2.3
clang rocm-5.3.2
clang rocm-5.7.0
clang rocm-6.0.2
clang rocm-6.1.2
clang staging rocm-6.1.2
clang trunk rocm-6.1.2
trunk sm_90 CUDA-12.6.1
Options
Source code
#include <cuda.h> #include <mma.h> template<unsigned int TILE_ROWS, unsigned int TILE_COLS, unsigned int NUM_THREADS> __device__ __forceinline__ void tileMemcpyUnrolledVectorized( half* src, half* dst, const unsigned int src_stride ) { // reinterpret input/output as float4 float4* src_float4 = reinterpret_cast<float4*>(src); float4* dst_float4 = reinterpret_cast<float4*>(dst); const unsigned int src_stride_vectorized = src_stride / 8; // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); // flatten out 2d grid of threads into in order of increasing threadIdx.x const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { dst_float4[thread_row * TILE_COLS_VECTORIZED + thread_col] = src_float4[thread_row * src_stride_vectorized + thread_col]; thread_row += ROW_STEP; } } __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { uint32_t address; asm("{\n\t" " .reg .u64 u64addr;\n\t" " cvta.to.shared.u64 u64addr, %1;\n\t" " cvt.u32.u64 %0, u64addr;\n\t" "}" : "=r"(address) : "l"(pointer)); return address; } __device__ __forceinline__ void stmatrix_m16n8( half* dst, half (®)[4], unsigned int dst_stride_bytes ) { const unsigned int laneIdx = threadIdx.x % 32; uint32_t (®_) [2] = reinterpret_cast<uint32_t(&)[2]>(reg); uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(dst); dst_stride_bytes /= sizeof(uint32_t); unsigned int fragment_row = laneIdx / 4; const unsigned int fragment_col = laneIdx % 4; // 4 adjacent threads storing 4 bytes each == 16 byte transactions dst_ptr[fragment_row * dst_stride_bytes + fragment_col] = reg_[0]; fragment_row += 8; dst_ptr[fragment_row * dst_stride_bytes + fragment_col] = reg_[1]; } __device__ __forceinline__ void ldmatrix_m16n8_gmem( half* src, half (®)[4], unsigned int src_stride_bytes ) { const unsigned int laneIdx = threadIdx.x % 32; uint32_t (®_) [2] = reinterpret_cast<uint32_t(&)[2]>(reg); uint32_t* src_ptr = reinterpret_cast<uint32_t*>(src); src_stride_bytes /= sizeof(uint32_t); unsigned int fragment_row = laneIdx / 4; const unsigned int fragment_col = laneIdx % 4; // 4 adjacent threads storing 4 bytes each == 16 byte transactions reg_[0] = src_ptr[fragment_row * src_stride_bytes + fragment_col]; fragment_row += 8; reg_[1] = src_ptr[fragment_row * src_stride_bytes + fragment_col]; } // template <unsigned int BM_dim, // unsigned int BN_dim, // unsigned int BK_dim, // unsigned int WM_dim, // unsigned int WN_dim, // unsigned int WK_dim, // unsigned int NUM_THREADS> __global__ void kernel_2(half* A, half* B, half* C, half* D, const float alpha, const float beta, const unsigned int M, const unsigned int N, unsigned int K) { constexpr unsigned int BM_dim = 256; constexpr unsigned int BN_dim = 128; constexpr unsigned int BK_dim = 64; constexpr unsigned int WM_dim = 64; constexpr unsigned int WN_dim = 64; constexpr unsigned int WK_dim = 64; constexpr unsigned int NUM_THREADS = 256; constexpr unsigned int MMA_M_dim = 16; constexpr unsigned int MMA_N_dim = 8; constexpr unsigned int MMA_K_dim = 8; // for convenience/readability in index calculations const unsigned int A_stride = K; const unsigned int B_stride = N; const unsigned int CD_stride = N; // loop bounds, constexpr where possible allows for loop unrolling constexpr unsigned int mma_tiles_per_warp_k = WK_dim / MMA_K_dim; constexpr unsigned int mma_tiles_per_warp_m = WM_dim / MMA_M_dim; constexpr unsigned int mma_tiles_per_warp_n = WN_dim / MMA_N_dim; constexpr unsigned int warp_tiles_per_block_k = BK_dim / WK_dim; const unsigned int num_block_tiles_k = K / BK_dim; // calculate block/warp indices const unsigned int block_m = blockIdx.y; const unsigned int block_n = blockIdx.x; const unsigned int warp_m = threadIdx.y; const unsigned int warp_n = threadIdx.x / 32; extern __shared__ half shmem[]; half* A_block_smem = shmem; half* B_block_smem = &shmem[BM_dim * BK_dim]; // declare register storage // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; // convenience cast to half for accumulator registers half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]>(acc_register); uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; // into start at 0 for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) { acc_register_[mma_m][mma_n][0] = 0; acc_register_[mma_m][mma_n][1] = 0; acc_register_[mma_m][mma_n][2] = 0; acc_register_[mma_m][mma_n][3] = 0; } } for (unsigned int block_k = 0; block_k < num_block_tiles_k; block_k++) { half* A_block_gmem = A + (block_m * BM_dim * A_stride) + (block_k * BK_dim); half* B_block_gmem = B + (block_k * BK_dim * B_stride) + (block_n * BN_dim); tileMemcpyUnrolledVectorized<BM_dim, BK_dim, NUM_THREADS>(A_block_gmem, A_block_smem, K); tileMemcpyUnrolledVectorized<BK_dim, BN_dim, NUM_THREADS>(B_block_gmem, B_block_smem, N); __syncthreads(); for (unsigned int warp_k = 0; warp_k < warp_tiles_per_block_k; warp_k++) { // preload tiles of a into registers half* A_warp_tile = A_block_smem + (warp_m * WM_dim * BK_dim) + (warp_k * WK_dim); half* B_warp_tile = B_block_smem + (warp_k * WK_dim * BN_dim) + (warp_n * WN_dim); uint32_t A_warp_tile_byte_offset = cvta_to_shared_u32(A_warp_tile); uint32_t B_warp_tile_byte_offset = cvta_to_shared_u32(B_warp_tile); for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++) { // byte offset to the top left of the mma tile const unsigned int mma_tile_byte_offset = ((mma_m * MMA_M_dim * BK_dim) + (mma_k * MMA_K_dim)) * sizeof(half); // byte offset to the start of this thread's slice of the mma tile const unsigned int thread_byte_offset = (threadIdx.x % MMA_M_dim) * BK_dim * sizeof(half); // calculate offset in bytes WRT to the start of our shared memory allocation const unsigned int thread_offset_bytes = A_warp_tile_byte_offset + mma_tile_byte_offset + thread_byte_offset; asm volatile ( "ldmatrix.sync.aligned.m8n8.x2.shared.b16 " "{%0, %1}, [%2];" : "=r"(A_register[mma_m][mma_k][0]), "=r"(A_register[mma_m][mma_k][1]) : "r"(thread_offset_bytes) ); } } // preload tiles of b into registers for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++) { for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) { const unsigned int mma_tile_byte_offset = ((mma_k * MMA_K_dim * BN_dim) + (mma_n * MMA_N_dim)) * sizeof(half); const unsigned int thread_byte_offset = (threadIdx.x % MMA_K_dim) * BN_dim * sizeof(half); const unsigned int thread_offset_bytes = B_warp_tile_byte_offset + mma_tile_byte_offset + thread_byte_offset; asm volatile ( "ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 " "{%0}, [%1];" : "=r"(B_register[mma_k][mma_n]) : "r"(thread_offset_bytes) ); } } // outer product between mma tiles for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++) { for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) { for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { asm volatile ( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}, " "{%2, %3}, " "{%4}, " "{%5, %6};" : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), "r"(B_register[mma_k][mma_n]) "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) ); } } } } __syncthreads(); } ////////////// // epilogue // ////////////// half alpha_ = (half)alpha; half beta_ = (half)beta; half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]; // calculate pointers for this warps C and D tiles half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) { half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half)); // scale C by beta acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_; acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_; acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_; acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_; } } for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) { half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half)); } } }
Become a Patron
Sponsor on GitHub
Donate via PayPal
Source on GitHub
Mailing list
Installed libraries
Wiki
Report an issue
How it works
Contact the author
CE on Mastodon
About the author
Statistics
Changelog
Version tree