Skip to content

Commit

Permalink
Add coop matrix HLSL header file
Browse files Browse the repository at this point in the history
Implements [HLSL spec proposal 0021](https://github.com/microsoft/hlsl-specs/blob/9c42e278961dbf5b77d7d18b1e7ba7f6c0856154/proposals/0021-vk-coop-matrix.md)

This commit adds a header file to the HLSL Standard Header library to
implement cooperative matrices for SPIR-V. See the proposal for the
specifics of the interface, and design decisions.
  • Loading branch information
s-perron committed Sep 16, 2024
1 parent c780db8 commit 3028830
Show file tree
Hide file tree
Showing 18 changed files with 1,418 additions and 4 deletions.
274 changes: 274 additions & 0 deletions tools/clang/lib/Headers/hlsl/vk/khr/cooperative_matrix.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
// Copyright (c) 2024 Google LLC
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_
#define _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_

// TODO: Add a macro to HLSL to be able to check the Vulkan version being
// targeted.

#include "vk/spirv.h"

namespace vk {
namespace khr {

// The base cooperative matrix class. The template arguments correspond to the
// operands in the OpTypeCooperativeMatrixKHR instruction.
template <typename ComponentType, Scope scope, uint rows, uint columns,
CooperativeMatrixUse use>
class CooperativeMatrix {
template <class NewComponentType>
CooperativeMatrix<NewComponentType, scope, rows, columns, use> cast();

// Apply OpSNegate or OFNegate, depending on ComponentType, in a element by
// element manner.
CooperativeMatrix negate();

// Apply OpIAdd or OFAdd, depending on ComponentType, in a element by element
// manner.
CooperativeMatrix operator+(CooperativeMatrix other);

// Apply OpISub or OFSub, depending on ComponentType, in a element by element
// manner.
CooperativeMatrix operator-(CooperativeMatrix other);

// Apply OpIMul or OFMul, depending on ComponentType, in a element by element
// manner.
CooperativeMatrix operator*(CooperativeMatrix other);

// Apply OpSDiv, OpUDiv or OFDiv, depending on ComponentType, in a element by
// element manner.
CooperativeMatrix operator/(CooperativeMatrix other);

// Apply OpMatrixTimesScalar in a element by element manner.
CooperativeMatrix operator*(ComponentType scalar);

// Store the cooperative matrix using OpCooperativeMatrixStoreKHR to
// data using the given memory layout, stride, and memory access operands.
// `NonPrivatePointer` and `MakePointerAvailable` with the workgroup scope
// will be added to the memory access operands to make the memory coherent.
//
// This function uses a SPIR-V pointer because HLSL does not allow groupshared
// memory object to be passed by reference. The pointer is a hack to get
// around that.
//
// The layout and stride will be passed to the SPIR-V instruction as is. The
// precise meaning can be found in the specification for
// SPV_KHR_cooperative_matrix.
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
class Type>
void Store(WorkgroupSpirvPointer<Type> data, uint32_t stride);

// Same as above, but uses MemoryAccessMaskNone for the memory access
// operands.
template <CooperativeMatrixLayout layout, class Type>
void Store(WorkgroupSpirvPointer<Type> data, uint32_t stride) {
Store<MemoryAccessMaskNone, layout>(data, stride);
}

// Store the cooperative matrix using OpCooperativeMatrixStoreKHR to
// data[index] using the given memory layout, stride, and memory access
// operands. The layout and stride will be passed to the SPIR-V instruction as
// is. The precise meaning can be found in the specification for
// SPV_KHR_cooperative_matrix.
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
class Type>
void Store(RWStructuredBuffer<Type> data, uint32_t index, uint32_t stride);

// Same as above, but uses MemoryAccessMaskNone for the memory access
// operands.
template <CooperativeMatrixLayout layout, class Type>
void Store(RWStructuredBuffer<Type> data, uint32_t index, uint32_t stride) {
Store<MemoryAccessMaskNone, layout>(data, index, stride);
}

// Store the cooperative matrix using OpCooperativeMatrixStoreKHR to
// data[index] using the given memory layout, stride, and memory access
// operands. `NonPrivatePointer` and `MakePointerAvailable` with the
// QueueFamily scope will be added to the memory access operands to make the
// memory coherent.
//
// The layout and stride will be passed to the SPIR-V instruction as is. The
// precise meaning can be found in the specification for
// SPV_KHR_cooperative_matrix.
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
class Type>
void CoherentStore(globallycoherent RWStructuredBuffer<Type> data,
uint32_t index, uint32_t stride);

// Same as above, but uses MemoryAccessMaskNone for the memory access operands
// template argument.
template <CooperativeMatrixLayout layout, class Type>
void CoherentStore(globallycoherent RWStructuredBuffer<Type> data,
uint32_t index, uint32_t stride) {
CoherentStore<MemoryAccessMaskNone, layout>(data, index, stride);
}

// Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
// data using the given memory layout, stride, and memory access operands.
// `NonPrivatePointer` and `MakePointerVisible` with the workgroup scope
// will be added to the memory access operands to make the memory coherent.
//
// This function uses a SPIR-V pointer because HLSL does not allow groupshared
// memory object to be passed by reference. The pointer is a hack to get
// around that.
//
// The layout and stride will be passed to the SPIR-V instruction as is. The
// precise meaning can be found in the specification for
// SPV_KHR_cooperative_matrix.
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
class Type>
static CooperativeMatrix Load(WorkgroupSpirvPointer<Type> data,
uint32_t stride);

// Same as above, but uses MemoryAccessMaskNone for the memory access
// operands.
template <CooperativeMatrixLayout layout, class Type>
static CooperativeMatrix Load(WorkgroupSpirvPointer<Type> data,
uint32_t stride) {
return Load<MemoryAccessMaskNone, layout>(data, stride);
}

// Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
// data[index] using the given memory layout, stride, and memory access
// operands.
//
// The layout and stride will be passed to the SPIR-V instruction as is. The
// precise meaning can be found in the specification for
// SPV_KHR_cooperative_matrix.
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
class Type>
static CooperativeMatrix Load(RWStructuredBuffer<Type> data, uint32_t index,
uint32_t stride);

// Same as above, but uses MemoryAccessMaskNone for the memory access
// operands.
template <CooperativeMatrixLayout layout, class Type>
static CooperativeMatrix Load(RWStructuredBuffer<Type> data, uint32_t index,
uint32_t stride) {
return Load<MemoryAccessMaskNone, layout>(data, index, stride);
}

// Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
// data[index] using the given memory layout, stride, and memory access
// operands. `NonPrivatePointer` and `MakePointerVisible` with the QueueFamily
// scope will be added to the memory access operands to make the memory
// coherent.
//
//
// The layout and stride will be passed to the SPIR-V instruction as is. The
// precise meaning can be found in the specification for
// SPV_KHR_cooperative_matrix.
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
class Type>
static CooperativeMatrix
CoherentLoad(globallycoherent RWStructuredBuffer<Type> data, uint32_t index,
uint32_t stride);

// Same as above, but uses MemoryAccessMaskNone for the memory access operands
// template argument.
template <CooperativeMatrixLayout layout, class Type>
static CooperativeMatrix
CoherentLoad(globallycoherent RWStructuredBuffer<Type> data, uint32_t index,
uint32_t stride) {
return CoherentLoad<MemoryAccessMaskNone, layout>(data, index, stride);
}

// Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
// data[index] using the given memory layout, stride, and memory access
// operands. No memory access bits are added to the operands. Since the memory
// is readonly, there should be no need.
//
// The layout and stride will be passed to the SPIR-V instruction as is. The
// precise meaning can be found in the specification for
// SPV_KHR_cooperative_matrix.
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
class Type>
static CooperativeMatrix Load(StructuredBuffer<Type> data, uint32_t index,
uint32_t stride);

// Same as above, but uses MemoryAccessMaskNone for the memory access
// operands.
template <CooperativeMatrixLayout layout, class Type>
static CooperativeMatrix Load(StructuredBuffer<Type> data, uint32_t index,
uint32_t stride) {
return Load<MemoryAccessMaskNone, layout>(data, index, stride);
}

// Constructs a cooperative matrix with all values initialized to v. Note that
// all threads in scope must have the same value for v.
static CooperativeMatrix Splat(ComponentType v);

// Returns the result of OpCooperativeMatrixLengthKHR on the current type.
static uint32_t GetLength();

// Functions to access the elements of the cooperative matrix. The index must
// be less than GetLength().
void Set(ComponentType value, uint32_t index);
ComponentType Get(uint32_t index);

static const bool hasSignedIntegerComponentType =
(ComponentType(0) - ComponentType(1) < ComponentType(0));

// clang-format off
using SpirvMatrixType = vk::SpirvOpaqueType<
/* OpTypeCooperativeMatrixKHR */ 4456, ComponentType,
vk::integral_constant<uint, scope>, vk::integral_constant<uint, rows>,
vk::integral_constant<uint, columns>, vk::integral_constant<uint, use> >;

[[vk::ext_extension("SPV_KHR_cooperative_matrix")]]
[[vk::ext_capability(/* CooperativeMatrixKHRCapability */ 6022)]]
[[vk::ext_capability(/* VulkanMemoryModel */ 5345)]]
SpirvMatrixType _matrix;
// clang-format on
};

// Cooperative matrix that can be used in the "a" position of a multiply add
// instruction (r = (a * b) + c).
template <typename ComponentType, Scope scope, uint rows, uint columns>
using CooperativeMatrixA =
CooperativeMatrix<ComponentType, scope, rows, columns,
CooperativeMatrixUseMatrixAKHR>;

// Cooperative matrix that can be used in the "b" position of a multiply add
// instruction (r = (a * b) + c).
template <typename ComponentType, Scope scope, uint rows, uint columns>
using CooperativeMatrixB =
CooperativeMatrix<ComponentType, scope, rows, columns,
CooperativeMatrixUseMatrixBKHR>;

// Cooperative matrix that can be used in the "r" and "c" position of a multiply
// add instruction (r = (a * b) + c).
template <typename ComponentType, Scope scope, uint rows, uint columns>
using CooperativeMatrixAccumulator =
CooperativeMatrix<ComponentType, scope, rows, columns,
CooperativeMatrixUseMatrixAccumulatorKHR>;

// Returns the result of OpCooperativeMatrixMulAddKHR when applied to a, b, and
// c. The cooperative matrix operands are inferred, with the
// SaturatingAccumulationKHR bit not set.
template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
cooperativeMatrixMultiplyAdd(
CooperativeMatrixA<ComponentType, scope, rows, K> a,
CooperativeMatrixB<ComponentType, scope, K, columns> b,
CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c);

// Returns the result of OpCooperativeMatrixMulAddKHR when applied to a, b, and
// c. The cooperative matrix operands are inferred, with the
// SaturatingAccumulationKHR bit set.
template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
cooperativeMatrixSaturatingMultiplyAdd(
CooperativeMatrixA<ComponentType, scope, rows, K> a,
CooperativeMatrixB<ComponentType, scope, K, columns> b,
CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c);

} // namespace khr
} // namespace vk

#include "cooperative_matrix.impl"
#endif // _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_
Loading

0 comments on commit 3028830

Please sign in to comment.