diff --git a/include/nbl/builtin/hlsl/algorithm.hlsl b/include/nbl/builtin/hlsl/algorithm.hlsl index 7a05101437..4c14814714 100644 --- a/include/nbl/builtin/hlsl/algorithm.hlsl +++ b/include/nbl/builtin/hlsl/algorithm.hlsl @@ -109,15 +109,6 @@ uint upper_bound(inout Accessor accessor, const uint begin, const uint end, cons namespace impl { -template -struct comparator_lt_t -{ - bool operator()(const T lhs, const T rhs) - { - return lhs uint lower_bound(inout Accessor accessor, const uint begin, const uint end, const T value) diff --git a/include/nbl/builtin/hlsl/binops.hlsl b/include/nbl/builtin/hlsl/binops.hlsl deleted file mode 100644 index 7cfe7af00b..0000000000 --- a/include/nbl/builtin/hlsl/binops.hlsl +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright (C) 2022 - DevSH Graphics Programming Sp. z O.O. -// This file is part of the "Nabla Engine". -// For conditions of distribution and use, see copyright notice in nabla.h -#ifndef _NBL_BUILTIN_HLSL_BINOPS_INCLUDED_ -#define _NBL_BUILTIN_HLSL_BINOPS_INCLUDED_ - -namespace nbl -{ -namespace hlsl -{ -namespace binops -{ -template -struct bitwise_and -{ - T operator()(const T lhs, const T rhs) - { - return lhs&rhs; - } - - static T identity() - { - return ~0; - } -}; - -template -struct bitwise_or -{ - T operator()(const T lhs, const T rhs) - { - return lhs|rhs; - } - - static T identity() - { - return 0; - } -}; - -template -struct bitwise_xor -{ - T operator()(const T lhs, const T rhs) - { - return lhs^rhs; - } - - static T identity() - { - return 0; - } -}; - -template -struct add -{ - T operator()(const T lhs, const T rhs) - { - return lhs+rhs; - } - - static T identity() - { - return 0; - } -}; - -template -struct mul -{ - T operator()(const T lhs, const T rhs) - { - return lhs*rhs; - } - - static T identity() - { - return 1; - } -}; - -template -struct comparator_lt_t -{ - bool operator()(const T lhs, const T rhs) - { - return lhs -struct comparator_gt_t -{ - bool operator()(const T lhs, const T rhs) - { - return lhs>rhs; - } -}; - -template -struct comparator_lte_t -{ - bool operator()(const T lhs, const T rhs) - { - return lhs<=rhs; - } -}; - -template -struct comparator_gte_t -{ - bool operator()(const T lhs, const T rhs) - { - return lhs>=rhs; - } -}; - -template -struct min -{ - T operator()(const T lhs, const T rhs) - { - comparator_lt_t comp; - return comp(lhs, rhs) ? lhs : rhs; - } - - static T identity() - { - return ~0; - } -}; - -template -struct max -{ - T operator()(const T lhs, const T rhs) - { - comparator_gt_t comp; - return comp(lhs, rhs) ? lhs : rhs; - } - - static T identity() - { - return 0; - } -}; - -} -} -} - -#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/bxdf/common.hlsl b/include/nbl/builtin/hlsl/bxdf/common.hlsl index 643a1111f5..c6e8679d3b 100644 --- a/include/nbl/builtin/hlsl/bxdf/common.hlsl +++ b/include/nbl/builtin/hlsl/bxdf/common.hlsl @@ -4,9 +4,9 @@ #ifndef _NBL_BUILTIN_HLSL_BXDF_COMMON_INCLUDED_ #define _NBL_BUILTIN_HLSL_BXDF_COMMON_INCLUDED_ -#include -#include -#include +#include "nbl/builtin/hlsl/limits.hlsl" +#include "nbl/builtin/hlsl/numbers.hlsl" +#include "nbl/builtin/hlsl/math/functions.glsl" namespace nbl { diff --git a/include/nbl/builtin/hlsl/functional.hlsl b/include/nbl/builtin/hlsl/functional.hlsl new file mode 100644 index 0000000000..8e9a03feeb --- /dev/null +++ b/include/nbl/builtin/hlsl/functional.hlsl @@ -0,0 +1,166 @@ +// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_FUNCTIONAL_INCLUDED_ +#define _NBL_BUILTIN_HLSL_FUNCTIONAL_INCLUDED_ + +namespace nbl +{ +namespace hlsl +{ +#ifndef __HLSL_VERSION // CPP + +template struct bit_and : std::bit_and +{ + NBL_CONSTEXPR_STATIC_INLINE T identity = ~0; +}; + +template struct bit_or : std::bit_or +{ + NBL_CONSTEXPR_STATIC_INLINE T identity = 0; +}; + +template struct bit_xor : std::bit_xor +{ + NBL_CONSTEXPR_STATIC_INLINE T identity = 0; +}; + +template struct plus : std::plus +{ + NBL_CONSTEXPR_STATIC_INLINE T identity = 0; +}; + +template struct multiplies : std::multiplies +{ + NBL_CONSTEXPR_STATIC_INLINE T identity = 1; +}; + +template struct greater : std::greater; +template struct less : std::less; +template struct greater_equal : std::greater_equal; +template struct less_equal : std::less_equal; + +#else // HLSL + +template +struct bit_and +{ + T operator()(const T lhs, const T rhs) + { + return lhs & rhs; + } + + NBL_CONSTEXPR_STATIC_INLINE T identity = ~0; +}; + +template +struct bit_or +{ + T operator()(const T lhs, const T rhs) + { + return lhs | rhs; + } + + NBL_CONSTEXPR_STATIC_INLINE T identity = 0; +}; + +template +struct bit_xor +{ + T operator()(const T lhs, const T rhs) + { + return lhs ^ rhs; + } + + NBL_CONSTEXPR_STATIC_INLINE T identity = 0; +}; + +template +struct plus +{ + T operator()(const T lhs, const T rhs) + { + return lhs + rhs; + } + + NBL_CONSTEXPR_STATIC_INLINE T identity = 0; +}; + +template +struct multiplies +{ + T operator()(const T lhs, const T rhs) + { + return lhs * rhs; + } + + NBL_CONSTEXPR_STATIC_INLINE T identity = 1; +}; + +template +struct greater +{ + bool operator()(const T lhs, const T rhs) + { + return lhs > rhs; + } +}; + +template +struct less +{ + bool operator()(const T lhs, const T rhs) + { + return lhs < rhs; + } +}; + +template +struct greater_equal +{ + bool operator()(const T lhs, const T rhs) + { + return lhs >= rhs; + } +}; + +template +struct less_equal +{ + bool operator()(const T lhs, const T rhs) + { + return lhs <= rhs; + } +}; + +#endif + +// Min and Max are outside of the HLSL/C++ directives because we want these to be available in both contexts +// TODO: implement as mix(rhs +struct minimum +{ + T operator()(const T lhs, const T rhs) + { + return (rhs < lhs) ? rhs : lhs; + } + + NBL_CONSTEXPR_STATIC_INLINE T identity = ~0; +}; + +template +struct maximum +{ + T operator()(const T lhs, const T rhs) + { + return (lhs < rhs) ? rhs : lhs; + } + + NBL_CONSTEXPR_STATIC_INLINE T identity = 0; +}; + +} +} + +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/glsl_compat/core.hlsl b/include/nbl/builtin/hlsl/glsl_compat/core.hlsl index 9cfe40cda7..9e85b8e35d 100644 --- a/include/nbl/builtin/hlsl/glsl_compat/core.hlsl +++ b/include/nbl/builtin/hlsl/glsl_compat/core.hlsl @@ -55,12 +55,22 @@ T atomicCompSwap(NBL_REF_ARG(T) ptr, T comparator, T value) return spirv::atomicCompSwap(ptr, 1, 0, 0, value, comparator); } +/** + * For Compute Shaders + */ void barrier() { - spirv::controlBarrier(2, 2, 0x8 | 0x100); + spirv::controlBarrier(spv::ScopeWorkgroup, spv::ScopeWorkgroup, spv::MemorySemanticsAcquireReleaseMask | spv::MemorySemanticsWorkgroupMemoryMask); +} + +/** + * For Tessellation Control Shaders + */ +void tess_ctrl_barrier() { + spirv::controlBarrier(spv::ScopeWorkgroup, spv::ScopeInvocation, 0); } void memoryBarrierShared() { - spirv::memoryBarrier(1, 0x8 | 0x100); + spirv::memoryBarrier(spv::ScopeDevice, spv::MemorySemanticsAcquireReleaseMask | spv::MemorySemanticsWorkgroupMemoryMask); } } diff --git a/include/nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl b/include/nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl index ff379f5aee..bc47b81d61 100644 --- a/include/nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl +++ b/include/nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl @@ -53,17 +53,27 @@ uint4 gl_SubgroupLtMask() { } bool subgroupElect() { - return spirv::subgroupElect(/*subgroup execution scope*/ 3); + return spirv::subgroupElect(spv::ScopeSubgroup); } -// Memory Semantics: AcquireRelease, UniformMemory, WorkgroupMemory, AtomicCounterMemory, ImageMemory void subgroupBarrier() { - // REVIEW-519: barrier with subgroup scope is not supported so leave commented out for now - //spirv::controlBarrier(3, 3, 0x800 | 0x400 | 0x100 | 0x40 | 0x8); + spirv::controlBarrier(spv::ScopeSubgroup, spv::ScopeSubgroup, spv::MemorySemanticsImageMemoryMask | spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsUniformMemoryMask | spv::MemorySemanticsAcquireReleaseMask); +} + +void subgroupMemoryBarrier() { + spirv::memoryBarrier(spv::ScopeSubgroup, spv::MemorySemanticsImageMemoryMask | spv::MemorySemanticsWorkgroupMemoryMask | spv::MemorySemanticsUniformMemoryMask | spv::MemorySemanticsAcquireReleaseMask); +} + +void subgroupMemoryBarrierBuffer() { + spirv::memoryBarrier(spv::ScopeSubgroup, spv::MemorySemanticsAcquireReleaseMask | spv::MemorySemanticsUniformMemoryMask); } void subgroupMemoryBarrierShared() { - spirv::memoryBarrier(3, 0x800 | 0x400 | 0x100 | 0x40 | 0x8); + spirv::memoryBarrier(spv::ScopeSubgroup, spv::MemorySemanticsAcquireReleaseMask | spv::MemorySemanticsWorkgroupMemoryMask); +} + +void subgroupMemoryBarrierImage() { + spirv::memoryBarrier(spv::ScopeSubgroup, spv::MemorySemanticsAcquireReleaseMask | spv::MemorySemanticsImageMemoryMask); } } diff --git a/include/nbl/builtin/hlsl/memory_accessor.hlsl b/include/nbl/builtin/hlsl/memory_accessor.hlsl new file mode 100644 index 0000000000..8066e764ce --- /dev/null +++ b/include/nbl/builtin/hlsl/memory_accessor.hlsl @@ -0,0 +1,136 @@ +// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_MEMORY_ACCESSOR_INCLUDED_ +#define _NBL_BUILTIN_HLSL_MEMORY_ACCESSOR_INCLUDED_ + +#include "nbl/builtin/hlsl/glsl_compat/core.hlsl" + +namespace nbl +{ +namespace hlsl +{ +template +struct MemoryAdaptor +{ + NumberMemoryAccessor accessor; + + uint get(const uint ix) { return accessor.get(ix); } + void get(const uint ix, NBL_REF_ARG(uint) value) { value = accessor.get(ix);} + void get(const uint ix, NBL_REF_ARG(uint2) value) { value = uint2(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_));} + void get(const uint ix, NBL_REF_ARG(uint3) value) { value = uint3(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_));} + void get(const uint ix, NBL_REF_ARG(uint4) value) { value = uint4(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_));} + + void get(const uint ix, NBL_REF_ARG(int) value) { value = asint(accessor.get(ix));} + void get(const uint ix, NBL_REF_ARG(int2) value) { value = asint(uint2(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_)));} + void get(const uint ix, NBL_REF_ARG(int3) value) { value = asint(uint3(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_)));} + void get(const uint ix, NBL_REF_ARG(int4) value) { value = asint(uint4(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_)));} + + void get(const uint ix, NBL_REF_ARG(float) value) { value = asfloat(accessor.get(ix));} + void get(const uint ix, NBL_REF_ARG(float2) value) { value = asfloat(uint2(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_)));} + void get(const uint ix, NBL_REF_ARG(float3) value) { value = asfloat(uint3(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_)));} + void get(const uint ix, NBL_REF_ARG(float4) value) { value = asfloat(uint4(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_)));} + + void set(const uint ix, const uint value) {accessor.set(ix, value);} + void set(const uint ix, const uint2 value) { + accessor.set(ix, value.x); + accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, value.y); + } + void set(const uint ix, const uint3 value) { + accessor.set(ix, value.x); + accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, value.y); + accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, value.z); + } + void set(const uint ix, const uint4 value) { + accessor.set(ix, value.x); + accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, value.y); + accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, value.z); + accessor.set(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_, value.w); + } + + void set(const uint ix, const int value) {accessor.set(ix, asuint(value));} + void set(const uint ix, const int2 value) { + accessor.set(ix, asuint(value.x)); + accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y)); + } + void set(const uint ix, const int3 value) { + accessor.set(ix, asuint(value.x)); + accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y)); + accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.z)); + } + void set(const uint ix, const int4 value) { + accessor.set(ix, asuint(value.x)); + accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y)); + accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.z)); + accessor.set(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.w)); + } + + void set(const uint ix, const float value) {accessor.set(ix, asuint(value));} + void set(const uint ix, const float2 value) { + accessor.set(ix, asuint(value.x)); + accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y)); + } + void set(const uint ix, const float3 value) { + accessor.set(ix, asuint(value.x)); + accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y)); + accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.z)); + } + void set(const uint ix, const float4 value) { + accessor.set(ix, asuint(value.x)); + accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y)); + accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.z)); + accessor.set(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.w)); + } + + void atomicAnd(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { + orig = accessor.atomicAnd(ix, value); + } + void atomicAnd(const uint ix, const int value, NBL_REF_ARG(int) orig) { + orig = asint(accessor.atomicAnd(ix, asuint(value))); + } + void atomicAnd(const uint ix, const float value, NBL_REF_ARG(float) orig) { + orig = asfloat(accessor.atomicAnd(ix, asuint(value))); + } + void atomicOr(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { + orig = accessor.atomicOr(ix, value); + } + void atomicOr(const uint ix, const int value, NBL_REF_ARG(int) orig) { + orig = asint(accessor.atomicOr(ix, asuint(value))); + } + void atomicOr(const uint ix, const float value, NBL_REF_ARG(float) orig) { + orig = asfloat(accessor.atomicOr(ix, asuint(value))); + } + void atomicXor(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { + orig = accessor.atomicXor(ix, value); + } + void atomicXor(const uint ix, const int value, NBL_REF_ARG(int) orig) { + orig = asint(accessor.atomicXor(ix, asuint(value))); + } + void atomicXor(const uint ix, const float value, NBL_REF_ARG(float) orig) { + orig = asfloat(accessor.atomicXor(ix, asuint(value))); + } + void atomicAdd(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { + orig = accessor.atomicAdd(ix, value); + } + void atomicMin(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { + orig = accessor.atomicMin(ix, value); + } + void atomicMax(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { + orig = accessor.atomicMax(ix, value); + } + void atomicExchange(const uint ix, const uint value, NBL_REF_ARG(uint) orig) { + orig = accessor.atomicExchange(ix, value); + } + void atomicCompSwap(const uint ix, const uint value, const uint comp, NBL_REF_ARG(uint) orig) { + orig = accessor.atomicCompSwap(ix, comp, value); + } + + void workgroupExecutionAndMemoryBarrier() { + accessor.workgroupExecutionAndMemoryBarrier(); + } +}; + +} +} + +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/scan/declarations.hlsl b/include/nbl/builtin/hlsl/scan/declarations.hlsl new file mode 100644 index 0000000000..2d2e66e66d --- /dev/null +++ b/include/nbl/builtin/hlsl/scan/declarations.hlsl @@ -0,0 +1,66 @@ +#ifndef _NBL_HLSL_SCAN_DECLARATIONS_INCLUDED_ +#define _NBL_HLSL_SCAN_DECLARATIONS_INCLUDED_ + +// REVIEW: Not sure if this file is needed in HLSL implementation + +#include "nbl/builtin/hlsl/scan/parameters_struct.hlsl" + + +#ifndef _NBL_HLSL_SCAN_GET_PARAMETERS_DECLARED_ +namespace nbl +{ +namespace hlsl +{ +namespace scan +{ + Parameters_t getParameters(); +} +} +} +#define _NBL_HLSL_SCAN_GET_PARAMETERS_DECLARED_ +#endif + +#ifndef _NBL_HLSL_SCAN_GET_PADDED_DATA_DECLARED_ +namespace nbl +{ +namespace hlsl +{ +namespace scan +{ + template + void getData( + inout Storage_t data, + in uint levelInvocationIndex, + in uint localWorkgroupIndex, + in uint treeLevel, + in uint pseudoLevel + ); +} +} +} +#define _NBL_HLSL_SCAN_GET_PADDED_DATA_DECLARED_ +#endif + +#ifndef _NBL_HLSL_SCAN_SET_DATA_DECLARED_ +namespace nbl +{ +namespace hlsl +{ +namespace scan +{ + template + void setData( + in Storage_t data, + in uint levelInvocationIndex, + in uint localWorkgroupIndex, + in uint treeLevel, + in uint pseudoLevel, + in bool inRange + ); +} +} +} +#define _NBL_HLSL_SCAN_SET_DATA_DECLARED_ +#endif + +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/scan/default_scheduler.hlsl b/include/nbl/builtin/hlsl/scan/default_scheduler.hlsl new file mode 100644 index 0000000000..450368475d --- /dev/null +++ b/include/nbl/builtin/hlsl/scan/default_scheduler.hlsl @@ -0,0 +1,221 @@ +#ifndef _NBL_HLSL_SCAN_DEFAULT_SCHEDULER_INCLUDED_ +#define _NBL_HLSL_SCAN_DEFAULT_SCHEDULER_INCLUDED_ + +#include "nbl/builtin/hlsl/scan/parameters_struct.hlsl" + +#ifdef __cplusplus +#define uint uint32_t +#endif + +namespace nbl +{ +namespace hlsl +{ +namespace scan +{ + struct DefaultSchedulerParameters_t + { + uint finishedFlagOffset[NBL_BUILTIN_MAX_SCAN_LEVELS-1]; + uint cumulativeWorkgroupCount[NBL_BUILTIN_MAX_SCAN_LEVELS]; + + }; +} +} +} + +#ifdef __cplusplus +#undef uint +#else + +namespace nbl +{ +namespace hlsl +{ +namespace scan +{ +namespace scheduler +{ + /** + * The CScanner.h parameter computation calculates the number of virtual workgroups that will have to be launched for the Scan operation + * (always based on the elementCount) as well as different offsets for the results of each step of the Scan operation, flag positions + * that are used for synchronization etc. + * Remember that CScanner does a Blelloch Scan which works in levels. In each level of the Blelloch scan the array of elements is + * broken down into sets of size=WorkgroupSize and each set is scanned using Hillis & Steele (aka Stone-Kogge adder). The result of + * the scan is provided as an array element for the next level of the Blelloch Scan. This means that if we have 10000 elements and + * WorkgroupSize=250, we will break the array into 40 sets and take their reduction results. The next level of the Blelloch Scan will + * have an array of size 40. Only a single workgroup will be needed to work on that. After that array is scanned, we use the results + * in the downsweep phase of Blelloch Scan. + * Keep in mind that each virtual workgroup executes a single step of the whole algorithm, which is why we have the cumulativeWorkgroupCount. + * The first virtual workgroups will work on the upsweep phase, the next on the downsweep phase. + * The intermediate results are stored in a scratch buffer. That buffer's size is is the sum of the element-array size for all the + * Blelloch levels. Using the previous example, the scratch size should be 10000 + 40. + * + * Parameter meaning: + * |> lastElement - the index of the last element of each Blelloch level in the scratch buffer + * |> topLevel - the top level the Blelloch Scan will have (this depends on the elementCount and the WorkgroupSize) + * |> temporaryStorageOffset - an offset array for each level of the Blelloch Scan. It is used when storing the REDUCTION result of each workgroup scan + * |> cumulativeWorkgroupCount - the sum-scan of all the workgroups that will need to be launched for each level of the Blelloch Scan (both upsweep and downsweep) + * |> finishedFlagOffset - an index in the scratch buffer where each virtual workgroup indicates that ALL its invocations have finished their work. This helps + * synchronizing between workgroups with while-loop spinning. + */ + void computeParameters(in uint elementCount, out Parameters_t _scanParams, out DefaultSchedulerParameters_t _schedulerParams) + { +#define WorkgroupCount(Level) (_scanParams.lastElement[Level+1]+1u) + _scanParams.lastElement[0] = elementCount-1u; + _scanParams.topLevel = firstbithigh(_scanParams.lastElement[0])/_NBL_HLSL_WORKGROUP_SIZE_LOG2_; + // REVIEW: _NBL_HLSL_WORKGROUP_SIZE_LOG2_ is defined in files that include THIS file. Why not query the API for workgroup size at runtime? + + for (uint i=0; i>_NBL_HLSL_WORKGROUP_SIZE_LOG2_; + i = next; + } + _schedulerParams.cumulativeWorkgroupCount[0] = WorkgroupCount(0); + _schedulerParams.finishedFlagOffset[0] = 0u; + switch(_scanParams.topLevel) + { + case 1u: + _schedulerParams.cumulativeWorkgroupCount[1] = _schedulerParams.cumulativeWorkgroupCount[0]+1u; + _schedulerParams.cumulativeWorkgroupCount[2] = _schedulerParams.cumulativeWorkgroupCount[1]+WorkgroupCount(0); + // climb up + _schedulerParams.finishedFlagOffset[1] = 1u; + + _scanParams.temporaryStorageOffset[0] = 2u; + break; + case 2u: + _schedulerParams.cumulativeWorkgroupCount[1] = _schedulerParams.cumulativeWorkgroupCount[0]+WorkgroupCount(1); + _schedulerParams.cumulativeWorkgroupCount[2] = _schedulerParams.cumulativeWorkgroupCount[1]+1u; + _schedulerParams.cumulativeWorkgroupCount[3] = _schedulerParams.cumulativeWorkgroupCount[2]+WorkgroupCount(1); + _schedulerParams.cumulativeWorkgroupCount[4] = _schedulerParams.cumulativeWorkgroupCount[3]+WorkgroupCount(0); + // climb up + _schedulerParams.finishedFlagOffset[1] = WorkgroupCount(1); + _schedulerParams.finishedFlagOffset[2] = _schedulerParams.finishedFlagOffset[1]+1u; + // climb down + _schedulerParams.finishedFlagOffset[3] = _schedulerParams.finishedFlagOffset[1]+2u; + + _scanParams.temporaryStorageOffset[0] = _schedulerParams.finishedFlagOffset[3]+WorkgroupCount(1); + _scanParams.temporaryStorageOffset[1] = _scanParams.temporaryStorageOffset[0]+WorkgroupCount(0); + break; + case 3u: + _schedulerParams.cumulativeWorkgroupCount[1] = _schedulerParams.cumulativeWorkgroupCount[0]+WorkgroupCount(1); + _schedulerParams.cumulativeWorkgroupCount[2] = _schedulerParams.cumulativeWorkgroupCount[1]+WorkgroupCount(2); + _schedulerParams.cumulativeWorkgroupCount[3] = _schedulerParams.cumulativeWorkgroupCount[2]+1u; + _schedulerParams.cumulativeWorkgroupCount[4] = _schedulerParams.cumulativeWorkgroupCount[3]+WorkgroupCount(2); + _schedulerParams.cumulativeWorkgroupCount[5] = _schedulerParams.cumulativeWorkgroupCount[4]+WorkgroupCount(1); + _schedulerParams.cumulativeWorkgroupCount[6] = _schedulerParams.cumulativeWorkgroupCount[5]+WorkgroupCount(0); + // climb up + _schedulerParams.finishedFlagOffset[1] = WorkgroupCount(1); + _schedulerParams.finishedFlagOffset[2] = _schedulerParams.finishedFlagOffset[1]+WorkgroupCount(2); + _schedulerParams.finishedFlagOffset[3] = _schedulerParams.finishedFlagOffset[2]+1u; + // climb down + _schedulerParams.finishedFlagOffset[4] = _schedulerParams.finishedFlagOffset[2]+2u; + _schedulerParams.finishedFlagOffset[5] = _schedulerParams.finishedFlagOffset[4]+WorkgroupCount(2); + + _scanParams.temporaryStorageOffset[0] = _schedulerParams.finishedFlagOffset[5]+WorkgroupCount(1); + _scanParams.temporaryStorageOffset[1] = _scanParams.temporaryStorageOffset[0]+WorkgroupCount(0); + _scanParams.temporaryStorageOffset[2] = _scanParams.temporaryStorageOffset[1]+WorkgroupCount(1); + break; + default: + break; +#if NBL_BUILTIN_MAX_SCAN_LEVELS>7 +#error "Switch needs more cases" +#endif + } +#undef WorkgroupCount + } + + /** + * treeLevel - the current level in the Blelloch Scan + * localWorkgroupIndex - the workgroup index the current invocation is a part of in the specific virtual dispatch. + * For example, if we have dispatched 10 workgroups and we the virtual workgroup number is 35, then the localWorkgroupIndex should be 5. + */ + template + bool getWork(in DefaultSchedulerParameters_t params, in uint topLevel, out uint treeLevel, out uint localWorkgroupIndex) + { + ScratchAccessor sharedScratch; + if(SubgroupContiguousIndex() == 0u) + { + uint64_t original; + InterlockedAdd(scanScratch.workgroupsStarted, 1u, original); // REVIEW: Refactor InterlockedAdd with GLSL terminology? // TODO (PentaKon): Refactor this when the ScanScratch descriptor set is declared + sharedScratch.set(SubgroupContiguousIndex(), original); + } + else if (SubgroupContiguousIndex() == 1u) + { + sharedScratch.set(SubgroupContiguousIndex(), 0u); + } + GroupMemoryBarrierWithGroupSync(); // REVIEW: refactor this somewhere with GLSL terminology? + + const uint globalWorkgroupIndex; // does every thread need to know? + sharedScratch.get(0u, globalWorkgroupIndex); + const uint lastLevel = topLevel<<1u; + if (SubgroupContiguousIndex()<=lastLevel && globalWorkgroupIndex>=params.cumulativeWorkgroupCount[SubgroupContiguousIndex()]) + { + InterlockedAdd(sharedScratch.get(1u, ?), 1u); // REVIEW: The way scratchaccessoradaptor is implemented (e.g. under subgroup/arithmetic_portability) doesn't allow for atomic ops on the scratch buffer. Should we ask for another implementation that overrides the [] operator ? + } + GroupMemoryBarrierWithGroupSync(); // TODO (PentaKon): Possibly refactor? + + sharedScratch.get(1u, treeLevel); + if(treeLevel>lastLevel) + return true; + + localWorkgroupIndex = globalWorkgroupIndex; + const bool dependentLevel = treeLevel != 0u; + if(dependentLevel) + { + const uint prevLevel = treeLevel - 1u; + localWorkgroupIndex -= params.cumulativeWorkgroupCount[prevLevel]; + if(SubgroupContiguousIndex() == 0u) + { + uint dependentsCount = 1u; + if(treeLevel <= topLevel) + { + dependentsCount = _NBL_HLSL_WORKGROUP_SIZE_; // REVIEW: Defined in the files that include this file? + const bool lastWorkgroup = (globalWorkgroupIndex+1u)==params.cumulativeWorkgroupCount[treeLevel]; + if (lastWorkgroup) + { + const Parameters_t scanParams = getParameters(); // TODO (PentaKon): Undeclared as of now, this should return the Parameters_t from the push constants of (in)direct shader + dependentsCount = scanParams.lastElement[treeLevel]+1u; + if (treeLeveltopLevel) // !(prevLevel globallycoherent \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/scan/direct.hlsl b/include/nbl/builtin/hlsl/scan/direct.hlsl new file mode 100644 index 0000000000..325a08e3f0 --- /dev/null +++ b/include/nbl/builtin/hlsl/scan/direct.hlsl @@ -0,0 +1,50 @@ +#ifndef _NBL_HLSL_WORKGROUP_SIZE_ +#define _NBL_HLSL_WORKGROUP_SIZE_ 256 +#endif + +#include "nbl/builtin/hlsl/scan/descriptors.hlsl" +#include "nbl/builtin/hlsl/scan/virtual_workgroup.hlsl" +#include "nbl/builtin/hlsl/scan/default_scheduler.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace scan +{ +#ifndef _NBL_HLSL_SCAN_PUSH_CONSTANTS_DEFINED_ + cbuffer PC // REVIEW: register and packoffset selection + { + Parameters_t scanParams; + DefaultSchedulerParameters_t schedulerParams; + }; +#define _NBL_HLSL_SCAN_PUSH_CONSTANTS_DEFINED_ +#endif + +#ifndef _NBL_HLSL_SCAN_GET_PARAMETERS_DEFINED_ +Parameters_t getParameters() +{ + return pc.scanParams; +} +#define _NBL_HLSL_SCAN_GET_PARAMETERS_DEFINED_ +#endif + +#ifndef _NBL_HLSL_SCAN_GET_SCHEDULER_PARAMETERS_DEFINED_ +DefaultSchedulerParameters_t getSchedulerParameters() +{ + return pc.schedulerParams; +} +#define _NBL_HLSL_SCAN_GET_SCHEDULER_PARAMETERS_DEFINED_ +#endif +} +} +} + +#ifndef _NBL_HLSL_MAIN_DEFINED_ +[numthreads(_NBL_HLSL_WORKGROUP_SIZE_, 1, 1)] +void CSMain() +{ + nbl::hlsl::scan::main(); +} +#define _NBL_HLSL_MAIN_DEFINED_ +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/scan/indirect.hlsl b/include/nbl/builtin/hlsl/scan/indirect.hlsl new file mode 100644 index 0000000000..1191731f65 --- /dev/null +++ b/include/nbl/builtin/hlsl/scan/indirect.hlsl @@ -0,0 +1,48 @@ +#ifndef _NBL_HLSL_WORKGROUP_SIZE_ +#define _NBL_HLSL_WORKGROUP_SIZE_ 256 +#define _NBL_HLSL_WORKGROUP_SIZE_LOG2_ 8 +#endif + +#include "nbl/builtin/hlsl/scan/descriptors.hlsl" +#include "nbl/builtin/hlsl/scan/virtual_workgroup.hlsl" +#include "nbl/builtin/hlsl/scan/default_scheduler.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace scan +{ +#ifndef _NBL_HLSL_SCAN_GET_PARAMETERS_DEFINED_ +Parameters_t scanParams; +Parameters_t getParameters() +{ + return scanParams; +} +#define _NBL_HLSL_SCAN_GET_PARAMETERS_DEFINED_ +#endif + +uint getIndirectElementCount(); + +#ifndef _NBL_HLSL_SCAN_GET_SCHEDULER_PARAMETERS_DEFINED_ +DefaultSchedulerParameters_t schedulerParams; +DefaultSchedulerParameters_t getSchedulerParameters() +{ + scheduler::computeParameters(getIndirectElementCount(),scanParams,schedulerParams); + return schedulerParams; +} +#define _NBL_HLSL_SCAN_GET_SCHEDULER_PARAMETERS_DEFINED_ +#endif +} +} +} + +#ifndef _NBL_HLSL_MAIN_DEFINED_ +[numthreads(_NBL_HLSL_WORKGROUP_SIZE_, 1, 1)] +void CSMain() +{ + if (bool(nbl::hlsl::scan::getIndirectElementCount())) + nbl::hlsl::scan::main(); +} +#define _NBL_HLSL_MAIN_DEFINED_ +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/scan/parameters_struct.hlsl b/include/nbl/builtin/hlsl/scan/parameters_struct.hlsl new file mode 100644 index 0000000000..bfeba13be2 --- /dev/null +++ b/include/nbl/builtin/hlsl/scan/parameters_struct.hlsl @@ -0,0 +1,30 @@ +#ifndef _NBL_HLSL_SCAN_PARAMETERS_STRUCT_INCLUDED_ +#define _NBL_HLSL_SCAN_PARAMETERS_STRUCT_INCLUDED_ + +#define NBL_BUILTIN_MAX_SCAN_LEVELS 7 + +#ifdef __cplusplus +#define uint uint32_t +#endif + +namespace nbl +{ +namespace hlsl +{ +namespace scan +{ + // REVIEW: Putting topLevel second allows better alignment for packing of constant variables, assuming lastElement has length 4. (https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules) + struct Parameters_t { + uint lastElement[NBL_BUILTIN_MAX_SCAN_LEVELS/2+1]; + uint topLevel; + uint temporaryStorageOffset[NBL_BUILTIN_MAX_SCAN_LEVELS/2]; + } +} +} +} + +#ifdef __cplusplus +#undef uint +#endif + +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/scan/virtual_workgroup.hlsl b/include/nbl/builtin/hlsl/scan/virtual_workgroup.hlsl new file mode 100644 index 0000000000..488bf29012 --- /dev/null +++ b/include/nbl/builtin/hlsl/scan/virtual_workgroup.hlsl @@ -0,0 +1,92 @@ +#ifndef _NBL_HLSL_SCAN_VIRTUAL_WORKGROUP_INCLUDED_ +#define _NBL_HLSL_SCAN_VIRTUAL_WORKGROUP_INCLUDED_ + +// TODO (PentaKon): Decide if these are needed once we have a clearer picture of the refactor +#include "nbl/builtin/hlsl/limits/numeric.hlsl" +#include "nbl/builtin/hlsl/math/typeless_arithmetic.hlsl" +#include "nbl/builtin/hlsl/workgroup/arithmetic.hlsl" // This is where all the nbl_glsl_workgroupOPs are defined +#include "nbl/builtin/hlsl/scan/declarations.hlsl" + +#include "nbl/builtin/hlsl/binops.hlsl" + +#if 0 +namespace nbl +{ +namespace hlsl +{ +namespace scan +{ + template + void virtualWorkgroup(in uint treeLevel, in uint localWorkgroupIndex) + { + const Parameters_t params = getParameters(); + const uint levelInvocationIndex = localWorkgroupIndex * _NBL_HLSL_WORKGROUP_SIZE_ + SubgroupContiguousIndex(); + const bool lastInvocationInGroup = SubgroupContiguousIndex() == (_NBL_HLSL_WORKGROUP_SIZE_ - 1); + + const uint lastLevel = params.topLevel << 1u; + const uint pseudoLevel = levelInvocationIndex <= params.lastElement[pseudoLevel]; + + const bool inRange = levelInvocationIndex <= params.lastElement[pseudoLevel]; + + Storage_t data = Binop::identity(); + if(inRange) + { + getData(data, levelInvocationIndex, localWorkgroupIndex, treeLevel, pseudoLevel); + } + + if(treeLevel < params.topLevel) + { + #error "Must also define some scratch accessor when calling operation()" + data = workgroup::reduction()(data); + } + // REVIEW: missing _TYPE_ check and extra case here + else if (treeLevel != params.topLevel) + { + data = workgroup::inclusive_scan()(data); + } + else + { + data = workgroup::exclusive_scan()(data); + } + setData(data, levelInvocationIndex, localWorkgroupIndex, treeLevel, pseudoLevel, inRange); + } +} +} +} + +#ifndef _NBL_HLSL_SCAN_MAIN_DEFINED_ // TODO REVIEW: Are these needed, can this logic be refactored? +#include "nbl/builtin/hlsl/scan/default_scheduler.hlsl" +namespace nbl +{ +namespace hlsl +{ +namespace scan +{ + DefaultSchedulerParameters_t getSchedulerParameters(); // this is defined in the final shader that assembles all the SCAN operation components + void main() + { + const DefaultSchedulerParameters_t schedulerParams = getSchedulerParameters(); + const uint topLevel = getParameters().topLevel; + // persistent workgroups + while (true) + { + uint treeLevel,localWorkgroupIndex; + if (scheduler::getWork(schedulerParams,topLevel,treeLevel,localWorkgroupIndex)) + { + return; + } + + virtualWorkgroup(treeLevel,localWorkgroupIndex); + + scheduler::markComplete(schedulerParams,topLevel,treeLevel,localWorkgroupIndex); + } + } +} +} +} +#endif + +#define _NBL_HLSL_SCAN_MAIN_DEFINED_ +#endif + +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl b/include/nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl index ba9754dd14..6b92b21c5a 100644 --- a/include/nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl +++ b/include/nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl @@ -6,7 +6,7 @@ #include "nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl" #include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl" -#include "nbl/builtin/hlsl/binops.hlsl" +#include "nbl/builtin/hlsl/functional.hlsl" #include "nbl/builtin/hlsl/subgroup/ballot.hlsl" namespace nbl @@ -29,7 +29,7 @@ struct inclusive_scan; // *** AND *** template -struct reduction > +struct reduction > { T operator()(const T x) { @@ -38,7 +38,7 @@ struct reduction > }; template -struct inclusive_scan > +struct inclusive_scan > { T operator()(const T x) { @@ -47,7 +47,7 @@ struct inclusive_scan > }; template -struct exclusive_scan > +struct exclusive_scan > { T operator()(const T x) { @@ -57,7 +57,7 @@ struct exclusive_scan > // *** OR *** template -struct reduction > +struct reduction > { T operator()(const T x) { @@ -66,7 +66,7 @@ struct reduction > }; template -struct inclusive_scan > +struct inclusive_scan > { T operator()(const T x) { @@ -75,7 +75,7 @@ struct inclusive_scan > }; template -struct exclusive_scan > +struct exclusive_scan > { T operator()(const T x) { @@ -85,7 +85,7 @@ struct exclusive_scan > // *** XOR *** template -struct reduction > +struct reduction > { T operator()(const T x) { @@ -94,7 +94,7 @@ struct reduction > }; template -struct inclusive_scan > +struct inclusive_scan > { T operator()(const T x) { @@ -103,7 +103,7 @@ struct inclusive_scan > }; template -struct exclusive_scan > +struct exclusive_scan > { T operator()(const T x) { @@ -113,7 +113,7 @@ struct exclusive_scan > // *** ADD *** template -struct reduction > +struct reduction > { T operator()(const T x) { @@ -121,7 +121,7 @@ struct reduction > } }; template -struct inclusive_scan > +struct inclusive_scan > { T operator()(const T x) { @@ -129,7 +129,7 @@ struct inclusive_scan > } }; template -struct exclusive_scan > +struct exclusive_scan > { T operator()(const T x) { @@ -139,7 +139,7 @@ struct exclusive_scan > // *** MUL *** template -struct reduction > +struct reduction > { T operator()(const T x) { @@ -147,7 +147,7 @@ struct reduction > } }; template -struct exclusive_scan > +struct exclusive_scan > { T operator()(const T x) { @@ -155,7 +155,7 @@ struct exclusive_scan > } }; template -struct inclusive_scan > +struct inclusive_scan > { T operator()(const T x) { @@ -165,7 +165,7 @@ struct inclusive_scan > // *** MIN *** template -struct reduction > +struct reduction > { T operator()(const T x) { @@ -174,7 +174,7 @@ struct reduction > }; template<> -struct inclusive_scan > +struct inclusive_scan > { int operator()(const int x) { @@ -183,7 +183,7 @@ struct inclusive_scan > }; template<> -struct inclusive_scan > +struct inclusive_scan > { uint operator()(const uint x) { @@ -192,7 +192,7 @@ struct inclusive_scan > }; template<> -struct inclusive_scan > +struct inclusive_scan > { float operator()(const float x) { @@ -201,7 +201,7 @@ struct inclusive_scan > }; template<> -struct exclusive_scan > +struct exclusive_scan > { int operator()(const int x) { @@ -210,7 +210,7 @@ struct exclusive_scan > }; template<> -struct exclusive_scan > +struct exclusive_scan > { uint operator()(const uint x) { @@ -219,7 +219,7 @@ struct exclusive_scan > }; template<> -struct exclusive_scan > +struct exclusive_scan > { float operator()(const float x) { @@ -229,7 +229,7 @@ struct exclusive_scan > // *** MAX *** template -struct reduction > +struct reduction > { T operator()(const T x) { @@ -238,7 +238,7 @@ struct reduction > }; template<> -struct inclusive_scan > +struct inclusive_scan > { int operator()(const int x) { @@ -247,7 +247,7 @@ struct inclusive_scan > }; template<> -struct inclusive_scan > +struct inclusive_scan > { uint operator()(const uint x) { @@ -256,7 +256,7 @@ struct inclusive_scan > }; template<> -struct inclusive_scan > +struct inclusive_scan > { float operator()(const float x) { @@ -265,7 +265,7 @@ struct inclusive_scan > }; template<> -struct exclusive_scan > +struct exclusive_scan > { int operator()(const int x) { @@ -274,7 +274,7 @@ struct exclusive_scan > }; template<> -struct exclusive_scan > +struct exclusive_scan > { uint operator()(const uint x) { @@ -283,7 +283,7 @@ struct exclusive_scan > }; template<> -struct exclusive_scan > +struct exclusive_scan > { float operator()(const float x) { diff --git a/include/nbl/builtin/hlsl/subgroup/ballot.hlsl b/include/nbl/builtin/hlsl/subgroup/ballot.hlsl index df654a7697..0b6b4bcdd1 100644 --- a/include/nbl/builtin/hlsl/subgroup/ballot.hlsl +++ b/include/nbl/builtin/hlsl/subgroup/ballot.hlsl @@ -7,6 +7,7 @@ #include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" #include "nbl/builtin/hlsl/glsl_compat/subgroup_ballot.hlsl" #include "nbl/builtin/hlsl/subgroup/basic.hlsl" +#include "nbl/builtin/hlsl/workgroup/basic.hlsl" namespace nbl { @@ -20,7 +21,7 @@ uint ElectedSubgroupInvocationID() { } uint ElectedLocalInvocationID() { - return glsl::subgroupBroadcastFirst(gl_LocalInvocationIndex); + return glsl::subgroupBroadcastFirst(workgroup::SubgroupContiguousIndex()); } } diff --git a/include/nbl/builtin/hlsl/type_traits.hlsl b/include/nbl/builtin/hlsl/type_traits.hlsl index 55f5674ab4..5ed799a626 100644 --- a/include/nbl/builtin/hlsl/type_traits.hlsl +++ b/include/nbl/builtin/hlsl/type_traits.hlsl @@ -7,12 +7,8 @@ // C++ headers #ifndef __HLSL_VERSION #include -#endif - - -#include #include - +#endif // Since HLSL currently doesnt allow type aliases we declare them as seperate structs thus they are (WORKAROUND)s /* @@ -149,6 +145,8 @@ template struct negation; */ +#else +#include namespace nbl { diff --git a/include/nbl/builtin/hlsl/workgroup/arithmetic.hlsl b/include/nbl/builtin/hlsl/workgroup/arithmetic.hlsl new file mode 100644 index 0000000000..829f698d1b --- /dev/null +++ b/include/nbl/builtin/hlsl/workgroup/arithmetic.hlsl @@ -0,0 +1,128 @@ +// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_ARITHMETIC_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP_ARITHMETIC_INCLUDED_ + +#include "nbl/builtin/hlsl/cpp_compat.hlsl" +#include "nbl/builtin/hlsl/functional.hlsl" +#include "nbl/builtin/hlsl/workgroup/ballot.hlsl" +#include "nbl/builtin/hlsl/workgroup/broadcast.hlsl" +#include "nbl/builtin/hlsl/workgroup/shared_scan.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup +{ + +#define REDUCE Reduce, SharedAccessor, _NBL_HLSL_WORKGROUP_SIZE_> +#define SCAN(isExclusive) Scan, SharedAccessor, _NBL_HLSL_WORKGROUP_SIZE_, isExclusive> +template +T reduction(T value, NBL_REF_ARG(SharedAccessor) accessor) +{ + REDUCE reduce = REDUCE::create(); + reduce(value, accessor); + accessor.main.workgroupExecutionAndMemoryBarrier(); + T retVal = Broadcast(reduce.lastLevelScan, accessor, reduce.lastInvocationInLevel); + return retVal; +} + +template +T inclusive_scan(T value, NBL_REF_ARG(SharedAccessor) accessor) +{ + SCAN(false) incl_scan = SCAN(false)::create(); + T retVal = incl_scan(value, accessor); + return retVal; +} + +template +T exclusive_scan(T value, NBL_REF_ARG(SharedAccessor) accessor) +{ + SCAN(true) excl_scan = SCAN(true)::create(); + T retVal = excl_scan(value, accessor); + return retVal; +} + +#undef REDUCE +#undef SCAN + +#define REDUCE Reduce >, SharedAccessor, impl::uballotBitfieldCount> +#define SCAN Scan, subgroup::inclusive_scan >, SharedAccessor, impl::uballotBitfieldCount, true> +/** + * Gives us the sum (reduction) of all ballots for the workgroup. + * + * Only the first few invocations are used for performing the sum + * since we only have `uballotBitfieldCount` amount of uints that we need + * to add together. + * + * We add them all in the shared array index after the last DWORD + * that is used for the ballots. For example, if we have 128 workgroup size, + * then the array index in which we accumulate the sum is `4` since + * indexes 0..3 are used for ballots. + */ +template +uint ballotBitCount(NBL_REF_ARG(SharedAccessor) accessor) +{ + uint participatingBitfield = 0; + if(SubgroupContiguousIndex() < impl::uballotBitfieldCount) + { + participatingBitfield = accessor.ballot.get(SubgroupContiguousIndex()); + } + accessor.ballot.workgroupExecutionAndMemoryBarrier(); + REDUCE reduce = REDUCE::create(); + reduce(countbits(participatingBitfield), accessor); + accessor.main.workgroupExecutionAndMemoryBarrier(); + return Broadcast(reduce.lastLevelScan, accessor, reduce.lastInvocationInLevel); +} + +template +uint ballotScanBitCount(const bool exclusive, NBL_REF_ARG(SharedAccessor) accessor) +{ + const uint _dword = impl::getDWORD(SubgroupContiguousIndex()); + const uint localBitfield = accessor.ballot.get(_dword); + uint globalCount; + { + uint participatingBitfield; + if(SubgroupContiguousIndex() < impl::uballotBitfieldCount) + { + participatingBitfield = accessor.ballot.get(SubgroupContiguousIndex()); + } + // scan hierarchically, invocations with `SubgroupContiguousIndex() >= uballotBitfieldCount` will have garbage here + accessor.ballot.workgroupExecutionAndMemoryBarrier(); + + SCAN scan = SCAN::create(); + uint bitscan = scan(countbits(participatingBitfield), accessor); + + accessor.main.set(SubgroupContiguousIndex(), bitscan); + accessor.main.workgroupExecutionAndMemoryBarrier(); + + // fix it (abuse the fact memory is left over) + globalCount = _dword != 0u ? accessor.main.get(_dword) : 0u; + accessor.main.workgroupExecutionAndMemoryBarrier(); + } + const uint mask = (exclusive ? 0x7fFFffFFu:0xFFffFFffu)>>(31u-(SubgroupContiguousIndex()&31u)); + return globalCount + countbits(localBitfield & mask); +} + +template +uint ballotInclusiveBitCount(NBL_REF_ARG(SharedAccessor) accessor) +{ + return ballotScanBitCount(false, accessor); +} + +template +uint ballotExclusiveBitCount(NBL_REF_ARG(SharedAccessor) accessor) +{ + return ballotScanBitCount(true, accessor); +} + +#undef REDUCE +#undef SCAN + +} +} +} + +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/workgroup/ballot.hlsl b/include/nbl/builtin/hlsl/workgroup/ballot.hlsl new file mode 100644 index 0000000000..b415bd0372 --- /dev/null +++ b/include/nbl/builtin/hlsl/workgroup/ballot.hlsl @@ -0,0 +1,78 @@ +// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_BALLOT_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP_BALLOT_INCLUDED_ + +#include "nbl/builtin/hlsl/cpp_compat.hlsl" +#include "nbl/builtin/hlsl/workgroup/basic.hlsl" +#include "nbl/builtin/hlsl/subgroup/arithmetic_portability.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup +{ +namespace impl +{ +uint getDWORD(uint invocation) +{ + return invocation >> 5; +} + +// uballotBitfieldCount essentially means 'how many DWORDs are needed to store ballots in bitfields, for each invocation of the workgroup' +// can't use getDWORD because we want the static const to be treated as 'constexpr' +static const uint uballotBitfieldCount = (_NBL_HLSL_WORKGROUP_SIZE_+31) >> 5; // in case WGSZ is not a multiple of 32 we might miscalculate the DWORDs after the right-shift by 5 which is why we add 31 + +} +/** + * Simple ballot function. + * + * Each invocation provides a boolean value. Each value is represented by a + * single bit of a Uint. For example, if invocation index 5 supplies `value = true` + * then the Uint will be ...00100000 + * This way we can encode 32 invocations into a single Uint. + * + * All Uints are kept in contiguous accessor memory in a shared array. + * The size of that array is based on the WORKGROUP SIZE. In this case we use uballotBitfieldCount. + * + * For each group of 32 invocations, a DWORD is assigned to the array (i.e. a 32-bit value, in this case Uint). + * For example, for a workgroup size 128, 4 DWORDs are needed. + * For each invocation index, we can find its respective DWORD index in the accessor array + * by calling the getDWORD function. + */ +template +void ballot(const bool value, NBL_REF_ARG(SharedAccessor) accessor) +{ + uint initialize = SubgroupContiguousIndex() < impl::uballotBitfieldCount; + if(initialize) { + accessor.ballot.set(SubgroupContiguousIndex(), 0u); + } + accessor.ballot.workgroupExecutionAndMemoryBarrier(); + if(value) { + uint dummy; + accessor.ballot.atomicOr(impl::getDWORD(SubgroupContiguousIndex()), 1u<<(SubgroupContiguousIndex()&31u), dummy); + } +} + +template +bool ballotBitExtract(const uint index, NBL_REF_ARG(SharedAccessor) accessor) +{ + return (accessor.ballot.get(impl::getDWORD(index)) & (1u << (index & 31u))) != 0u; +} + +/** + * Once we have assigned ballots in the shared array, we can + * extract any invocation's ballot value using this function. + */ +template +bool inverseBallot(NBL_REF_ARG(SharedAccessor) accessor) +{ + return ballotBitExtract(SubgroupContiguousIndex(), accessor); +} + +} +} +} +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/workgroup/basic.hlsl b/include/nbl/builtin/hlsl/workgroup/basic.hlsl new file mode 100644 index 0000000000..2c41c86f0d --- /dev/null +++ b/include/nbl/builtin/hlsl/workgroup/basic.hlsl @@ -0,0 +1,31 @@ +// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_BASIC_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP_BASIC_INCLUDED_ + +#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" + +//! all functions must be called in uniform control flow (all workgroup invocations active) +namespace nbl +{ +namespace hlsl +{ +namespace workgroup +{ + static const uint MaxWorkgroupSizeLog2 = 11; + static const uint MaxWorkgroupSize = 0x1u << MaxWorkgroupSizeLog2; + + uint SubgroupContiguousIndex() + { + return glsl::gl_SubgroupID() * glsl::gl_SubgroupSize() + glsl::gl_SubgroupInvocationID(); + } + + bool Elect() + { + return SubgroupContiguousIndex()==0u; + } +} +} +} +#endif diff --git a/include/nbl/builtin/hlsl/workgroup/broadcast.hlsl b/include/nbl/builtin/hlsl/workgroup/broadcast.hlsl new file mode 100644 index 0000000000..c9c05039c5 --- /dev/null +++ b/include/nbl/builtin/hlsl/workgroup/broadcast.hlsl @@ -0,0 +1,50 @@ +// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_BROADCAST_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP_BROADCAST_INCLUDED_ + +#include "nbl/builtin/hlsl/cpp_compat.hlsl" +#include "nbl/builtin/hlsl/workgroup/ballot.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup +{ + +/** + * Broadcasts the value `val` of invocation index `id` + * to all other invocations. + * + * We save the value in the shared array in the uballotBitfieldCount index + * and then all invocations access that index. + */ +template +T Broadcast(const T val, NBL_REF_ARG(SharedAccessor) accessor, const uint id) +{ + if(SubgroupContiguousIndex() == id) { + accessor.broadcast.set(impl::uballotBitfieldCount, val); + } + + accessor.broadcast.workgroupExecutionAndMemoryBarrier(); + + return accessor.broadcast.get(impl::uballotBitfieldCount); +} + +template +T BroadcastFirst(const T val, NBL_REF_ARG(SharedAccessor) accessor) +{ + if (Elect()) + accessor.broadcast.set(impl::uballotBitfieldCount, val); + + accessor.broadcast.workgroupExecutionAndMemoryBarrier(); + + return accessor.broadcast.get(impl::uballotBitfieldCount); +} + +} +} +} +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/workgroup/scratch_sz.hlsl b/include/nbl/builtin/hlsl/workgroup/scratch_sz.hlsl new file mode 100644 index 0000000000..beedffe440 --- /dev/null +++ b/include/nbl/builtin/hlsl/workgroup/scratch_sz.hlsl @@ -0,0 +1,34 @@ +// Copyright (C) 2022 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_SCRATCH_SZ_INCLUDED_ +#define _NBL_BUILTIN_HLSL_SCRATCH_SZ_INCLUDED_ + + +#include "nbl/builtin/hlsl/type_traits.hlsl" + +// REVIEW-519: Review this whole header and content (whether it should be here or somewhere else) +namespace nbl +{ +namespace hlsl +{ +namespace workgroup +{ +namespace impl +{ +template +struct ceil_div : integral_constant {}; + +template +struct trunc_geom_series; + +template +struct trunc_geom_series : integral_constant::value+trunc_geom_series::value> {}; + +template +struct trunc_geom_series : integral_constant {}; +} +} +} +} +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/workgroup/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup/shared_scan.hlsl new file mode 100644 index 0000000000..000fc682b3 --- /dev/null +++ b/include/nbl/builtin/hlsl/workgroup/shared_scan.hlsl @@ -0,0 +1,174 @@ +// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_SHARED_SCAN_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP_SHARED_SCAN_INCLUDED_ + +#include "nbl/builtin/hlsl/cpp_compat.hlsl" +#include "nbl/builtin/hlsl/workgroup/broadcast.hlsl" +#include "nbl/builtin/hlsl/glsl_compat/core.hlsl" +#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup +{ + +template +struct Reduce +{ + T firstLevelScan; + T lastLevelScan; + uint lastInvocation; + uint lastInvocationInLevel; + uint scanLoadIndex; + bool participate; + + static Reduce create() + { + Reduce wsh; + wsh.lastInvocation = itemCount - 1u; + return wsh; + } + + void operator()(T value, NBL_REF_ARG(SharedAccessor) sharedAccessor) + { + const uint subgroupMask = glsl::gl_SubgroupSize() - 1u; + lastInvocationInLevel = lastInvocation; + + SubgroupOp subgroupOp; + firstLevelScan = subgroupOp(value); + T scan = firstLevelScan; + + const bool isLastSubgroupInvocation = glsl::gl_SubgroupInvocationID() == glsl::gl_SubgroupSize() - 1u; + + // Since we are scanning the RESULT of the initial scan (which paired one input per subgroup invocation) + // every group of gl_SubgroupSz invocations has been coallesced into 1 result value. This means that the results of + // the first gl_SubgroupSz^2 invocations will be processed by the first subgroup and so on. + // Consequently, those first gl_SubgroupSz^2 invocations will store their results on gl_SubgroupSz scratch slots + // and the next level will follow the same + the previous as an `offset`. + + scanLoadIndex = SubgroupContiguousIndex(); + const uint loadStoreIndexDiff = scanLoadIndex - glsl::gl_SubgroupID(); + + participate = SubgroupContiguousIndex() <= lastInvocationInLevel; + // to cancel out the index shift on the first iteration + if (lastInvocationInLevel >= glsl::gl_SubgroupSize()) + scanLoadIndex -= lastInvocationInLevel-1u; + // TODO: later [unroll(scan_levels::value-1)] + [unroll(1)] + while(lastInvocationInLevel >= glsl::gl_SubgroupSize()) + { + scanLoadIndex += lastInvocationInLevel+1u; + // only invocations that have the final value of the subgroupOp (inclusive scan) store their results + if (participate && (SubgroupContiguousIndex()==lastInvocationInLevel || isLastSubgroupInvocation)) + sharedAccessor.main.set(scanLoadIndex - loadStoreIndexDiff, scan); // For subgroupSz = 32, first 512 invocations store index is [0,15], 512-1023 [16,31] etc. + sharedAccessor.main.workgroupExecutionAndMemoryBarrier(); + participate = SubgroupContiguousIndex() <= (lastInvocationInLevel >>= glsl::gl_SubgroupSizeLog2()); + if(participate) + { + const uint prevLevelScan = sharedAccessor.main.get(scanLoadIndex); + scan = subgroupOp(prevLevelScan); + } + } + lastLevelScan = scan; // only invocations of SubgroupContiguousIndex() < gl_SubgroupSize will have correct values, rest will have garbage + } +}; + +template +struct Scan +{ + Reduce reduce; + + static Scan create() + { + Scan scan; + scan.reduce = Reduce::create(); + return scan; + } + + T operator()(T value, NBL_REF_ARG(SharedAccessor) sharedAccessor) + { + reduce(value, sharedAccessor); + + Binop binop; + uint lastInvocation = reduce.lastInvocation; + uint firstLevelScan = reduce.firstLevelScan; + const uint subgroupId = glsl::gl_SubgroupID(); + + // abuse integer wraparound to map 0 to 0xffFFffFFu + const uint32_t prevSubgroupID = uint32_t(glsl::gl_SubgroupID())-1u; + + // important check to prevent weird `firstbithigh` overlflows + if(lastInvocation >= glsl::gl_SubgroupSize()) + { + // different than Upsweep cause we need to translate high level inclusive scans into exclusive on the fly, so we get the value of the subgroup behind our own in each level + const uint32_t storeLoadIndexDiff = uint32_t(SubgroupContiguousIndex()) - prevSubgroupID ; + + // because DXC doesn't do references and I need my "frozen" registers + #define scanStoreIndex reduce.scanLoadIndex + // we sloop over levels from highest to penultimate + // as we iterate some previously active (higher level) invocations hold their exclusive prefix sum in `lastLevelScan` + const uint32_t temp = firstbithigh(lastInvocation) / glsl::gl_SubgroupSizeLog2(); // doing division then multiplication might be optimized away by the compiler + const uint32_t initialLogShift = temp * glsl::gl_SubgroupSizeLog2(); + // TODO: later [unroll(scan_levels::value-1)] + [unroll(1)] + for(uint32_t logShift=initialLogShift; bool(logShift); logShift-=glsl::gl_SubgroupSizeLog2()) + { + // on the first iteration gl_SubgroupID==0 will participate but not afterwards because binop operand is identity + if (reduce.participate) + { + // we need to add the higher level invocation exclusive prefix sum to current value + if (logShift!=initialLogShift) // but the top level doesn't have any level above itself + { + // this is fine if on the way up you also += under `if (participate)` + scanStoreIndex -= reduce.lastInvocationInLevel+1; + reduce.lastLevelScan = binop(reduce.lastLevelScan,sharedAccessor.main.get(scanStoreIndex)); + } + // now `lastLevelScan` has current level's inclusive prefux sum computed properly + // note we're overwriting the same location with same invocation so no barrier needed + // we store everything even though we'll never use the last entry due to shuffleup on read + sharedAccessor.main.set(scanStoreIndex,reduce.lastLevelScan); + } + sharedAccessor.main.workgroupExecutionAndMemoryBarrier(); + // we're sneaky and exclude `gl_SubgroupID==0` from participation by abusing integer underflow + reduce.participate = prevSubgroupID> logShift; + } + #undef scanStoreIndex + + //assert((lastInvocation>>glsl::gl_SubgroupSizeLog2())==reduce.lastInvocationInLevel); + + // the very first prefix sum we did is in a register, not Accessor scratch mem hence the special path + if ( prevSubgroupID < reduce.lastInvocationInLevel) + firstLevelScan = binop(reduce.lastLevelScan,firstLevelScan); + } + + if(isExclusive) + { + firstLevelScan = glsl::subgroupShuffleUp(firstLevelScan, 1u); + if(glsl::subgroupElect()) + { // shuffle doesn't work between subgroups but the value for each elected subgroup invocation is just the previous higherLevelExclusive + // note that we assume we might have to do scans with itemCount <= gl_WorkgroupSize + firstLevelScan = bool(subgroupId) ? reduce.lastLevelScan : Binop::identity(); + } + return firstLevelScan; + } + else + { + return firstLevelScan; + } + } +}; +} +} +} + +#endif \ No newline at end of file diff --git a/include/nbl/core/string/StringLiteral.h b/include/nbl/core/string/StringLiteral.h index e79fd044b8..ebbed673f6 100644 --- a/include/nbl/core/string/StringLiteral.h +++ b/include/nbl/core/string/StringLiteral.h @@ -23,6 +23,6 @@ struct StringLiteral } // for compatibility's sake -#define NBL_CORE_UNIQUE_STRING_LITERAL_TYPE(STRING_LITERAL) nbl::core::StringLiteral(STRING_LITERAL) +#define NBL_CORE_UNIQUE_STRING_LITERAL_TYPE(STRING_LITERAL) nbl::core::StringLiteral(STRING_LITERAL) #endif // _NBL_CORE_STRING_LITERAL_H_INCLUDED_ diff --git a/src/nbl/builtin/CMakeLists.txt b/src/nbl/builtin/CMakeLists.txt index b4ebf2c748..929a5fb38d 100644 --- a/src/nbl/builtin/CMakeLists.txt +++ b/src/nbl/builtin/CMakeLists.txt @@ -184,6 +184,12 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/arithmetic_portabili LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/arithmetic_portability_impl.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/ballot.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/basic.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/arithmetic.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/ballot.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/basic.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/broadcast.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/scratch_sz.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/shared_scan.hlsl") #glsl compat LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/glsl_compat/core.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/glsl_compat/subgroup_arithmetic.hlsl") @@ -281,7 +287,7 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/bit.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/math/quadrature/gauss_legendre/gauss_legendre.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/math/quadrature/gauss_legendre/impl.hlsl") -LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/binops.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/functional.hlsl") macro(NBL_ADD_BUILTIN_RESOURCES _TARGET_) # internal & Nabla only, must be added with the macro to properly propagate scope ADD_CUSTOM_BUILTIN_RESOURCES("${_TARGET_}" NBL_RESOURCES_TO_EMBED "${NBL_ROOT_PATH}/include" "nbl/builtin" "nbl::builtin" "${NBL_ROOT_PATH_BINARY}/include" "${NBL_ROOT_PATH_BINARY}/src" "STATIC" "INTERNAL")