Thanks for using Compiler Explorer
Sponsors
Jakt
C++
Ada
Algol68
Analysis
Android Java
Android Kotlin
Assembly
C
C3
Carbon
C with Coccinelle
C++ with Coccinelle
C++ (Circle)
CIRCT
Clean
CMake
CMakeScript
COBOL
C++ for OpenCL
MLIR
Cppx
Cppx-Blue
Cppx-Gold
Cpp2-cppfront
Crystal
C#
CUDA C++
D
Dart
Elixir
Erlang
Fortran
F#
GLSL
Go
Haskell
HLSL
Hook
Hylo
IL
ispc
Java
Julia
Kotlin
LLVM IR
LLVM MIR
Modula-2
Mojo
Nim
Numba
Nix
Objective-C
Objective-C++
OCaml
Odin
OpenCL C
Pascal
Pony
PTX
Python
Racket
Raku
Ruby
Rust
Sail
Snowball
Scala
Slang
Solidity
Spice
SPIR-V
Swift
LLVM TableGen
Toit
Triton
TypeScript Native
V
Vala
Visual Basic
Vyper
WASM
Zig
Javascript
GIMPLE
Ygen
sway
hlsl source #1
Output
Compile to binary object
Link to binary
Execute the code
Intel asm syntax
Demangle identifiers
Verbose demangling
Filters
Unused labels
Library functions
Directives
Comments
Horizontal whitespace
Debug intrinsics
Compiler
AMD RGA 2.10
AMD RGA 2.11
AMD RGA 2.12
AMD RGA 2.13
AMD RGA 2.9.1
Clang (trunk)
DXC (trunk)
DXC 1.6.2112
DXC 1.7.2207
DXC 1.7.2212
DXC 1.7.2308
DXC 1.8.2306-preview
DXC 1.8.2403
DXC 1.8.2403.1
DXC 1.8.2403.2
DXC 1.8.2405
DXC 1.8.2407
DXC 1.8.2502
DXC 1.8.2505
DXC 1.8.2505.1
RGA 2.6.1 (DXC 1.6.2112)
RGA 2.6.1 (DXC 1.7.2207)
RGA 2.6.2 (DXC 1.6.2112)
RGA 2.6.2 (DXC 1.7.2207)
RGA 2.6.2 (DXC trunk)
RGA 2.9.0 (DXC trunk)
Options
Source code
namespace hlsl { template <typename T> struct is_arithmetic { static const bool value = false; }; #define __ARITHMETIC_TYPE(type) \ template <> struct is_arithmetic<type> { \ static const bool value = true; \ }; #if __HLSL_ENABLE_16_BIT __ARITHMETIC_TYPE(uint16_t) __ARITHMETIC_TYPE(int16_t) #endif __ARITHMETIC_TYPE(uint) __ARITHMETIC_TYPE(int) __ARITHMETIC_TYPE(uint64_t) __ARITHMETIC_TYPE(int64_t) __ARITHMETIC_TYPE(half) __ARITHMETIC_TYPE(float) __ARITHMETIC_TYPE(double) template <bool B, typename T> struct enable_if {}; template <typename T> struct enable_if<true, T> { using type = T; }; } // namespace hlsl namespace dx { namespace linalg { enum class MatrixComponentType { Invalid = 0, I1 = 1, I16 = 2, U16 = 3, I32 = 4, U32 = 5, I64 = 6, U64 = 7, F16 = 8, F32 = 9, F64 = 10, SNormF16 = 11, UNormF16 = 12, SNormF32 = 13, UNormF32 = 14, SNormF64 = 15, UNormF64 = 16, PackedS8x32 = 17, PackedU8x32 = 18, }; namespace __detail { template <MatrixComponentType T> struct ComponentTypeTraits { using Type = uint; static const bool IsNativeScalar = false; }; #define __MATRIX_SCALAR_COMPONENT_MAPPING(enum_val, type) \ template <> struct ComponentTypeTraits<enum_val> { \ using Type = type; \ static const bool IsNativeScalar = true; \ }; #if __HLSL_ENABLE_16_BIT __MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::I16, int16_t) __MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::U16, uint16_t) __MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::F16, float16_t) #endif __MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::I32, int32_t) __MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::U32, uint32_t) __MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::F32, float) __MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::I64, int64_t) __MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::U64, uint64_t) __MATRIX_SCALAR_COMPONENT_MAPPING(MatrixComponentType::F64, double) } // namespace __detail enum class MatrixUse { A = 0, B = 1, Accumulator = 2, }; enum class MatrixScope { Thread = 0, Wave = 1, }; enum class UnaryOperation { NOp = 0, Negate = 1, Abs = 2, Sin = 3, Cos = 4, Tan = 5, }; template <MatrixComponentType ComponentTy, uint M, uint N, MatrixUse Use, MatrixScope Scope> class Matrix { using ElementType = typename __detail::ComponentTypeTraits<ComponentTy>::Type; // If this isn't a native scalar, we have an 8-bit type, so we have 4 elements // packed in each scalar value. static const uint ElementsPerScalar = __detail::ComponentTypeTraits<ComponentTy>::IsNativeScalar ? 1 : 4; static const uint MScalars = (M + (ElementsPerScalar - 1)) / ElementsPerScalar; static const uint NScalars = (N + (ElementsPerScalar - 1)) / ElementsPerScalar; template <MatrixComponentType NewCompTy, MatrixUse NewUse = Use> Matrix<NewCompTy, M, N, NewUse, Scope> cast(); // Element-wise operations template <typename T> typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, Matrix>::type operator+=(T); template <typename T> typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, Matrix>::type operator-=(T); template <typename T> typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, Matrix>::type operator*=(T); template <typename T> typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, Matrix>::type operator/=(T); // Apply a unary operation to each element. template <UnaryOperation Op> Matrix ApplyUnaryOperation(); template <typename T> static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, Matrix>::type Splat(T Val); static Matrix Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, bool ColMajor, uint Align = sizeof(ElementType)); static Matrix Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride, bool ColMajor, uint Align = sizeof(ElementType)); template <typename T> static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, Matrix>::type Load(/*groupshared*/ T Arr[], uint StartIdx, uint Stride, bool ColMajor); template <MatrixUse UseLocal = Use> typename hlsl::enable_if<Use == MatrixUse::A && Scope == MatrixScope::Wave && UseLocal == Use, Matrix>::type FromThreadVectors(vector<ElementType, MScalars>); template <MatrixUse UseLocal = Use> typename hlsl::enable_if<Use == MatrixUse::B && Scope == MatrixScope::Wave && UseLocal == Use, Matrix>::type FromThreadVectors(vector<ElementType, NScalars>); void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride, bool ColMajor, uint Align = sizeof(ElementType)); template <typename T> typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, void>::type Store(/*groupshared*/ T Arr[], uint StartIdx, uint Stride, bool ColMajor); // Extract the thread-specific vector. template <MatrixUse UseLocal = Use> typename hlsl::enable_if<Use == MatrixUse::A && Scope == MatrixScope::Wave && UseLocal == Use, vector<ElementType, MScalars>>::type GetThreadVector(uint Index = 0); template <MatrixUse UseLocal = Use> typename hlsl::enable_if<Use == MatrixUse::B && Scope == MatrixScope::Wave && UseLocal == Use, vector<ElementType, NScalars>>::type GetThreadVector(uint Index = 0); template <MatrixComponentType LHSTy, MatrixComponentType RHSTy, uint K, MatrixUse UseLocal = Use> typename hlsl::enable_if<Use == MatrixUse::Accumulator && Scope == MatrixScope::Wave && UseLocal == Use, void>::type MultiplyAccumulate(const Matrix<LHSTy, M, K, MatrixUse::A, Scope>, const Matrix<RHSTy, K, N, MatrixUse::B, Scope>); template <MatrixComponentType LHSTy, MatrixComponentType RHSTy, uint K, MatrixUse UseLocal = Use> typename hlsl::enable_if<Use == MatrixUse::Accumulator && Scope == MatrixScope::Wave && UseLocal == Use, void>::type SumAccumulate(const Matrix<LHSTy, M, K, MatrixUse::A, Scope>, const Matrix<RHSTy, K, N, MatrixUse::B, Scope>); // Cooperative Vector outer product accumulate. template <typename T, MatrixUse UseLocal = Use> typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use, void>::type OuterProductAccumulate(const vector<T, M>, const vector<T, N>); }; MatrixUse AccumulatorLayout(); template <MatrixComponentType OutTy, MatrixComponentType ATy, MatrixComponentType BTy, uint M, uint N, uint K> Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Wave> Multiply(const Matrix<ATy, M, K, MatrixUse::A, MatrixScope::Wave>, const Matrix<BTy, K, N, MatrixUse::B, MatrixScope::Wave>); template <MatrixComponentType T, uint M, uint N, uint K> Matrix<T, M, N, MatrixUse::Accumulator, MatrixScope::Wave> Multiply(const Matrix<T, M, K, MatrixUse::A, MatrixScope::Wave>, const Matrix<T, K, N, MatrixUse::B, MatrixScope::Wave>); // Cooperative Vector Replacement API // Cooperative Vector operates on per-thread vectors multiplying against B // matrices. template <typename OutputElTy, typename InputElTy, uint M, uint K, MatrixComponentType MatrixDT, MatrixScope Scope> vector<OutputElTy, K> Multiply(vector<InputElTy, M>, Matrix<MatrixDT, M, K, MatrixUse::B, Scope>); template <typename OutputElTy, typename InputElTy, typename BiasElTy, uint M, uint K, MatrixComponentType MatrixDT, MatrixScope Scope> vector<OutputElTy, K> MultiplyAdd(vector<InputElTy, M>, Matrix<MatrixDT, M, K, MatrixUse::B, Scope>, vector<BiasElTy, K>); } // namespace linalg } // namespace dx RWByteAddressBuffer B : register(u0); void WaveMatrixExample() { using namespace dx::linalg; using MatrixATy = Matrix<MatrixComponentType::F16, 8, 32, MatrixUse::A, MatrixScope::Wave>; using MatrixBTy = Matrix<MatrixComponentType::F16, 32, 16, MatrixUse::B, MatrixScope::Wave>; using MatrixAccumTy = Matrix<MatrixComponentType::F16, 8, 16, MatrixUse::Accumulator, MatrixScope::Wave>; using MatrixAccum32Ty = Matrix<MatrixComponentType::F32, 8, 16, MatrixUse::Accumulator, MatrixScope::Wave>; MatrixATy MatA = MatrixATy::Load(B, 0, 8 * 4, false); MatrixBTy MatB = MatrixBTy::Load(B, 0, 32 * 4, false); MatrixAccumTy Accum = Multiply(MatA, MatB); MatrixAccum32Ty Accum32 = Multiply<MatrixComponentType::F32>(MatA, MatB); } void CoopVec() { using namespace dx::linalg; using MatrixBTy = Matrix<MatrixComponentType::F16, 32, 16, MatrixUse::B, MatrixScope::Wave>; vector<float16_t, 32> Vec = (vector<float16_t, 32>)0; MatrixBTy MatB = MatrixBTy::Load(B, 0, 32 * 4, false); vector<float16_t, 16> Accum = Multiply<float16_t>(Vec, MatB); }
Become a Patron
Sponsor on GitHub
Donate via PayPal
Compiler Explorer Shop
Source on GitHub
Mailing list
Installed libraries
Wiki
Report an issue
How it works
Contact the author
CE on Mastodon
CE on Bluesky
Statistics
Changelog
Version tree