diff --git a/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.hlsli b/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.hlsli index 7bd026a718..8a50f0690a 100644 --- a/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.hlsli +++ b/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.hlsli @@ -15,7 +15,6 @@ #ifndef VULKAN_HLSL_SPV_KHR_COOPERATIVE_MATRIX_H_ #define VULKAN_HLSL_SPV_KHR_COOPERATIVE_MATRIX_H_ - // TODO: Add a macro to HLSL to be able to check the Vulkan version being // targeted. @@ -24,63 +23,69 @@ namespace vk { // TODO: Move these defines to a new header file for defines. enum CooperativeMatrixUse { - CooperativeMatrixUseMatrixAKHR = 0, - CooperativeMatrixUseMatrixBKHR = 1, - CooperativeMatrixUseMatrixAccumulatorKHR = 2, - CooperativeMatrixUseMax = 0x7fffffff, + CooperativeMatrixUseMatrixAKHR = 0, + CooperativeMatrixUseMatrixBKHR = 1, + CooperativeMatrixUseMatrixAccumulatorKHR = 2, + CooperativeMatrixUseMax = 0x7fffffff, }; enum CooperativeMatrixLayout { - CooperativeMatrixLayoutRowMajorKHR = 0, - CooperativeMatrixLayoutColumnMajorKHR = 1, - CooperativeMatrixLayoutRowBlockedInterleavedARM = 4202, - CooperativeMatrixLayoutColumnBlockedInterleavedARM = 4203, - CooperativeMatrixLayoutMax = 0x7fffffff, + CooperativeMatrixLayoutRowMajorKHR = 0, + CooperativeMatrixLayoutColumnMajorKHR = 1, + CooperativeMatrixLayoutRowBlockedInterleavedARM = 4202, + CooperativeMatrixLayoutColumnBlockedInterleavedARM = 4203, + CooperativeMatrixLayoutMax = 0x7fffffff, }; enum CooperativeMatrixOperandsMask { - CooperativeMatrixOperandsMaskNone = 0, - CooperativeMatrixOperandsMatrixASignedComponentsKHRMask = 0x00000001, - CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask = 0x00000002, - CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask = 0x00000004, - CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask = 0x00000008, - CooperativeMatrixOperandsSaturatingAccumulationKHRMask = 0x00000010, + CooperativeMatrixOperandsMaskNone = 0, + CooperativeMatrixOperandsMatrixASignedComponentsKHRMask = 0x00000001, + CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask = 0x00000002, + CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask = 0x00000004, + CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask = 0x00000008, + CooperativeMatrixOperandsSaturatingAccumulationKHRMask = 0x00000010, }; enum MemoryAccessMask { - MemoryAccessMaskNone = 0, - MemoryAccessVolatileMask = 0x00000001, - MemoryAccessAlignedMask = 0x00000002, - MemoryAccessNontemporalMask = 0x00000004, - MemoryAccessMakePointerAvailableMask = 0x00000008, - MemoryAccessMakePointerAvailableKHRMask = 0x00000008, - MemoryAccessMakePointerVisibleMask = 0x00000010, - MemoryAccessMakePointerVisibleKHRMask = 0x00000010, - MemoryAccessNonPrivatePointerMask = 0x00000020, - MemoryAccessNonPrivatePointerKHRMask = 0x00000020, - MemoryAccessAliasScopeINTELMaskMask = 0x00010000, - MemoryAccessNoAliasINTELMaskMask = 0x00020000, + MemoryAccessMaskNone = 0, + MemoryAccessVolatileMask = 0x00000001, + MemoryAccessAlignedMask = 0x00000002, + MemoryAccessNontemporalMask = 0x00000004, + MemoryAccessMakePointerAvailableMask = 0x00000008, + MemoryAccessMakePointerAvailableKHRMask = 0x00000008, + MemoryAccessMakePointerVisibleMask = 0x00000010, + MemoryAccessMakePointerVisibleKHRMask = 0x00000010, + MemoryAccessNonPrivatePointerMask = 0x00000020, + MemoryAccessNonPrivatePointerKHRMask = 0x00000020, + MemoryAccessAliasScopeINTELMaskMask = 0x00010000, + MemoryAccessNoAliasINTELMaskMask = 0x00020000, }; enum Scope { - ScopeCrossDevice = 0, - ScopeDevice = 1, - ScopeWorkgroup = 2, - ScopeSubgroup = 3, - ScopeInvocation = 4, - ScopeQueueFamily = 5, - ScopeQueueFamilyKHR = 5, - ScopeShaderCallKHR = 6, - ScopeMax = 0x7fffffff, + ScopeCrossDevice = 0, + ScopeDevice = 1, + ScopeWorkgroup = 2, + ScopeSubgroup = 3, + ScopeInvocation = 4, + ScopeQueueFamily = 5, + ScopeQueueFamilyKHR = 5, + ScopeShaderCallKHR = 6, + ScopeMax = 0x7fffffff, }; namespace khr { +#define SPV_KHR_CooperativeMatrix \ + vk::ext_extension("SPV_KHR_cooperative_matrix"), \ + vk::ext_capability(/* CooperativeMatrixKHRCapability */ 6022) + template class CooperativeMatrix { CooperativeMatrix negate(); CooperativeMatrix operator+(CooperativeMatrix other); CooperativeMatrix operator-(CooperativeMatrix other); + CooperativeMatrix operator*(CooperativeMatrix other); + CooperativeMatrix operator/(CooperativeMatrix other); CooperativeMatrix operator*(ComponentType scalar); void StoreRowMajor(RWStructuredBuffer data, uint32_t index); @@ -93,6 +98,15 @@ class CooperativeMatrix { static CooperativeMatrix LoadColumnMajor(BufferType buffer, uint32_t index); static uint32_t GetLength(); + + static const bool hasSignedIntegerComponentType = + (ComponentType(0) - ComponentType(1) < ComponentType(0)); + using SpirvMatrixType = vk::SpirvOpaqueType< + /* OpTypeCooperativeMatrixKHR */ 4456, ComponentType, + vk::integral_constant, vk::integral_constant, + vk::integral_constant, vk::integral_constant>; + + SpirvMatrixType _matrix; }; template diff --git a/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.impl b/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.impl index 18790e44ab..c0c680a2ba 100644 --- a/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.impl +++ b/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.impl @@ -12,52 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#define SPV_KHR_CooperativeMatrix \ - vk::ext_extension("SPV_KHR_cooperative_matrix"), \ - vk::ext_capability(/* CooperativeMatrixKHRCapability */ 6022) - -#define DECLARE_UNARY_OP(name, opcode) \ - template \ - [[vk::ext_instruction(opcode), \ - SPV_KHR_CooperativeMatrix]] ResultType __builtin_spv_##name(ResultType a) - -DECLARE_UNARY_OP(SNegate, 126); -DECLARE_UNARY_OP(FNegate, 127); - -#undef DECLARY_UNARY_OP - -#define DECLARE_BINOP(name, opcode) \ - template \ - [[vk::ext_instruction(opcode), \ - SPV_KHR_CooperativeMatrix]] ResultType __builtin_spv_##name(ResultType a, \ - ResultType b) - -DECLARE_BINOP(IAdd, 128); -DECLARE_BINOP(FAdd, 129); -DECLARE_BINOP(ISub, 130); -DECLARE_BINOP(FSub, 131); -DECLARE_BINOP(IMul, 132); -DECLARE_BINOP(FMul, 133); -DECLARE_BINOP(UDiv, 134); -DECLARE_BINOP(SDiv, 135); -DECLARE_BINOP(FDiv, 136); - -#undef DECLARE_BINOP +#include "arithmetic_selector.hlsli" + template -[[vk::ext_instruction(/* OpMatrixTimesScalar */ 143), - SPV_KHR_CooperativeMatrix]] ResultType +[[vk::ext_instruction(/* OpMatrixTimesScalar */ 143)]] ResultType __builtin_spv_MatrixTimesScalar(ResultType a, ComponentType b); -// Type-Declaration Instructions - -// TODO: make sure scope, rows, and cols can be specialization constants -template -using __builtin_spv_CooperativeMatrixKHR = vk::SpirvOpaqueType< - /* OpTypeCooperativeMatrixKHR */ 4456, ComponentType, - vk::integral_constant, vk::integral_constant, - vk::integral_constant, vk::integral_constant >; - // Define the load and store instructions template [[vk::ext_instruction( @@ -101,92 +62,101 @@ __builtin_spv_CooperativeMatrixMulAddKHR(MatrixTypeA a, MatrixTypeB b, [[vk::ext_literal]] int operands); namespace vk { namespace khr { -#define COOPERATIVE_MATRIX(ComponentType, Negate, Add, Sub, \ - SIGNED_INTEGER_TYPE) \ - template \ - class CooperativeMatrix { \ - using SpirvMatrixType = \ - __builtin_spv_CooperativeMatrixKHR; \ - \ - CooperativeMatrix negate() { \ - CooperativeMatrix result; \ - result._matrix = Negate(_matrix); \ - return result; \ - } \ - CooperativeMatrix operator+(CooperativeMatrix other) { \ - CooperativeMatrix result; \ - result._matrix = Add(_matrix, other._matrix); \ - return result; \ - } \ - CooperativeMatrix operator-(CooperativeMatrix other) { \ - CooperativeMatrix result; \ - result._matrix = Sub(_matrix, other._matrix); \ - return result; \ - } \ - CooperativeMatrix operator*(ComponentType scalar) { \ - CooperativeMatrix result; \ - result._matrix = __builtin_spv_MatrixTimesScalar(_matrix, scalar); \ - return result; \ - } \ - void StoreRowMajor(RWStructuredBuffer data, \ - uint32_t index) { \ - __builtin_spv_CooperativeMatrixStoreKHR( \ - data[index], _matrix, CooperativeMatrixLayoutRowMajorKHR); \ - } \ - void StoreColumnMajor(RWStructuredBuffer data, \ - uint32_t index) { \ - __builtin_spv_CooperativeMatrixStoreKHR( \ - data[index], _matrix, CooperativeMatrixLayoutColumnMajorKHR); \ - } \ - \ - template \ - static CooperativeMatrix LoadRowMajor(BufferType buffer, uint32_t index) { \ - CooperativeMatrix result; \ - result._matrix = \ - __builtin_spv_CooperativeMatrixLoadKHR( \ - buffer[index], CooperativeMatrixLayoutRowMajorKHR); \ - return result; \ - } \ - \ - template \ - static CooperativeMatrix LoadColumnMajor(BufferType buffer, \ - uint32_t index) { \ - CooperativeMatrix result; \ - result._matrix = \ - __builtin_spv_CooperativeMatrixLoadKHR( \ - buffer[index], CooperativeMatrixLayoutColumnMajorKHR); \ - return result; \ - } \ - \ - static uint GetLength() { \ - return __builtin_spv_CooperativeMatrixLengthKHR(); \ - } \ - \ - static const bool isSignedInteger = SIGNED_INTEGER_TYPE; \ - SpirvMatrixType _matrix; \ - }; - -COOPERATIVE_MATRIX(half, __builtin_spv_FNegate, __builtin_spv_FAdd, - __builtin_spv_FSub, false); -COOPERATIVE_MATRIX(float, __builtin_spv_FNegate, __builtin_spv_FAdd, - __builtin_spv_FSub, false); -COOPERATIVE_MATRIX(double, __builtin_spv_FNegate, __builtin_spv_FAdd, - __builtin_spv_FSub, false); - -// TODO: Need to check for a 16-bit-type enabled -// COOPERATIVE_MATRIX(int16_t, __builtin_spv_SNegate, __builtin_spv_IAdd, -// __builtin_spv_ISub, true); -COOPERATIVE_MATRIX(int32_t, __builtin_spv_SNegate, __builtin_spv_IAdd, - __builtin_spv_ISub, true); -COOPERATIVE_MATRIX(int64_t, __builtin_spv_SNegate, __builtin_spv_IAdd, - __builtin_spv_ISub, true); -// COOPERATIVE_MATRIX(uint16_t, __builtin_spv_SNegate, __builtin_spv_IAdd, -// __builtin_spv_ISub, false); -COOPERATIVE_MATRIX(uint32_t, __builtin_spv_SNegate, __builtin_spv_IAdd, - __builtin_spv_ISub, false); -COOPERATIVE_MATRIX(uint64_t, __builtin_spv_SNegate, __builtin_spv_IAdd, - __builtin_spv_ISub, false); + +template +CooperativeMatrix +CooperativeMatrix::negate() { + CooperativeMatrix result; + result._matrix = util::ArithmeticSelector::Negate(_matrix); + return result; +} + +template +CooperativeMatrix +CooperativeMatrix::operator+( + CooperativeMatrix other) { + CooperativeMatrix result; + result._matrix = + util::ArithmeticSelector::Add(_matrix, other._matrix); + return result; +} + +template +CooperativeMatrix +CooperativeMatrix::operator-( + CooperativeMatrix other) { + CooperativeMatrix result; + result._matrix = util::ArithmeticSelector::Sub(_matrix, other._matrix); + return result; +} + +template +CooperativeMatrix +CooperativeMatrix::operator*( + CooperativeMatrix other) { + CooperativeMatrix result; + result._matrix = util::ArithmeticSelector::Mul(_matrix, other._matrix); + return result; +} + +template +CooperativeMatrix +CooperativeMatrix::operator/( + CooperativeMatrix other) { + CooperativeMatrix result; + result._matrix = util::ArithmeticSelector::Div(_matrix, other._matrix); + return result; +} + +template +CooperativeMatrix +CooperativeMatrix::operator*( + ComponentType scalar) { + CooperativeMatrix result; + result._matrix = __builtin_spv_MatrixTimesScalar(_matrix, scalar); + return result; +} + +template +void CooperativeMatrix::StoreRowMajor( + RWStructuredBuffer data, uint32_t index) { + __builtin_spv_CooperativeMatrixStoreKHR(data[index], _matrix, + CooperativeMatrixLayoutRowMajorKHR); +} + +template +void CooperativeMatrix:: + StoreColumnMajor(RWStructuredBuffer data, uint32_t index) { + __builtin_spv_CooperativeMatrixStoreKHR( + data[index], _matrix, CooperativeMatrixLayoutColumnMajorKHR); +} + +template +template +CooperativeMatrix +CooperativeMatrix::LoadRowMajor( + BufferType buffer, uint32_t index) { + CooperativeMatrix result; + result._matrix = __builtin_spv_CooperativeMatrixLoadKHR( + buffer[index], CooperativeMatrixLayoutRowMajorKHR); + return result; +} + +template +template +CooperativeMatrix +CooperativeMatrix::LoadColumnMajor( + BufferType buffer, uint32_t index) { + CooperativeMatrix result; + result._matrix = __builtin_spv_CooperativeMatrixLoadKHR( + buffer[index], CooperativeMatrixLayoutColumnMajorKHR); + return result; +} + +template +uint CooperativeMatrix::GetLength() { + return __builtin_spv_CooperativeMatrixLengthKHR(); +} template CooperativeMatrixAccumulator @@ -203,8 +173,9 @@ cooperativeMatrixMultiplyAdd( const vk::CooperativeMatrixOperandsMask operands = (vk::CooperativeMatrixOperandsMask)( - a.isSignedInteger ? allSignedComponents - : vk::CooperativeMatrixOperandsMaskNone); + a.hasSignedIntegerComponentType + ? allSignedComponents + : vk::CooperativeMatrixOperandsMaskNone); CooperativeMatrixAccumulator result; result._matrix = __builtin_spv_CooperativeMatrixMulAddKHR< @@ -230,7 +201,7 @@ cooperativeMatrixSaturatingMultiplyAdd( const vk::CooperativeMatrixOperandsMask operands = (vk::CooperativeMatrixOperandsMask)( - a.isSignedInteger + a.hasSignedIntegerComponentType ? allSignedComponents : vk::CooperativeMatrixOperandsSaturatingAccumulationKHRMask); CooperativeMatrixAccumulator result;