Skip to content

Commit

Permalink
[SPIR-V] Handle vectors passed to asuint
Browse files Browse the repository at this point in the history
`asuint` should be able to take vectors in addition to scalar values.
Previously, it would be lowered as a bitcast from the input value to a
vector of uints with a width of 2, which is not always large enough if
the input value is larger than a scalar value. In order to handle, for
example, an input value that is a `double4`, we instead bitcast the
pointer to the input value to a pointer to an array of `uint4` with a
width of 2.
  • Loading branch information
cassiebeckley committed Oct 9, 2024
1 parent b26fd80 commit 5961b83
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 21 deletions.
56 changes: 39 additions & 17 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11519,23 +11519,45 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
}
case 3: {
// Handling Method 6.
auto *value = doExpr(arg0);
auto *lowbits = doExpr(callExpr->getArg(1));
auto *highbits = doExpr(callExpr->getArg(2));
const auto uintType = astContext.UnsignedIntTy;
const auto uintVec2Type = astContext.getExtVectorType(uintType, 2);
auto *vecResult = spvBuilder.createUnaryOp(spv::Op::OpBitcast, uintVec2Type,
value, loc, range);
spvBuilder.createStore(
lowbits,
spvBuilder.createCompositeExtract(uintType, vecResult, {0},
arg0->getLocStart(), range),
loc, range);
spvBuilder.createStore(
highbits,
spvBuilder.createCompositeExtract(uintType, vecResult, {1},
arg0->getLocStart(), range),
loc, range);
const Expr *arg1 = callExpr->getArg(1);
SourceLocation arg1loc = arg1->getExprLoc();
SourceRange arg1range = arg1->getSourceRange();

const Expr *arg2 = callExpr->getArg(2);
SourceLocation arg2loc = arg2->getExprLoc();
SourceRange arg2range = arg2->getSourceRange();

SpirvInstruction *value = doExpr(arg0->IgnoreParenLValueCasts());
if (!value->isLValue()) {
value = turnIntoLValue(argType, value, arg0->getExprLoc());
}
SpirvInstruction *lowbits = doExpr(arg1);
SpirvInstruction *highbits = doExpr(arg2);

QualType outType = arg1->getType();
QualType arrayType = astContext.getConstantArrayType(
outType, llvm::APInt(32, 2), clang::ArrayType::Normal, 0);

const SpirvType *arrayPtrType =
spvContext.getPointerType(arrayType, spv::StorageClass::Function);

auto *arrayResultPtr =
spvBuilder.createUnaryOp(spv::Op::OpBitcast, arrayPtrType, value, loc);

SpirvInstruction *lowbitsResultPtr = spvBuilder.createAccessChain(
outType, arrayResultPtr, {getValueZero(astContext.UnsignedIntTy)},
arg1loc);
SpirvInstruction *lowbitsResult =
spvBuilder.createLoad(outType, lowbitsResultPtr, arg1loc, arg1range);
spvBuilder.createStore(lowbits, lowbitsResult, arg1loc, arg1range);

SpirvInstruction *highbitsResultPtr = spvBuilder.createAccessChain(
outType, arrayResultPtr, {getValueOne(astContext.UnsignedIntTy)},
arg2loc);
SpirvInstruction *highbitsResult =
spvBuilder.createLoad(outType, highbitsResultPtr, arg2loc, arg2range);
spvBuilder.createStore(highbits, highbitsResult, arg2loc, arg2range);

return nullptr;
}
default:
Expand Down
31 changes: 27 additions & 4 deletions tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,34 @@ void main() {
double value;
uint lowbits;
uint highbits;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %double %value
// CHECK-NEXT: [[resultVec:%[0-9]+]] = OpBitcast %v2uint [[value]]
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 0
// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_uint_uint_2 %value
// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %uint [[chain0]]
// CHECK-NEXT: OpStore %lowbits [[resultVec0]]
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 1
// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %uint [[chain1]]
// CHECK-NEXT: OpStore %highbits [[resultVec1]]
asuint(value, lowbits, highbits);

double4 value4;
uint4 lowbits4;
uint4 highbits4;
// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_v4uint_uint_2 %value4
// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_v4uint [[resultArr]] %uint_0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %v4uint [[chain0]]
// CHECK-NEXT: OpStore %lowbits4 [[resultVec0]]
// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_v4uint [[resultArr]] %uint_1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %v4uint [[chain1]]
// CHECK-NEXT: OpStore %highbits4 [[resultVec1]]
asuint(value4, lowbits4, highbits4);

// CHECK-NEXT: OpStore %temp_var_double %double_1234
// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_uint_uint_2 %temp_var_double
// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %uint [[chain0]]
// CHECK-NEXT: OpStore %lowbits [[resultVec0]]
// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %uint [[chain1]]
// CHECK-NEXT: OpStore %highbits [[resultVec1]]
asuint(1234.0, lowbits, highbits);
}

0 comments on commit 5961b83

Please sign in to comment.