Thanks for using Compiler Explorer
Sponsors
Jakt
C++
Ada
Analysis
Android Java
Android Kotlin
Assembly
C
C3
Carbon
C++ (Circle)
CIRCT
Clean
CMake
CMakeScript
COBOL
C++ for OpenCL
MLIR
Cppx
Cppx-Blue
Cppx-Gold
Cpp2-cppfront
Crystal
C#
CUDA C++
D
Dart
Elixir
Erlang
Fortran
F#
Go
Haskell
HLSL
Hook
Hylo
ispc
Java
Julia
Kotlin
LLVM IR
LLVM MIR
Modula-2
Nim
Objective-C
Objective-C++
OCaml
OpenCL C
Pascal
Pony
Python
Racket
Ruby
Rust
Snowball
Scala
Solidity
Spice
Swift
LLVM TableGen
Toit
TypeScript Native
V
Vala
Visual Basic
WASM
Zig
Javascript
GIMPLE
cuda source #1
Output
Compile to binary object
Link to binary
Execute the code
Intel asm syntax
Demangle identifiers
Verbose demangling
Filters
Unused labels
Library functions
Directives
Comments
Horizontal whitespace
Debug intrinsics
Compiler
10.0.0 sm_75 CUDA-10.2
10.0.1 sm_75 CUDA-10.2
11.0.0 sm_75 CUDA-10.2
NVCC 10.0.130
NVCC 10.1.105
NVCC 10.1.168
NVCC 10.1.243
NVCC 10.2.89
NVCC 11.0.2
NVCC 11.0.3
NVCC 11.1.0
NVCC 11.1.1
NVCC 11.2.0
NVCC 11.2.1
NVCC 11.2.2
NVCC 11.3.0
NVCC 11.3.1
NVCC 11.4.0
NVCC 11.4.1
NVCC 11.4.2
NVCC 11.4.3
NVCC 11.4.4
NVCC 11.5.0
NVCC 11.5.1
NVCC 11.5.2
NVCC 11.6.0
NVCC 11.6.1
NVCC 11.6.2
NVCC 11.7.0
NVCC 11.7.1
NVCC 11.8.0
NVCC 12.0.0
NVCC 12.0.1
NVCC 12.1.0
NVCC 12.2.1
NVCC 12.3.1
NVCC 12.4.1
NVCC 12.5.1
NVCC 9.1.85
NVCC 9.2.88
NVRTC 11.0.2
NVRTC 11.0.3
NVRTC 11.1.0
NVRTC 11.1.1
NVRTC 11.2.0
NVRTC 11.2.1
NVRTC 11.2.2
NVRTC 11.3.0
NVRTC 11.3.1
NVRTC 11.4.0
NVRTC 11.4.1
NVRTC 11.5.0
NVRTC 11.5.1
NVRTC 11.5.2
NVRTC 11.6.0
NVRTC 11.6.1
NVRTC 11.6.2
NVRTC 11.7.0
NVRTC 11.7.1
NVRTC 11.8.0
NVRTC 12.0.0
NVRTC 12.0.1
NVRTC 12.1.0
clang 7.0.0 sm_70 CUDA-9.1
clang 8.0.0 sm_75 CUDA-10.0
clang 9.0.0 sm_75 CUDA-10.1
clang rocm-4.5.2
clang rocm-5.0.2
clang rocm-5.1.3
clang rocm-5.2.3
clang rocm-5.3.2
clang rocm-5.7.0
clang rocm-6.0.2
clang rocm-6.1.2
clang staging rocm-6.1.2
clang trunk rocm-6.1.2
trunk sm_86 CUDA-11.3
Options
Source code
#include <cuda_runtime.h> #include <cuda_bf16.h> typedef __nv_bfloat16 floatX; template<class ElementType> struct alignas(16) Packed128 { __device__ Packed128() = default; __device__ explicit Packed128(int4 bits) { static_assert(sizeof(bits) == sizeof(payload), "Size mismatch."); memcpy(&payload, &bits, sizeof(bits)); } __device__ ElementType& operator[](int index) { return payload[index]; } __device__ const ElementType& operator[](int index) const { return payload[index]; } __device__ int4 get_bits() const { int4 bits; static_assert(sizeof(bits) == sizeof(payload), "Size mismatch."); memcpy(&bits, &payload, sizeof(bits)); return bits; } // e.g. sizeof(int4) is 16 (4 X 4 bytes), sizeof(bfloat16) = 2, so size = 8 // so in the case where ElementType = bfloat16, we store 8 elements in one Packed128 static constexpr const int size = sizeof(int4) / sizeof(ElementType); ElementType payload[size]; }; // short-form typedef typedef Packed128<float> f128; // load a Packed128 from an aligned memory address template<class ElementType> __device__ Packed128<ElementType> load128(const ElementType* address) { return Packed128<ElementType>{*reinterpret_cast<const int4*>(address)}; } // load a Packed128 from an aligned memory address with streaming cache hint template<class ElementType> __device__ Packed128<ElementType> load128cs(const ElementType* address) { return Packed128<ElementType>{__ldcs(reinterpret_cast<const int4*>(address))}; } // store a Packed128 to an aligned memory address template<class ElementType> __device__ void store128(ElementType* target, Packed128<ElementType> value) { *reinterpret_cast<int4*>(target) = value.get_bits(); } // store a Packed128 to an aligned memory address with streaming cache hint template<class ElementType> __device__ void store128cs(ElementType* target, Packed128<ElementType> value) { __stcs(reinterpret_cast<int4*>(target), value.get_bits()); } typedef Packed128<floatX> x128; // ---------------- // Simple xorshift RNG __device__ __host__ unsigned int random_u32(unsigned long long *state) { // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A *state ^= *state >> 12; *state ^= *state << 25; *state ^= *state >> 27; return (*state * 0x2545F4914F6CDD1Dull) >> 32; } __device__ __host__ float random_f32(unsigned long long *state) { // random float32 in [0,1) return (random_u32(state) >> 8) / 16777216.0f; } // SquirrelNoise5 - Squirrel's Raw Noise utilities (version 5) // This gives us a random number from threadIdx/blockIdx + a single seed for the entire GPU // todo - possibly overkill and we don't need such high quality random numbers? (tbd) // http://eiserloh.net/noise/SquirrelNoise5.hpp __device__ __host__ constexpr unsigned int SquirrelNoise5(int positionX, unsigned int seed) { constexpr unsigned int SQ5_BIT_NOISE1 = 0xd2a80a3f; // 11010010101010000000101000111111 constexpr unsigned int SQ5_BIT_NOISE2 = 0xa884f197; // 10101000100001001111000110010111 constexpr unsigned int SQ5_BIT_NOISE3 = 0x6C736F4B; // 01101100011100110110111101001011 constexpr unsigned int SQ5_BIT_NOISE4 = 0xB79F3ABB; // 10110111100111110011101010111011 constexpr unsigned int SQ5_BIT_NOISE5 = 0x1b56c4f5; // 00011011010101101100010011110101 unsigned int mangledBits = (unsigned int) positionX; mangledBits *= SQ5_BIT_NOISE1; mangledBits += seed; mangledBits ^= (mangledBits >> 9); mangledBits += SQ5_BIT_NOISE2; mangledBits ^= (mangledBits >> 11); mangledBits *= SQ5_BIT_NOISE3; mangledBits ^= (mangledBits >> 13); mangledBits += SQ5_BIT_NOISE4; mangledBits ^= (mangledBits >> 15); mangledBits *= SQ5_BIT_NOISE5; mangledBits ^= (mangledBits >> 17); return mangledBits; } __device__ __host__ constexpr unsigned int Get1dNoiseUint(int positionX, unsigned int seed) { return SquirrelNoise5(positionX, seed); } __device__ __host__ constexpr unsigned int Get2dNoiseUint(int indexX, int indexY, unsigned int seed) { constexpr int PRIME_NUMBER = 198491317; // Large prime number with non-boring bits return SquirrelNoise5(indexX + (PRIME_NUMBER * indexY), seed); } __device__ __host__ constexpr float Get1dNoiseZeroToOne(int index, unsigned int seed) { constexpr double ONE_OVER_MAX_UINT = (1.0 / (double) 0xFFFFFFFF); return (float)(ONE_OVER_MAX_UINT * (double) SquirrelNoise5(index, seed)); } __device__ __host__ constexpr float Get2dNoiseZeroToOne(int indexX, int indexY, unsigned int seed) { constexpr double ONE_OVER_MAX_UINT = (1.0 / (double) 0xFFFFFFFF); return (float)(ONE_OVER_MAX_UINT * (double) Get2dNoiseUint(indexX, indexY, seed)); } // stochastic rounding built on top of Squirel Noise above (with seed updated per step via xorshift) __device__ __forceinline__ void stochastic_rounding(float in, __nv_bfloat16 *out, unsigned int seed) { // todo - is this stochastic rounding *too good*? can we cut any corners? unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed); unsigned int threshold = random & 0xFFFF; unsigned int float_bits = __float_as_uint(in); unsigned int rounded_bits = float_bits & 0x0000FFFF; float_bits = (rounded_bits > threshold) ? (float_bits | 0xFFFF) : (float_bits & ~0xFFFF); *out = __float2bfloat16_rn(__uint_as_float(float_bits)); } __device__ __forceinline__ void stochastic_rounding(float in, half *out, unsigned int random) { *out = (float)in; // todo - implement this... } __device__ __forceinline__ void stochastic_rounding(float in, float *out, unsigned int random) { *out = in; // dummy function for when floatX is float (FP32 mode) } // Implements linear interpolation using only two floating-point operations (as opposed to three in a naive implementation). // Reference: https://developer.nvidia.com/blog/lerp-faster-cuda __device__ inline float lerp(float start, float end, float weight) { return fma(weight, end, fma(-weight, start, start)); } __global__ void adamw_kernel4(floatX* params_memory, const floatX* grads_memory, float* m_memory, float* v_memory, size_t num_parameters, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay, unsigned int seed) { int i = blockIdx.x * blockDim.x + threadIdx.x; x128 packed_grads_memory = load128(grads_memory+(i*x128::size)); x128 packed_params_memory = load128(params_memory+(i*x128::size)); f128 packed_m_memory = load128(m_memory+(i*f128::size)); f128 packed_v_memory = load128(v_memory+(i*f128::size)); for(int k = 0; k < packed_v_memory.size; ++k){ float grad = (float)packed_grads_memory[k]; float m = packed_m_memory[k]; float v = packed_v_memory[k]; // update the first moment (momentum) m = lerp(grad, m, beta1); packed_m_memory[k] = m; // update the second moment (RMSprop) v = lerp(grad * grad, v, beta2); packed_v_memory[k] = v; m /= beta1_correction; // Setting these values explicitly due to compiler error for modifying v /= beta2_correction; // packed128 values when using // update the parameters (weight/bias) float param = (float)packed_params_memory[k] - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * (float)packed_params_memory[k])); unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed); // todo - explain stochastic rounding here stochastic_rounding(param, &packed_params_memory[k], random); } store128(m_memory+(i*f128::size), packed_m_memory); store128(v_memory+(i*f128::size), packed_v_memory); store128(params_memory+(i*x128::size), packed_params_memory); }
Become a Patron
Sponsor on GitHub
Donate via PayPal
Source on GitHub
Mailing list
Installed libraries
Wiki
Report an issue
How it works
Contact the author
CE on Mastodon
About the author
Statistics
Changelog
Version tree