Thanks for using Compiler Explorer
Sponsors
Jakt
C++
Ada
Algol68
Analysis
Android Java
Android Kotlin
Assembly
C
C3
Carbon
C with Coccinelle
C++ with Coccinelle
C++ (Circle)
CIRCT
Clean
CMake
CMakeScript
COBOL
C++ for OpenCL
MLIR
Cppx
Cppx-Blue
Cppx-Gold
Cpp2-cppfront
Crystal
C#
CUDA C++
D
Dart
Elixir
Erlang
Fortran
F#
GLSL
Go
Haskell
HLSL
Hook
Hylo
IL
ispc
Java
Julia
Kotlin
LLVM IR
LLVM MIR
Modula-2
Mojo
Nim
Numba
Nix
Objective-C
Objective-C++
OCaml
Odin
OpenCL C
Pascal
Pony
PTX
Python
Racket
Raku
Ruby
Rust
Sail
Snowball
Scala
Slang
Solidity
Spice
SPIR-V
Swift
LLVM TableGen
Toit
Triton
TypeScript Native
V
Vala
Visual Basic
Vyper
WASM
Zig
Javascript
GIMPLE
Ygen
sway
cuda source #1
Output
Compile to binary object
Link to binary
Execute the code
Intel asm syntax
Demangle identifiers
Verbose demangling
Filters
Unused labels
Library functions
Directives
Comments
Horizontal whitespace
Debug intrinsics
Compiler
10.0.0 sm_75 CUDA-10.2
10.0.1 sm_75 CUDA-10.2
11.0.0 sm_75 CUDA-10.2
16.0.0 sm_90 CUDA-11.8
17.0.1(libc++) sm_90 CUDA-12.1
18.1.0(libc++) sm_90 CUDA-12.3.1
19.1.0 sm_90 CUDA-12.5.1
20.1.0 sm_90 CUDA-12.5.1
20.1.0 sm_90 CUDA-12.6.1
20.1.0 sm_90 CUDA-12.6.2
NVCC 10.0.130
NVCC 10.1.105
NVCC 10.1.168
NVCC 10.1.243
NVCC 10.2.89
NVCC 11.0.2
NVCC 11.0.3
NVCC 11.1.0
NVCC 11.1.1
NVCC 11.2.0
NVCC 11.2.1
NVCC 11.2.2
NVCC 11.3.0
NVCC 11.3.1
NVCC 11.4.0
NVCC 11.4.1
NVCC 11.4.2
NVCC 11.4.3
NVCC 11.4.4
NVCC 11.5.0
NVCC 11.5.1
NVCC 11.5.2
NVCC 11.6.0
NVCC 11.6.1
NVCC 11.6.2
NVCC 11.7.0
NVCC 11.7.1
NVCC 11.8.0
NVCC 12.0.0
NVCC 12.0.1
NVCC 12.1.0
NVCC 12.2.1
NVCC 12.3.1
NVCC 12.4.1
NVCC 12.5.1
NVCC 12.6.1
NVCC 12.6.2
NVCC 12.8.1
NVCC 12.9.0
NVCC 12.9.1
NVCC 13.0.0
NVCC 9.1.85
NVCC 9.2.88
NVRTC 11.0.2
NVRTC 11.0.3
NVRTC 11.1.0
NVRTC 11.1.1
NVRTC 11.2.0
NVRTC 11.2.1
NVRTC 11.2.2
NVRTC 11.3.0
NVRTC 11.3.1
NVRTC 11.4.0
NVRTC 11.4.1
NVRTC 11.5.0
NVRTC 11.5.1
NVRTC 11.5.2
NVRTC 11.6.0
NVRTC 11.6.1
NVRTC 11.6.2
NVRTC 11.7.0
NVRTC 11.7.1
NVRTC 11.8.0
NVRTC 12.0.0
NVRTC 12.0.1
NVRTC 12.1.0
NVRTC 12.2.1
NVRTC 12.3.1
NVRTC 12.4.1
NVRTC 12.5.1
NVRTC 12.6.1
NVRTC 12.6.2
NVRTC 12.8.1
NVRTC 12.9.0
NVRTC 12.9.1
NVRTC 13.0.0
clang 7.0.0 sm_70 CUDA-9.1
clang 8.0.0 sm_75 CUDA-10.0
clang 9.0.0 sm_75 CUDA-10.1
clang rocm-4.5.2
clang rocm-5.0.2
clang rocm-5.1.3
clang rocm-5.2.3
clang rocm-5.3.2
clang rocm-5.7.0
clang rocm-6.0.2
clang rocm-6.1.2
clang rocm-6.2.4
clang rocm-6.3.3
clang rocm-6.4.0
clang staging rocm-6.1.2
clang staging rocm-6.2.4
clang staging rocm-6.3.3
clang staging rocm-6.4.0
clang trunk rocm-6.1.2
clang trunk rocm-6.2.4
clang trunk rocm-6.3.3
clang trunk rocm-6.4.0
trunk sm_100a CUDA-12.8.1
Options
Source code
// warp-level reduction for finding the maximum value __device__ float warpReduceMax(float val) { for (int offset = 16; offset > 0; offset /= 2) { val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset)); } return val; } // warp-level reduction for summing values __device__ float warpReduceSum(float val) { for (int offset = 16; offset > 0; offset /= 2) { val += __shfl_down_sync(0xFFFFFFFF, val, offset); } return val; } __global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int C) { // out is (N, C) just like inp. Each row of inp will get softmaxed. // same as kernel4, but optimised for very large Cs with advanced unrolling // The trick is to read into a register array (all indices known at compile time) // and always read UNROLL_FACTOR values to maximise memory level parallelism // even if we would be out of bounds, we set the index to min(C-1, idx) // so we just do some unnecessary reads (obviously bad for small C) // the writes are in a separate loop with a conditional check for out of bounds // making it separate is necessary to convince the compiler to do the right thing const int UNROLL_FACTOR = 8; const int warpsPerBlock = blockDim.x / 32; extern __align__(16) __shared__ float shared[]; int idx = blockIdx.x; int tid = threadIdx.x; int warpId = threadIdx.x / 32; // warp index within a block int laneId = threadIdx.x % 32; // thread index within a warp // shared[] must be allocated to have 2 * warpsPerBlock elements // first half for max values, the second half for sum values __align__(16) float* maxvals = shared; __align__(16) float* sumvals = &shared[warpsPerBlock]; if (tid >= C) { maxvals[warpId] = -INFINITY; sumvals[warpId] = 0.0f; return; } const float* x = inp + idx * C; // input float* y = out + idx * C; // output // first, thread coarsening by directly accessing global memory in series float maxval = -INFINITY; for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) { #pragma unroll for (int u = 0; u < UNROLL_FACTOR; u++) { maxval = fmaxf(maxval, x[min(C - 1, i + u*blockDim.x)]); } } // now within-warp reductions for maxval maxval = warpReduceMax(maxval); // the 0th thread of each warp writes the maxval of that warp to shared memory if (laneId == 0) maxvals[warpId] = maxval; __syncthreads(); // now the 0th thread reduces the maxvals in shared memory, i.e. across warps if (tid == 0) { float val = maxvals[tid]; #pragma unroll for (int i = 1; i < warpsPerBlock; i++) { val = fmaxf(val, maxvals[i]); } // store the final max in the first position maxvals[0] = val; } __syncthreads(); // broadcast the max to all threads float offset = maxvals[0]; // compute expf and write the result to global memory // + thread coarsening for sum float sumval = 0.0f; for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) { float reg_array[UNROLL_FACTOR]; #pragma unroll for (int u = 0; u < UNROLL_FACTOR; u++) { reg_array[u] = __ldcs(&x[min(C - 1, i + u*blockDim.x)]); } #pragma unroll for (int u = 0; u < UNROLL_FACTOR; u++) { if (i + u*blockDim.x < C) { float output = expf(reg_array[u] - offset); y[min(C - 1, i + u*blockDim.x)] = output; // compiler likes redundant min()?! sumval += output; // combined into the same loop unlike kernel3 } } } // okay now we calculated exp(x - max(x)) // step 2: sum all the values and divide by the sum // within-warp reduction for sumval sumval = warpReduceSum(sumval); // write sumval to shared memory if (laneId == 0) sumvals[warpId] = sumval; __syncthreads(); // inter-thread reduction of sum if (tid == 0) { float val = sumvals[tid]; #pragma unroll for (int i = 1; i < warpsPerBlock; ++i) { val += sumvals[i]; } sumvals[0] = val; } __syncthreads(); // broadcast the sum to all threads float sum = sumvals[0]; // divide the whole row by the sum for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) { float reg_array[UNROLL_FACTOR]; #pragma unroll for (int u = 0; u < UNROLL_FACTOR; u++) { reg_array[u] = y[min(C - 1, i + u*blockDim.x)]; } #pragma unroll for (int u = 0; u < UNROLL_FACTOR; u++) { if (i + u*blockDim.x < C) { y[i + u*blockDim.x] = reg_array[u] / sum; } } } }
Become a Patron
Sponsor on GitHub
Donate via PayPal
Compiler Explorer Shop
Source on GitHub
Mailing list
Installed libraries
Wiki
Report an issue
How it works
Contact the author
CE on Mastodon
CE on Bluesky
Statistics
Changelog
Version tree