Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIRV] Expand RWBuffer load and store from HLSL #122355

Merged
merged 3 commits into from
Jan 17, 2025

Conversation

s-perron
Copy link
Contributor

@s-perron s-perron commented Jan 9, 2025

The code pattern that clang will generate for HLSL has changed from the
original plan. This allows the SPIR-V backend to generate code for the
current code generation.

It looks for patterns of the form:

%1 = @llvm.spv.resource.handlefrombinding
%2 = @llvm.spv.resource.getpointer(%1, index)
load/store %2

These three llvm-ir instruction are treated as a single unit that will

  1. Generate or find the global variable identified by the call to
    resource.handlefrombinding.
  2. Generate an OpLoad of the variable to get the handle to the image.
  3. Generate an OpImageRead or OpImageWrite using that handle with the
    given index.

This will generate the OpLoad in the same BB as the read/write.

Note: Now that resource.handlefrombinding is not processed on its own,
many existing tests had to be removed. We do not have intrinsics that
are able to use handles to sampled images, input attachments, etc., so
we cannot generate the load of the handle. These tests are removed for
now, and will be added when those resource types are fully implemented.

@s-perron s-perron requested a review from Keenuts January 9, 2025 19:57
@s-perron s-perron changed the title rwbuffer access [SPIRV] Expand RWBuffer load and store from HLSL Jan 10, 2025
The code pattern that clang will generate for HLSL has changed from the
original plan. This allows the SPIR-V backend to generate code for the
current code generation.

It looks for patterns of the form:

```
%1 = @llvm.spv.resource.handlefrombinding
%2 = @llvm.spv.resource.getpointer(%1, index)
load/store %2
```

These three llvm-ir instruction are treated as a single unit that will

1. Generate or find the global variable identified by the call to
   `resource.handlefrombinding`.
2. Generate an OpLoad of the variable to get the handle to the image.
3. Generate an OpImageRead or OpImageWrite using that handle with the
   given index.

This will generate the OpLoad in the same BB as the read/write.

Note: Now that `resource.handlefrombinding` is not processed on its own,
many existing tests had to be removed. We do not have intrinsics that
are able to use handles to sampled images, input attachments, etc., so
we cannot generate the load of the handle. These tests are removed for
now, and will be added when those resource types are fully implemented.
@s-perron s-perron marked this pull request as ready for review January 13, 2025 16:41
@llvmbot
Copy link
Member

llvmbot commented Jan 13, 2025

@llvm/pr-subscribers-backend-spir-v

Author: Steven Perron (s-perron)

Changes

The code pattern that clang will generate for HLSL has changed from the
original plan. This allows the SPIR-V backend to generate code for the
current code generation.

It looks for patterns of the form:

%1 = @<!-- -->llvm.spv.resource.handlefrombinding
%2 = @<!-- -->llvm.spv.resource.getpointer(%1, index)
load/store %2

These three llvm-ir instruction are treated as a single unit that will

  1. Generate or find the global variable identified by the call to
    resource.handlefrombinding.
  2. Generate an OpLoad of the variable to get the handle to the image.
  3. Generate an OpImageRead or OpImageWrite using that handle with the
    given index.

This will generate the OpLoad in the same BB as the read/write.

Note: Now that resource.handlefrombinding is not processed on its own,
many existing tests had to be removed. We do not have intrinsics that
are able to use handles to sampled images, input attachments, etc., so
we cannot generate the load of the handle. These tests are removed for
now, and will be added when those resource types are fully implemented.


Patch is 58.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122355.diff

23 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+16-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+6-3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+134-35)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+4-2)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-resources/BufferLoadStore.ll (+60)
  • (modified) llvm/test/CodeGen/SPIRV/hlsl-resources/BufferStore.ll (+2-1)
  • (removed) llvm/test/CodeGen/SPIRV/hlsl-resources/CombinedSamplerImageDynIdx.ll (-40)
  • (removed) llvm/test/CodeGen/SPIRV/hlsl-resources/CombinedSamplerImageNonUniformIdx.ll (-47)
  • (removed) llvm/test/CodeGen/SPIRV/hlsl-resources/InputAttachmentImageDynIdx.ll (-39)
  • (removed) llvm/test/CodeGen/SPIRV/hlsl-resources/InputAttachmentImageNonUniformIdx.ll (-46)
  • (removed) llvm/test/CodeGen/SPIRV/hlsl-resources/SampledImageDynIdx.ll (-65)
  • (removed) llvm/test/CodeGen/SPIRV/hlsl-resources/SampledImageNonUniformIdx.ll (-46)
  • (removed) llvm/test/CodeGen/SPIRV/hlsl-resources/SamplerArrayDynIdx.ll (-38)
  • (removed) llvm/test/CodeGen/SPIRV/hlsl-resources/SamplerArrayNonUniformIdx.ll (-45)
  • (modified) llvm/test/CodeGen/SPIRV/hlsl-resources/ScalarResourceType.ll (+8)
  • (modified) llvm/test/CodeGen/SPIRV/hlsl-resources/StorageImageDynIdx.ll (+4)
  • (modified) llvm/test/CodeGen/SPIRV/hlsl-resources/StorageImageNonUniformIdx.ll (+4)
  • (removed) llvm/test/CodeGen/SPIRV/hlsl-resources/StorageTexelBufferDynIdx.ll (-39)
  • (removed) llvm/test/CodeGen/SPIRV/hlsl-resources/StorageTexelBufferNonUniformIdx.ll (-46)
  • (removed) llvm/test/CodeGen/SPIRV/hlsl-resources/UniformTexelBufferDynIdx.ll (-39)
  • (removed) llvm/test/CodeGen/SPIRV/hlsl-resources/UniformTexelBufferNonUniformIdx.ll (-46)
  • (modified) llvm/test/CodeGen/SPIRV/hlsl-resources/UnknownBufferStore.ll (+2-1)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index d2b14d6d058c92..68f5fc9ee5975e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -264,7 +264,14 @@ bool expectIgnoredInIRTranslation(const Instruction *I) {
   const auto *II = dyn_cast<IntrinsicInst>(I);
   if (!II)
     return false;
-  return II->getIntrinsicID() == Intrinsic::invariant_start;
+  switch (II->getIntrinsicID()) {
+  case Intrinsic::invariant_start:
+  case Intrinsic::spv_resource_handlefrombinding:
+  case Intrinsic::spv_resource_getpointer:
+    return true;
+  default:
+    return false;
+  }
 }
 
 bool allowEmitFakeUse(const Value *Arg) {
@@ -725,6 +732,14 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
       if (Ty)
         break;
     }
+  } else if (auto *II = dyn_cast<IntrinsicInst>(I)) {
+    if (II->getIntrinsicID() == Intrinsic::spv_resource_getpointer) {
+      auto *ImageType = cast<TargetExtType>(II->getOperand(0)->getType());
+      assert(ImageType->getTargetExtName() == "spirv.Image");
+      Ty = ImageType->getTypeParameter(0);
+      // TODO: Need to look at the use to see if it needs to be a vector of the
+      // type.
+    }
   } else if (auto *CI = dyn_cast<CallInst>(I)) {
     static StringMap<unsigned> ResTypeByArg = {
         {"to_global", 0},
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index a06c62e68d1062..874894ae987268 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1114,9 +1114,12 @@ SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
   return nullptr;
 }
 
-SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg) {
-  MachineInstr *Instr = getVRegDef(CurMF->getRegInfo(), VReg);
-  return getSPIRVTypeForVReg(Instr->getOperand(1).getReg());
+SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg,
+                                              MachineFunction *MF) {
+  if (!MF)
+    MF = CurMF;
+  MachineInstr *Instr = getVRegDef(MF->getRegInfo(), VReg);
+  return getSPIRVTypeForVReg(Instr->getOperand(1).getReg(), MF);
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 528baf5f8d9e21..0c94ec4df97f54 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -377,7 +377,7 @@ class SPIRVGlobalRegistry {
                                  const MachineFunction *MF = nullptr) const;
 
   // Return the result type of the instruction defining the register.
-  SPIRVType *getResultType(Register VReg);
+  SPIRVType *getResultType(Register VReg, MachineFunction *MF = nullptr);
 
   // Whether the given VReg has a SPIR-V type mapped to it yet.
   bool hasSPIRVTypeForVReg(Register VReg) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b7b32dd0d626c6..eb8e629a0f7a28 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -276,8 +276,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
 
   bool selectReadImageIntrinsic(Register &ResVReg, const SPIRVType *ResType,
                                 MachineInstr &I) const;
-
   bool selectImageWriteIntrinsic(MachineInstr &I) const;
+  bool selectResourceGetPointer(Register &ResVReg, const SPIRVType *ResType,
+                                MachineInstr &I) const;
 
   // Utilities
   std::pair<Register, bool>
@@ -307,10 +308,15 @@ class SPIRVInstructionSelector : public InstructionSelector {
   SPIRVType *widenTypeToVec4(const SPIRVType *Type, MachineInstr &I) const;
   bool extractSubvector(Register &ResVReg, const SPIRVType *ResType,
                         Register &ReadReg, MachineInstr &InsertionPoint) const;
+  bool generateImageRead(Register &ResVReg, const SPIRVType *ResType,
+                         Register ImageReg, Register IdxReg, DebugLoc Loc,
+                         MachineInstr &Pos) const;
   bool BuildCOPY(Register DestReg, Register SrcReg, MachineInstr &I) const;
   bool loadVec3BuiltinInputID(SPIRV::BuiltIn::BuiltIn BuiltInValue,
                               Register ResVReg, const SPIRVType *ResType,
                               MachineInstr &I) const;
+  bool loadHandleBeforePosition(Register &HandleReg, const SPIRVType *ResType,
+                                GIntrinsic &HandleDef, MachineInstr &Pos) const;
 };
 
 } // end anonymous namespace
@@ -1018,6 +1024,25 @@ bool SPIRVInstructionSelector::selectLoad(Register ResVReg,
                                           MachineInstr &I) const {
   unsigned OpOffset = isa<GIntrinsic>(I) ? 1 : 0;
   Register Ptr = I.getOperand(1 + OpOffset).getReg();
+
+  auto *PtrDef = getVRegDef(*MRI, Ptr);
+  auto *IntPtrDef = dyn_cast<GIntrinsic>(PtrDef);
+  if (IntPtrDef &&
+      IntPtrDef->getIntrinsicID() == Intrinsic::spv_resource_getpointer) {
+    Register ImageReg = IntPtrDef->getOperand(2).getReg();
+    Register NewImageReg =
+        MRI->createVirtualRegister(MRI->getRegClass(ImageReg));
+    auto *ImageDef = cast<GIntrinsic>(getVRegDef(*MRI, ImageReg));
+    if (!loadHandleBeforePosition(NewImageReg, GR.getSPIRVTypeForVReg(ImageReg),
+                                  *ImageDef, I)) {
+      return false;
+    }
+
+    Register IdxReg = IntPtrDef->getOperand(3).getReg();
+    return generateImageRead(ResVReg, ResType, NewImageReg, IdxReg,
+                             I.getDebugLoc(), I);
+  }
+
   auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad))
                  .addDef(ResVReg)
                  .addUse(GR.getSPIRVTypeID(ResType))
@@ -1037,6 +1062,29 @@ bool SPIRVInstructionSelector::selectStore(MachineInstr &I) const {
   unsigned OpOffset = isa<GIntrinsic>(I) ? 1 : 0;
   Register StoreVal = I.getOperand(0 + OpOffset).getReg();
   Register Ptr = I.getOperand(1 + OpOffset).getReg();
+
+  auto *PtrDef = getVRegDef(*MRI, Ptr);
+  auto *IntPtrDef = dyn_cast<GIntrinsic>(PtrDef);
+  if (IntPtrDef &&
+      IntPtrDef->getIntrinsicID() == Intrinsic::spv_resource_getpointer) {
+    Register ImageReg = IntPtrDef->getOperand(2).getReg();
+    Register NewImageReg =
+        MRI->createVirtualRegister(MRI->getRegClass(ImageReg));
+    auto *ImageDef = cast<GIntrinsic>(getVRegDef(*MRI, ImageReg));
+    if (!loadHandleBeforePosition(NewImageReg, GR.getSPIRVTypeForVReg(ImageReg),
+                                  *ImageDef, I)) {
+      return false;
+    }
+
+    Register IdxReg = IntPtrDef->getOperand(3).getReg();
+    return BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                   TII.get(SPIRV::OpImageWrite))
+        .addUse(NewImageReg)
+        .addUse(IdxReg)
+        .addUse(StoreVal)
+        .constrainAllUses(TII, TRI, RBI);
+  }
+
   MachineBasicBlock &BB = *I.getParent();
   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpStore))
                  .addUse(Ptr)
@@ -3007,6 +3055,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   case Intrinsic::spv_resource_load_typedbuffer: {
     return selectReadImageIntrinsic(ResVReg, ResType, I);
   }
+  case Intrinsic::spv_resource_getpointer: {
+    return selectResourceGetPointer(ResVReg, ResType, I);
+  }
   case Intrinsic::spv_discard: {
     return selectDiscard(ResVReg, ResType, I);
   }
@@ -3024,27 +3075,7 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
 bool SPIRVInstructionSelector::selectHandleFromBinding(Register &ResVReg,
                                                        const SPIRVType *ResType,
                                                        MachineInstr &I) const {
-
-  uint32_t Set = foldImm(I.getOperand(2), MRI);
-  uint32_t Binding = foldImm(I.getOperand(3), MRI);
-  uint32_t ArraySize = foldImm(I.getOperand(4), MRI);
-  Register IndexReg = I.getOperand(5).getReg();
-  bool IsNonUniform = ArraySize > 1 && foldImm(I.getOperand(6), MRI);
-
-  MachineIRBuilder MIRBuilder(I);
-  Register VarReg = buildPointerToResource(ResType, Set, Binding, ArraySize,
-                                           IndexReg, IsNonUniform, MIRBuilder);
-
-  if (IsNonUniform)
-    buildOpDecorate(ResVReg, I, TII, SPIRV::Decoration::NonUniformEXT, {});
-
-  // TODO: For now we assume the resource is an image, which needs to be
-  // loaded to get the handle. That will not be true for storage buffers.
-  return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad))
-      .addDef(ResVReg)
-      .addUse(GR.getSPIRVTypeID(ResType))
-      .addUse(VarReg)
-      .constrainAllUses(TII, TRI, RBI);
+  return true;
 }
 
 bool SPIRVInstructionSelector::selectReadImageIntrinsic(
@@ -3057,34 +3088,49 @@ bool SPIRVInstructionSelector::selectReadImageIntrinsic(
   // We will do that when we can, but for now trying to move forward with other
   // issues.
   Register ImageReg = I.getOperand(2).getReg();
-  assert(MRI->getVRegDef(ImageReg)->getParent() == I.getParent() &&
-         "The image must be loaded in the same basic block as its use.");
+  auto *ImageDef = cast<GIntrinsic>(getVRegDef(*MRI, ImageReg));
+  Register NewImageReg = MRI->createVirtualRegister(MRI->getRegClass(ImageReg));
+  if (!loadHandleBeforePosition(NewImageReg, GR.getSPIRVTypeForVReg(ImageReg),
+                                *ImageDef, I)) {
+    return false;
+  }
+
+  Register IdxReg = I.getOperand(3).getReg();
+  DebugLoc Loc = I.getDebugLoc();
+  MachineInstr &Pos = I;
 
+  return generateImageRead(ResVReg, ResType, NewImageReg, IdxReg, Loc, Pos);
+}
+
+bool SPIRVInstructionSelector::generateImageRead(Register &ResVReg,
+                                                 const SPIRVType *ResType,
+                                                 Register ImageReg,
+                                                 Register IdxReg, DebugLoc Loc,
+                                                 MachineInstr &Pos) const {
   uint64_t ResultSize = GR.getScalarOrVectorComponentCount(ResType);
   if (ResultSize == 4) {
-    return BuildMI(*I.getParent(), I, I.getDebugLoc(),
-                   TII.get(SPIRV::OpImageRead))
+    return BuildMI(*Pos.getParent(), Pos, Loc, TII.get(SPIRV::OpImageRead))
         .addDef(ResVReg)
         .addUse(GR.getSPIRVTypeID(ResType))
         .addUse(ImageReg)
-        .addUse(I.getOperand(3).getReg())
+        .addUse(IdxReg)
         .constrainAllUses(TII, TRI, RBI);
   }
 
-  SPIRVType *ReadType = widenTypeToVec4(ResType, I);
+  SPIRVType *ReadType = widenTypeToVec4(ResType, Pos);
   Register ReadReg = MRI->createVirtualRegister(GR.getRegClass(ReadType));
   bool Succeed =
-      BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpImageRead))
+      BuildMI(*Pos.getParent(), Pos, Loc, TII.get(SPIRV::OpImageRead))
           .addDef(ReadReg)
           .addUse(GR.getSPIRVTypeID(ReadType))
           .addUse(ImageReg)
-          .addUse(I.getOperand(3).getReg())
+          .addUse(IdxReg)
           .constrainAllUses(TII, TRI, RBI);
   if (!Succeed)
     return false;
 
   if (ResultSize == 1) {
-    return BuildMI(*I.getParent(), I, I.getDebugLoc(),
+    return BuildMI(*Pos.getParent(), Pos, Loc,
                    TII.get(SPIRV::OpCompositeExtract))
         .addDef(ResVReg)
         .addUse(GR.getSPIRVTypeID(ResType))
@@ -3092,7 +3138,25 @@ bool SPIRVInstructionSelector::selectReadImageIntrinsic(
         .addImm(0)
         .constrainAllUses(TII, TRI, RBI);
   }
-  return extractSubvector(ResVReg, ResType, ReadReg, I);
+  return extractSubvector(ResVReg, ResType, ReadReg, Pos);
+}
+
+bool SPIRVInstructionSelector::selectResourceGetPointer(
+    Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
+#ifdef ASSERT
+  // For now, the operand is an image. This will change once we start handling
+  // more resource types.
+  Register ResourcePtr = I.getOperand(2).getReg();
+  SPIRVType *RegType = GR.getResultType(ResourcePtr);
+  assert(RegType->getOpcode() == SPIRV::OpTypeImage &&
+         "Can only handle texel buffers for now.");
+#endif
+
+  // For texel buffers, the index into the image is part of the OpImageRead or
+  // OpImageWrite instructions. So we will do nothing in this case. This
+  // intrinsic will be combined with the load or store when selecting the load
+  // or store.
+  return true;
 }
 
 bool SPIRVInstructionSelector::extractSubvector(
@@ -3144,15 +3208,20 @@ bool SPIRVInstructionSelector::selectImageWriteIntrinsic(
   // We will do that when we can, but for now trying to move forward with other
   // issues.
   Register ImageReg = I.getOperand(1).getReg();
-  assert(MRI->getVRegDef(ImageReg)->getParent() == I.getParent() &&
-         "The image must be loaded in the same basic block as its use.");
+  auto *ImageDef = cast<GIntrinsic>(getVRegDef(*MRI, ImageReg));
+  Register NewImageReg = MRI->createVirtualRegister(MRI->getRegClass(ImageReg));
+  if (!loadHandleBeforePosition(NewImageReg, GR.getSPIRVTypeForVReg(ImageReg),
+                                *ImageDef, I)) {
+    return false;
+  }
+
   Register CoordinateReg = I.getOperand(2).getReg();
   Register DataReg = I.getOperand(3).getReg();
   assert(GR.getResultType(DataReg)->getOpcode() == SPIRV::OpTypeVector);
   assert(GR.getScalarOrVectorComponentCount(GR.getResultType(DataReg)) == 4);
   return BuildMI(*I.getParent(), I, I.getDebugLoc(),
                  TII.get(SPIRV::OpImageWrite))
-      .addUse(ImageReg)
+      .addUse(NewImageReg)
       .addUse(CoordinateReg)
       .addUse(DataReg)
       .constrainAllUses(TII, TRI, RBI);
@@ -3677,6 +3746,36 @@ SPIRVType *SPIRVInstructionSelector::widenTypeToVec4(const SPIRVType *Type,
   return GR.getOrCreateSPIRVVectorType(ScalarType, 4, MIRBuilder);
 }
 
+bool SPIRVInstructionSelector::loadHandleBeforePosition(
+    Register &HandleReg, const SPIRVType *ResType, GIntrinsic &HandleDef,
+    MachineInstr &Pos) const {
+
+  assert(HandleDef.getIntrinsicID() ==
+         Intrinsic::spv_resource_handlefrombinding);
+  uint32_t Set = foldImm(HandleDef.getOperand(2), MRI);
+  uint32_t Binding = foldImm(HandleDef.getOperand(3), MRI);
+  uint32_t ArraySize = foldImm(HandleDef.getOperand(4), MRI);
+  Register IndexReg = HandleDef.getOperand(5).getReg();
+  bool IsNonUniform = ArraySize > 1 && foldImm(HandleDef.getOperand(6), MRI);
+
+  MachineIRBuilder MIRBuilder(HandleDef);
+  Register VarReg = buildPointerToResource(ResType, Set, Binding, ArraySize,
+                                           IndexReg, IsNonUniform, MIRBuilder);
+
+  if (IsNonUniform)
+    buildOpDecorate(HandleReg, HandleDef, TII, SPIRV::Decoration::NonUniformEXT,
+                    {});
+
+  // TODO: For now we assume the resource is an image, which needs to be
+  // loaded to get the handle. That will not be true for storage buffers.
+  return BuildMI(*Pos.getParent(), Pos, HandleDef.getDebugLoc(),
+                 TII.get(SPIRV::OpLoad))
+      .addDef(HandleReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addUse(VarReg)
+      .constrainAllUses(TII, TRI, RBI);
+}
+
 namespace llvm {
 InstructionSelector *
 createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 020c11a3af4e16..fee22e751a50f4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1696,14 +1696,16 @@ void addInstrRequirements(const MachineInstr &MI,
     break;
   case SPIRV::OpImageRead: {
     Register ImageReg = MI.getOperand(2).getReg();
-    SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(ImageReg);
+    SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
+        ImageReg, const_cast<MachineFunction *>(MI.getMF()));
     if (isImageTypeWithUnknownFormat(TypeDef))
       Reqs.addCapability(SPIRV::Capability::StorageImageReadWithoutFormat);
     break;
   }
   case SPIRV::OpImageWrite: {
     Register ImageReg = MI.getOperand(0).getReg();
-    SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(ImageReg);
+    SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
+        ImageReg, const_cast<MachineFunction *>(MI.getMF()));
     if (isImageTypeWithUnknownFormat(TypeDef))
       Reqs.addCapability(SPIRV::Capability::StorageImageWriteWithoutFormat);
     break;
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/BufferLoadStore.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/BufferLoadStore.ll
new file mode 100644
index 00000000000000..25dcc90cb61cda
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/BufferLoadStore.ll
@@ -0,0 +1,60 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-vulkan-library %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-library %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: [[float:%[0-9]+]] = OpTypeFloat 32
+; CHECK-DAG: [[v4float:%[0-9]+]] = OpTypeVector [[float]] 4
+; CHECK-DAG: [[int:%[0-9]+]] = OpTypeInt 32 0
+; CHECK-DAG: [[zero:%[0-9]+]] = OpConstant [[int]] 0
+; CHECK-DAG: [[one:%[0-9]+]] = OpConstant [[int]] 1
+; CHECK-DAG: [[twenty:%[0-9]+]] = OpConstant [[int]] 20
+; CHECK-DAG: [[twenty_three:%[0-9]+]] = OpConstant [[int]] 23
+; CHECK-DAG: [[ImageType:%[0-9]+]] = OpTypeImage [[float]] Buffer 2 0 0 2 Rgba32f 
+; CHECK-DAG: [[ImagePtr:%[0-9]+]] = OpTypePointer UniformConstant [[ImageType]]
+; CHECK: [[Var:%[0-9]+]] = OpVariable [[ImagePtr]] UniformConstant
+
+; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind willreturn memory(readwrite, inaccessiblemem: none)
+define void @main() local_unnamed_addr #0 {
+entry:
+; CHECK: [[H:%[0-9]+]] = OpLoad [[ImageType]] [[Var]]
+  %s_h.i = tail call target("spirv.Image", float, 5, 2, 0, 0, 2, 1) @llvm.spv.resource.handlefrombinding.tspirv.Image_f32_5_2_0_0_2_0t(i32 3, i32 5, i32 1, i32 0, i1 false)
+
+; CHECK: [[R:%[0-9]+]] = OpImageRead [[v4float]] [[H]] [[one]]
+; CHECK: [[V:%[0-9]+]] = OpCompositeExtract [[float]] [[R]] 0
+  %0 = tail call noundef nonnull align 4 dereferenceable(4) ptr @llvm.spv.resource.getpointer.p0.tspirv.Image_f32_5_2_0_0_2_0t(target("spirv.Image", float, 5, 2, 0, 0, 2, 1) %s_h.i, i32 1)
+  %1 = load float, ptr %0, align 4
+; CHECK: OpBranch [[bb_store:%[0-9]+]]
+  br label %bb_store
+
+; CHECK: [[bb_store]] = OpLabel
+bb_store:
+
+; CHECK: [[H:%[0-9]+]] = OpLoad [[ImageType]] [[Var]]
+; CHECK: OpImageWrite [[H]] [[zero]] [[V]]
+  %2 = tail call noundef nonnull align 4 dereferenceable(4) ptr @llvm.spv.resource.getpointer.p0.tspirv.Image_f32_5_2_0_0_2_0t(target("spirv.Image", float, 5, 2, 0, 0, 2, 1) %s_h.i, i32 0)
+  store float %1, ptr %2, align 4
+; CHECK: OpBranch [[bb_both:%[0-9]+]]
+  br label %bb_both
+
+; CHECK: [[bb_both]] = OpLabel
+bb_both:  
+; CHECK: [[H:%[0-9]+]] = OpLoad [[ImageType]] [[Var]]
+; CHECK: [[R:%[0-9]+]] = OpImageRead [[v4float]] [[H]] [[twenty_three]]
+; CHECK: [[V:%[0-9]+]] = OpCompositeExtract [[float]] [[R]] 0
+  %3 = tail call noundef nonnull align 4 dereferenceable(4) ptr @llvm.spv.resource.getpointer.p0.tspirv.Image_f32_5_2_0_0_2_0t(target("spirv.Image", float, 5, 2, 0, 0, 2, 1) %s_h.i, i32 23)
+  %4 = load float, ptr %3, align 4
+  
+; CHECK: [[H:%[0-9]+]] = OpLoad [[ImageType]] [[Var]]
+; CHECK: OpImageWrite [[H]] [[twenty]] [[V]]
+  %5 = tail call noundef nonnull align 4 dereferenceable(4) ptr @llvm.spv.resource.getpointer.p0.tspirv.Image_f32_5_2_0_0_2_0t(target("spirv.Image", float, 5, 2, 0, 0, 2, 1) %s_h.i, i32 20)
+  store float %4, ptr %5, align 4
+  ret void
+}
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare ptr @llvm.spv.resource.getpointer.p0.tspirv.Image_f32_5_2_0_0_2_0t(target("spirv.Image", float, 5, 2, 0, 0, 2, 1), i32) #1
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare target("spirv.Image", float, 5, 2, 0, 0, 2, 1) @llvm.spv.resource.handlefrombinding.tspirv.Image_f32_5_2_0_0_2_0t(i32, i32, i32, i32, i1) #1
+
+attributes #0 = { mustprogress nofree noinline norecurse nosync nounwind willreturn memory(readwrite, inaccessiblemem: none) "frame-pointer"="all" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { mustprogress nocallb...
[truncated]

Copy link
Contributor

@Keenuts Keenuts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall.
I wonder how optimizations will do with this IR (is the getpointer considered to have a side-effect to prevent moving it? and other transforms like reg2mem), but we might want to start with this, and then see how it breaks.

@bogner bogner self-requested a review January 13, 2025 18:06
@s-perron s-perron merged commit 4b692a9 into llvm:main Jan 17, 2025
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants