Thanks for using Compiler Explorer
Sponsors
Jakt
C++
Ada
Algol68
Analysis
Android Java
Android Kotlin
Assembly
C
C3
Carbon
C with Coccinelle
C++ with Coccinelle
C++ (Circle)
CIRCT
Clean
Clojure
CMake
CMakeScript
COBOL
C++ for OpenCL
MLIR
Cppx
Cppx-Blue
Cppx-Gold
Cpp2-cppfront
Crystal
C#
CUDA C++
D
Dart
Elixir
Erlang
Fortran
F#
GLSL
Go
Haskell
HLSL
Helion
Hook
Hylo
IL
ispc
Java
Julia
Kotlin
LLVM IR
LLVM MIR
Modula-2
Mojo
Nim
Numba
Nix
Objective-C
Objective-C++
OCaml
Odin
OpenCL C
Pascal
Pony
PTX
Python
Racket
Raku
Ruby
Rust
Sail
Snowball
Scala
Slang
Solidity
Spice
SPIR-V
Swift
LLVM TableGen
Toit
Triton
TypeScript Native
V
Vala
Visual Basic
Vyper
WASM
Yul (Solidity IR)
Zig
Javascript
GIMPLE
Ygen
sway
hlsl source #1
Output
Compile to binary object
Link to binary
Execute the code
Intel asm syntax
Demangle identifiers
Verbose demangling
Filters
Unused labels
Library functions
Directives
Comments
Horizontal whitespace
Debug intrinsics
Compiler
AMD RGA 2.10
AMD RGA 2.11
AMD RGA 2.12
AMD RGA 2.13
AMD RGA 2.9.1
Clang (assertions trunk)
Clang (trunk)
DXC (trunk)
DXC 1.6.2112
DXC 1.7.2207
DXC 1.7.2212
DXC 1.7.2308
DXC 1.8.2306-preview
DXC 1.8.2403
DXC 1.8.2403.1
DXC 1.8.2403.2
DXC 1.8.2405
DXC 1.8.2407
DXC 1.8.2502
DXC 1.8.2505
DXC 1.8.2505.1
RGA 2.6.1 (DXC 1.6.2112)
RGA 2.6.1 (DXC 1.7.2207)
RGA 2.6.2 (DXC 1.6.2112)
RGA 2.6.2 (DXC 1.7.2207)
RGA 2.6.2 (DXC trunk)
RGA 2.9.0 (DXC trunk)
Options
Source code
namespace hlsl { #ifdef __hlsl_dx_compiler #define SIZE_TYPE int #else #define SIZE_TYPE uint #endif template <typename T> struct is_arithmetic { static const bool value = false; }; #define __ARITHMETIC_TYPE(type) \ template <> struct is_arithmetic<type> { \ static const bool value = true; \ }; #if __HLSL_ENABLE_16_BIT __ARITHMETIC_TYPE(uint16_t) __ARITHMETIC_TYPE(int16_t) #endif __ARITHMETIC_TYPE(uint) __ARITHMETIC_TYPE(int) __ARITHMETIC_TYPE(uint64_t) __ARITHMETIC_TYPE(int64_t) __ARITHMETIC_TYPE(half) __ARITHMETIC_TYPE(float) __ARITHMETIC_TYPE(double) template <bool B, typename T> struct enable_if {}; template <typename T> struct enable_if<true, T> { using type = T; }; } // namespace hlsl namespace dx { namespace linalg { struct ComponentType { enum ComponentEnum { Invalid = 0, I1 = 1, I8 = 2, U8 = 3, I16 = 4, U16 = 5, I32 = 6, U32 = 7, I64 = 8, U64 = 9, F16 = 10, F32 = 11, F64 = 12, SNormF16 = 13, UNormF16 = 14, SNormF32 = 15, UNormF32 = 16, SNormF64 = 17, UNormF64 = 18, F8_E4M3 = 19, F8_E5M2 = 20, }; }; using ComponentEnum = ComponentType::ComponentEnum; struct MatrixUse { enum MatrixUseEnum { A = 0, B = 1, Accumulator = 2, }; }; using MatrixUseEnum = MatrixUse::MatrixUseEnum; struct MatrixScope { enum MatrixScopeEnum { Thread = 0, Wave = 1, ThreadGroup = 2, }; }; using MatrixScopeEnum = MatrixScope::MatrixScopeEnum; struct MatrixLayout { enum MatrixLayoutEnum { RowMajor = 0, ColMajor = 1, MulOptimal = 2, OuterProductOptimal = 3, }; }; using MatrixLayoutEnum = MatrixLayout::MatrixLayoutEnum; namespace __detail { template <ComponentEnum T> struct ComponentTypeTraits { using Type = uint; static const bool IsNativeScalar = false; static const uint ElementsPerScalar = 4; }; #define __MATRIX_SCALAR_COMPONENT_MAPPING(enum_val, type) \ template <> struct ComponentTypeTraits<enum_val> { \ using Type = type; \ static const bool IsNativeScalar = true; \ static const uint ElementsPerScalar = 1; \ }; #if __HLSL_ENABLE_16_BIT __MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::I16, int16_t) __MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U16, uint16_t) __MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F16, float16_t) #endif __MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::I32, int32_t) __MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U32, uint32_t) __MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F32, float) __MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::I64, int64_t) __MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U64, uint64_t) __MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F64, double) } // namespace __detail template <ComponentEnum ElementType, uint DimA> struct VectorRef { ByteAddressBuffer Buf; uint Offset; }; template <typename T, int N, ComponentEnum DT> struct InterpretedVector { vector<T, N> Data; static const ComponentEnum Interpretation = DT; static const SIZE_TYPE Size = __detail::ComponentTypeTraits<DT>::ElementsPerScalar * N; }; template <ComponentEnum DT, typename T, int N> InterpretedVector<T, N, DT> MakeInterpretedVector(vector<T, N> Vec) { InterpretedVector<T, N, DT> IV = {Vec}; return IV; } template <ComponentEnum ComponentTy, SIZE_TYPE M, SIZE_TYPE N, MatrixUseEnum Use, MatrixScopeEnum Scope> class Matrix { using ElementType = typename __detail::ComponentTypeTraits<ComponentTy>::Type; // If this isn't a native scalar, we have an 8-bit type, so we have 4 elements // packed in each scalar value. static const uint ElementsPerScalar = __detail::ComponentTypeTraits<ComponentTy>::ElementsPerScalar; template <ComponentEnum NewCompTy, MatrixUseEnum NewUse = Use> Matrix<NewCompTy, M, N, NewUse, Scope> Cast(); template <typename T> static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, Matrix>::type Splat(T Val); static Matrix Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, MatrixLayoutEnum Layout, uint Align = sizeof(ElementType)); static Matrix Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride, MatrixLayoutEnum Layout, uint Align = sizeof(ElementType)); template <typename T> static typename hlsl::enable_if<hlsl::is_arithmetic<T>::value, Matrix>::type Load(/*groupshared*/ T Arr[], uint StartIdx, uint Stride, MatrixLayoutEnum Layout); uint Length(); uint2 GetCoordinate(uint); ElementType Get(uint); void Set(uint, ElementType); void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride, MatrixLayoutEnum Layout, uint Align = sizeof(ElementType)); template <typename T, SIZE_TYPE Size> typename hlsl::enable_if<hlsl::is_arithmetic<T>::value && (M * N / ElementsPerScalar >= Size), void>::type Store(/*groupshared*/ T Arr[Size], uint StartIdx, uint Stride, MatrixLayoutEnum Layout); // Accumulate methods template <MatrixUseEnum UseLocal = Use> typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use, void>::type Accumulate(RWByteAddressBuffer Res, uint StartOffset, uint Stride, MatrixLayoutEnum Layout, uint Align = sizeof(ElementType)); template <typename T, MatrixUseEnum UseLocal = Use> typename hlsl::enable_if<hlsl::is_arithmetic<T>::value && Use == MatrixUse::Accumulator && UseLocal == Use, void>::type Accumulate(/*groupshared*/ T Arr[], uint StartIdx, uint Stride, MatrixLayoutEnum Layout); template <ComponentEnum LHSTy, ComponentEnum RHSTy, uint K, MatrixUseEnum UseLocal = Use> typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use, void>::type MultiplyAccumulate(const Matrix<LHSTy, M, K, MatrixUse::A, Scope>, const Matrix<RHSTy, K, N, MatrixUse::B, Scope>); template <ComponentEnum LHSTy, ComponentEnum RHSTy, uint K, MatrixUseEnum UseLocal = Use> typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use, void>::type SumAccumulate(const Matrix<LHSTy, M, K, MatrixUse::A, Scope>, const Matrix<RHSTy, K, N, MatrixUse::B, Scope>); }; // Thread-scope Matrices are read-only. Using a template partial specialization // for this simplifies the SFINAE-foo above. template <ComponentEnum ComponentTy, SIZE_TYPE M, SIZE_TYPE N, MatrixUseEnum Use> class Matrix<ComponentTy, M, N, Use, MatrixScope::Thread> { using ElementType = typename __detail::ComponentTypeTraits<ComponentTy>::Type; static Matrix Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, MatrixLayoutEnum Layout, uint Align = sizeof(ElementType)); void Accumulate(RWByteAddressBuffer Res, uint StartOffset, uint Stride, MatrixLayoutEnum Layout, uint Align = sizeof(ElementType)); }; MatrixUseEnum AccumulatorLayout(); template <ComponentEnum OutTy, ComponentEnum ATy, ComponentEnum BTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K> Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::Wave> Multiply(const Matrix<ATy, M, K, MatrixUse::A, MatrixScope::Wave>, const Matrix<BTy, K, N, MatrixUse::B, MatrixScope::Wave>); template <ComponentEnum T, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K> Matrix<T, M, N, MatrixUse::Accumulator, MatrixScope::Wave> Multiply(const Matrix<T, M, K, MatrixUse::A, MatrixScope::Wave>, const Matrix<T, K, N, MatrixUse::B, MatrixScope::Wave>); template <ComponentEnum OutTy, ComponentEnum ATy, ComponentEnum BTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K> Matrix<OutTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(const Matrix<ATy, M, K, MatrixUse::A, MatrixScope::ThreadGroup>, const Matrix<BTy, K, N, MatrixUse::B, MatrixScope::ThreadGroup>); template <ComponentEnum T, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K> Matrix<T, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(const Matrix<T, M, K, MatrixUse::A, MatrixScope::ThreadGroup>, const Matrix<T, K, N, MatrixUse::B, MatrixScope::ThreadGroup>); // Cooperative Vector Replacement API // Cooperative Vector operates on per-thread vectors multiplying against B // matrices with thread scope. template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT, MatrixScopeEnum Scope> vector<OutputElTy, K> Multiply(vector<InputElTy, M>, Matrix<MatrixDT, M, K, MatrixUse::B, Scope>); template <typename OutputElTy, typename InputElTy, typename BiasElTy, SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT, MatrixScopeEnum Scope> vector<OutputElTy, K> MultiplyAdd(vector<InputElTy, M>, Matrix<MatrixDT, M, K, MatrixUse::B, Scope>, vector<BiasElTy, K>); template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp, typename BiasElTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K, ComponentEnum MatrixDT, MatrixScopeEnum Scope> typename hlsl::enable_if<InterpretedVector<InputElTy, N, InputInterp>::Size == M, vector<OutputElTy, K> >::type MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp>, Matrix<MatrixDT, M, K, MatrixUse::B, Scope>, vector<BiasElTy, K>); template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT> vector<OutputElTy, K> MultiplyAdd(vector<InputElTy, M>, Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>, VectorRef<BiasElTy, K>); template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp, ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K, ComponentEnum MatrixDT> typename hlsl::enable_if<InterpretedVector<InputElTy, N, InputInterp>::Size == M, vector<OutputElTy, K> >::type MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp>, Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>, VectorRef<BiasElTy, K>); // Outer product functions template <ComponentEnum OutTy, MatrixScopeEnum Scope, typename InputElTy, SIZE_TYPE M, SIZE_TYPE N> Matrix<OutTy, M, N, MatrixUse::Accumulator, Scope> OuterProduct(vector<InputElTy, M>, vector<InputElTy, N>); } // namespace linalg } // namespace dx RWByteAddressBuffer B : register(u0); void WaveMatrixExample() { using namespace dx::linalg; using MatrixATy = Matrix<ComponentType::F16, 8, 32, MatrixUse::A, MatrixScope::Wave>; using MatrixBTy = Matrix<ComponentType::F16, 32, 16, MatrixUse::B, MatrixScope::Wave>; using MatrixAccumTy = Matrix<ComponentType::F16, 8, 16, MatrixUse::Accumulator, MatrixScope::Wave>; using MatrixAccum32Ty = Matrix<ComponentType::F32, 8, 16, MatrixUse::Accumulator, MatrixScope::Wave>; MatrixATy MatA = MatrixATy::Load( B, 0, /* Row stride = number of columns * element size */ 32 * 4, MatrixLayout::RowMajor); MatrixBTy MatB = MatrixBTy::Load( B, 0, /* Row stride = number of columns * element size */ 16 * 4, MatrixLayout::RowMajor); for (uint I = 0; I < MatB.Length(); ++I) { uint2 Pos = MatB.GetCoordinate(I); // Run `tanh` on all but the diagonal components for no reasonable reason. if (Pos.x != Pos.y) { float16_t Val = MatB.Get(I); MatB.Set(I, tanh(Val)); } } MatrixAccumTy Accum = Multiply(MatA, MatB); MatrixAccum32Ty Accum32 = Multiply<ComponentType::F32>(MatA, MatB); } ByteAddressBuffer MBuf : register(t0); void CoopVec() { using namespace dx::linalg; using MatrixBTy = Matrix<ComponentType::F16, 16, 16, MatrixUse::B, MatrixScope::Thread>; vector<float16_t, 16> Vec = (vector<float16_t, 16>)0; MatrixBTy MatB = MatrixBTy::Load( MBuf, 0, /* Row stride = number of columns * element size */ 16 * 4, MatrixLayout::RowMajor); vector<float16_t, 16> Layer1 = Multiply<float16_t>(Vec, MatB); vector<float16_t, 16> NullBias = (vector<float16_t, 16>)0; vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(Layer1, MatB, NullBias); VectorRef<ComponentType::F8_E4M3, 16> MemBias = {MBuf, /*start offset*/ 4096}; vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(Layer2, MatB, MemBias); // Clang doesn't yet support packed types. #ifdef __hlsl_dx_compiler vector<uint8_t4_packed, 4> SomeData = (vector<uint8_t4_packed, 4>)0; vector<float16_t, 16> Layer4 = MultiplyAdd<float16_t>( MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MatB, MemBias); vector<float16_t, 16> Layer5 = MultiplyAdd<float16_t>( MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MatB, NullBias); #endif } RWByteAddressBuffer Buf : register(u1); void OuterProdAccum() { using namespace dx::linalg; using MatrixAccumTy = Matrix<ComponentType::F16, 16, 8, MatrixUse::Accumulator, MatrixScope::Thread>; vector<float16_t, 16> VecA = (vector<float16_t, 16>)0; vector<float16_t, 8> VecB = (vector<float16_t, 8>)0; MatrixAccumTy MatAcc = OuterProduct<ComponentType::F16, MatrixScope::Thread>(VecA, VecB); MatAcc.Accumulate(Buf, 0, 0, MatrixLayout::OuterProductOptimal); }
Become a Patron
Sponsor on GitHub
Donate via PayPal
Compiler Explorer Shop
Source on GitHub
Mailing list
Installed libraries
Wiki
Report an issue
How it works
Contact the author
CE on Mastodon
CE on Bluesky
Statistics
Changelog
Version tree