Thanks for using Compiler Explorer
Sponsors
Jakt
C++
Ada
Algol68
Analysis
Android Java
Android Kotlin
Assembly
C
C3
Carbon
C with Coccinelle
C++ with Coccinelle
C++ (Circle)
CIRCT
Clean
Clojure
CMake
CMakeScript
COBOL
C++ for OpenCL
MLIR
Cppx
Cppx-Blue
Cppx-Gold
Cpp2-cppfront
Crystal
C#
CUDA C++
D
Dart
Elixir
Erlang
Fortran
F#
GLSL
Go
Haskell
HLSL
Helion
Hook
Hylo
IL
ispc
Java
Julia
Kotlin
LLVM IR
LLVM MIR
Modula-2
Mojo
Nim
Numba
Nix
Objective-C
Objective-C++
OCaml
Odin
OpenCL C
Pascal
Pony
PTX
Python
Racket
Raku
Ruby
Rust
Sail
Snowball
Scala
Slang
Solidity
Spice
SPIR-V
Swift
LLVM TableGen
Toit
Triton
TypeScript Native
V
Vala
Visual Basic
Vyper
WASM
Yul (Solidity IR)
Zig
Javascript
GIMPLE
Ygen
sway
cuda source #1
Output
Compile to binary object
Link to binary
Execute the code
Intel asm syntax
Demangle identifiers
Verbose demangling
Filters
Unused labels
Library functions
Directives
Comments
Horizontal whitespace
Debug intrinsics
Compiler
10.0.0 sm_75 CUDA-10.2
10.0.1 sm_75 CUDA-10.2
11.0.0 sm_75 CUDA-10.2
16.0.0 sm_90 CUDA-11.8
17.0.1(libc++) sm_90 CUDA-12.1
18.1.0(libc++) sm_90 CUDA-12.3.1
19.1.0 sm_90 CUDA-12.5.1
20.1.0 sm_90 CUDA-12.5.1
20.1.0 sm_90 CUDA-12.6.1
20.1.0 sm_90 CUDA-12.6.2
NVCC 10.0.130
NVCC 10.1.105
NVCC 10.1.168
NVCC 10.1.243
NVCC 10.2.89
NVCC 11.0.2
NVCC 11.0.3
NVCC 11.1.0
NVCC 11.1.1
NVCC 11.2.0
NVCC 11.2.1
NVCC 11.2.2
NVCC 11.3.0
NVCC 11.3.1
NVCC 11.4.0
NVCC 11.4.1
NVCC 11.4.2
NVCC 11.4.3
NVCC 11.4.4
NVCC 11.5.0
NVCC 11.5.1
NVCC 11.5.2
NVCC 11.6.0
NVCC 11.6.1
NVCC 11.6.2
NVCC 11.7.0
NVCC 11.7.1
NVCC 11.8.0
NVCC 12.0.0
NVCC 12.0.1
NVCC 12.1.0
NVCC 12.2.1
NVCC 12.3.1
NVCC 12.4.1
NVCC 12.5.1
NVCC 12.6.1
NVCC 12.6.2
NVCC 12.8.1
NVCC 12.9.0
NVCC 12.9.1
NVCC 13.0.0
NVCC 13.0.1
NVCC 13.0.2
NVCC 13.1.0
NVCC 9.1.85
NVCC 9.2.88
NVRTC 11.0.2
NVRTC 11.0.3
NVRTC 11.1.0
NVRTC 11.1.1
NVRTC 11.2.0
NVRTC 11.2.1
NVRTC 11.2.2
NVRTC 11.3.0
NVRTC 11.3.1
NVRTC 11.4.0
NVRTC 11.4.1
NVRTC 11.5.0
NVRTC 11.5.1
NVRTC 11.5.2
NVRTC 11.6.0
NVRTC 11.6.1
NVRTC 11.6.2
NVRTC 11.7.0
NVRTC 11.7.1
NVRTC 11.8.0
NVRTC 12.0.0
NVRTC 12.0.1
NVRTC 12.1.0
NVRTC 12.2.1
NVRTC 12.3.1
NVRTC 12.4.1
NVRTC 12.5.1
NVRTC 12.6.1
NVRTC 12.6.2
NVRTC 12.8.1
NVRTC 12.9.0
NVRTC 12.9.1
NVRTC 13.0.0
NVRTC 13.0.1
NVRTC 13.0.2
NVRTC 13.1.0
clang 7.0.0 sm_70 CUDA-9.1
clang 8.0.0 sm_75 CUDA-10.0
clang 9.0.0 sm_75 CUDA-10.1
clang rocm-4.5.2
clang rocm-5.0.2
clang rocm-5.1.3
clang rocm-5.2.3
clang rocm-5.3.2
clang rocm-5.7.0
clang rocm-6.0.2
clang rocm-6.1.2
clang rocm-6.2.4
clang rocm-6.3.3
clang rocm-6.4.0
clang rocm-7.0.1
clang staging rocm-6.1.2
clang staging rocm-6.2.4
clang staging rocm-6.3.3
clang staging rocm-6.4.0
clang staging rocm-7.0.1
clang trunk rocm-6.1.2
clang trunk rocm-6.2.4
clang trunk rocm-6.3.3
clang trunk rocm-6.4.0
clang trunk rocm-7.0.1
trunk sm_100a CUDA-12.8.1
Options
Source code
// NVFP4 Batched GEMV Kernel - PTX-Optimized for sm_100a (Blackwell B200) // Target: <25µs geometric mean #include <cuda_runtime.h> #include <cuda_fp16.h> #include <stdint.h> //============================================================================== // FP4 E2M1 Decode (2-bit exponent, 1-bit mantissa) //============================================================================== // Lookup table in constant memory for FP4 decode // Index bits: [sign(1) | exp(2) | mant(1)] // Values: exact per PTX_KERNEL_SPEC.md section 3.1 __constant__ float FP4_LUT[16] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, // Positive (0x0-0x7) -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f // Negative (0x8-0xF) }; /** * Decode FP4 E2M1 value from packed byte * @param byte: 8-bit value containing 2 FP4 values (hi/lo nibbles) * @param offset: 0 for first FP4 element, 1 for second FP4 element * @return: FP32 decoded value * * NOTE v404: SWAPPED nibble extraction to test if high nibble comes first * offset=0 → high nibble (bits 4-7) = first element * offset=1 → low nibble (bits 0-3) = second element */ __device__ __forceinline__ float decode_fp4(uint8_t byte, int offset) { return FP4_LUT[(byte >> ((1 - offset) * 4)) & 0xF]; } //============================================================================== // FP8 E4M3fnuz Decode (4-bit exponent, 3-bit mantissa, finite, no neg zero) //============================================================================== /** * Decode FP8 E4M3fnuz value to FP32 * CRITICAL: Uses integer bit manipulation, NOT __powf() for accuracy * Format: [sign(1) | exp(4) | mant(3)] * Reference: PTX_KERNEL_SPEC.md section 3.2 */ __device__ __forceinline__ float decode_fp8_e4m3(uint8_t bits) { uint8_t sign = (bits >> 7) & 0x1; uint8_t exp = (bits >> 3) & 0xF; uint8_t mant = bits & 0x7; // Special case: zero (exp=0, mant=0) if (exp == 0 && mant == 0) { return sign ? -0.0f : 0.0f; } // Special case: NaN (exp=0xF, mant=0x7) if (exp == 0xF && mant == 0x7) { return 0.0f / 0.0f; // NaN } float value; if (exp == 0) { // Subnormal: 2^(-6) * (mant/8) // = 2^(-9) * mant // = (1/512) * mant value = __int2float_rn(mant) * 0.001953125f; // 1/512 = 2^(-9) } else { // Normal: 2^(exp-7) * (1 + mant/8) int exp_val = static_cast<int>(exp) - 7; float mantissa = 1.0f + __int2float_rn(mant) * 0.125f; // 1/8 // Use integer bit shifts for exact 2^exp_val computation // Avoids __powf() which has accuracy issues (per V221_HANDOFF.md) if (exp_val >= 0) { value = mantissa * static_cast<float>(1 << exp_val); } else { value = mantissa / static_cast<float>(1 << (-exp_val)); } } return sign ? -value : value; } //============================================================================== // Main Kernel: NVFP4 Batched GEMV //============================================================================== /** * NVFP4 Batched GEMV Kernel * * Computes: C[m,0,l] = Σ(k=0..K-1) [decode_fp4(A[m,k,l]) * decode_fp4(B[0,k,l]) * * SFA[m,k//16,l] * SFB[0,k//16,l]] * * Algorithm: * - Process K in blocks of 16 FP4 elements * - Scale-outside-inner-loop pattern (critical for accuracy) * - FP32 accumulation, FP16 output * - Shared memory caching for B and SFB vectors * * Grid: (ceil_div(M,128), 1, L) * Block: 128 threads (matches v130) * Shared memory: K FP32 (B decoded) + K/16 FP32 (SFB decoded) * * @param A: [M, K//2, L] FP4 packed matrix * @param B: [128, K//2, L] FP4 packed vector (only row 0 used) * @param SFA: [M, K//16, L] FP8 E4M3 scale factors for A * @param SFB: [128, K//16, L] FP8 E4M3 scale factors for B (only row 0 used) * @param C: [M, 1, L] FP16 output * @param M, K, L: problem dimensions (K is logical unpacked size) * @param *_stride_*: byte strides for each tensor dimension */ extern "C" __global__ void nvfp4_batched_gemv( const uint8_t* __restrict__ A, const uint8_t* __restrict__ B, const uint8_t* __restrict__ SFA, const uint8_t* __restrict__ SFB, __half* __restrict__ C, int M, int K, int L, long long a_stride_m, long long a_stride_k, long long a_stride_l, long long b_stride_n, long long b_stride_k, long long b_stride_l, long long sfa_stride_m, long long sfa_stride_k, long long sfa_stride_l, long long sfb_stride_n, long long sfb_stride_k, long long sfb_stride_l, long long c_stride_m, long long c_stride_l ) { //========================================================================== // Shared Memory Layout //========================================================================== extern __shared__ float smem[]; float* B_smem = smem; // K FP32 values (decoded B vector) float* SFB_smem = smem + K; // K/16 FP32 values (decoded SFB scales) //========================================================================== // Thread/Block Mapping (v130-compatible) //========================================================================== const int tid = threadIdx.x; // Thread ID within block [0, 128) const int block_size = blockDim.x; // 128 threads per block const int m_block = blockIdx.x; // M-block index const int l = blockIdx.z; // Batch index const int m = m_block * block_size + tid; // Global row index const int K_blocks = K / 16; // Number of 16-element blocks //========================================================================== // Cooperative Load: B Vector into Shared Memory // All threads in block load B[0, :, l] (N=0, only first row) //========================================================================== for (int k = tid; k < K; k += block_size) { int k_byte = k / 2; // Byte index (2 FP4 values per byte) int k_offset = k % 2; // Nibble offset (0=low, 1=high) // B index: B[0, k_byte, l] long long b_idx = 0 * b_stride_n + k_byte * b_stride_k + l * b_stride_l; B_smem[k] = decode_fp4(B[b_idx], k_offset); } //========================================================================== // Cooperative Load: SFB Scale Factors into Shared Memory // All threads in block load SFB[0, :, l] (N=0, only first row) //========================================================================== for (int kb = tid; kb < K_blocks; kb += block_size) { // CRITICAL FIX: Scale factors are replicated 16× in memory // Access element (kb * 16) to get scale for block kb int k_element = kb * 16; long long sfb_idx = 0 * sfb_stride_n + k_element * sfb_stride_k + l * sfb_stride_l; SFB_smem[kb] = decode_fp8_e4m3(SFB[sfb_idx]); } __syncthreads(); // Wait for all threads to finish loading shared memory //========================================================================== // Bounds Check: Skip threads with invalid row index //========================================================================== if (m >= M) return; //========================================================================== // Main Computation Loop // Each thread computes one output element: C[m, 0, l] //========================================================================== float acc = 0.0f; // FP32 accumulator // Outer loop: Iterate over K in blocks of 16 elements for (int k_block = 0; k_block < K_blocks; k_block++) { //====================================================================== // Load and Decode Scale Factors (once per 16-element block) //====================================================================== // CRITICAL FIX: Scale factors are replicated 16× in memory // To get scale for block k_block, access element (k_block * 16) // This matches v148f line 114: idx = i * 16; scale = tArSFA[idx] int k_element = k_block * 16; // Convert block index to element index long long sfa_idx = m * sfa_stride_m + k_element * sfa_stride_k + l * sfa_stride_l; float sfa = decode_fp8_e4m3(SFA[sfa_idx]); // Combined scale: SFA[m, k_block, l] * SFB[0, k_block, l] float scale = sfa * SFB_smem[k_block]; //====================================================================== // Inner Loop: Accumulate 16 A*B Products (scale applied OUTSIDE) // Critical pattern: sum all products first, THEN multiply by scale // MANUAL UNROLL to avoid compiler bug at k=7 (see V401_ALGORITHM_DEBUG.md) //====================================================================== float block_sum = 0.0f; int k_start = k_block * 16; // Element 0 { int k = k_start + 0; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 1 { int k = k_start + 1; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 2 { int k = k_start + 2; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 3 { int k = k_start + 3; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 4 { int k = k_start + 4; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 5 { int k = k_start + 5; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 6 { int k = k_start + 6; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 7 - CRITICAL: Previously had compiler bug (A[k+8] instead of A[k]) { int k = k_start + 7; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 8 { int k = k_start + 8; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 9 { int k = k_start + 9; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 10 { int k = k_start + 10; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 11 { int k = k_start + 11; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 12 { int k = k_start + 12; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 13 { int k = k_start + 13; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 14 { int k = k_start + 14; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Element 15 { int k = k_start + 15; int k_byte = k / 2; int k_offset = k % 2; long long a_idx = m * a_stride_m + k_byte * a_stride_k + l * a_stride_l; float a_val = decode_fp4(A[a_idx], k_offset); float b_val = B_smem[k]; block_sum += a_val * b_val; } // Apply scale ONCE per block (scale-outside-loop pattern) acc += block_sum * scale; } //========================================================================== // Write Output: Convert FP32 to FP16 //========================================================================== // C index: C[m, 0, l] (N dimension is always 0 for GEMV) long long c_idx = m * c_stride_m + l * c_stride_l; C[c_idx] = __float2half(acc); } //============================================================================== // Host Launcher Function //============================================================================== /** * Host-side launcher for nvfp4_batched_gemv kernel * Configures grid/block dimensions and shared memory size * * Grid: (ceil_div(M,128), 1, L) - matches v130 * Block: 128 threads * Shared memory: (K + K/16) * sizeof(float) */ extern "C" void launch_nvfp4_gemv( const uint8_t* A, const uint8_t* B, const uint8_t* SFA, const uint8_t* SFB, __half* C, int M, int K, int L, long long a_stride_m, long long a_stride_k, long long a_stride_l, long long b_stride_n, long long b_stride_k, long long b_stride_l, long long sfa_stride_m, long long sfa_stride_k, long long sfa_stride_l, long long sfb_stride_n, long long sfb_stride_k, long long sfb_stride_l, long long c_stride_m, long long c_stride_l, cudaStream_t stream ) { //========================================================================== // Configure Launch Parameters (v130-compatible) //========================================================================== const int threads = 128; // Threads per block (matches v130) const int blocks_x = (M + threads - 1) / threads; // Ceil division dim3 grid(blocks_x, 1, L); // Grid: (M_blocks, 1, L) dim3 block(threads, 1, 1); // Block: (128, 1, 1) //========================================================================== // Compute Shared Memory Size //========================================================================== const int K_blocks = K / 16; const size_t smem_size = (K + K_blocks) * sizeof(float); //========================================================================== // Launch Kernel //========================================================================== nvfp4_batched_gemv<<<grid, block, smem_size, stream>>>( A, B, SFA, SFB, C, M, K, L, a_stride_m, a_stride_k, a_stride_l, b_stride_n, b_stride_k, b_stride_l, sfa_stride_m, sfa_stride_k, sfa_stride_l, sfb_stride_n, sfb_stride_k, sfb_stride_l, c_stride_m, c_stride_l ); }
Become a Patron
Sponsor on GitHub
Donate via PayPal
Compiler Explorer Shop
Source on GitHub
Mailing list
Installed libraries
Wiki
Report an issue
How it works
Contact the author
CE on Mastodon
CE on Bluesky
Statistics
Changelog
Version tree