diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index ac455c547f..6a2892a332 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -11519,23 +11519,45 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) { } case 3: { // Handling Method 6. - auto *value = doExpr(arg0); - auto *lowbits = doExpr(callExpr->getArg(1)); - auto *highbits = doExpr(callExpr->getArg(2)); - const auto uintType = astContext.UnsignedIntTy; - const auto uintVec2Type = astContext.getExtVectorType(uintType, 2); - auto *vecResult = spvBuilder.createUnaryOp(spv::Op::OpBitcast, uintVec2Type, - value, loc, range); - spvBuilder.createStore( - lowbits, - spvBuilder.createCompositeExtract(uintType, vecResult, {0}, - arg0->getLocStart(), range), - loc, range); - spvBuilder.createStore( - highbits, - spvBuilder.createCompositeExtract(uintType, vecResult, {1}, - arg0->getLocStart(), range), - loc, range); + const Expr *arg1 = callExpr->getArg(1); + SourceLocation arg1loc = arg1->getExprLoc(); + SourceRange arg1range = arg1->getSourceRange(); + + const Expr *arg2 = callExpr->getArg(2); + SourceLocation arg2loc = arg2->getExprLoc(); + SourceRange arg2range = arg2->getSourceRange(); + + SpirvInstruction *value = doExpr(arg0->IgnoreParenLValueCasts()); + if (!value->isLValue()) { + value = turnIntoLValue(argType, value, arg0->getExprLoc()); + } + SpirvInstruction *lowbits = doExpr(arg1); + SpirvInstruction *highbits = doExpr(arg2); + + QualType outType = arg1->getType(); + QualType arrayType = astContext.getConstantArrayType( + outType, llvm::APInt(32, 2), clang::ArrayType::Normal, 0); + + const SpirvType *arrayPtrType = + spvContext.getPointerType(arrayType, spv::StorageClass::Function); + + auto *arrayResultPtr = + spvBuilder.createUnaryOp(spv::Op::OpBitcast, arrayPtrType, value, loc); + + SpirvInstruction *lowbitsResultPtr = spvBuilder.createAccessChain( + outType, arrayResultPtr, {getValueZero(astContext.UnsignedIntTy)}, + arg1loc); + SpirvInstruction *lowbitsResult = + spvBuilder.createLoad(outType, lowbitsResultPtr, arg1loc, arg1range); + spvBuilder.createStore(lowbits, lowbitsResult, arg1loc, arg1range); + + SpirvInstruction *highbitsResultPtr = spvBuilder.createAccessChain( + outType, arrayResultPtr, {getValueOne(astContext.UnsignedIntTy)}, + arg2loc); + SpirvInstruction *highbitsResult = + spvBuilder.createLoad(outType, highbitsResultPtr, arg2loc, arg2range); + spvBuilder.createStore(highbits, highbitsResult, arg2loc, arg2range); + return nullptr; } default: diff --git a/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl b/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl index 8b6998fbe0..73a767ce5e 100644 --- a/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl +++ b/tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl @@ -76,11 +76,34 @@ void main() { double value; uint lowbits; uint highbits; -// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %double %value -// CHECK-NEXT: [[resultVec:%[0-9]+]] = OpBitcast %v2uint [[value]] -// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 0 +// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_uint_uint_2 %value +// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_0 +// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %uint [[chain0]] // CHECK-NEXT: OpStore %lowbits [[resultVec0]] -// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 1 +// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_1 +// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %uint [[chain1]] // CHECK-NEXT: OpStore %highbits [[resultVec1]] asuint(value, lowbits, highbits); + + double4 value4; + uint4 lowbits4; + uint4 highbits4; +// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_v4uint_uint_2 %value4 +// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_v4uint [[resultArr]] %uint_0 +// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %v4uint [[chain0]] +// CHECK-NEXT: OpStore %lowbits4 [[resultVec0]] +// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_v4uint [[resultArr]] %uint_1 +// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %v4uint [[chain1]] +// CHECK-NEXT: OpStore %highbits4 [[resultVec1]] + asuint(value4, lowbits4, highbits4); + +// CHECK-NEXT: OpStore %temp_var_double %double_1234 +// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_uint_uint_2 %temp_var_double +// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_0 +// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %uint [[chain0]] +// CHECK-NEXT: OpStore %lowbits [[resultVec0]] +// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_1 +// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %uint [[chain1]] +// CHECK-NEXT: OpStore %highbits [[resultVec1]] + asuint(1234.0, lowbits, highbits); }