Skip to content

Commit

Permalink
Add specific handling for inline spirv pointer types
Browse files Browse the repository at this point in the history
  • Loading branch information
s-perron committed Aug 21, 2024
1 parent 9018970 commit 57c2948
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 0 deletions.
45 changes: 45 additions & 0 deletions tools/clang/lib/SPIRV/LowerTypeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,15 @@ const SpirvType *LowerTypeVisitor::lowerInlineSpirvType(

auto args = specDecl->getTemplateArgs()[operandsIndex].getPackAsArray();

if (operandsIndex == 1 && args.size() == 2 &&
static_cast<spv::Op>(opcode) == spv::Op::OpTypePointer) {
const SpirvType *result =
getSpirvPointerFromInlineSpirvType(args, rule, isRowMajor, srcLoc);
if (result) {
return result;
}
}

for (TemplateArgument arg : args) {
switch (arg.getKind()) {
case TemplateArgument::ArgKind::Type: {
Expand Down Expand Up @@ -1363,5 +1372,41 @@ LowerTypeVisitor::populateLayoutInformation(
return result;
}

const SpirvType *LowerTypeVisitor::getSpirvPointerFromInlineSpirvType(
ArrayRef<TemplateArgument> args, SpirvLayoutRule rule,
Optional<bool> isRowMajor, SourceLocation location) {

assert(args.size() == 2 && "OpTypePoint requires exactly 2 arguments.");
QualType scLiteralType = args[0].getAsType();
SpirvConstant *constant = nullptr;
if (!getVkIntegralConstantValue(scLiteralType, constant, location) ||
!constant) {
return nullptr;
}
if (!constant->isLiteral())
return nullptr;

auto *intConstant = dyn_cast<SpirvConstantInteger>(constant);
if (!intConstant) {
return nullptr;
}

visitInstruction(constant);
spv::StorageClass storageClass =
static_cast<spv::StorageClass>(intConstant->getValue().getLimitedValue());

QualType pointeeType;
if (args[1].getKind() == TemplateArgument::ArgKind::Type) {
pointeeType = args[1].getAsType();
} else {
TemplateName templateName = args[1].getAsTemplate();
pointeeType = createASTTypeFromTemplateName(templateName);
}

const SpirvType *pointeeSpirvType =
lowerType(pointeeType, rule, isRowMajor, location);
return spvContext.getPointerType(pointeeSpirvType, storageClass);
}

} // namespace spirv
} // namespace clang
7 changes: 7 additions & 0 deletions tools/clang/lib/SPIRV/LowerTypeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ class LowerTypeVisitor : public Visitor {
SpirvLayoutRule rule,
const uint32_t fieldIndex);

/// Get a lowered SpirvPointer from the args to a SpirvOpaqueType.
/// The pointer will use the given layout rule. `isRowMajor` is used to
/// lower the pointee type.
const SpirvType *getSpirvPointerFromInlineSpirvType(
ArrayRef<TemplateArgument> args, SpirvLayoutRule rule,
Optional<bool> isRowMajor, SourceLocation location);

private:
ASTContext &astContext; /// AST context
SpirvContext &spvContext; /// SPIR-V context
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers %s 2>&1 | FileCheck %s

#include "vk/spirv.h"

// CHECK: OpCapability VariablePointers

RWStructuredBuffer<int> data;

groupshared int shared_data[64];

[[vk::ext_instruction(/* OpLoad */ 61)]] int
Load(vk::WorkgroupSpirvPointer<int> p);

[[noinline]]
int foo(vk::WorkgroupSpirvPointer<int> param) {
return Load(param);
}

[[vk::ext_capability(/* VariablePointersCapability */ 4442)]]
[numthreads(64, 1, 1)] void main() {
vk::WorkgroupSpirvPointer<int> p = vk::GetGroupSharedAddress(shared_data[0]);
data[0] = foo(p);
}

0 comments on commit 57c2948

Please sign in to comment.