From 57c2948801a3790dcc61a6bbc1f95435b6320f93 Mon Sep 17 00:00:00 2001 From: Steven Perron <stevenperron@google.com> Date: Fri, 2 Aug 2024 11:30:33 -0400 Subject: [PATCH] Add specific handling for inline spirv pointer types --- tools/clang/lib/SPIRV/LowerTypeVisitor.cpp | 45 +++++++++++++++++++ tools/clang/lib/SPIRV/LowerTypeVisitor.h | 7 +++ .../workgroupspirvpointer.varpointer.hlsl | 23 ++++++++++ 3 files changed, 75 insertions(+) create mode 100644 tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.varpointer.hlsl diff --git a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp index 874fcb7aef..3b0ac7badc 100644 --- a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp +++ b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp @@ -732,6 +732,15 @@ const SpirvType *LowerTypeVisitor::lowerInlineSpirvType( auto args = specDecl->getTemplateArgs()[operandsIndex].getPackAsArray(); + if (operandsIndex == 1 && args.size() == 2 && + static_cast<spv::Op>(opcode) == spv::Op::OpTypePointer) { + const SpirvType *result = + getSpirvPointerFromInlineSpirvType(args, rule, isRowMajor, srcLoc); + if (result) { + return result; + } + } + for (TemplateArgument arg : args) { switch (arg.getKind()) { case TemplateArgument::ArgKind::Type: { @@ -1363,5 +1372,41 @@ LowerTypeVisitor::populateLayoutInformation( return result; } +const SpirvType *LowerTypeVisitor::getSpirvPointerFromInlineSpirvType( + ArrayRef<TemplateArgument> args, SpirvLayoutRule rule, + Optional<bool> isRowMajor, SourceLocation location) { + + assert(args.size() == 2 && "OpTypePoint requires exactly 2 arguments."); + QualType scLiteralType = args[0].getAsType(); + SpirvConstant *constant = nullptr; + if (!getVkIntegralConstantValue(scLiteralType, constant, location) || + !constant) { + return nullptr; + } + if (!constant->isLiteral()) + return nullptr; + + auto *intConstant = dyn_cast<SpirvConstantInteger>(constant); + if (!intConstant) { + return nullptr; + } + + visitInstruction(constant); + spv::StorageClass storageClass = + static_cast<spv::StorageClass>(intConstant->getValue().getLimitedValue()); + + QualType pointeeType; + if (args[1].getKind() == TemplateArgument::ArgKind::Type) { + pointeeType = args[1].getAsType(); + } else { + TemplateName templateName = args[1].getAsTemplate(); + pointeeType = createASTTypeFromTemplateName(templateName); + } + + const SpirvType *pointeeSpirvType = + lowerType(pointeeType, rule, isRowMajor, location); + return spvContext.getPointerType(pointeeSpirvType, storageClass); +} + } // namespace spirv } // namespace clang diff --git a/tools/clang/lib/SPIRV/LowerTypeVisitor.h b/tools/clang/lib/SPIRV/LowerTypeVisitor.h index 895e0e6cfe..96235d1508 100644 --- a/tools/clang/lib/SPIRV/LowerTypeVisitor.h +++ b/tools/clang/lib/SPIRV/LowerTypeVisitor.h @@ -124,6 +124,13 @@ class LowerTypeVisitor : public Visitor { SpirvLayoutRule rule, const uint32_t fieldIndex); + /// Get a lowered SpirvPointer from the args to a SpirvOpaqueType. + /// The pointer will use the given layout rule. `isRowMajor` is used to + /// lower the pointee type. + const SpirvType *getSpirvPointerFromInlineSpirvType( + ArrayRef<TemplateArgument> args, SpirvLayoutRule rule, + Optional<bool> isRowMajor, SourceLocation location); + private: ASTContext &astContext; /// AST context SpirvContext &spvContext; /// SPIR-V context diff --git a/tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.varpointer.hlsl b/tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.varpointer.hlsl new file mode 100644 index 0000000000..30abf2ff31 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.varpointer.hlsl @@ -0,0 +1,23 @@ +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers %s 2>&1 | FileCheck %s + +#include "vk/spirv.h" + +// CHECK: OpCapability VariablePointers + +RWStructuredBuffer<int> data; + +groupshared int shared_data[64]; + +[[vk::ext_instruction(/* OpLoad */ 61)]] int +Load(vk::WorkgroupSpirvPointer<int> p); + +[[noinline]] +int foo(vk::WorkgroupSpirvPointer<int> param) { + return Load(param); +} + +[[vk::ext_capability(/* VariablePointersCapability */ 4442)]] +[numthreads(64, 1, 1)] void main() { + vk::WorkgroupSpirvPointer<int> p = vk::GetGroupSharedAddress(shared_data[0]); + data[0] = foo(p); +}