Skip to content

Commit

Permalink
Refactor the counting sort logic into general utility
Browse files Browse the repository at this point in the history
  • Loading branch information
nipunG314 committed May 3, 2024
1 parent c2b0409 commit bc1f092
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 6 deletions.
30 changes: 30 additions & 0 deletions include/nbl/builtin/hlsl/sort/common.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2018-2024 - DevSH Graphics Programming Sp. z O.O.
// This file is part of the "Nabla Engine".
// For conditions of distribution and use, see copyright notice in nabla.h

#ifndef _NBL_BUILTIN_HLSL_SORT_COMMON_INCLUDED_
#define _NBL_BUILTIN_HLSL_SORT_COMMON_INCLUDED_

namespace nbl
{
namespace hlsl
{
namespace sort
{

struct CountingPushData
{
uint64_t inputKeyAddress;
uint64_t inputValueAddress;
uint64_t scratchAddress;
uint64_t outputKeyAddress;
uint64_t outputValueAddress;
uint32_t dataElementCount;
uint32_t minimum;
uint32_t elementsPerWT;
};

}
}
}
#endif
130 changes: 124 additions & 6 deletions include/nbl/builtin/hlsl/sort/counting.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// This file is part of the "Nabla Engine".
// For conditions of distribution and use, see copyright notice in nabla.h

#include "nbl/builtin/hlsl/sort/common.hlsl"
#include "nbl/builtin/hlsl/workgroup/arithmetic.hlsl"

#ifndef _NBL_BUILTIN_HLSL_SORT_COUNTING_INCLUDED_
#define _NBL_BUILTIN_HLSL_SORT_COUNTING_INCLUDED_

Expand Down Expand Up @@ -49,7 +52,7 @@ T load(P pointer, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002
template<typename T, typename P, uint32_t alignment >
[[vk::ext_capability( /*PhysicalStorageBufferAddresses */5347)]]
[[vk::ext_instruction( /*OpStore*/62)]]
void store(P pointer, [[vk::ext_reference]] T obj, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);
void store(P pointer, T obj, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);

// TODO: atomics for different types
template<typename T, typename P> // integers operate on 2s complement so same op for signed and unsigned
Expand Down Expand Up @@ -148,7 +151,7 @@ struct __ptr
// TODO: assert(addr&uint64_t(alignment-1)==0);
using retval_t = __ref < T, alignment, _restrict>;
retval_t retval;
retval.__init(impl::bitcast<typename retval_t::spv_ptr_t,uint64_t>(addr));
retval.__init(impl::bitcast<typename retval_t::spv_ptr_t>(addr));
return retval;
}
};
Expand All @@ -161,22 +164,137 @@ namespace hlsl
namespace sort
{

template<typename KeyAccessor>
NBL_CONSTEXPR uint32_t BucketsPerThread = ceil((float) BucketCount / WorkgroupSize);

groupshared uint32_t prefixScratch[BucketCount];

struct ScratchProxy
{
uint32_t get(const uint32_t ix)
{
return prefixScratch[ix];
}
void set(const uint32_t ix, const uint32_t value)
{
prefixScratch[ix] = value;
}

void workgroupExecutionAndMemoryBarrier()
{
nbl::hlsl::glsl::barrier();
}
};

static ScratchProxy arithmeticAccessor;

groupshared uint32_t sdata[BucketCount];

template<typename KeyAccessor, typename ValueAccessor, typename ScratchAccessor>
struct counting
{
void init(const uint64_t key_addr, uint32_t index) {
key_ptr = KeyAccessor(key_addr + sizeof(uint32_t) * index);
void init(
const CountingPushData data
) {
in_key_addr = data.inputKeyAddress;
out_key_addr = data.outputKeyAddress;
in_value_addr = data.inputValueAddress;
out_value_addr = data.outputValueAddress;
scratch_addr = data.scratchAddress;
data_element_count = data.dataElementCount;
minimum = data.minimum;
elements_per_wt = data.elementsPerWT;
}

void histogram()
{
uint32_t tid = nbl::hlsl::workgroup::SubgroupContiguousIndex();

[unroll]
for (int i = 0; i < BucketsPerThread; i++)
sdata[BucketsPerThread * tid + i] = 0;
uint32_t index = (nbl::hlsl::glsl::gl_WorkGroupID().x * WorkgroupSize + tid) * elements_per_wt;

nbl::hlsl::glsl::barrier();

for (int i = 0; i < elements_per_wt; i++)
{
if (index + i >= data_element_count)
break;
uint32_t value = ValueAccessor(in_value_addr + sizeof(uint32_t) * (index + i)).template deref<4>().load();
nbl::hlsl::glsl::atomicAdd(sdata[value - minimum], (uint32_t) 1);
}

nbl::hlsl::glsl::barrier();

uint32_t sum = 0;
uint32_t scan_sum = 0;

for (int i = 0; i < BucketsPerThread; i++)
{
sum = nbl::hlsl::workgroup::exclusive_scan < nbl::hlsl::plus < uint32_t >, WorkgroupSize > ::
template __call <ScratchProxy>
(sdata[WorkgroupSize * i + tid], arithmeticAccessor);

arithmeticAccessor.workgroupExecutionAndMemoryBarrier();

ScratchAccessor(scratch_addr + sizeof(uint32_t) * (WorkgroupSize * i + tid)).template deref<4>().atomicAdd(sum);
if ((tid == WorkgroupSize - 1) && i > 0)
ScratchAccessor(scratch_addr + sizeof(uint32_t) * (WorkgroupSize * i)).template deref<4>().atomicAdd(scan_sum);

arithmeticAccessor.workgroupExecutionAndMemoryBarrier();

if ((tid == WorkgroupSize - 1) && i < (BucketsPerThread - 1))
{
scan_sum = sum + sdata[WorkgroupSize * i + tid];
sdata[WorkgroupSize * (i + 1)] += scan_sum;
}

arithmeticAccessor.workgroupExecutionAndMemoryBarrier();
}
}

void scatter()
{
uint32_t tid = nbl::hlsl::workgroup::SubgroupContiguousIndex();

[unroll]
for (int i = 0; i < BucketsPerThread; i++)
sdata[BucketsPerThread * tid + i] = 0;
uint32_t index = (nbl::hlsl::glsl::gl_WorkGroupID().x * WorkgroupSize + tid) * elements_per_wt;

nbl::hlsl::glsl::barrier();

[unroll]
for (int i = 0; i < elements_per_wt; i++)
{
if (index + i >= data_element_count)
break;
uint32_t key = KeyAccessor(in_key_addr + sizeof(uint32_t) * (index + i)).template deref<4>().load();
uint32_t value = ValueAccessor(in_value_addr + sizeof(uint32_t) * (index + i)).template deref<4>().load();
nbl::hlsl::glsl::atomicAdd(sdata[value - minimum], (uint32_t) 1);
}

[unroll]
for (int i = 0; i < elements_per_wt; i++)
{
if (index + i >= data_element_count)
break;
uint32_t key = KeyAccessor(in_key_addr + sizeof(uint32_t) * (index + i)).template deref<4>().load();
uint32_t value = ValueAccessor(in_value_addr + sizeof(uint32_t) * (index + i)).template deref<4>().load();
sdata[value - minimum] = ScratchAccessor(scratch_addr + sizeof(uint32_t) * (value - minimum)).template deref<4>().atomicAdd(1);
KeyAccessor(out_key_addr + sizeof(uint32_t) * sdata[value - minimum]).template deref<4>().store(key);
ValueAccessor(out_value_addr + sizeof(uint32_t) * sdata[value - minimum]).template deref<4>().store(value);
}

nbl::hlsl::glsl::barrier();
}

KeyAccessor key_ptr;
uint64_t in_key_addr, out_key_addr;
uint64_t in_value_addr, out_value_addr;
uint64_t scratch_addr;
uint32_t data_element_count;
uint32_t minimum;
uint32_t elements_per_wt;
};

}
Expand Down
1 change: 1 addition & 0 deletions src/nbl/builtin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/shapes/line.hlsl")
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/shapes/beziers.hlsl")
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/shapes/util.hlsl")
#sort
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/sort/common.hlsl")
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/sort/counting.hlsl")

#subgroup
Expand Down

0 comments on commit bc1f092

Please sign in to comment.