From 57c2948801a3790dcc61a6bbc1f95435b6320f93 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron@google.com>
Date: Fri, 2 Aug 2024 11:30:33 -0400
Subject: [PATCH] Add specific handling for inline spirv pointer types

---
 tools/clang/lib/SPIRV/LowerTypeVisitor.cpp    | 45 +++++++++++++++++++
 tools/clang/lib/SPIRV/LowerTypeVisitor.h      |  7 +++
 .../workgroupspirvpointer.varpointer.hlsl     | 23 ++++++++++
 3 files changed, 75 insertions(+)
 create mode 100644 tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.varpointer.hlsl

diff --git a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp
index 874fcb7aef..3b0ac7badc 100644
--- a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp
+++ b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp
@@ -732,6 +732,15 @@ const SpirvType *LowerTypeVisitor::lowerInlineSpirvType(
 
   auto args = specDecl->getTemplateArgs()[operandsIndex].getPackAsArray();
 
+  if (operandsIndex == 1 && args.size() == 2 &&
+      static_cast<spv::Op>(opcode) == spv::Op::OpTypePointer) {
+    const SpirvType *result =
+        getSpirvPointerFromInlineSpirvType(args, rule, isRowMajor, srcLoc);
+    if (result) {
+      return result;
+    }
+  }
+
   for (TemplateArgument arg : args) {
     switch (arg.getKind()) {
     case TemplateArgument::ArgKind::Type: {
@@ -1363,5 +1372,41 @@ LowerTypeVisitor::populateLayoutInformation(
   return result;
 }
 
+const SpirvType *LowerTypeVisitor::getSpirvPointerFromInlineSpirvType(
+    ArrayRef<TemplateArgument> args, SpirvLayoutRule rule,
+    Optional<bool> isRowMajor, SourceLocation location) {
+
+  assert(args.size() == 2 && "OpTypePoint requires exactly 2 arguments.");
+  QualType scLiteralType = args[0].getAsType();
+  SpirvConstant *constant = nullptr;
+  if (!getVkIntegralConstantValue(scLiteralType, constant, location) ||
+      !constant) {
+    return nullptr;
+  }
+  if (!constant->isLiteral())
+    return nullptr;
+
+  auto *intConstant = dyn_cast<SpirvConstantInteger>(constant);
+  if (!intConstant) {
+    return nullptr;
+  }
+
+  visitInstruction(constant);
+  spv::StorageClass storageClass =
+      static_cast<spv::StorageClass>(intConstant->getValue().getLimitedValue());
+
+  QualType pointeeType;
+  if (args[1].getKind() == TemplateArgument::ArgKind::Type) {
+    pointeeType = args[1].getAsType();
+  } else {
+    TemplateName templateName = args[1].getAsTemplate();
+    pointeeType = createASTTypeFromTemplateName(templateName);
+  }
+
+  const SpirvType *pointeeSpirvType =
+      lowerType(pointeeType, rule, isRowMajor, location);
+  return spvContext.getPointerType(pointeeSpirvType, storageClass);
+}
+
 } // namespace spirv
 } // namespace clang
diff --git a/tools/clang/lib/SPIRV/LowerTypeVisitor.h b/tools/clang/lib/SPIRV/LowerTypeVisitor.h
index 895e0e6cfe..96235d1508 100644
--- a/tools/clang/lib/SPIRV/LowerTypeVisitor.h
+++ b/tools/clang/lib/SPIRV/LowerTypeVisitor.h
@@ -124,6 +124,13 @@ class LowerTypeVisitor : public Visitor {
                                    SpirvLayoutRule rule,
                                    const uint32_t fieldIndex);
 
+  /// Get a lowered SpirvPointer from the args to a SpirvOpaqueType.
+  /// The pointer will use the given layout rule. `isRowMajor` is used to
+  /// lower the pointee type.
+  const SpirvType *getSpirvPointerFromInlineSpirvType(
+      ArrayRef<TemplateArgument> args, SpirvLayoutRule rule,
+      Optional<bool> isRowMajor, SourceLocation location);
+
 private:
   ASTContext &astContext;                /// AST context
   SpirvContext &spvContext;              /// SPIR-V context
diff --git a/tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.varpointer.hlsl b/tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.varpointer.hlsl
new file mode 100644
index 0000000000..30abf2ff31
--- /dev/null
+++ b/tools/clang/test/CodeGenSPIRV/workgroupspirvpointer.varpointer.hlsl
@@ -0,0 +1,23 @@
+// RUN: dxc -fspv-target-env=vulkan1.3 -T cs_6_0 -E main -spirv -HV 2021 -I %hlsl_headers %s 2>&1 | FileCheck %s
+
+#include "vk/spirv.h"
+
+// CHECK: OpCapability VariablePointers
+
+RWStructuredBuffer<int> data;
+
+groupshared int shared_data[64];
+
+[[vk::ext_instruction(/* OpLoad */ 61)]] int
+Load(vk::WorkgroupSpirvPointer<int> p);
+
+[[noinline]]
+int foo(vk::WorkgroupSpirvPointer<int> param) {
+  return Load(param);
+}
+
+[[vk::ext_capability(/* VariablePointersCapability */ 4442)]]
+[numthreads(64, 1, 1)] void main() {
+  vk::WorkgroupSpirvPointer<int> p = vk::GetGroupSharedAddress(shared_data[0]);
+  data[0] = foo(p);
+}