diff --git a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp index 37cbd8e710..1d7ebd8c9f 100644 --- a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp +++ b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp @@ -733,6 +733,15 @@ const SpirvType *LowerTypeVisitor::lowerInlineSpirvType( auto args = specDecl->getTemplateArgs()[operandsIndex].getPackAsArray(); + if (operandsIndex == 1 && args.size() == 2 && + static_cast(opcode) == spv::Op::OpTypePointer) { + const SpirvType *result = + getSpirvPointerFromInlineSpirvType(args, rule, isRowMajor, srcLoc); + if (result) { + return result; + } + } + for (TemplateArgument arg : args) { switch (arg.getKind()) { case TemplateArgument::ArgKind::Type: { @@ -1364,5 +1373,41 @@ LowerTypeVisitor::populateLayoutInformation( return result; } +const SpirvType *LowerTypeVisitor::getSpirvPointerFromInlineSpirvType( + ArrayRef args, SpirvLayoutRule rule, + Optional isRowMajor, SourceLocation location) { + + assert(args.size() == 2 && "OpTypePoint requires exactly 2 arguments."); + QualType scLiteralType = args[0].getAsType(); + SpirvConstant *constant = nullptr; + if (!getVkIntegralConstantValue(scLiteralType, constant, location) || + !constant) { + return nullptr; + } + if (!constant->isLiteral()) + return nullptr; + + auto *intConstant = dyn_cast(constant); + if (!intConstant) { + return nullptr; + } + + visitInstruction(constant); + spv::StorageClass storageClass = + static_cast(intConstant->getValue().getLimitedValue()); + + QualType pointeeType; + if (args[1].getKind() == TemplateArgument::ArgKind::Type) { + pointeeType = args[1].getAsType(); + } else { + TemplateName templateName = args[1].getAsTemplate(); + pointeeType = createASTTypeFromTemplateName(templateName); + } + + const SpirvType *pointeeSpirvType = + lowerType(pointeeType, rule, isRowMajor, location); + return spvContext.getPointerType(pointeeSpirvType, storageClass); +} + } // namespace spirv } // namespace clang diff --git a/tools/clang/lib/SPIRV/LowerTypeVisitor.h b/tools/clang/lib/SPIRV/LowerTypeVisitor.h index 895e0e6cfe..96235d1508 100644 --- a/tools/clang/lib/SPIRV/LowerTypeVisitor.h +++ b/tools/clang/lib/SPIRV/LowerTypeVisitor.h @@ -124,6 +124,13 @@ class LowerTypeVisitor : public Visitor { SpirvLayoutRule rule, const uint32_t fieldIndex); + /// Get a lowered SpirvPointer from the args to a SpirvOpaqueType. + /// The pointer will use the given layout rule. `isRowMajor` is used to + /// lower the pointee type. + const SpirvType *getSpirvPointerFromInlineSpirvType( + ArrayRef args, SpirvLayoutRule rule, + Optional isRowMajor, SourceLocation location); + private: ASTContext &astContext; /// AST context SpirvContext &spvContext; /// SPIR-V context diff --git a/tools/clang/test/CodeGenSPIRV/coopmatrix.element.access.hlsl b/tools/clang/test/CodeGenSPIRV/coopmatrix.element.access.hlsl index d99f22523e..318d10e155 100644 --- a/tools/clang/test/CodeGenSPIRV/coopmatrix.element.access.hlsl +++ b/tools/clang/test/CodeGenSPIRV/coopmatrix.element.access.hlsl @@ -21,7 +21,7 @@ int stride; uint32_t length = a.GetLength(); // CHECK: OpLoopMerge [[mbb:%[0-9]+]] for (int i = 0; i < length; ++i) { - // CHECK: [[ac:%[0-9]+]] = OpAccessChain %spirvIntrinsicType_0 [[a]] + // CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_Function_int [[a]] // CHECK: [[get:%[0-9]+]] = OpLoad %int [[ac]] // CHECK: [[add:%[0-9]+]] = OpIAdd %int [[get]] %int_1 // CHECK: OpStore [[ac]] [[add]] diff --git a/tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.varpointer.hlsl b/tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.varpointer.hlsl new file mode 100644 index 0000000000..30abf2ff31 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.varpointer.hlsl @@ -0,0 +1,23 @@ +// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers %s 2>&1 | FileCheck %s + +#include "vk/spirv.h" + +// CHECK: OpCapability VariablePointers + +RWStructuredBuffer data; + +groupshared int shared_data[64]; + +[[vk::ext_instruction(/* OpLoad */ 61)]] int +Load(vk::WorkgroupSpirvPointer p); + +[[noinline]] +int foo(vk::WorkgroupSpirvPointer param) { + return Load(param); +} + +[[vk::ext_capability(/* VariablePointersCapability */ 4442)]] +[numthreads(64, 1, 1)] void main() { + vk::WorkgroupSpirvPointer p = vk::GetGroupSharedAddress(shared_data[0]); + data[0] = foo(p); +}