From 68130541239ca60d7fd8b2d126d774b5deb6de94 Mon Sep 17 00:00:00 2001 From: Cassandra Beckley Date: Wed, 9 Oct 2024 16:09:59 -0700 Subject: [PATCH 1/5] [SPIR-V] Handle vectors passed to asuint `asuint` should be able to take vectors in addition to scalar values. Previously, it would be lowered as a bitcast from the input value to a vector of uints with a width of 2, which is not large enough if the input value is larger than a scalar value. In order to handle, for example, an input value that is a `double4`, we instead perform a component-wise bitcast. Fixes #6735 --- tools/clang/lib/SPIRV/SpirvEmitter.cpp | 56 +++++++++++++------ .../test/CodeGenSPIRV/intrinsics.asuint.hlsl | 31 ++++++++-- 2 files changed, 66 insertions(+), 21 deletions(-) diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index ac455c547f..6a2892a332 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -11519,23 +11519,45 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) { } case 3: { // Handling Method 6. - auto *value = doExpr(arg0); - auto *lowbits = doExpr(callExpr->getArg(1)); - auto *highbits = doExpr(callExpr->getArg(2)); - const auto uintType = astContext.UnsignedIntTy; - const auto uintVec2Type = astContext.getExtVectorType(uintType, 2); - auto *vecResult = spvBuilder.createUnaryOp(spv::Op::OpBitcast, uintVec2Type, - value, loc, range); - spvBuilder.createStore( - lowbits, - spvBuilder.createCompositeExtract(uintType, vecResult, {0}, - arg0->getLocStart(), range), - loc, range); - spvBuilder.createStore( - highbits, - spvBuilder.createCompositeExtract(uintType, vecResult, {1}, - arg0->getLocStart(), range), - loc, range); + const Expr *arg1 = callExpr->getArg(1); + SourceLocation arg1loc = arg1->getExprLoc(); + SourceRange arg1range = arg1->getSourceRange(); + + const Expr *arg2 = callExpr->getArg(2); + SourceLocation arg2loc = arg2->getExprLoc(); + SourceRange arg2range = arg2->getSourceRange(); + + SpirvInstruction *value = doExpr(arg0->IgnoreParenLValueCasts()); + if (!value->isLValue()) { + value = turnIntoLValue(argType, value, arg0->getExprLoc()); + } + SpirvInstruction *lowbits = doExpr(arg1); + SpirvInstruction *highbits = doExpr(arg2); + + QualType outType = arg1->getType(); + QualType arrayType = astContext.getConstantArrayType( + outType, llvm::APInt(32, 2), clang::ArrayType::Normal, 0); + + const SpirvType *arrayPtrType = + spvContext.getPointerType(arrayType, spv::StorageClass::Function); + + auto *arrayResultPtr = + spvBuilder.createUnaryOp(spv::Op::OpBitcast, arrayPtrType, value, loc); + + SpirvInstruction *lowbitsResultPtr = spvBuilder.createAccessChain( + outType, arrayResultPtr, {getValueZero(astContext.UnsignedIntTy)}, + arg1loc); + SpirvInstruction *lowbitsResult = + spvBuilder.createLoad(outType, lowbitsResultPtr, arg1loc, arg1range); + spvBuilder.createStore(lowbits, lowbitsResult, arg1loc, arg1range); + + SpirvInstruction *highbitsResultPtr = spvBuilder.createAccessChain( + outType, arrayResultPtr, {getValueOne(astContext.UnsignedIntTy)}, + arg2loc); + SpirvInstruction *highbitsResult = + spvBuilder.createLoad(outType, highbitsResultPtr, arg2loc, arg2range); + spvBuilder.createStore(highbits, highbitsResult, arg2loc, arg2range); + return nullptr; } default: diff --git a/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl b/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl index 8b6998fbe0..73a767ce5e 100644 --- a/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl +++ b/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl @@ -76,11 +76,34 @@ void main() { double value; uint lowbits; uint highbits; -// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %double %value -// CHECK-NEXT: [[resultVec:%[0-9]+]] = OpBitcast %v2uint [[value]] -// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 0 +// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_uint_uint_2 %value +// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_0 +// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %uint [[chain0]] // CHECK-NEXT: OpStore %lowbits [[resultVec0]] -// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 1 +// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_1 +// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %uint [[chain1]] // CHECK-NEXT: OpStore %highbits [[resultVec1]] asuint(value, lowbits, highbits); + + double4 value4; + uint4 lowbits4; + uint4 highbits4; +// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_v4uint_uint_2 %value4 +// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_v4uint [[resultArr]] %uint_0 +// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %v4uint [[chain0]] +// CHECK-NEXT: OpStore %lowbits4 [[resultVec0]] +// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_v4uint [[resultArr]] %uint_1 +// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %v4uint [[chain1]] +// CHECK-NEXT: OpStore %highbits4 [[resultVec1]] + asuint(value4, lowbits4, highbits4); + +// CHECK-NEXT: OpStore %temp_var_double %double_1234 +// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_uint_uint_2 %temp_var_double +// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_0 +// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %uint [[chain0]] +// CHECK-NEXT: OpStore %lowbits [[resultVec0]] +// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_1 +// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %uint [[chain1]] +// CHECK-NEXT: OpStore %highbits [[resultVec1]] + asuint(1234.0, lowbits, highbits); } From 6f5e5ddbb1473308bcaf38ae63f78cbc66e1693e Mon Sep 17 00:00:00 2001 From: Cassandra Beckley Date: Mon, 14 Oct 2024 14:44:24 -0700 Subject: [PATCH 2/5] Do a component-wise bitcast instead of reinterpreting the pointer --- tools/clang/lib/SPIRV/SpirvEmitter.cpp | 112 ++++++++++++++---- tools/clang/lib/SPIRV/SpirvEmitter.h | 16 +++ .../test/CodeGenSPIRV/intrinsics.asuint.hlsl | 81 +++++++++---- 3 files changed, 162 insertions(+), 47 deletions(-) diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 6a2892a332..fc33d02d90 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -11415,6 +11415,79 @@ SpirvEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr, return nullptr; } +void SpirvEmitter::splitDouble(SpirvInstruction *value, + SpirvInstruction *&lowbits, + SpirvInstruction *&highbits, SourceLocation loc, + SourceRange range) { + const QualType uintType = astContext.UnsignedIntTy; + const QualType uintVec2Type = astContext.getExtVectorType(uintType, 2); + + SpirvInstruction *uints = spvBuilder.createUnaryOp( + spv::Op::OpBitcast, uintVec2Type, value, loc, range); + + lowbits = spvBuilder.createCompositeExtract(uintType, uints, {0}, loc, range); + highbits = + spvBuilder.createCompositeExtract(uintType, uints, {1}, loc, range); +} + +void SpirvEmitter::splitDoubleVector(QualType elemType, uint32_t count, + QualType outputType, + SpirvInstruction *value, + SpirvInstruction *&lowbits, + SpirvInstruction *&highbits, + SourceLocation loc, SourceRange range) { + llvm::SmallVector lowElems; + llvm::SmallVector highElems; + + for (uint32_t i = 0; i < count; ++i) { + SpirvInstruction *elem = + spvBuilder.createCompositeExtract(elemType, value, {i}, loc, range); + SpirvInstruction *lowbitsResult = nullptr; + SpirvInstruction *highbitsResult = nullptr; + splitDouble(elem, lowbitsResult, highbitsResult, loc, range); + lowElems.push_back(lowbitsResult); + highElems.push_back(highbitsResult); + } + + lowbits = + spvBuilder.createCompositeConstruct(outputType, lowElems, loc, range); + highbits = + spvBuilder.createCompositeConstruct(outputType, highElems, loc, range); +} + +void SpirvEmitter::splitDoubleMatrix(QualType elemType, uint32_t rowCount, + uint32_t colCount, QualType outputType, + SpirvInstruction *value, + SpirvInstruction *&lowbits, + SpirvInstruction *&highbits, + SourceLocation loc, SourceRange range) { + + llvm::SmallVector lowElems; + llvm::SmallVector highElems; + + QualType colType = astContext.getExtVectorType(elemType, rowCount); + + const QualType uintType = astContext.UnsignedIntTy; + const QualType outputColType = + astContext.getExtVectorType(uintType, rowCount); + + for (uint32_t i = 0; i < colCount; ++i) { + SpirvInstruction *column = + spvBuilder.createCompositeExtract(colType, value, {i}, loc, range); + SpirvInstruction *lowbitsResult = nullptr; + SpirvInstruction *highbitsResult = nullptr; + splitDoubleVector(elemType, colCount, outputColType, column, lowbitsResult, + highbitsResult, loc, range); + lowElems.push_back(lowbitsResult); + highElems.push_back(highbitsResult); + } + + lowbits = + spvBuilder.createCompositeConstruct(outputType, lowElems, loc, range); + highbits = + spvBuilder.createCompositeConstruct(outputType, highElems, loc, range); +} + SpirvInstruction * SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) { // This function handles the following intrinsics: @@ -11527,37 +11600,32 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) { SourceLocation arg2loc = arg2->getExprLoc(); SourceRange arg2range = arg2->getSourceRange(); - SpirvInstruction *value = doExpr(arg0->IgnoreParenLValueCasts()); - if (!value->isLValue()) { - value = turnIntoLValue(argType, value, arg0->getExprLoc()); - } + SpirvInstruction *value = doExpr(arg0); SpirvInstruction *lowbits = doExpr(arg1); SpirvInstruction *highbits = doExpr(arg2); - QualType outType = arg1->getType(); - QualType arrayType = astContext.getConstantArrayType( - outType, llvm::APInt(32, 2), clang::ArrayType::Normal, 0); + QualType elemType = QualType(); + uint32_t rowCount = 0; + uint32_t colCount = 0; - const SpirvType *arrayPtrType = - spvContext.getPointerType(arrayType, spv::StorageClass::Function); + SpirvInstruction *lowbitsResult = nullptr; + SpirvInstruction *highbitsResult = nullptr; - auto *arrayResultPtr = - spvBuilder.createUnaryOp(spv::Op::OpBitcast, arrayPtrType, value, loc); + if (isScalarType(argType)) { + splitDouble(value, lowbitsResult, highbitsResult, loc, range); + } else if (isVectorType(argType, &elemType, &rowCount)) { + splitDoubleVector(elemType, rowCount, arg1->getType(), value, + lowbitsResult, highbitsResult, loc, range); + } else if (isMxNMatrix(argType, &elemType, &rowCount, &colCount)) { + splitDoubleMatrix(elemType, rowCount, colCount, arg1->getType(), value, + lowbitsResult, highbitsResult, loc, range); + } - SpirvInstruction *lowbitsResultPtr = spvBuilder.createAccessChain( - outType, arrayResultPtr, {getValueZero(astContext.UnsignedIntTy)}, - arg1loc); - SpirvInstruction *lowbitsResult = - spvBuilder.createLoad(outType, lowbitsResultPtr, arg1loc, arg1range); spvBuilder.createStore(lowbits, lowbitsResult, arg1loc, arg1range); - - SpirvInstruction *highbitsResultPtr = spvBuilder.createAccessChain( - outType, arrayResultPtr, {getValueOne(astContext.UnsignedIntTy)}, - arg2loc); - SpirvInstruction *highbitsResult = - spvBuilder.createLoad(outType, highbitsResultPtr, arg2loc, arg2range); spvBuilder.createStore(highbits, highbitsResult, arg2loc, arg2range); + // TODO: handle matrices + return nullptr; } default: diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 6589e642c6..32c4323d57 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -1312,6 +1312,22 @@ class SpirvEmitter : public ASTConsumer { /// the Vulkan memory model capability has been added to the module. bool UpgradeToVulkanMemoryModelIfNeeded(std::vector *module); + // TODO: docs + void splitDouble(SpirvInstruction *value, SpirvInstruction *&lowbits, + SpirvInstruction *&highbits, SourceLocation loc, + SourceRange range); + + void splitDoubleVector(QualType elemType, uint32_t count, QualType outputType, + SpirvInstruction *value, SpirvInstruction *&lowbits, + SpirvInstruction *&highbits, SourceLocation loc, + SourceRange range); + + void splitDoubleMatrix(QualType elemType, uint32_t rowCount, + uint32_t colCount, QualType outputType, + SpirvInstruction *value, SpirvInstruction *&lowbits, + SpirvInstruction *&highbits, SourceLocation loc, + SourceRange range); + public: /// \brief Wrapper method to create a fatal error message and report it /// in the diagnostic engine associated with this consumer. diff --git a/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl b/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl index 73a767ce5e..3c7dd6083b 100644 --- a/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl +++ b/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl @@ -76,34 +76,65 @@ void main() { double value; uint lowbits; uint highbits; -// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_uint_uint_2 %value -// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_0 -// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %uint [[chain0]] +// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %double %value +// CHECK-NEXT: [[resultVec:%[0-9]+]] = OpBitcast %v2uint [[value]] +// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 0 +// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 1 // CHECK-NEXT: OpStore %lowbits [[resultVec0]] -// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_1 -// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %uint [[chain1]] // CHECK-NEXT: OpStore %highbits [[resultVec1]] asuint(value, lowbits, highbits); - double4 value4; - uint4 lowbits4; - uint4 highbits4; -// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_v4uint_uint_2 %value4 -// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_v4uint [[resultArr]] %uint_0 -// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %v4uint [[chain0]] -// CHECK-NEXT: OpStore %lowbits4 [[resultVec0]] -// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_v4uint [[resultArr]] %uint_1 -// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %v4uint [[chain1]] -// CHECK-NEXT: OpStore %highbits4 [[resultVec1]] - asuint(value4, lowbits4, highbits4); + double3 value3; + uint3 lowbits3; + uint3 highbits3; +// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %v3double %value3 +// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[value]] 0 +// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]] +// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0 +// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1 +// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[value]] 1 +// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]] +// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0 +// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1 +// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[value]] 2 +// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]] +// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0 +// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1 +// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %v3uint [[low0]] [[low1]] [[low2]] +// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %v3uint [[high0]] [[high1]] [[high2]] +// CHECK-NEXT: OpStore %lowbits3 [[low]] +// CHECK-NEXT: OpStore %highbits3 [[high]] + asuint(value3, lowbits3, highbits3); -// CHECK-NEXT: OpStore %temp_var_double %double_1234 -// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_uint_uint_2 %temp_var_double -// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_0 -// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %uint [[chain0]] -// CHECK-NEXT: OpStore %lowbits [[resultVec0]] -// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_1 -// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %uint [[chain1]] -// CHECK-NEXT: OpStore %highbits [[resultVec1]] - asuint(1234.0, lowbits, highbits); + double2x2 value2x2; + uint2x2 lowbits2x2; + uint2x2 highbits2x2; +// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %mat2v2double %value2x2 +// CHECK-NEXT: [[row0:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 0 +// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[row0]] 0 +// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]] +// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0 +// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1 +// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[row0]] 1 +// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]] +// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0 +// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1 +// CHECK-NEXT: [[lowRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]] +// CHECK-NEXT: [[highRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]] +// CHECK-NEXT: [[row1:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 1 +// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[row1]] 0 +// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]] +// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0 +// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1 +// CHECK-NEXT: [[value3:%[0-9]+]] = OpCompositeExtract %double [[row1]] 1 +// CHECK-NEXT: [[resultVec3:%[0-9]+]] = OpBitcast %v2uint [[value3]] +// CHECK-NEXT: [[low3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 0 +// CHECK-NEXT: [[high3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 1 +// CHECK-NEXT: [[lowRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[low2]] [[low3]] +// CHECK-NEXT: [[highRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[high2]] [[high3]] +// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_2 [[lowRow0]] [[lowRow1]] +// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_2 [[highRow0]] [[highRow1]] +// CHECK-NEXT: OpStore %lowbits2x2 [[low]] +// CHECK-NEXT: OpStore %highbits2x2 [[high]] + asuint(value2x2, lowbits2x2, highbits2x2); } From 6f15e7c2f89680d8c7b178b78e74224ec870dfaa Mon Sep 17 00:00:00 2001 From: Cassandra Beckley Date: Mon, 14 Oct 2024 21:01:58 -0700 Subject: [PATCH 3/5] Fix handling of rows and columns --- tools/clang/lib/SPIRV/SpirvEmitter.cpp | 6 +-- .../test/CodeGenSPIRV/intrinsics.asuint.hlsl | 43 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index fc33d02d90..3cbed75f22 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -11465,13 +11465,13 @@ void SpirvEmitter::splitDoubleMatrix(QualType elemType, uint32_t rowCount, llvm::SmallVector lowElems; llvm::SmallVector highElems; - QualType colType = astContext.getExtVectorType(elemType, rowCount); + QualType colType = astContext.getExtVectorType(elemType, colCount); const QualType uintType = astContext.UnsignedIntTy; const QualType outputColType = - astContext.getExtVectorType(uintType, rowCount); + astContext.getExtVectorType(uintType, colCount); - for (uint32_t i = 0; i < colCount; ++i) { + for (uint32_t i = 0; i < rowCount; ++i) { SpirvInstruction *column = spvBuilder.createCompositeExtract(colType, value, {i}, loc, range); SpirvInstruction *lowbitsResult = nullptr; diff --git a/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl b/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl index 3c7dd6083b..3f84a4970f 100644 --- a/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl +++ b/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl @@ -137,4 +137,47 @@ void main() { // CHECK-NEXT: OpStore %lowbits2x2 [[low]] // CHECK-NEXT: OpStore %highbits2x2 [[high]] asuint(value2x2, lowbits2x2, highbits2x2); + + double3x2 value3x2; + uint3x2 lowbits3x2; + uint3x2 highbits3x2; +// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %mat3v2double %value3x2 +// CHECK-NEXT: [[row0:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 0 +// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[row0]] 0 +// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]] +// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0 +// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1 +// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[row0]] 1 +// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]] +// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0 +// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1 +// CHECK-NEXT: [[lowRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]] +// CHECK-NEXT: [[highRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]] +// CHECK-NEXT: [[row1:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 1 +// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[row1]] 0 +// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]] +// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0 +// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1 +// CHECK-NEXT: [[value3:%[0-9]+]] = OpCompositeExtract %double [[row1]] 1 +// CHECK-NEXT: [[resultVec3:%[0-9]+]] = OpBitcast %v2uint [[value3]] +// CHECK-NEXT: [[low3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 0 +// CHECK-NEXT: [[high3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 1 +// CHECK-NEXT: [[lowRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[low2]] [[low3]] +// CHECK-NEXT: [[highRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[high2]] [[high3]] +// CHECK-NEXT: [[row2:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 2 +// CHECK-NEXT: [[value4:%[0-9]+]] = OpCompositeExtract %double [[row2]] 0 +// CHECK-NEXT: [[resultVec4:%[0-9]+]] = OpBitcast %v2uint [[value4]] +// CHECK-NEXT: [[low4:%[0-9]+]] = OpCompositeExtract %uint [[resultVec4]] 0 +// CHECK-NEXT: [[high4:%[0-9]+]] = OpCompositeExtract %uint [[resultVec4]] 1 +// CHECK-NEXT: [[value5:%[0-9]+]] = OpCompositeExtract %double [[row2]] 1 +// CHECK-NEXT: [[resultVec5:%[0-9]+]] = OpBitcast %v2uint [[value5]] +// CHECK-NEXT: [[low5:%[0-9]+]] = OpCompositeExtract %uint [[resultVec5]] 0 +// CHECK-NEXT: [[high5:%[0-9]+]] = OpCompositeExtract %uint [[resultVec5]] 1 +// CHECK-NEXT: [[lowRow2:%[0-9]+]] = OpCompositeConstruct %v2uint [[low4]] [[low5]] +// CHECK-NEXT: [[highRow2:%[0-9]+]] = OpCompositeConstruct %v2uint [[high4]] [[high5]] +// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_3 [[lowRow0]] [[lowRow1]] [[lowRow2]] +// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_3 [[highRow0]] [[highRow1]] [[highRow2]] +// CHECK-NEXT: OpStore %lowbits3x2 [[low]] +// CHECK-NEXT: OpStore %highbits3x2 [[high]] + asuint(value3x2, lowbits3x2, highbits3x2); } From 4668b5de0d96f731f85f41b3b4fe75f0f94a02cc Mon Sep 17 00:00:00 2001 From: Cassandra Beckley Date: Tue, 15 Oct 2024 15:00:41 -0700 Subject: [PATCH 4/5] Address comments --- tools/clang/lib/SPIRV/SpirvEmitter.cpp | 11 ++---- tools/clang/lib/SPIRV/SpirvEmitter.h | 25 ++++++++++++- .../test/CodeGenSPIRV/intrinsics.asuint.hlsl | 36 +++++++++++++++++++ 3 files changed, 62 insertions(+), 10 deletions(-) diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 3cbed75f22..44e20248a4 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -11593,12 +11593,7 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) { case 3: { // Handling Method 6. const Expr *arg1 = callExpr->getArg(1); - SourceLocation arg1loc = arg1->getExprLoc(); - SourceRange arg1range = arg1->getSourceRange(); - const Expr *arg2 = callExpr->getArg(2); - SourceLocation arg2loc = arg2->getExprLoc(); - SourceRange arg2range = arg2->getSourceRange(); SpirvInstruction *value = doExpr(arg0); SpirvInstruction *lowbits = doExpr(arg1); @@ -11621,10 +11616,8 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) { lowbitsResult, highbitsResult, loc, range); } - spvBuilder.createStore(lowbits, lowbitsResult, arg1loc, arg1range); - spvBuilder.createStore(highbits, highbitsResult, arg2loc, arg2range); - - // TODO: handle matrices + spvBuilder.createStore(lowbits, lowbitsResult, loc, range); + spvBuilder.createStore(highbits, highbitsResult, loc, range); return nullptr; } diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 32c4323d57..7b0bcc20f3 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -1312,16 +1312,39 @@ class SpirvEmitter : public ASTConsumer { /// the Vulkan memory model capability has been added to the module. bool UpgradeToVulkanMemoryModelIfNeeded(std::vector *module); - // TODO: docs + // Splits the `value`, which must be a 64-bit scalar, into two 32-bit wide + // uints, stored in `lowbits` and `highbits`. void splitDouble(SpirvInstruction *value, SpirvInstruction *&lowbits, SpirvInstruction *&highbits, SourceLocation loc, SourceRange range); + // Splits the value, which must be a vector with element type `elemType` and + // size `count`, into two composite values of size `count` and type + // `outputType`. The elements are split component-wise: the vector + // {0x0123456789abcdef, 0x0123456789abcdef} is split into `lowbits` + // {0x89abcdef, 0x89abcdef} and and `highbits` {0x01234567, 0x01234567}. void splitDoubleVector(QualType elemType, uint32_t count, QualType outputType, SpirvInstruction *value, SpirvInstruction *&lowbits, SpirvInstruction *&highbits, SourceLocation loc, SourceRange range); + // Splits the value, which must be a matrix with element type `elemType` and + // dimensions `rowCount` and `colCount`, into two composite values of + // dimensions `rowCount` and `colCount`. The elements are split + // component-wise: the matrix + // + // { 0x0123456789abcdef, 0x0123456789abcdef, + // 0x0123456789abcdef, 0x0123456789abcdef } + // + // is split into `lowbits` + // + // { 0x89abcdef, 0x89abcdef, + // 0x89abcdef, 0x89abcdef } + // + // and `highbits` + // + // { 0x012345678, 0x012345678, + // 0x012345678, 0x012345678 }. void splitDoubleMatrix(QualType elemType, uint32_t rowCount, uint32_t colCount, QualType outputType, SpirvInstruction *value, SpirvInstruction *&lowbits, diff --git a/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl b/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl index 3f84a4970f..75df8e620e 100644 --- a/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl +++ b/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl @@ -180,4 +180,40 @@ void main() { // CHECK-NEXT: OpStore %lowbits3x2 [[low]] // CHECK-NEXT: OpStore %highbits3x2 [[high]] asuint(value3x2, lowbits3x2, highbits3x2); + + double2x1 value2x1; + uint2x1 lowbits2x1; + uint2x1 highbits2x1; +// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %v2double %value2x1 +// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[value]] 0 +// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]] +// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0 +// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1 +// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[value]] 1 +// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]] +// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0 +// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1 +// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]] +// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]] +// CHECK-NEXT: OpStore %lowbits2x1 [[low]] +// CHECK-NEXT: OpStore %highbits2x1 [[high]] + asuint(value2x1, lowbits2x1, highbits2x1); + + double1x2 value1x2; + uint1x2 lowbits1x2; + uint1x2 highbits1x2; +// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %v2double %value1x2 +// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[value]] 0 +// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]] +// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0 +// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1 +// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[value]] 1 +// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]] +// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0 +// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1 +// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]] +// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]] +// CHECK-NEXT: OpStore %lowbits1x2 [[low]] +// CHECK-NEXT: OpStore %highbits1x2 [[high]] + asuint(value1x2, lowbits1x2, highbits1x2); } From b82dec11d59b5ab909dc2d8c0c68e9f26edd58e8 Mon Sep 17 00:00:00 2001 From: Cassandra Beckley Date: Tue, 29 Oct 2024 14:45:59 -0700 Subject: [PATCH 5/5] Reorder output parameters and add llvm_unreachable assert --- tools/clang/lib/SPIRV/SpirvEmitter.cpp | 33 ++++++++++++++------------ tools/clang/lib/SPIRV/SpirvEmitter.h | 18 +++++++------- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 44e20248a4..27b0aa6f69 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -11415,10 +11415,9 @@ SpirvEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr, return nullptr; } -void SpirvEmitter::splitDouble(SpirvInstruction *value, - SpirvInstruction *&lowbits, - SpirvInstruction *&highbits, SourceLocation loc, - SourceRange range) { +void SpirvEmitter::splitDouble(SpirvInstruction *value, SourceLocation loc, + SourceRange range, SpirvInstruction *&lowbits, + SpirvInstruction *&highbits) { const QualType uintType = astContext.UnsignedIntTy; const QualType uintVec2Type = astContext.getExtVectorType(uintType, 2); @@ -11433,9 +11432,9 @@ void SpirvEmitter::splitDouble(SpirvInstruction *value, void SpirvEmitter::splitDoubleVector(QualType elemType, uint32_t count, QualType outputType, SpirvInstruction *value, + SourceLocation loc, SourceRange range, SpirvInstruction *&lowbits, - SpirvInstruction *&highbits, - SourceLocation loc, SourceRange range) { + SpirvInstruction *&highbits) { llvm::SmallVector lowElems; llvm::SmallVector highElems; @@ -11444,7 +11443,7 @@ void SpirvEmitter::splitDoubleVector(QualType elemType, uint32_t count, spvBuilder.createCompositeExtract(elemType, value, {i}, loc, range); SpirvInstruction *lowbitsResult = nullptr; SpirvInstruction *highbitsResult = nullptr; - splitDouble(elem, lowbitsResult, highbitsResult, loc, range); + splitDouble(elem, loc, range, lowbitsResult, highbitsResult); lowElems.push_back(lowbitsResult); highElems.push_back(highbitsResult); } @@ -11458,9 +11457,9 @@ void SpirvEmitter::splitDoubleVector(QualType elemType, uint32_t count, void SpirvEmitter::splitDoubleMatrix(QualType elemType, uint32_t rowCount, uint32_t colCount, QualType outputType, SpirvInstruction *value, + SourceLocation loc, SourceRange range, SpirvInstruction *&lowbits, - SpirvInstruction *&highbits, - SourceLocation loc, SourceRange range) { + SpirvInstruction *&highbits) { llvm::SmallVector lowElems; llvm::SmallVector highElems; @@ -11476,8 +11475,8 @@ void SpirvEmitter::splitDoubleMatrix(QualType elemType, uint32_t rowCount, spvBuilder.createCompositeExtract(colType, value, {i}, loc, range); SpirvInstruction *lowbitsResult = nullptr; SpirvInstruction *highbitsResult = nullptr; - splitDoubleVector(elemType, colCount, outputColType, column, lowbitsResult, - highbitsResult, loc, range); + splitDoubleVector(elemType, colCount, outputColType, column, loc, range, + lowbitsResult, highbitsResult); lowElems.push_back(lowbitsResult); highElems.push_back(highbitsResult); } @@ -11607,13 +11606,17 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) { SpirvInstruction *highbitsResult = nullptr; if (isScalarType(argType)) { - splitDouble(value, lowbitsResult, highbitsResult, loc, range); + splitDouble(value, loc, range, lowbitsResult, highbitsResult); } else if (isVectorType(argType, &elemType, &rowCount)) { - splitDoubleVector(elemType, rowCount, arg1->getType(), value, - lowbitsResult, highbitsResult, loc, range); + splitDoubleVector(elemType, rowCount, arg1->getType(), value, loc, range, + lowbitsResult, highbitsResult); } else if (isMxNMatrix(argType, &elemType, &rowCount, &colCount)) { splitDoubleMatrix(elemType, rowCount, colCount, arg1->getType(), value, - lowbitsResult, highbitsResult, loc, range); + loc, range, lowbitsResult, highbitsResult); + } else { + llvm_unreachable( + "unexpected argument type is not scalar, vector, or matrix"); + return nullptr; } spvBuilder.createStore(lowbits, lowbitsResult, loc, range); diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 7b0bcc20f3..2044837e14 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -1314,9 +1314,9 @@ class SpirvEmitter : public ASTConsumer { // Splits the `value`, which must be a 64-bit scalar, into two 32-bit wide // uints, stored in `lowbits` and `highbits`. - void splitDouble(SpirvInstruction *value, SpirvInstruction *&lowbits, - SpirvInstruction *&highbits, SourceLocation loc, - SourceRange range); + void splitDouble(SpirvInstruction *value, SourceLocation loc, + SourceRange range, SpirvInstruction *&lowbits, + SpirvInstruction *&highbits); // Splits the value, which must be a vector with element type `elemType` and // size `count`, into two composite values of size `count` and type @@ -1324,9 +1324,9 @@ class SpirvEmitter : public ASTConsumer { // {0x0123456789abcdef, 0x0123456789abcdef} is split into `lowbits` // {0x89abcdef, 0x89abcdef} and and `highbits` {0x01234567, 0x01234567}. void splitDoubleVector(QualType elemType, uint32_t count, QualType outputType, - SpirvInstruction *value, SpirvInstruction *&lowbits, - SpirvInstruction *&highbits, SourceLocation loc, - SourceRange range); + SpirvInstruction *value, SourceLocation loc, + SourceRange range, SpirvInstruction *&lowbits, + SpirvInstruction *&highbits); // Splits the value, which must be a matrix with element type `elemType` and // dimensions `rowCount` and `colCount`, into two composite values of @@ -1347,9 +1347,9 @@ class SpirvEmitter : public ASTConsumer { // 0x012345678, 0x012345678 }. void splitDoubleMatrix(QualType elemType, uint32_t rowCount, uint32_t colCount, QualType outputType, - SpirvInstruction *value, SpirvInstruction *&lowbits, - SpirvInstruction *&highbits, SourceLocation loc, - SourceRange range); + SpirvInstruction *value, SourceLocation loc, + SourceRange range, SpirvInstruction *&lowbits, + SpirvInstruction *&highbits); public: /// \brief Wrapper method to create a fatal error message and report it