diff --git a/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.h b/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.h new file mode 100644 index 0000000000..e40502bb5f --- /dev/null +++ b/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.h @@ -0,0 +1,274 @@ +// Copyright (c) 2024 Google LLC +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_ +#define _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_ + +// TODO: Add a macro to HLSL to be able to check the Vulkan version being +// targeted. + +#include "vk/spirv.h" + +namespace vk { +namespace khr { + +// The base cooperative matrix class. The template arguments correspond to the +// operands in the OpTypeCooperativeMatrixKHR instruction. +template +class CooperativeMatrix { + template + CooperativeMatrix cast(); + + // Apply OpSNegate or OFNegate, depending on ComponentType, in a element by + // element manner. + CooperativeMatrix negate(); + + // Apply OpIAdd or OFAdd, depending on ComponentType, in a element by element + // manner. + CooperativeMatrix operator+(CooperativeMatrix other); + + // Apply OpISub or OFSub, depending on ComponentType, in a element by element + // manner. + CooperativeMatrix operator-(CooperativeMatrix other); + + // Apply OpIMul or OFMul, depending on ComponentType, in a element by element + // manner. + CooperativeMatrix operator*(CooperativeMatrix other); + + // Apply OpSDiv, OpUDiv or OFDiv, depending on ComponentType, in a element by + // element manner. + CooperativeMatrix operator/(CooperativeMatrix other); + + // Apply OpMatrixTimesScalar in a element by element manner. + CooperativeMatrix operator*(ComponentType scalar); + + // Store the cooperative matrix using OpCooperativeMatrixStoreKHR to + // data using the given memory layout, stride, and memory access operands. + // `NonPrivatePointer` and `MakePointerAvailable` with the workgroup scope + // will be added to the memory access operands to make the memory coherent. + // + // This function uses a SPIR-V pointer because HLSL does not allow groupshared + // memory object to be passed by reference. The pointer is a hack to get + // around that. + // + // The layout and stride will be passed to the SPIR-V instruction as is. The + // precise meaning can be found in the specification for + // SPV_KHR_cooperative_matrix. + template + void Store(WorkgroupSpirvPointer data, uint32_t stride); + + // Same as above, but uses MemoryAccessMaskNone for the memory access + // operands. + template + void Store(WorkgroupSpirvPointer data, uint32_t stride) { + Store(data, stride); + } + + // Store the cooperative matrix using OpCooperativeMatrixStoreKHR to + // data[index] using the given memory layout, stride, and memory access + // operands. The layout and stride will be passed to the SPIR-V instruction as + // is. The precise meaning can be found in the specification for + // SPV_KHR_cooperative_matrix. + template + void Store(RWStructuredBuffer data, uint32_t index, uint32_t stride); + + // Same as above, but uses MemoryAccessMaskNone for the memory access + // operands. + template + void Store(RWStructuredBuffer data, uint32_t index, uint32_t stride) { + Store(data, index, stride); + } + + // Store the cooperative matrix using OpCooperativeMatrixStoreKHR to + // data[index] using the given memory layout, stride, and memory access + // operands. `NonPrivatePointer` and `MakePointerAvailable` with the + // QueueFamily scope will be added to the memory access operands to make the + // memory coherent. + // + // The layout and stride will be passed to the SPIR-V instruction as is. The + // precise meaning can be found in the specification for + // SPV_KHR_cooperative_matrix. + template + void CoherentStore(globallycoherent RWStructuredBuffer data, + uint32_t index, uint32_t stride); + + // Same as above, but uses MemoryAccessMaskNone for the memory access operands + // template argument. + template + void CoherentStore(globallycoherent RWStructuredBuffer data, + uint32_t index, uint32_t stride) { + CoherentStore(data, index, stride); + } + + // Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from + // data using the given memory layout, stride, and memory access operands. + // `NonPrivatePointer` and `MakePointerVisible` with the workgroup scope + // will be added to the memory access operands to make the memory coherent. + // + // This function uses a SPIR-V pointer because HLSL does not allow groupshared + // memory object to be passed by reference. The pointer is a hack to get + // around that. + // + // The layout and stride will be passed to the SPIR-V instruction as is. The + // precise meaning can be found in the specification for + // SPV_KHR_cooperative_matrix. + template + static CooperativeMatrix Load(WorkgroupSpirvPointer data, + uint32_t stride); + + // Same as above, but uses MemoryAccessMaskNone for the memory access + // operands. + template + static CooperativeMatrix Load(WorkgroupSpirvPointer data, + uint32_t stride) { + return Load(data, stride); + } + + // Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from + // data[index] using the given memory layout, stride, and memory access + // operands. + // + // The layout and stride will be passed to the SPIR-V instruction as is. The + // precise meaning can be found in the specification for + // SPV_KHR_cooperative_matrix. + template + static CooperativeMatrix Load(RWStructuredBuffer data, uint32_t index, + uint32_t stride); + + // Same as above, but uses MemoryAccessMaskNone for the memory access + // operands. + template + static CooperativeMatrix Load(RWStructuredBuffer data, uint32_t index, + uint32_t stride) { + return Load(data, index, stride); + } + + // Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from + // data[index] using the given memory layout, stride, and memory access + // operands. `NonPrivatePointer` and `MakePointerVisible` with the QueueFamily + // scope will be added to the memory access operands to make the memory + // coherent. + // + // + // The layout and stride will be passed to the SPIR-V instruction as is. The + // precise meaning can be found in the specification for + // SPV_KHR_cooperative_matrix. + template + static CooperativeMatrix + CoherentLoad(globallycoherent RWStructuredBuffer data, uint32_t index, + uint32_t stride); + + // Same as above, but uses MemoryAccessMaskNone for the memory access operands + // template argument. + template + static CooperativeMatrix + CoherentLoad(globallycoherent RWStructuredBuffer data, uint32_t index, + uint32_t stride) { + return CoherentLoad(data, index, stride); + } + + // Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from + // data[index] using the given memory layout, stride, and memory access + // operands. No memory access bits are added to the operands. Since the memory + // is readonly, there should be no need. + // + // The layout and stride will be passed to the SPIR-V instruction as is. The + // precise meaning can be found in the specification for + // SPV_KHR_cooperative_matrix. + template + static CooperativeMatrix Load(StructuredBuffer data, uint32_t index, + uint32_t stride); + + // Same as above, but uses MemoryAccessMaskNone for the memory access + // operands. + template + static CooperativeMatrix Load(StructuredBuffer data, uint32_t index, + uint32_t stride) { + return Load(data, index, stride); + } + + // Constructs a cooperative matrix with all values initialized to v. Note that + // all threads in scope must have the same value for v. + static CooperativeMatrix Splat(ComponentType v); + + // Returns the result of OpCooperativeMatrixLengthKHR on the current type. + static uint32_t GetLength(); + + // Functions to access the elements of the cooperative matrix. The index must + // be less than GetLength(). + void Set(ComponentType value, uint32_t index); + ComponentType Get(uint32_t index); + + static const bool hasSignedIntegerComponentType = + (ComponentType(0) - ComponentType(1) < ComponentType(0)); + + // clang-format off + using SpirvMatrixType = vk::SpirvOpaqueType< + /* OpTypeCooperativeMatrixKHR */ 4456, ComponentType, + vk::integral_constant, vk::integral_constant, + vk::integral_constant, vk::integral_constant >; + + [[vk::ext_extension("SPV_KHR_cooperative_matrix")]] + [[vk::ext_capability(/* CooperativeMatrixKHRCapability */ 6022)]] + [[vk::ext_capability(/* VulkanMemoryModel */ 5345)]] + SpirvMatrixType _matrix; + // clang-format on +}; + +// Cooperative matrix that can be used in the "a" position of a multiply add +// instruction (r = (a * b) + c). +template +using CooperativeMatrixA = + CooperativeMatrix; + +// Cooperative matrix that can be used in the "b" position of a multiply add +// instruction (r = (a * b) + c). +template +using CooperativeMatrixB = + CooperativeMatrix; + +// Cooperative matrix that can be used in the "r" and "c" position of a multiply +// add instruction (r = (a * b) + c). +template +using CooperativeMatrixAccumulator = + CooperativeMatrix; + +// Returns the result of OpCooperativeMatrixMulAddKHR when applied to a, b, and +// c. The cooperative matrix operands are inferred, with the +// SaturatingAccumulationKHR bit not set. +template +CooperativeMatrixAccumulator +cooperativeMatrixMultiplyAdd( + CooperativeMatrixA a, + CooperativeMatrixB b, + CooperativeMatrixAccumulator c); + +// Returns the result of OpCooperativeMatrixMulAddKHR when applied to a, b, and +// c. The cooperative matrix operands are inferred, with the +// SaturatingAccumulationKHR bit set. +template +CooperativeMatrixAccumulator +cooperativeMatrixSaturatingMultiplyAdd( + CooperativeMatrixA a, + CooperativeMatrixB b, + CooperativeMatrixAccumulator c); + +} // namespace khr +} // namespace vk + +#include "cooperative_matrix.impl" +#endif // _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_ diff --git a/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.impl b/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.impl new file mode 100644 index 0000000000..2acae8ec96 --- /dev/null +++ b/tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.impl @@ -0,0 +1,377 @@ +// Copyright (c) 2024 Google LLC +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "vk/opcode_selector.h" + +template +[[vk::ext_instruction(/* OpMatrixTimesScalar */ 143)]] ResultType +__builtin_spv_MatrixTimesScalar(ResultType a, ComponentType b); + +template +[[vk::ext_instruction(/* OpCompositeExtract */ 81)]] ComponentType +__builtin_spv_ExtractFromCooperativeMatrix( + typename vk::khr::CooperativeMatrix::SpirvMatrixType matrix, + uint32_t index); + +template +[[vk::ext_instruction(/* OpCompositeConstruct */ 80)]] CoopMatrixType +__builtin_spv_ConstructCooperativeMatrix(ComponentType value); + +template +[[vk::ext_instruction(/* OpAccessChain */ 65)]] ResultPointerType +__builtin_spv_AccessChain([[vk::ext_reference]] BaseType base, uint32_t index); + +template +[[vk::ext_instruction(/* OpLoad */ 61)]] ObjectType +__builtin_spv_LoadPointer(PointerType base); + +template +[[vk::ext_instruction(/* OpLoad */ 62)]] void +__builtin_spv_StorePointer(PointerType base, ObjectType object); + +template +[[vk::ext_instruction(/* OpCompositeInsert */ 82)]] +typename vk::khr::CooperativeMatrix::SpirvMatrixType +__builtin_spv_InsertIntoCooperativeMatrix( + ComponentType value, + typename vk::khr::CooperativeMatrix::SpirvMatrixType matrix, + uint32_t index); + +// Define the load and store instructions +template +[[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType +__builtin_spv_CooperativeMatrixLoadKHR( + [[vk::ext_reference]] PointerType pointer, + vk::CooperativeMatrixLayout memory_layout, uint stride, + [[vk::ext_literal]] uint32_t memory_operand); + +template +[[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType +__builtin_spv_CooperativeMatrixLoadKHR( + [[vk::ext_reference]] PointerType pointer, + vk::CooperativeMatrixLayout memory_layout, uint stride, + [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope); + +template +[[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType +__builtin_spv_CooperativeMatrixWorkgroupLoadKHR( + vk::WorkgroupSpirvPointer pointer, + vk::CooperativeMatrixLayout memory_layout, uint stride, + [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope); + +template +[[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void +__builtin_spv_CooperativeMatrixStoreKHR( + [[vk::ext_reference]] PointerType pointer, ObjectType object, + vk::CooperativeMatrixLayout memory_layout, uint stride, + [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope); + +template +[[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void +__builtin_spv_CooperativeMatrixStoreKHR( + [[vk::ext_reference]] PointerType pointer, ObjectType object, + vk::CooperativeMatrixLayout memory_layout, uint stride, + [[vk::ext_literal]] uint32_t memory_operand); + +template +[[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void +__builtin_spv_CooperativeMatrixWorkgroupStoreKHR( + vk::WorkgroupSpirvPointer pointer, ObjectType object, + vk::CooperativeMatrixLayout memory_layout, uint stride, + [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope); + +// We cannot define `OpCooperativeMatrixLengthKHR` using ext_instruction because +// one of the operands is a type id. This builtin will have specific code in the +// compiler to expand it. +template uint __builtin_spv_CooperativeMatrixLengthKHR(); + +// Arithmetic Instructions +template +[[vk::ext_instruction(/* OpCooperativeMatrixMulAddKHR */ 4459)]] ResultType +__builtin_spv_CooperativeMatrixMulAddKHR(MatrixTypeA a, MatrixTypeB b, + MatrixTypeC c, + [[vk::ext_literal]] int operands); +namespace vk { +namespace khr { + +template +template +CooperativeMatrix +CooperativeMatrix::cast() { + using ResultType = + CooperativeMatrix; + ResultType result; + result._matrix = util::ConversionSelector:: + template Convert(_matrix); + return result; +} + +template +CooperativeMatrix +CooperativeMatrix::negate() { + CooperativeMatrix result; + result._matrix = util::ArithmeticSelector::Negate(_matrix); + return result; +} + +template +CooperativeMatrix +CooperativeMatrix::operator+( + CooperativeMatrix other) { + CooperativeMatrix result; + result._matrix = + util::ArithmeticSelector::Add(_matrix, other._matrix); + return result; +} + +template +CooperativeMatrix +CooperativeMatrix::operator-( + CooperativeMatrix other) { + CooperativeMatrix result; + result._matrix = + util::ArithmeticSelector::Sub(_matrix, other._matrix); + return result; +} + +template +CooperativeMatrix +CooperativeMatrix::operator*( + CooperativeMatrix other) { + CooperativeMatrix result; + result._matrix = + util::ArithmeticSelector::Mul(_matrix, other._matrix); + return result; +} + +template +CooperativeMatrix +CooperativeMatrix::operator/( + CooperativeMatrix other) { + CooperativeMatrix result; + result._matrix = + util::ArithmeticSelector::Div(_matrix, other._matrix); + return result; +} + +template +CooperativeMatrix +CooperativeMatrix::operator*( + ComponentType scalar) { + CooperativeMatrix result; + result._matrix = __builtin_spv_MatrixTimesScalar(_matrix, scalar); + return result; +} + +template +template +void CooperativeMatrix::Store( + WorkgroupSpirvPointer data, uint32_t stride) { + __builtin_spv_CooperativeMatrixWorkgroupStoreKHR( + data, _matrix, layout, stride, + memoryAccessOperands | MemoryAccessNonPrivatePointerMask | + MemoryAccessMakePointerAvailableMask, + ScopeWorkgroup); +} + +template +template +void CooperativeMatrix::Store( + RWStructuredBuffer data, uint32_t index, uint32_t stride) { + __builtin_spv_CooperativeMatrixStoreKHR(data[index], _matrix, layout, stride, + memoryAccessOperands); +} + +template +template +void CooperativeMatrix::CoherentStore( + globallycoherent RWStructuredBuffer data, uint32_t index, + uint32_t stride) { + __builtin_spv_CooperativeMatrixStoreKHR( + data[index], _matrix, layout, stride, + memoryAccessOperands | MemoryAccessNonPrivatePointerMask | + MemoryAccessMakePointerAvailableMask, + ScopeQueueFamily); +} + +template +template +CooperativeMatrix +CooperativeMatrix::Load( + vk::WorkgroupSpirvPointer buffer, uint32_t stride) { + CooperativeMatrix result; + result._matrix = + __builtin_spv_CooperativeMatrixWorkgroupLoadKHR( + buffer, layout, stride, + memoryAccessOperands | MemoryAccessNonPrivatePointerMask | + MemoryAccessMakePointerVisibleMask, + ScopeWorkgroup); + return result; +} + +template +template +CooperativeMatrix +CooperativeMatrix::Load( + RWStructuredBuffer buffer, uint32_t index, uint32_t stride) { + CooperativeMatrix result; + result._matrix = __builtin_spv_CooperativeMatrixLoadKHR( + buffer[index], layout, stride, memoryAccessOperands); + return result; +} + +template +template +CooperativeMatrix +CooperativeMatrix::CoherentLoad( + RWStructuredBuffer buffer, uint32_t index, uint32_t stride) { + CooperativeMatrix result; + result._matrix = __builtin_spv_CooperativeMatrixLoadKHR( + buffer[index], layout, stride, + memoryAccessOperands | MemoryAccessNonPrivatePointerMask | + MemoryAccessMakePointerVisibleMask, + ScopeQueueFamily); + return result; +} + +template +template +CooperativeMatrix +CooperativeMatrix::Load( + StructuredBuffer buffer, uint32_t index, uint32_t stride) { + CooperativeMatrix result; + result._matrix = __builtin_spv_CooperativeMatrixLoadKHR( + buffer[index], layout, stride, MemoryAccessMaskNone); + return result; +} + +template +CooperativeMatrix +CooperativeMatrix::Splat( + ComponentType v) { + CooperativeMatrix result; + result._matrix = __builtin_spv_ConstructCooperativeMatrix(v); + return result; +} + +template +uint CooperativeMatrix::GetLength() { + return __builtin_spv_CooperativeMatrixLengthKHR(); +} + +template +ComponentType CooperativeMatrix::Get( + uint32_t index) { + // clang-format off + using ComponentPtr = vk::SpirvOpaqueType< + /* OpTypePointer */ 32, + /* function storage class */ vk::Literal >, + ComponentType>; + // clang-format on + ComponentPtr ptr = __builtin_spv_AccessChain(_matrix, index); + return __builtin_spv_LoadPointer(ptr); +} + +template +void CooperativeMatrix::Set( + ComponentType value, uint32_t index) { + // clang-format off + using ComponentPtr = vk::SpirvOpaqueType< + /* OpTypePointer */ 32, + /* function storage class */ vk::Literal >, + ComponentType>; + // clang-format on + ComponentPtr ptr = __builtin_spv_AccessChain(_matrix, index); + return __builtin_spv_StorePointer(ptr, value); +} + +template +CooperativeMatrixAccumulator +cooperativeMatrixMultiplyAdd( + CooperativeMatrixA a, + CooperativeMatrixB b, + CooperativeMatrixAccumulator c) { + + const vk::CooperativeMatrixOperandsMask allSignedComponents = + vk::CooperativeMatrixOperandsMatrixASignedComponentsKHRMask | + vk::CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask | + vk::CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask | + vk::CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask; + + const vk::CooperativeMatrixOperandsMask operands = + (vk::CooperativeMatrixOperandsMask)( + a.hasSignedIntegerComponentType + ? allSignedComponents + : vk::CooperativeMatrixOperandsMaskNone); + + CooperativeMatrixAccumulator result; + result._matrix = __builtin_spv_CooperativeMatrixMulAddKHR< + typename CooperativeMatrixAccumulator::SpirvMatrixType>( + a._matrix, b._matrix, c._matrix, operands); + return result; +} + +template +CooperativeMatrixAccumulator +cooperativeMatrixSaturatingMultiplyAdd( + CooperativeMatrixA a, + CooperativeMatrixB b, + CooperativeMatrixAccumulator c) { + + const vk::CooperativeMatrixOperandsMask allSignedComponents = + vk::CooperativeMatrixOperandsMatrixASignedComponentsKHRMask | + vk::CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask | + vk::CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask | + vk::CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask | + vk::CooperativeMatrixOperandsSaturatingAccumulationKHRMask; + + const vk::CooperativeMatrixOperandsMask operands = + (vk::CooperativeMatrixOperandsMask)( + a.hasSignedIntegerComponentType + ? allSignedComponents + : vk::CooperativeMatrixOperandsSaturatingAccumulationKHRMask); + CooperativeMatrixAccumulator result; + result._matrix = __builtin_spv_CooperativeMatrixMulAddKHR< + typename CooperativeMatrixAccumulator::SpirvMatrixType>( + a._matrix, b._matrix, c._matrix, operands); + return result; +} + +} // namespace khr +} // namespace vk diff --git a/tools/clang/lib/Headers/hlsl/vk/opcode_selector.h b/tools/clang/lib/Headers/hlsl/vk/opcode_selector.h new file mode 100644 index 0000000000..42e5744239 --- /dev/null +++ b/tools/clang/lib/Headers/hlsl/vk/opcode_selector.h @@ -0,0 +1,227 @@ +// Copyright (c) 2024 Google LLC +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef _HLSL_VK_KHR_OPCODE_SELECTOR_H_ +#define _HLSL_VK_KHR_OPCODE_SELECTOR_H_ + +#define DECLARE_UNARY_OP(name, opcode) \ + template \ + [[vk::ext_instruction(opcode)]] ResultType __builtin_spv_##name( \ + ResultType a) + +DECLARE_UNARY_OP(CopyObj, 83); +DECLARE_UNARY_OP(SNegate, 126); +DECLARE_UNARY_OP(FNegate, 127); + +#define DECLARE_CONVERSION_OP(name, opcode) \ + template \ + [[vk::ext_instruction(opcode)]] ResultType __builtin_spv_##name( \ + OperandType a) + +DECLARE_CONVERSION_OP(ConvertFtoU, 109); +DECLARE_CONVERSION_OP(ConvertFtoS, 110); +DECLARE_CONVERSION_OP(ConvertSToF, 111); +DECLARE_CONVERSION_OP(ConvertUToF, 112); +DECLARE_CONVERSION_OP(UConvert, 113); +DECLARE_CONVERSION_OP(SConvert, 114); +DECLARE_CONVERSION_OP(FConvert, 115); +DECLARE_CONVERSION_OP(Bitcast, 124); + +#undef DECLARY_UNARY_OP + +#define DECLARE_BINOP(name, opcode) \ + template \ + [[vk::ext_instruction(opcode)]] ResultType __builtin_spv_##name( \ + ResultType a, ResultType b) + +DECLARE_BINOP(IAdd, 128); +DECLARE_BINOP(FAdd, 129); +DECLARE_BINOP(ISub, 130); +DECLARE_BINOP(FSub, 131); +DECLARE_BINOP(IMul, 132); +DECLARE_BINOP(FMul, 133); +DECLARE_BINOP(UDiv, 134); +DECLARE_BINOP(SDiv, 135); +DECLARE_BINOP(FDiv, 136); + +#undef DECLARE_BINOP +namespace vk { +namespace util { + +template class ArithmeticSelector; + +#define ARITHMETIC_SELECTOR(BaseType, OpNegate, OpAdd, OpSub, OpMul, OpDiv, \ + SIGNED_INTEGER_TYPE) \ + template <> class ArithmeticSelector { \ + template static T Negate(T a) { return OpNegate(a); } \ + template static T Add(T a, T b) { return OpAdd(a, b); } \ + template static T Sub(T a, T b) { return OpSub(a, b); } \ + template static T Mul(T a, T b) { return OpMul(a, b); } \ + template static T Div(T a, T b) { return OpDiv(a, b); } \ + }; + +ARITHMETIC_SELECTOR(half, __builtin_spv_FNegate, __builtin_spv_FAdd, + __builtin_spv_FSub, __builtin_spv_FMul, __builtin_spv_FDiv, + false); +ARITHMETIC_SELECTOR(float, __builtin_spv_FNegate, __builtin_spv_FAdd, + __builtin_spv_FSub, __builtin_spv_FMul, __builtin_spv_FDiv, + false); +ARITHMETIC_SELECTOR(double, __builtin_spv_FNegate, __builtin_spv_FAdd, + __builtin_spv_FSub, __builtin_spv_FMul, __builtin_spv_FDiv, + false); + +#if __HLSL_ENABLE_16_BIT +ARITHMETIC_SELECTOR(int16_t, __builtin_spv_SNegate, __builtin_spv_IAdd, + __builtin_spv_ISub, __builtin_spv_IMul, __builtin_spv_SDiv, + true); +ARITHMETIC_SELECTOR(uint16_t, __builtin_spv_SNegate, __builtin_spv_IAdd, + __builtin_spv_ISub, __builtin_spv_IMul, __builtin_spv_UDiv, + false); +#endif // __HLSL_ENABLE_16_BIT + +ARITHMETIC_SELECTOR(int32_t, __builtin_spv_SNegate, __builtin_spv_IAdd, + __builtin_spv_ISub, __builtin_spv_IMul, __builtin_spv_SDiv, + true); +ARITHMETIC_SELECTOR(int64_t, __builtin_spv_SNegate, __builtin_spv_IAdd, + __builtin_spv_ISub, __builtin_spv_IMul, __builtin_spv_SDiv, + true); +ARITHMETIC_SELECTOR(uint32_t, __builtin_spv_SNegate, __builtin_spv_IAdd, + __builtin_spv_ISub, __builtin_spv_IMul, __builtin_spv_UDiv, + false); +ARITHMETIC_SELECTOR(uint64_t, __builtin_spv_SNegate, __builtin_spv_IAdd, + __builtin_spv_ISub, __builtin_spv_IMul, __builtin_spv_UDiv, + false); + +// The conversion selector is will be used to convert one type to another +// using the SPIR-V conversion instructions. See +// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_conversion_instructions. +// SourceType and TargetType must be integer or floating point scalar type. + +// ConversionSelector::Convert converts an object of type S to an object of type +// T. S must be SourceType, a vector of SourceType, or a cooperative matrix of +// SourceType. T must be TargetType, a vector of TargetType, or a cooperative +// matrix of TargetType. T must have the same number of components as S. T is a +// cooperative matrix if and only if S is a cooperative matrix. +template class ConversionSelector; + +#define CONVERSION_SELECTOR(SourceType, TargetType, OpConvert) \ + template <> class ConversionSelector { \ + template static T Convert(S a) { \ + return OpConvert(a); \ + } \ + }; + +#if __HLSL_ENABLE_16_BIT +CONVERSION_SELECTOR(uint16_t, uint16_t, __builtin_spv_CopyObj); +CONVERSION_SELECTOR(uint16_t, int16_t, __builtin_spv_Bitcast); +CONVERSION_SELECTOR(uint16_t, uint32_t, __builtin_spv_UConvert); +CONVERSION_SELECTOR(uint16_t, int32_t, __builtin_spv_SConvert); +CONVERSION_SELECTOR(uint16_t, uint64_t, __builtin_spv_UConvert); +CONVERSION_SELECTOR(uint16_t, int64_t, __builtin_spv_SConvert); +CONVERSION_SELECTOR(uint16_t, half, __builtin_spv_ConvertUToF); +CONVERSION_SELECTOR(uint16_t, float, __builtin_spv_ConvertUToF); +CONVERSION_SELECTOR(uint16_t, double, __builtin_spv_ConvertUToF); + +CONVERSION_SELECTOR(int16_t, uint16_t, __builtin_spv_Bitcast); +CONVERSION_SELECTOR(int16_t, int16_t, __builtin_spv_CopyObj); +CONVERSION_SELECTOR(int16_t, uint32_t, __builtin_spv_UConvert); +CONVERSION_SELECTOR(int16_t, int32_t, __builtin_spv_SConvert); +CONVERSION_SELECTOR(int16_t, uint64_t, __builtin_spv_UConvert); +CONVERSION_SELECTOR(int16_t, int64_t, __builtin_spv_SConvert); +CONVERSION_SELECTOR(int16_t, half, __builtin_spv_ConvertSToF); +CONVERSION_SELECTOR(int16_t, float, __builtin_spv_ConvertSToF); +CONVERSION_SELECTOR(int16_t, double, __builtin_spv_ConvertSToF); + +CONVERSION_SELECTOR(uint32_t, uint16_t, __builtin_spv_UConvert); +CONVERSION_SELECTOR(uint32_t, int16_t, __builtin_spv_SConvert); + +CONVERSION_SELECTOR(int32_t, uint16_t, __builtin_spv_UConvert); +CONVERSION_SELECTOR(int32_t, int16_t, __builtin_spv_SConvert); + +CONVERSION_SELECTOR(uint64_t, uint16_t, __builtin_spv_UConvert); +CONVERSION_SELECTOR(uint64_t, int16_t, __builtin_spv_SConvert); + +CONVERSION_SELECTOR(int64_t, uint16_t, __builtin_spv_UConvert); +CONVERSION_SELECTOR(int64_t, int16_t, __builtin_spv_SConvert); + +CONVERSION_SELECTOR(half, uint16_t, __builtin_spv_ConvertFtoU); +CONVERSION_SELECTOR(half, int16_t, __builtin_spv_ConvertFtoS); + +CONVERSION_SELECTOR(float, uint16_t, __builtin_spv_ConvertFtoU); +CONVERSION_SELECTOR(float, int16_t, __builtin_spv_ConvertFtoS); + +CONVERSION_SELECTOR(double, uint16_t, __builtin_spv_ConvertFtoU); +CONVERSION_SELECTOR(double, int16_t, __builtin_spv_ConvertFtoS); +#endif + +CONVERSION_SELECTOR(uint32_t, uint32_t, __builtin_spv_CopyObj); +CONVERSION_SELECTOR(uint32_t, int32_t, __builtin_spv_Bitcast); +CONVERSION_SELECTOR(uint32_t, uint64_t, __builtin_spv_UConvert); +CONVERSION_SELECTOR(uint32_t, int64_t, __builtin_spv_SConvert); +CONVERSION_SELECTOR(uint32_t, half, __builtin_spv_ConvertUToF); +CONVERSION_SELECTOR(uint32_t, float, __builtin_spv_ConvertUToF); +CONVERSION_SELECTOR(uint32_t, double, __builtin_spv_ConvertUToF); + +CONVERSION_SELECTOR(int32_t, uint32_t, __builtin_spv_Bitcast); +CONVERSION_SELECTOR(int32_t, int32_t, __builtin_spv_CopyObj); +CONVERSION_SELECTOR(int32_t, uint64_t, __builtin_spv_UConvert); +CONVERSION_SELECTOR(int32_t, int64_t, __builtin_spv_SConvert); +CONVERSION_SELECTOR(int32_t, half, __builtin_spv_ConvertSToF); +CONVERSION_SELECTOR(int32_t, float, __builtin_spv_ConvertSToF); +CONVERSION_SELECTOR(int32_t, double, __builtin_spv_ConvertSToF); + +CONVERSION_SELECTOR(uint64_t, uint32_t, __builtin_spv_UConvert); +CONVERSION_SELECTOR(uint64_t, int32_t, __builtin_spv_SConvert); +CONVERSION_SELECTOR(uint64_t, uint64_t, __builtin_spv_Bitcast); +CONVERSION_SELECTOR(uint64_t, int64_t, __builtin_spv_CopyObj); +CONVERSION_SELECTOR(uint64_t, half, __builtin_spv_ConvertUToF); +CONVERSION_SELECTOR(uint64_t, float, __builtin_spv_ConvertUToF); +CONVERSION_SELECTOR(uint64_t, double, __builtin_spv_ConvertUToF); + +CONVERSION_SELECTOR(int64_t, uint32_t, __builtin_spv_UConvert); +CONVERSION_SELECTOR(int64_t, int32_t, __builtin_spv_SConvert); +CONVERSION_SELECTOR(int64_t, uint64_t, __builtin_spv_Bitcast); +CONVERSION_SELECTOR(int64_t, int64_t, __builtin_spv_CopyObj); +CONVERSION_SELECTOR(int64_t, half, __builtin_spv_ConvertSToF); +CONVERSION_SELECTOR(int64_t, float, __builtin_spv_ConvertSToF); +CONVERSION_SELECTOR(int64_t, double, __builtin_spv_ConvertSToF); + +CONVERSION_SELECTOR(half, uint32_t, __builtin_spv_ConvertFtoU); +CONVERSION_SELECTOR(half, int32_t, __builtin_spv_ConvertFtoS); +CONVERSION_SELECTOR(half, uint64_t, __builtin_spv_ConvertFtoU); +CONVERSION_SELECTOR(half, int64_t, __builtin_spv_ConvertFtoS); +CONVERSION_SELECTOR(half, half, __builtin_spv_CopyObj); +#if __HLSL_ENABLE_16_BIT +CONVERSION_SELECTOR(half, float, __builtin_spv_FConvert); +#else +CONVERSION_SELECTOR(half, float, __builtin_spv_CopyObj); +#endif + +CONVERSION_SELECTOR(half, double, __builtin_spv_FConvert); + +CONVERSION_SELECTOR(float, uint32_t, __builtin_spv_ConvertFtoU); +CONVERSION_SELECTOR(float, int32_t, __builtin_spv_ConvertFtoS); +CONVERSION_SELECTOR(float, uint64_t, __builtin_spv_ConvertFtoU); +CONVERSION_SELECTOR(float, int64_t, __builtin_spv_ConvertFtoS); +#if __HLSL_ENABLE_16_BIT +CONVERSION_SELECTOR(float, half, __builtin_spv_FConvert); +#else +CONVERSION_SELECTOR(float, half, __builtin_spv_CopyObj); +#endif +CONVERSION_SELECTOR(float, float, __builtin_spv_CopyObj); +CONVERSION_SELECTOR(float, double, __builtin_spv_FConvert); + +CONVERSION_SELECTOR(double, uint32_t, __builtin_spv_ConvertFtoU); +CONVERSION_SELECTOR(double, int32_t, __builtin_spv_ConvertFtoS); +CONVERSION_SELECTOR(double, uint64_t, __builtin_spv_ConvertFtoU); +CONVERSION_SELECTOR(double, int64_t, __builtin_spv_ConvertFtoS); +CONVERSION_SELECTOR(double, half, __builtin_spv_FConvert); +CONVERSION_SELECTOR(double, float, __builtin_spv_FConvert); +CONVERSION_SELECTOR(double, double, __builtin_spv_CopyObj); +}; // namespace util +} // namespace vk + +#endif // _HLSL_VK_KHR_OPCODE_SELECTOR_H_ diff --git a/tools/clang/lib/Headers/hlsl/vk/spirv.h b/tools/clang/lib/Headers/hlsl/vk/spirv.h index 294af03078..69bb53bddc 100644 --- a/tools/clang/lib/Headers/hlsl/vk/spirv.h +++ b/tools/clang/lib/Headers/hlsl/vk/spirv.h @@ -9,6 +9,57 @@ namespace vk { +enum CooperativeMatrixUse { + CooperativeMatrixUseMatrixAKHR = 0, + CooperativeMatrixUseMatrixBKHR = 1, + CooperativeMatrixUseMatrixAccumulatorKHR = 2, + CooperativeMatrixUseMax = 0x7fffffff, +}; + +enum CooperativeMatrixLayout { + CooperativeMatrixLayoutRowMajorKHR = 0, + CooperativeMatrixLayoutColumnMajorKHR = 1, + CooperativeMatrixLayoutRowBlockedInterleavedARM = 4202, + CooperativeMatrixLayoutColumnBlockedInterleavedARM = 4203, + CooperativeMatrixLayoutMax = 0x7fffffff, +}; + +enum CooperativeMatrixOperandsMask { + CooperativeMatrixOperandsMaskNone = 0, + CooperativeMatrixOperandsMatrixASignedComponentsKHRMask = 0x00000001, + CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask = 0x00000002, + CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask = 0x00000004, + CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask = 0x00000008, + CooperativeMatrixOperandsSaturatingAccumulationKHRMask = 0x00000010, +}; + +enum MemoryAccessMask { + MemoryAccessMaskNone = 0, + MemoryAccessVolatileMask = 0x00000001, + MemoryAccessAlignedMask = 0x00000002, + MemoryAccessNontemporalMask = 0x00000004, + MemoryAccessMakePointerAvailableMask = 0x00000008, + MemoryAccessMakePointerAvailableKHRMask = 0x00000008, + MemoryAccessMakePointerVisibleMask = 0x00000010, + MemoryAccessMakePointerVisibleKHRMask = 0x00000010, + MemoryAccessNonPrivatePointerMask = 0x00000020, + MemoryAccessNonPrivatePointerKHRMask = 0x00000020, + MemoryAccessAliasScopeINTELMaskMask = 0x00010000, + MemoryAccessNoAliasINTELMaskMask = 0x00020000, +}; + +enum Scope { + ScopeCrossDevice = 0, + ScopeDevice = 1, + ScopeWorkgroup = 2, + ScopeSubgroup = 3, + ScopeInvocation = 4, + ScopeQueueFamily = 5, + ScopeQueueFamilyKHR = 5, + ScopeShaderCallKHR = 6, + ScopeMax = 0x7fffffff, +}; + enum StorageClass { StorageClassWorkgroup = 4, }; diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index d09bf76946..2df93bd720 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -413,6 +413,17 @@ void EmitVisitor::emitDebugLine(spv::Op op, const SourceLocation &loc, section->insert(section->end(), curInst.begin(), curInst.end()); } +bool EmitVisitor::emitCooperativeMatrixLength(SpirvUnaryOp *inst) { + initInstruction(inst); + curInst.push_back(inst->getResultTypeId()); + curInst.push_back(getOrAssignResultId(inst)); + const uint32_t operandResultTypeId = + typeHandler.emitType(inst->getOperand()->getResultType()); + curInst.push_back(operandResultTypeId); + finalizeInstruction(&mainBinary); + return true; +} + void EmitVisitor::initInstruction(SpirvInstruction *inst) { // Emit the result type if the instruction has a result type. if (inst->hasResultType()) { @@ -1318,6 +1329,10 @@ bool EmitVisitor::visit(SpirvNullaryOp *inst) { } bool EmitVisitor::visit(SpirvUnaryOp *inst) { + if (inst->getopcode() == spv::Op::OpCooperativeMatrixLengthKHR) { + return emitCooperativeMatrixLength(inst); + } + initInstruction(inst); curInst.push_back(inst->getResultTypeId()); curInst.push_back(getOrAssignResultId(inst)); diff --git a/tools/clang/lib/SPIRV/EmitVisitor.h b/tools/clang/lib/SPIRV/EmitVisitor.h index 9feacc58d2..b6da5bc16e 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.h +++ b/tools/clang/lib/SPIRV/EmitVisitor.h @@ -399,6 +399,11 @@ class EmitVisitor : public Visitor { emittedSource[fileId] = dbg_src_id; } + // Emits an OpCooperativeMatrixLength instruction into the main binary + // section. It will replace the operand with the id of the type of the + // operand. + bool emitCooperativeMatrixLength(SpirvUnaryOp *inst); + private: /// Emits error to the diagnostic engine associated with this visitor. template diff --git a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp index b847e1a040..94b9fe23d6 100644 --- a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp +++ b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp @@ -223,10 +223,11 @@ bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) { // Access chains must have a pointer type. The storage class for the pointer // is the same as the storage class of the access base. case spv::Op::OpAccessChain: { - const auto *pointerType = spvContext.getPointerType( - resultType, - cast(instr)->getBase()->getStorageClass()); - instr->setResultType(pointerType); + if (auto *acInst = dyn_cast(instr)) { + const auto *pointerType = spvContext.getPointerType( + resultType, acInst->getBase()->getStorageClass()); + instr->setResultType(pointerType); + } break; } // OpImageTexelPointer's result type must be a pointer with image storage diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 7f89ec9277..676e97131e 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -532,6 +532,12 @@ bool isVkRawBufferLoadIntrinsic(const clang::FunctionDecl *FD) { return true; } +bool isCooperativeMatrixGetLengthIntrinsic( + const FunctionDecl *functionDeclaration) { + return functionDeclaration->getName().equals( + "__builtin_spv_CooperativeMatrixLengthKHR"); +} + // Takes an AST member type, and determines its index in the equivalent SPIR-V // struct type. This is required as the struct layout might change between the // AST representation and SPIR-V representation. @@ -2905,6 +2911,11 @@ SpirvInstruction *SpirvEmitter::doCallExpr(const CallExpr *callExpr, return processRawBufferLoad(callExpr); } + // Handle CooperativeMatrix::GetLength() + if (isCooperativeMatrixGetLengthIntrinsic(funcDecl)) { + return processCooperativeMatrixGetLength(callExpr); + } + // Normal standalone functions return processCall(callExpr); } @@ -14685,6 +14696,35 @@ SpirvEmitter::processRawBufferStore(const CallExpr *callExpr) { callExpr->getLocStart()); } +SpirvInstruction * +SpirvEmitter::processCooperativeMatrixGetLength(const CallExpr *call) { + auto *declaration = dyn_cast(call->getCalleeDecl()); + assert(declaration); + + const auto *templateSpecializationInfo = + declaration->getTemplateSpecializationInfo(); + assert(templateSpecializationInfo); + const clang::TemplateArgumentList &templateArgs = + *templateSpecializationInfo->TemplateArguments; + assert(templateArgs.size() == 1); + const clang::TemplateArgument &arg = templateArgs[0]; + assert(arg.getKind() == clang::TemplateArgument::Type); + const clang::QualType &type = arg.getAsType(); + + // Create an undef for `type`. + SpirvInstruction *undef = spvBuilder.getUndef(type); + + // Create an OpCooperativeMatrixLengthKHR instruction. However, we cannot + // make a type a parameter at this point in the code. We will use an Undef of + // the type that will become the parameter, and then adjust the instruction + // in EmitVisitor. + SpirvInstruction *inst = getSpirvBuilder().createUnaryOp( + spv::Op::OpCooperativeMatrixLengthKHR, call->getType(), undef, + call->getLocStart(), call->getSourceRange()); + inst->setRValue(); + return inst; +} + SpirvInstruction * SpirvEmitter::processIntrinsicExecutionMode(const CallExpr *expr, bool useIdParams) { diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 152fd3df06..642b48da7c 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -752,6 +752,10 @@ class SpirvEmitter : public ASTConsumer { /// `vk::RawBufferStore()`. uint32_t getRawBufferAlignment(const Expr *expr); + /// Returns a spirv OpCooperativeMatrixLengthKHR instruction generated from a + /// call to __builtin_spv_CooperativeMatrixLengthKHR. + SpirvInstruction *processCooperativeMatrixGetLength(const CallExpr *call); + /// Process vk::ext_execution_mode intrinsic SpirvInstruction *processIntrinsicExecutionMode(const CallExpr *expr, bool useIdParams); diff --git a/tools/clang/test/CodeGenSPIRV/convert.selector.hlsl b/tools/clang/test/CodeGenSPIRV/convert.selector.hlsl new file mode 100644 index 0000000000..3a0b1bb315 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/convert.selector.hlsl @@ -0,0 +1,139 @@ +// Convert to half +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint16_t -DTARGET_TYPE=half -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UTOF +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint32_t -DTARGET_TYPE=half -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UTOF +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint64_t -DTARGET_TYPE=half -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UTOF +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=int16_t -DTARGET_TYPE=half -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=STOF +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=int32_t -DTARGET_TYPE=half -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=STOF +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=int64_t -DTARGET_TYPE=half -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=STOF +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=float -DTARGET_TYPE=half -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=double -DTARGET_TYPE=half -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FCONVERT + +// Convert to float +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint16_t -DTARGET_TYPE=float -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UTOF +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint32_t -DTARGET_TYPE=float -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UTOF +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint64_t -DTARGET_TYPE=float -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UTOF +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=int16_t -DTARGET_TYPE=float -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=STOF +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=int32_t -DTARGET_TYPE=float -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=STOF +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=int64_t -DTARGET_TYPE=float -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=STOF +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=half -DTARGET_TYPE=float -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=double -DTARGET_TYPE=float -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FCONVERT + +// Convert to double +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint16_t -DTARGET_TYPE=double -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UTOF +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint32_t -DTARGET_TYPE=double -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UTOF +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint64_t -DTARGET_TYPE=double -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UTOF +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=int16_t -DTARGET_TYPE=double -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=STOF +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=int32_t -DTARGET_TYPE=double -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=STOF +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=int64_t -DTARGET_TYPE=double -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=STOF +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=half -DTARGET_TYPE=double -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=float -DTARGET_TYPE=double -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FCONVERT + +// int type to int16_t +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=int32_t -DTARGET_TYPE=int16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=SCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=int64_t -DTARGET_TYPE=int16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=SCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint16_t -DTARGET_TYPE=int16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=BITCAST +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint32_t -DTARGET_TYPE=int16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=SCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint64_t -DTARGET_TYPE=int16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=SCONVERT + +// float type to int16_t +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=half -DTARGET_TYPE=int16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOS +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=float -DTARGET_TYPE=int16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOS +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=double -DTARGET_TYPE=int16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOS + +// int type to int32_t +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=int16_t -DTARGET_TYPE=int32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=SCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=int64_t -DTARGET_TYPE=int32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=SCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint16_t -DTARGET_TYPE=int32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=SCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint32_t -DTARGET_TYPE=int32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=BITCAST +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint64_t -DTARGET_TYPE=int32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=SCONVERT + +// float type to int32_t +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=half -DTARGET_TYPE=int32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOS +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=float -DTARGET_TYPE=int32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOS +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=double -DTARGET_TYPE=int32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOS + +// int type to int64_t +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=int16_t -DTARGET_TYPE=int64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=SCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=int32_t -DTARGET_TYPE=int64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=SCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint16_t -DTARGET_TYPE=int64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=SCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint32_t -DTARGET_TYPE=int64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=SCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint64_t -DTARGET_TYPE=int64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=BITCAST + +// float type to int64_t +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=half -DTARGET_TYPE=int64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOS +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=float -DTARGET_TYPE=int64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOS +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=double -DTARGET_TYPE=int64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOS + +// int type to uint16_t +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=int32_t -DTARGET_TYPE=uint16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=int64_t -DTARGET_TYPE=uint16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=int16_t -DTARGET_TYPE=uint16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=BITCAST +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint32_t -DTARGET_TYPE=uint16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint64_t -DTARGET_TYPE=uint16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UCONVERT + +// float type to uint16_t +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=half -DTARGET_TYPE=uint16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOU +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=float -DTARGET_TYPE=uint16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOU +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=double -DTARGET_TYPE=uint16_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOU + +// int type to uint32_t +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=int16_t -DTARGET_TYPE=uint32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=int64_t -DTARGET_TYPE=uint32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint16_t -DTARGET_TYPE=uint32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=int32_t -DTARGET_TYPE=uint32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=BITCAST +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint64_t -DTARGET_TYPE=uint32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UCONVERT + +// float type to uint32_t +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=half -DTARGET_TYPE=uint32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOU +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=float -DTARGET_TYPE=uint32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOU +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=double -DTARGET_TYPE=uint32_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOU + +// int type to uint64_t +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=int16_t -DTARGET_TYPE=uint64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=int32_t -DTARGET_TYPE=uint64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint16_t -DTARGET_TYPE=uint64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=uint32_t -DTARGET_TYPE=uint64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=UCONVERT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=int64_t -DTARGET_TYPE=uint64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=BITCAST + +// float type to uint64_t +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=half -DTARGET_TYPE=uint64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOU +// RUN: dxc -fspv-target-env=vulkan1.3 -enable-16bit-types -T cs_6_2 -E main -spirv -HV 2021 -DSOURCE_TYPE=float -DTARGET_TYPE=uint64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOU +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -DSOURCE_TYPE=double -DTARGET_TYPE=uint64_t -I %hlsl_headers %s | FileCheck %s --check-prefix=CHECK --check-prefix=FTOU + +#include "vk/opcode_selector.h" + +#define VEC_TYPE_INT(TYPE) TYPE##4 +#define VEC_TYPE(t) VEC_TYPE_INT(t) + +RWStructuredBuffer source; +RWStructuredBuffer target; + +[numthreads(64, 1, 1)] void main() { +// CHECK: [[ac:%[0-9]+]] = OpAccessChain {{%_ptr_StorageBuffer_.*}} %source %int_0 %uint_0 +// CHECK: [[ld:%[0-9]+]] = OpLoad {{%.*}} [[ac]] +// STOF: [[result:%[0-9]+]] = OpConvertSToF {{%.*}} [[ld]] +// FTOS: [[result:%[0-9]+]] = OpConvertFToS {{%.*}} [[ld]] +// UTOF: [[result:%[0-9]+]] = OpConvertUToF {{%.*}} [[ld]] +// FTOU: [[result:%[0-9]+]] = OpConvertFToU {{%.*}} [[ld]] +// FCONVERT: [[result:%[0-9]+]] = OpFConvert {{%.*}} [[ld]] +// UCONVERT: [[result:%[0-9]+]] = OpUConvert {{%.*}} [[ld]] +// SCONVERT: [[result:%[0-9]+]] = OpSConvert {{%.*}} [[ld]] +// BITCAST: [[result:%[0-9]+]] = OpBitcast {{%.*}} [[ld]] +// CHECK: [[ac:%[0-9]+]] = OpAccessChain {{%_ptr_StorageBuffer_.*}} %target %int_0 %uint_0 +// CHECK: OpStore [[ac]] [[result]] + target[0] = vk::util::ConversionSelector::Convert(source[0]); + +// CHECK: [[ac:%[0-9]+]] = OpAccessChain {{%_ptr_StorageBuffer_.*}} %source %int_0 %uint_0 %int_0 +// CHECK: [[ld:%[0-9]+]] = OpLoad {{%.*}} [[ac]] +// STOF: [[result:%[0-9]+]] = OpConvertSToF {{%.*}} [[ld]] +// FTOS: [[result:%[0-9]+]] = OpConvertFToS {{%.*}} [[ld]] +// UTOF: [[result:%[0-9]+]] = OpConvertUToF {{%.*}} [[ld]] +// FTOU: [[result:%[0-9]+]] = OpConvertFToU {{%.*}} [[ld]] +// FCONVERT: [[result:%[0-9]+]] = OpFConvert {{%.*}} [[ld]] +// UCONVERT: [[result:%[0-9]+]] = OpUConvert {{%.*}} [[ld]] +// SCONVERT: [[result:%[0-9]+]] = OpSConvert {{%.*}} [[ld]] +// BITCAST: [[result:%[0-9]+]] = OpBitcast {{%.*}} [[ld]] +// CHECK: [[ac:%[0-9]+]] = OpAccessChain {{%_ptr_StorageBuffer_.*}} %target %int_0 %uint_0 %int_0 +// CHECK: OpStore [[ac]] [[result]] + target[0].x = vk::util::ConversionSelector::Convert(source[0].x); +} diff --git a/tools/clang/test/CodeGenSPIRV/coopmatrix.arithmetic.hlsl b/tools/clang/test/CodeGenSPIRV/coopmatrix.arithmetic.hlsl new file mode 100644 index 0000000000..b452acaf16 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/coopmatrix.arithmetic.hlsl @@ -0,0 +1,95 @@ +// RUN: dxc -enable-16bit-types -fspv-target-env=vulkan1.3 -T cs_6_2 -E main -spirv -HV 2021 -I %hlsl_headers -DTYPE=int16_t %s | FileCheck %s --check-prefix=CHECK --check-prefix=INTEGERS --check-prefix=INT16 +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers -DTYPE=int %s | FileCheck %s --check-prefix=CHECK --check-prefix=INTEGERS --check-prefix=INT32 +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers -DTYPE=int64_t %s | FileCheck %s --check-prefix=CHECK --check-prefix=INTEGERS --check-prefix=INT64 +// RUN: dxc -enable-16bit-types -fspv-target-env=vulkan1.3 -T cs_6_2 -E main -spirv -HV 2021 -I %hlsl_headers -DTYPE=uint16_t %s | FileCheck %s --check-prefix=CHECK --check-prefix=INTEGERS --check-prefix=UINT16 +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers -DTYPE=uint %s | FileCheck %s --check-prefix=CHECK --check-prefix=INTEGERS --check-prefix=UINT32 +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers -DTYPE=uint64_t %s | FileCheck %s --check-prefix=CHECK --check-prefix=INTEGERS --check-prefix=UINT64 +// RUN: dxc -enable-16bit-types -fspv-target-env=vulkan1.3 -T cs_6_2 -E main -spirv -HV 2021 -I %hlsl_headers -DTYPE=half %s | FileCheck %s --check-prefix=CHECK --check-prefix=FLOATS --check-prefix=HALF-ENABLED +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers -DTYPE=half %s | FileCheck %s --check-prefix=CHECK --check-prefix=FLOATS --check-prefix=HALF-DISABLED +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers -DTYPE=float %s | FileCheck %s --check-prefix=CHECK --check-prefix=FLOATS --check-prefix=FLOAT +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers -DTYPE=double %s | FileCheck %s --check-prefix=CHECK --check-prefix=FLOATS --check-prefix=DOUBLE + +#include "vk/khr/cooperative_matrix.h" + +StructuredBuffer structured_buffer; + +RWStructuredBuffer data; + +// CHECK: OpCapability CooperativeMatrixKHR +// CHECK: OpExtension "SPV_KHR_cooperative_matrix" + +// Check that the type is correctly created. +// INT16: %spirvIntrinsicType = OpTypeCooperativeMatrixKHR %short %uint_3 %uint_16 %uint_8 %uint_0 +// INT32: %spirvIntrinsicType = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_16 %uint_8 %uint_0 +// INT64: %spirvIntrinsicType = OpTypeCooperativeMatrixKHR %long %uint_3 %uint_16 %uint_8 %uint_0 +// UINT16: %spirvIntrinsicType = OpTypeCooperativeMatrixKHR %ushort %uint_3 %uint_16 %uint_8 %uint_0 +// UINT32: %spirvIntrinsicType = OpTypeCooperativeMatrixKHR %uint %uint_3 %uint_16 %uint_8 %uint_0 +// UINT64: %spirvIntrinsicType = OpTypeCooperativeMatrixKHR %ulong %uint_3 %uint_16 %uint_8 %uint_0 + +// When 16bit types are not enabled, HALF is a float +// HALF-DISABLED: %spirvIntrinsicType = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_16 %uint_8 %uint_0 +// HALF-ENABLED: %spirvIntrinsicType = OpTypeCooperativeMatrixKHR %half %uint_3 %uint_16 %uint_8 %uint_0 +// FLOAT: %spirvIntrinsicType = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_16 %uint_8 %uint_0 +// DOUBLE: %spirvIntrinsicType = OpTypeCooperativeMatrixKHR %double %uint_3 %uint_16 %uint_8 %uint_0 + +[numthreads(64, 1, 1)] void main() { + using CoopMat = vk::khr::CooperativeMatrixA< + TYPE, vk::ScopeSubgroup, 16, 8>; + + // CHECK: [[ac1:%[0-9]+]] = OpAccessChain %_ptr_StorageBuffer_{{.*}} %data %int_0 %uint_0 + // CHECK: [[m:%[0-9]+]] = OpCooperativeMatrixLoadKHR %spirvIntrinsicType [[ac1]] %int_1 %uint_64 None + CoopMat m = CoopMat::Load(data, 0, 64); + + // CHECK: [[len:%[0-9]+]] = OpCooperativeMatrixLengthKHR %uint %spirvIntrinsicType + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_StorageBuffer_{{.*}} %structured_buffer %int_0 [[len]] + // CHECK: [[n:%[0-9]+]] = OpCooperativeMatrixLoadKHR %spirvIntrinsicType [[ac]] %int_0 %uint_64 None + uint32_t length = CoopMat::GetLength(); + CoopMat n = CoopMat::Load(structured_buffer, length, 64); + + // INTEGERS: [[r:%[0-9]+]] = OpIAdd %spirvIntrinsicType [[m]] [[n]] + // FLOATS: [[r:%[0-9]+]] = OpFAdd %spirvIntrinsicType [[m]] [[n]] + CoopMat r = m + n; + + // INTEGERS: [[n:%[0-9]+]] = OpISub %spirvIntrinsicType [[m]] [[r]] + // FLOATS: [[n:%[0-9]+]] = OpFSub %spirvIntrinsicType [[m]] [[r]] + n = m - r; + + // INTEGERS: [[m:%[0-9]+]] = OpSNegate %spirvIntrinsicType [[n]] + // FLOATS: [[m:%[0-9]+]] = OpFNegate %spirvIntrinsicType [[n]] + m = n.negate(); + + // INT16: [[r:%[0-9]+]] = OpMatrixTimesScalar %spirvIntrinsicType [[m]] %short_2 + // INT32: [[r:%[0-9]+]] = OpMatrixTimesScalar %spirvIntrinsicType [[m]] %int_2 + // INT64: [[r:%[0-9]+]] = OpMatrixTimesScalar %spirvIntrinsicType [[m]] %long_2 + // UINT16: [[r:%[0-9]+]] = OpMatrixTimesScalar %spirvIntrinsicType [[m]] %ushort_2 + // UINT32: [[r:%[0-9]+]] = OpMatrixTimesScalar %spirvIntrinsicType [[m]] %uint_2 + // UINT64: [[r:%[0-9]+]] = OpMatrixTimesScalar %spirvIntrinsicType [[m]] %ulong_2 + // HALF-DISABLED: [[r:%[0-9]+]] = OpMatrixTimesScalar %spirvIntrinsicType [[m]] %float_2 + // HALF-ENABLED: [[r:%[0-9]+]] = OpMatrixTimesScalar %spirvIntrinsicType [[m]] %half_0x1p_1 + // FLOAT: [[r:%[0-9]+]] = OpMatrixTimesScalar %spirvIntrinsicType [[m]] %float_2 + // DOUBLE: [[r:%[0-9]+]] = OpMatrixTimesScalar %spirvIntrinsicType [[m]] %double_2 + r = m * 2.0; + + // INT16: [[n:%[0-9]+]] = OpSDiv %spirvIntrinsicType [[r]] [[m]] + // INT32: [[n:%[0-9]+]] = OpSDiv %spirvIntrinsicType [[r]] [[m]] + // INT64: [[n:%[0-9]+]] = OpSDiv %spirvIntrinsicType [[r]] [[m]] + // UINT16: [[n:%[0-9]+]] = OpUDiv %spirvIntrinsicType [[r]] [[m]] + // UINT32: [[n:%[0-9]+]] = OpUDiv %spirvIntrinsicType [[r]] [[m]] + // UINT64: [[n:%[0-9]+]] = OpUDiv %spirvIntrinsicType [[r]] [[m]] + // HALF-DISABLED: [[n:%[0-9]+]] = OpFDiv %spirvIntrinsicType [[r]] [[m]] + // HALF-ENABLED: [[n:%[0-9]+]] = OpFDiv %spirvIntrinsicType [[r]] [[m]] + // FLOAT: [[n:%[0-9]+]] = OpFDiv %spirvIntrinsicType [[r]] [[m]] + // DOUBLE: [[n:%[0-9]+]] = OpFDiv %spirvIntrinsicType [[r]] [[m]] + n = r / m; + + // INTEGERS: [[r:%[0-9]+]] = OpIMul %spirvIntrinsicType [[n]] [[m]] + // FLOATS: [[r:%[0-9]+]] = OpFMul %spirvIntrinsicType [[n]] [[m]] + r = n * m; + + // CHECK: OpCooperativeMatrixStoreKHR [[ac1]] [[r]] %int_0 %uint_64 None + r.Store(data, 0, 64); + + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_StorageBuffer_{{.*}} %data %int_0 %uint_16 + // CHECK: OpCooperativeMatrixStoreKHR [[ac]] [[r]] %int_1 %uint_64 None + r.Store(data, 16, 64); +} diff --git a/tools/clang/test/CodeGenSPIRV/coopmatrix.convert.hlsl b/tools/clang/test/CodeGenSPIRV/coopmatrix.convert.hlsl new file mode 100644 index 0000000000..833d16b3b0 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/coopmatrix.convert.hlsl @@ -0,0 +1,25 @@ +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers %s | FileCheck %s + +#include "vk/khr/cooperative_matrix.h" + +RWStructuredBuffer data; +int stride; + +// CHECK: OpCapability CooperativeMatrixKHR +// CHECK: OpExtension "SPV_KHR_cooperative_matrix" + +[numthreads(64, 1, 1)] void main() { + using IntMatA = vk::khr::CooperativeMatrixA; + using FloatMatA = vk::khr::CooperativeMatrixA; + + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_StorageBuffer_int %data %int_0 %uint_0 + // CHECK: [[ld:%[0-9]+]] = OpCooperativeMatrixLoadKHR %spirvIntrinsicType [[ac]] %int_1 + IntMatA int_matrix = IntMatA::Load(data, 0, stride); + + // CHECK: [[result:%[0-9]+]] = OpConvertSToF %spirvIntrinsicType_0 [[ld]] + FloatMatA float_matrix = int_matrix.cast(); + + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_StorageBuffer_int %data %int_0 %uint_64 + // CHECK: OpCooperativeMatrixStoreKHR [[ac]] [[result]] %int_0 + float_matrix.Store(data, 64, stride); +} diff --git a/tools/clang/test/CodeGenSPIRV/coopmatrix.element.access.hlsl b/tools/clang/test/CodeGenSPIRV/coopmatrix.element.access.hlsl new file mode 100644 index 0000000000..3bbcdc8b38 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/coopmatrix.element.access.hlsl @@ -0,0 +1,33 @@ +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers %s | FileCheck %s + +#include "vk/khr/cooperative_matrix.h" + +RWStructuredBuffer data; +int stride; + +// CHECK: OpCapability CooperativeMatrixKHR +// CHECK: OpExtension "SPV_KHR_cooperative_matrix" + +// CHECK-DAG: [[typeA:%spirvIntrinsicType[_0-9]*]] = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_16 %uint_4 %uint_0 + +[numthreads(64, 1, 1)] void main() { + using IntMatA = vk::khr::CooperativeMatrixA; + + // CHECK: [[a:%[0-9]+]] = OpVariable %_ptr_Function_spirvIntrinsicType Function + // CHECK: [[v:%[0-9]+]] = OpCompositeConstruct %spirvIntrinsicType %int_10 + // CHECK: OpStore [[a]] [[v]] + IntMatA a = IntMatA::Splat(10); + + uint32_t length = a.GetLength(); + // CHECK: OpLoopMerge [[mbb:%[0-9]+]] + for (int i = 0; i < length; ++i) { + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_Function_int [[a]] + // CHECK: [[get:%[0-9]+]] = OpLoad %int [[ac]] + // CHECK: [[add:%[0-9]+]] = OpIAdd %int [[get]] %int_1 + // CHECK: OpStore [[ac]] [[add]] + int v = a.Get(i); + a.Set(v + 1, i); + } + // CHECK: [[mbb]] = OpLabel + a.Store(data, 64, stride); +} diff --git a/tools/clang/test/CodeGenSPIRV/coopmatrix.globallycoherent.hlsl b/tools/clang/test/CodeGenSPIRV/coopmatrix.globallycoherent.hlsl new file mode 100644 index 0000000000..b71213fdd4 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/coopmatrix.globallycoherent.hlsl @@ -0,0 +1,19 @@ +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers %s | FileCheck %s + +#include "vk/khr/cooperative_matrix.h" + +globallycoherent RWStructuredBuffer data; + +// CHECK: OpCapability CooperativeMatrixKHR +// CHECK: OpExtension "SPV_KHR_cooperative_matrix" +[numthreads(64, 1, 1)] void main() { + using FloatMatA = vk::khr::CooperativeMatrixA; + + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_StorageBuffer_int %data %int_0 %uint_0 + // CHECK: [[ld:%[0-9]+]] = OpCooperativeMatrixLoadKHR %spirvIntrinsicType [[ac]] %int_1 %uint_256 MakePointerVisible|NonPrivatePointer %int_5 + FloatMatA m = FloatMatA::CoherentLoad(data, 0, 256); + + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_StorageBuffer_int %data %int_0 %uint_64 + // CHECK: OpCooperativeMatrixStoreKHR [[ac]] [[ld]] %int_0 %uint_8 MakePointerAvailable|NonPrivatePointer %int_5 + m.CoherentStore(data, 64, 8); +} diff --git a/tools/clang/test/CodeGenSPIRV/coopmatrix.groupshared.hlsl b/tools/clang/test/CodeGenSPIRV/coopmatrix.groupshared.hlsl new file mode 100644 index 0000000000..95dad149c7 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/coopmatrix.groupshared.hlsl @@ -0,0 +1,30 @@ +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers %s | FileCheck %s + +#include "vk/khr/cooperative_matrix.h" + +RWStructuredBuffer data; + +groupshared float shared_data[64]; + +// CHECK: OpCapability CooperativeMatrixKHR +// CHECK: OpExtension "SPV_KHR_cooperative_matrix" +[numthreads(64, 1, 1)] void main() { + using FloatMatA = vk::khr::CooperativeMatrixA; + + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_StorageBuffer_int %data %int_0 %uint_0 + // CHECK: [[ld:%[0-9]+]] = OpCooperativeMatrixLoadKHR %spirvIntrinsicType [[ac]] %int_1 %uint_256 None + FloatMatA m = FloatMatA::Load(data, 0, 256); + + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_Workgroup_float %shared_data %int_0 + // CHECK: OpCooperativeMatrixStoreKHR [[ac]] [[ld]] %int_1 %uint_64 MakePointerAvailable|NonPrivatePointer %int_2 + m.Store(vk::GetGroupSharedAddress(shared_data[0]), 64); + + FloatMatA m2; + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_Workgroup_float %shared_data %int_10 + // CHECK: [[ld:%[0-9]+]] = OpCooperativeMatrixLoadKHR %spirvIntrinsicType [[ac]] %int_1 %uint_128 MakePointerVisible|NonPrivatePointer %int_2 + m2 = FloatMatA::Load(vk::GetGroupSharedAddress(shared_data[10]), 128); + + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_StorageBuffer_int %data %int_0 %uint_64 + // CHECK: OpCooperativeMatrixStoreKHR [[ac]] [[ld]] %int_0 %uint_8 None + m2.Store(data, 64, 8); +} diff --git a/tools/clang/test/CodeGenSPIRV/coopmatrix.memory.operand.hlsl b/tools/clang/test/CodeGenSPIRV/coopmatrix.memory.operand.hlsl new file mode 100644 index 0000000000..aa461d22a2 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/coopmatrix.memory.operand.hlsl @@ -0,0 +1,22 @@ +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers %s | FileCheck %s + +#include "vk/khr/cooperative_matrix.h" + +RWStructuredBuffer data; + +groupshared float shared_data[64]; + +// CHECK: OpCapability CooperativeMatrixKHR +// CHECK: OpExtension "SPV_KHR_cooperative_matrix" +[numthreads(64, 1, 1)] void main() { + using FloatMatA = vk::khr::CooperativeMatrixA; + + FloatMatA m; + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_Workgroup_float %shared_data %int_0 + // CHECK: [[ld:%[0-9]+]] = OpCooperativeMatrixLoadKHR %spirvIntrinsicType [[ac]] %int_1 %uint_128 Nontemporal + m = FloatMatA::Load(vk::GetGroupSharedAddress(shared_data[0]), 128); + + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_StorageBuffer_int %data %int_0 %uint_64 + // CHECK: OpCooperativeMatrixStoreKHR [[ac]] [[ld]] %int_0 %uint_8 Nontemporal + m.Store(data, 64, 8); +} diff --git a/tools/clang/test/CodeGenSPIRV/coopmatrix_muladd_test.hlsl b/tools/clang/test/CodeGenSPIRV/coopmatrix_muladd_test.hlsl new file mode 100644 index 0000000000..935d7d615d --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/coopmatrix_muladd_test.hlsl @@ -0,0 +1,41 @@ +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers %s | FileCheck %s + +#include "vk/khr/cooperative_matrix.h" + +RWStructuredBuffer data; +uint stride; + +// CHECK: OpCapability CooperativeMatrixKHR +// CHECK: OpExtension "SPV_KHR_cooperative_matrix" + +// CHECK-DAG: [[typeA:%spirvIntrinsicType[_0-9]*]] = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_16 %uint_4 %uint_0 +// CHECK-DAG: [[typeB:%spirvIntrinsicType[_0-9]*]] = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_4 %uint_8 %uint_1 +// CHECK-DAG: [[typeAc:%spirvIntrinsicType[_0-9]*]] = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_16 %uint_8 %uint_2 + +// CHECK: [[r:%[0-9]+]] = OpUndef [[typeAc]] +[numthreads(64, 1, 1)] void main() { + using IntMatA = vk::khr::CooperativeMatrixA; + using IntMatB = vk::khr::CooperativeMatrixB; + using IntMatAc = vk::khr::CooperativeMatrixAccumulator; + + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %_Globals %int_0 + // CHECK: [[stride:%[0-9]+]] = OpLoad %uint [[ac]] + + // CHECK: [[a:%[0-9]+]] = OpCooperativeMatrixLoadKHR [[typeA]] {{%[0-9]*}} %int_1 [[stride]] None + IntMatA a = IntMatA::Load(data, 0, stride); + + // CHECK: [[b:%[0-9]+]] = OpCooperativeMatrixLoadKHR [[typeB]] {{%[0-9]*}} %int_0 [[stride]] None + IntMatB b = IntMatB::Load(data, 32, stride); + + // TODO: Is default initialization meaningful? + IntMatAc r; + + // CHECK: [[r2:%[0-9]+]] = OpCooperativeMatrixMulAddKHR [[typeAc]] [[a]] [[b]] [[r]] MatrixASignedComponentsKHR|MatrixBSignedComponentsKHR|MatrixCSignedComponentsKHR|MatrixResultSignedComponentsKHR + r = cooperativeMatrixMultiplyAdd(a, b, r); + + // CHECK: [[r:%[0-9]+]] = OpCooperativeMatrixMulAddKHR [[typeAc]] [[a]] [[b]] [[r2]] MatrixASignedComponentsKHR|MatrixBSignedComponentsKHR|MatrixCSignedComponentsKHR|MatrixResultSignedComponentsKHR|SaturatingAccumulationKHR + r = cooperativeMatrixSaturatingMultiplyAdd(a, b, r); + + // CHECK: OpCooperativeMatrixStoreKHR {{.*}} [[r]] %int_0 [[stride]] None + r.Store(data, 64, stride); +} diff --git a/tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.const.hlsl b/tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.const.hlsl new file mode 100644 index 0000000000..18d18b0bf6 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.const.hlsl @@ -0,0 +1,16 @@ +// RUN: not dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers %s 2>&1 | FileCheck %s + +#include "vk/khr/cooperative_matrix.h" + +RWStructuredBuffer data; + +groupshared int shared_data[64]; + +[numthreads(64, 1, 1)] void main() { + vk::WorkgroupSpirvPointer p = vk::GetGroupSharedAddress(shared_data[0]); + if (data[0] > 10 ) { + p = vk::GetGroupSharedAddress(shared_data[0]); + } +} + +// CHECK: cannot assign to variable 'p' with const-qualified type