Skip to content

Commit

Permalink
Move instruction selection into its own class.
Browse files Browse the repository at this point in the history
  • Loading branch information
s-perron committed Jun 27, 2024
1 parent 2c11ae2 commit 59263a0
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 168 deletions.
88 changes: 51 additions & 37 deletions tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.hlsli
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#ifndef VULKAN_HLSL_SPV_KHR_COOPERATIVE_MATRIX_H_
#define VULKAN_HLSL_SPV_KHR_COOPERATIVE_MATRIX_H_


// TODO: Add a macro to HLSL to be able to check the Vulkan version being
// targeted.

Expand All @@ -24,63 +23,69 @@ namespace vk {
// TODO: Move these defines to a new header file for defines.

enum CooperativeMatrixUse {
CooperativeMatrixUseMatrixAKHR = 0,
CooperativeMatrixUseMatrixBKHR = 1,
CooperativeMatrixUseMatrixAccumulatorKHR = 2,
CooperativeMatrixUseMax = 0x7fffffff,
CooperativeMatrixUseMatrixAKHR = 0,
CooperativeMatrixUseMatrixBKHR = 1,
CooperativeMatrixUseMatrixAccumulatorKHR = 2,
CooperativeMatrixUseMax = 0x7fffffff,
};

enum CooperativeMatrixLayout {
CooperativeMatrixLayoutRowMajorKHR = 0,
CooperativeMatrixLayoutColumnMajorKHR = 1,
CooperativeMatrixLayoutRowBlockedInterleavedARM = 4202,
CooperativeMatrixLayoutColumnBlockedInterleavedARM = 4203,
CooperativeMatrixLayoutMax = 0x7fffffff,
CooperativeMatrixLayoutRowMajorKHR = 0,
CooperativeMatrixLayoutColumnMajorKHR = 1,
CooperativeMatrixLayoutRowBlockedInterleavedARM = 4202,
CooperativeMatrixLayoutColumnBlockedInterleavedARM = 4203,
CooperativeMatrixLayoutMax = 0x7fffffff,
};

enum CooperativeMatrixOperandsMask {
CooperativeMatrixOperandsMaskNone = 0,
CooperativeMatrixOperandsMatrixASignedComponentsKHRMask = 0x00000001,
CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask = 0x00000002,
CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask = 0x00000004,
CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask = 0x00000008,
CooperativeMatrixOperandsSaturatingAccumulationKHRMask = 0x00000010,
CooperativeMatrixOperandsMaskNone = 0,
CooperativeMatrixOperandsMatrixASignedComponentsKHRMask = 0x00000001,
CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask = 0x00000002,
CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask = 0x00000004,
CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask = 0x00000008,
CooperativeMatrixOperandsSaturatingAccumulationKHRMask = 0x00000010,
};

enum MemoryAccessMask {
MemoryAccessMaskNone = 0,
MemoryAccessVolatileMask = 0x00000001,
MemoryAccessAlignedMask = 0x00000002,
MemoryAccessNontemporalMask = 0x00000004,
MemoryAccessMakePointerAvailableMask = 0x00000008,
MemoryAccessMakePointerAvailableKHRMask = 0x00000008,
MemoryAccessMakePointerVisibleMask = 0x00000010,
MemoryAccessMakePointerVisibleKHRMask = 0x00000010,
MemoryAccessNonPrivatePointerMask = 0x00000020,
MemoryAccessNonPrivatePointerKHRMask = 0x00000020,
MemoryAccessAliasScopeINTELMaskMask = 0x00010000,
MemoryAccessNoAliasINTELMaskMask = 0x00020000,
MemoryAccessMaskNone = 0,
MemoryAccessVolatileMask = 0x00000001,
MemoryAccessAlignedMask = 0x00000002,
MemoryAccessNontemporalMask = 0x00000004,
MemoryAccessMakePointerAvailableMask = 0x00000008,
MemoryAccessMakePointerAvailableKHRMask = 0x00000008,
MemoryAccessMakePointerVisibleMask = 0x00000010,
MemoryAccessMakePointerVisibleKHRMask = 0x00000010,
MemoryAccessNonPrivatePointerMask = 0x00000020,
MemoryAccessNonPrivatePointerKHRMask = 0x00000020,
MemoryAccessAliasScopeINTELMaskMask = 0x00010000,
MemoryAccessNoAliasINTELMaskMask = 0x00020000,
};

enum Scope {
ScopeCrossDevice = 0,
ScopeDevice = 1,
ScopeWorkgroup = 2,
ScopeSubgroup = 3,
ScopeInvocation = 4,
ScopeQueueFamily = 5,
ScopeQueueFamilyKHR = 5,
ScopeShaderCallKHR = 6,
ScopeMax = 0x7fffffff,
ScopeCrossDevice = 0,
ScopeDevice = 1,
ScopeWorkgroup = 2,
ScopeSubgroup = 3,
ScopeInvocation = 4,
ScopeQueueFamily = 5,
ScopeQueueFamilyKHR = 5,
ScopeShaderCallKHR = 6,
ScopeMax = 0x7fffffff,
};

namespace khr {

#define SPV_KHR_CooperativeMatrix \
vk::ext_extension("SPV_KHR_cooperative_matrix"), \
vk::ext_capability(/* CooperativeMatrixKHRCapability */ 6022)

template <typename ComponentType, uint scope, uint rows, uint columns, uint use>
class CooperativeMatrix {
CooperativeMatrix negate();
CooperativeMatrix operator+(CooperativeMatrix other);
CooperativeMatrix operator-(CooperativeMatrix other);
CooperativeMatrix operator*(CooperativeMatrix other);
CooperativeMatrix operator/(CooperativeMatrix other);
CooperativeMatrix operator*(ComponentType scalar);

void StoreRowMajor(RWStructuredBuffer<ComponentType> data, uint32_t index);
Expand All @@ -93,6 +98,15 @@ class CooperativeMatrix {
static CooperativeMatrix LoadColumnMajor(BufferType buffer, uint32_t index);

static uint32_t GetLength();

static const bool hasSignedIntegerComponentType =
(ComponentType(0) - ComponentType(1) < ComponentType(0));
using SpirvMatrixType = vk::SpirvOpaqueType<
/* OpTypeCooperativeMatrixKHR */ 4456, ComponentType,
vk::integral_constant<uint, scope>, vk::integral_constant<uint, rows>,
vk::integral_constant<uint, columns>, vk::integral_constant<uint, use>>;

SpirvMatrixType _matrix;
};

template <typename ComponentType, uint scope, uint rows, uint columns>
Expand Down
233 changes: 102 additions & 131 deletions tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.impl
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#define SPV_KHR_CooperativeMatrix \
vk::ext_extension("SPV_KHR_cooperative_matrix"), \
vk::ext_capability(/* CooperativeMatrixKHRCapability */ 6022)

#define DECLARE_UNARY_OP(name, opcode) \
template <typename ResultType> \
[[vk::ext_instruction(opcode), \
SPV_KHR_CooperativeMatrix]] ResultType __builtin_spv_##name(ResultType a)

DECLARE_UNARY_OP(SNegate, 126);
DECLARE_UNARY_OP(FNegate, 127);

#undef DECLARY_UNARY_OP

#define DECLARE_BINOP(name, opcode) \
template <typename ResultType> \
[[vk::ext_instruction(opcode), \
SPV_KHR_CooperativeMatrix]] ResultType __builtin_spv_##name(ResultType a, \
ResultType b)

DECLARE_BINOP(IAdd, 128);
DECLARE_BINOP(FAdd, 129);
DECLARE_BINOP(ISub, 130);
DECLARE_BINOP(FSub, 131);
DECLARE_BINOP(IMul, 132);
DECLARE_BINOP(FMul, 133);
DECLARE_BINOP(UDiv, 134);
DECLARE_BINOP(SDiv, 135);
DECLARE_BINOP(FDiv, 136);

#undef DECLARE_BINOP
#include "arithmetic_selector.hlsli"


template <typename ResultType, typename ComponentType>
[[vk::ext_instruction(/* OpMatrixTimesScalar */ 143),
SPV_KHR_CooperativeMatrix]] ResultType
[[vk::ext_instruction(/* OpMatrixTimesScalar */ 143)]] ResultType
__builtin_spv_MatrixTimesScalar(ResultType a, ComponentType b);

// Type-Declaration Instructions

// TODO: make sure scope, rows, and cols can be specialization constants
template <typename ComponentType, uint scope, uint rows, uint columns, uint use>
using __builtin_spv_CooperativeMatrixKHR = vk::SpirvOpaqueType<
/* OpTypeCooperativeMatrixKHR */ 4456, ComponentType,
vk::integral_constant<uint, scope>, vk::integral_constant<uint, rows>,
vk::integral_constant<uint, columns>, vk::integral_constant<uint, use> >;

// Define the load and store instructions
template <typename ResultType, typename PointerType>
[[vk::ext_instruction(
Expand Down Expand Up @@ -101,92 +62,101 @@ __builtin_spv_CooperativeMatrixMulAddKHR(MatrixTypeA a, MatrixTypeB b,
[[vk::ext_literal]] int operands);
namespace vk {
namespace khr {
#define COOPERATIVE_MATRIX(ComponentType, Negate, Add, Sub, \
SIGNED_INTEGER_TYPE) \
template <uint scope, uint rows, uint columns, uint use> \
class CooperativeMatrix<ComponentType, scope, rows, columns, use> { \
using SpirvMatrixType = \
__builtin_spv_CooperativeMatrixKHR<ComponentType, scope, rows, \
columns, use>; \
\
CooperativeMatrix negate() { \
CooperativeMatrix result; \
result._matrix = Negate(_matrix); \
return result; \
} \
CooperativeMatrix operator+(CooperativeMatrix other) { \
CooperativeMatrix result; \
result._matrix = Add(_matrix, other._matrix); \
return result; \
} \
CooperativeMatrix operator-(CooperativeMatrix other) { \
CooperativeMatrix result; \
result._matrix = Sub(_matrix, other._matrix); \
return result; \
} \
CooperativeMatrix operator*(ComponentType scalar) { \
CooperativeMatrix result; \
result._matrix = __builtin_spv_MatrixTimesScalar(_matrix, scalar); \
return result; \
} \
void StoreRowMajor(RWStructuredBuffer<ComponentType> data, \
uint32_t index) { \
__builtin_spv_CooperativeMatrixStoreKHR( \
data[index], _matrix, CooperativeMatrixLayoutRowMajorKHR); \
} \
void StoreColumnMajor(RWStructuredBuffer<ComponentType> data, \
uint32_t index) { \
__builtin_spv_CooperativeMatrixStoreKHR( \
data[index], _matrix, CooperativeMatrixLayoutColumnMajorKHR); \
} \
\
template <class BufferType> \
static CooperativeMatrix LoadRowMajor(BufferType buffer, uint32_t index) { \
CooperativeMatrix result; \
result._matrix = \
__builtin_spv_CooperativeMatrixLoadKHR<SpirvMatrixType>( \
buffer[index], CooperativeMatrixLayoutRowMajorKHR); \
return result; \
} \
\
template <class BufferType> \
static CooperativeMatrix LoadColumnMajor(BufferType buffer, \
uint32_t index) { \
CooperativeMatrix result; \
result._matrix = \
__builtin_spv_CooperativeMatrixLoadKHR<SpirvMatrixType>( \
buffer[index], CooperativeMatrixLayoutColumnMajorKHR); \
return result; \
} \
\
static uint GetLength() { \
return __builtin_spv_CooperativeMatrixLengthKHR<SpirvMatrixType>(); \
} \
\
static const bool isSignedInteger = SIGNED_INTEGER_TYPE; \
SpirvMatrixType _matrix; \
};

COOPERATIVE_MATRIX(half, __builtin_spv_FNegate, __builtin_spv_FAdd,
__builtin_spv_FSub, false);
COOPERATIVE_MATRIX(float, __builtin_spv_FNegate, __builtin_spv_FAdd,
__builtin_spv_FSub, false);
COOPERATIVE_MATRIX(double, __builtin_spv_FNegate, __builtin_spv_FAdd,
__builtin_spv_FSub, false);

// TODO: Need to check for a 16-bit-type enabled
// COOPERATIVE_MATRIX(int16_t, __builtin_spv_SNegate, __builtin_spv_IAdd,
// __builtin_spv_ISub, true);
COOPERATIVE_MATRIX(int32_t, __builtin_spv_SNegate, __builtin_spv_IAdd,
__builtin_spv_ISub, true);
COOPERATIVE_MATRIX(int64_t, __builtin_spv_SNegate, __builtin_spv_IAdd,
__builtin_spv_ISub, true);
// COOPERATIVE_MATRIX(uint16_t, __builtin_spv_SNegate, __builtin_spv_IAdd,
// __builtin_spv_ISub, false);
COOPERATIVE_MATRIX(uint32_t, __builtin_spv_SNegate, __builtin_spv_IAdd,
__builtin_spv_ISub, false);
COOPERATIVE_MATRIX(uint64_t, __builtin_spv_SNegate, __builtin_spv_IAdd,
__builtin_spv_ISub, false);

template <class ComponentType, uint scope, uint rows, uint columns, uint use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>::negate() {
CooperativeMatrix result;
result._matrix = util::ArithmeticSelector<ComponentType>::Negate(_matrix);
return result;
}

template <class ComponentType, uint scope, uint rows, uint columns, uint use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator+(
CooperativeMatrix other) {
CooperativeMatrix result;
result._matrix =
util::ArithmeticSelector<ComponentType>::Add(_matrix, other._matrix);
return result;
}

template <class ComponentType, uint scope, uint rows, uint columns, uint use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator-(
CooperativeMatrix other) {
CooperativeMatrix result;
result._matrix = util::ArithmeticSelector<ComponentType>::Sub(_matrix, other._matrix);
return result;
}

template <class ComponentType, uint scope, uint rows, uint columns, uint use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator*(
CooperativeMatrix other) {
CooperativeMatrix result;
result._matrix = util::ArithmeticSelector<ComponentType>::Mul(_matrix, other._matrix);
return result;
}

template <class ComponentType, uint scope, uint rows, uint columns, uint use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator/(
CooperativeMatrix other) {
CooperativeMatrix result;
result._matrix = util::ArithmeticSelector<ComponentType>::Div(_matrix, other._matrix);
return result;
}

template <class ComponentType, uint scope, uint rows, uint columns, uint use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator*(
ComponentType scalar) {
CooperativeMatrix result;
result._matrix = __builtin_spv_MatrixTimesScalar(_matrix, scalar);
return result;
}

template <class ComponentType, uint scope, uint rows, uint columns, uint use>
void CooperativeMatrix<ComponentType, scope, rows, columns, use>::StoreRowMajor(
RWStructuredBuffer<ComponentType> data, uint32_t index) {
__builtin_spv_CooperativeMatrixStoreKHR(data[index], _matrix,
CooperativeMatrixLayoutRowMajorKHR);
}

template <class ComponentType, uint scope, uint rows, uint columns, uint use>
void CooperativeMatrix<ComponentType, scope, rows, columns, use>::
StoreColumnMajor(RWStructuredBuffer<ComponentType> data, uint32_t index) {
__builtin_spv_CooperativeMatrixStoreKHR(
data[index], _matrix, CooperativeMatrixLayoutColumnMajorKHR);
}

template <class ComponentType, uint scope, uint rows, uint columns, uint use>
template <class BufferType>
CooperativeMatrix<ComponentType, scope, rows, columns, use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>::LoadRowMajor(
BufferType buffer, uint32_t index) {
CooperativeMatrix result;
result._matrix = __builtin_spv_CooperativeMatrixLoadKHR<SpirvMatrixType>(
buffer[index], CooperativeMatrixLayoutRowMajorKHR);
return result;
}

template <class ComponentType, uint scope, uint rows, uint columns, uint use>
template <class BufferType>
CooperativeMatrix<ComponentType, scope, rows, columns, use>
CooperativeMatrix<ComponentType, scope, rows, columns, use>::LoadColumnMajor(
BufferType buffer, uint32_t index) {
CooperativeMatrix result;
result._matrix = __builtin_spv_CooperativeMatrixLoadKHR<SpirvMatrixType>(
buffer[index], CooperativeMatrixLayoutColumnMajorKHR);
return result;
}

template <class ComponentType, uint scope, uint rows, uint columns, uint use>
uint CooperativeMatrix<ComponentType, scope, rows, columns, use>::GetLength() {
return __builtin_spv_CooperativeMatrixLengthKHR<SpirvMatrixType>();
}

template <typename ComponentType, uint scope, uint rows, uint columns, uint K>
CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
Expand All @@ -203,8 +173,9 @@ cooperativeMatrixMultiplyAdd(

const vk::CooperativeMatrixOperandsMask operands =
(vk::CooperativeMatrixOperandsMask)(
a.isSignedInteger ? allSignedComponents
: vk::CooperativeMatrixOperandsMaskNone);
a.hasSignedIntegerComponentType
? allSignedComponents
: vk::CooperativeMatrixOperandsMaskNone);

CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> result;
result._matrix = __builtin_spv_CooperativeMatrixMulAddKHR<
Expand All @@ -230,7 +201,7 @@ cooperativeMatrixSaturatingMultiplyAdd(

const vk::CooperativeMatrixOperandsMask operands =
(vk::CooperativeMatrixOperandsMask)(
a.isSignedInteger
a.hasSignedIntegerComponentType
? allSignedComponents
: vk::CooperativeMatrixOperandsSaturatingAccumulationKHRMask);
CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> result;
Expand Down

0 comments on commit 59263a0

Please sign in to comment.