Skip to content

Commit

Permalink
Do a component-wise bitcast instead of reinterpreting the pointer
Browse files Browse the repository at this point in the history
  • Loading branch information
cassiebeckley committed Oct 14, 2024
1 parent 6813054 commit 7f4a6cd
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 47 deletions.
113 changes: 91 additions & 22 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11415,6 +11415,80 @@ SpirvEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
return nullptr;
}

void SpirvEmitter::splitDouble(SpirvInstruction *value,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range) {
const QualType uintType = astContext.UnsignedIntTy;
const QualType uintVec2Type = astContext.getExtVectorType(uintType, 2);

SpirvInstruction *uints = spvBuilder.createUnaryOp(
spv::Op::OpBitcast, uintVec2Type, value, loc, range);

lowbits = spvBuilder.createCompositeExtract(uintType, uints, {0}, loc, range);
highbits =
spvBuilder.createCompositeExtract(uintType, uints, {1}, loc, range);
}

void SpirvEmitter::splitDoubleVector(QualType elemType, uint32_t count,
QualType outputType,
SpirvInstruction *value,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits,
SourceLocation loc, SourceRange range) {
llvm::SmallVector<SpirvInstruction *, 4> lowElems;
llvm::SmallVector<SpirvInstruction *, 4> highElems;

for (uint32_t i = 0; i < count; ++i) {
SpirvInstruction *elem =
spvBuilder.createCompositeExtract(elemType, value, {i}, loc, range);
SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;
splitDouble(elem, lowbitsResult, highbitsResult, loc, range);
lowElems.push_back(lowbitsResult);
highElems.push_back(highbitsResult);
}

lowbits =
spvBuilder.createCompositeConstruct(outputType, lowElems, loc, range);
highbits =
spvBuilder.createCompositeConstruct(outputType, highElems, loc, range);
}

void SpirvEmitter::splitDoubleMatrix(QualType elemType, uint32_t rowCount,
uint32_t colCount, QualType outputType,
SpirvInstruction *value,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits,
SourceLocation loc, SourceRange range) {

llvm::SmallVector<SpirvInstruction *, 4> lowElems;
llvm::SmallVector<SpirvInstruction *, 4> highElems;

QualType colType = astContext.getExtVectorType(elemType, rowCount);

const QualType uintType = astContext.UnsignedIntTy;
const QualType outputColType = astContext.getExtVectorType(uintType, rowCount);

// TODO: check if row/columns need to be swappable

for (uint32_t i = 0; i < colCount; ++i) {
SpirvInstruction *column =
spvBuilder.createCompositeExtract(colType, value, {i}, loc, range);
SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;
splitDoubleVector(elemType, colCount, outputColType, column, lowbitsResult,
highbitsResult, loc, range);
lowElems.push_back(lowbitsResult);
highElems.push_back(highbitsResult);
}

lowbits =
spvBuilder.createCompositeConstruct(outputType, lowElems, loc, range);
highbits =
spvBuilder.createCompositeConstruct(outputType, highElems, loc, range);
}

SpirvInstruction *
SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
// This function handles the following intrinsics:
Expand Down Expand Up @@ -11527,37 +11601,32 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
SourceLocation arg2loc = arg2->getExprLoc();
SourceRange arg2range = arg2->getSourceRange();

SpirvInstruction *value = doExpr(arg0->IgnoreParenLValueCasts());
if (!value->isLValue()) {
value = turnIntoLValue(argType, value, arg0->getExprLoc());
}
SpirvInstruction *value = doExpr(arg0);
SpirvInstruction *lowbits = doExpr(arg1);
SpirvInstruction *highbits = doExpr(arg2);

QualType outType = arg1->getType();
QualType arrayType = astContext.getConstantArrayType(
outType, llvm::APInt(32, 2), clang::ArrayType::Normal, 0);
QualType elemType = QualType();
uint32_t rowCount = 0;
uint32_t colCount = 0;

const SpirvType *arrayPtrType =
spvContext.getPointerType(arrayType, spv::StorageClass::Function);
SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;

auto *arrayResultPtr =
spvBuilder.createUnaryOp(spv::Op::OpBitcast, arrayPtrType, value, loc);
if (isScalarType(argType)) {
splitDouble(value, lowbitsResult, highbitsResult, loc, range);
} else if (isVectorType(argType, &elemType, &rowCount)) {
splitDoubleVector(elemType, rowCount, arg1->getType(), value,
lowbitsResult, highbitsResult, loc, range);
} else if (isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
splitDoubleMatrix(elemType, rowCount, colCount, arg1->getType(), value,
lowbitsResult, highbitsResult, loc, range);
}

SpirvInstruction *lowbitsResultPtr = spvBuilder.createAccessChain(
outType, arrayResultPtr, {getValueZero(astContext.UnsignedIntTy)},
arg1loc);
SpirvInstruction *lowbitsResult =
spvBuilder.createLoad(outType, lowbitsResultPtr, arg1loc, arg1range);
spvBuilder.createStore(lowbits, lowbitsResult, arg1loc, arg1range);

SpirvInstruction *highbitsResultPtr = spvBuilder.createAccessChain(
outType, arrayResultPtr, {getValueOne(astContext.UnsignedIntTy)},
arg2loc);
SpirvInstruction *highbitsResult =
spvBuilder.createLoad(outType, highbitsResultPtr, arg2loc, arg2range);
spvBuilder.createStore(highbits, highbitsResult, arg2loc, arg2range);

// TODO: handle matrices

return nullptr;
}
default:
Expand Down
16 changes: 16 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,22 @@ class SpirvEmitter : public ASTConsumer {
/// the Vulkan memory model capability has been added to the module.
bool UpgradeToVulkanMemoryModelIfNeeded(std::vector<uint32_t> *module);

// TODO: docs
void splitDouble(SpirvInstruction *value, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range);

void splitDoubleVector(QualType elemType, uint32_t count, QualType outputType,
SpirvInstruction *value, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range);

void splitDoubleMatrix(QualType elemType, uint32_t rowCount,
uint32_t colCount, QualType outputType,
SpirvInstruction *value, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range);

public:
/// \brief Wrapper method to create a fatal error message and report it
/// in the diagnostic engine associated with this consumer.
Expand Down
81 changes: 56 additions & 25 deletions tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -76,34 +76,65 @@ void main() {
double value;
uint lowbits;
uint highbits;
// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_uint_uint_2 %value
// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %uint [[chain0]]
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %double %value
// CHECK-NEXT: [[resultVec:%[0-9]+]] = OpBitcast %v2uint [[value]]
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 0
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 1
// CHECK-NEXT: OpStore %lowbits [[resultVec0]]
// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %uint [[chain1]]
// CHECK-NEXT: OpStore %highbits [[resultVec1]]
asuint(value, lowbits, highbits);

double4 value4;
uint4 lowbits4;
uint4 highbits4;
// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_v4uint_uint_2 %value4
// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_v4uint [[resultArr]] %uint_0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %v4uint [[chain0]]
// CHECK-NEXT: OpStore %lowbits4 [[resultVec0]]
// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_v4uint [[resultArr]] %uint_1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %v4uint [[chain1]]
// CHECK-NEXT: OpStore %highbits4 [[resultVec1]]
asuint(value4, lowbits4, highbits4);
double3 value3;
uint3 lowbits3;
uint3 highbits3;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %v3double %value3
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[value]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[value]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[value]] 2
// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]]
// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0
// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %v3uint [[low0]] [[low1]] [[low2]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %v3uint [[high0]] [[high1]] [[high2]]
// CHECK-NEXT: OpStore %lowbits3 [[low]]
// CHECK-NEXT: OpStore %highbits3 [[high]]
asuint(value3, lowbits3, highbits3);

// CHECK-NEXT: OpStore %temp_var_double %double_1234
// CHECK-NEXT: [[resultArr:%[0-9]+]] = OpBitcast %_ptr_Function__arr_uint_uint_2 %temp_var_double
// CHECK-NEXT: [[chain0:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpLoad %uint [[chain0]]
// CHECK-NEXT: OpStore %lowbits [[resultVec0]]
// CHECK-NEXT: [[chain1:%[0-9]+]] = OpAccessChain %_ptr_Function_uint [[resultArr]] %uint_1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpLoad %uint [[chain1]]
// CHECK-NEXT: OpStore %highbits [[resultVec1]]
asuint(1234.0, lowbits, highbits);
double2x2 value2x2;
uint2x2 lowbits2x2;
uint2x2 highbits2x2;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %mat2v2double %value2x2
// CHECK-NEXT: [[row0:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 0
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[row0]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[row0]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[lowRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]]
// CHECK-NEXT: [[highRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]]
// CHECK-NEXT: [[row1:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 1
// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[row1]] 0
// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]]
// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0
// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1
// CHECK-NEXT: [[value3:%[0-9]+]] = OpCompositeExtract %double [[row1]] 1
// CHECK-NEXT: [[resultVec3:%[0-9]+]] = OpBitcast %v2uint [[value3]]
// CHECK-NEXT: [[low3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 0
// CHECK-NEXT: [[high3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 1
// CHECK-NEXT: [[lowRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[low2]] [[low3]]
// CHECK-NEXT: [[highRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[high2]] [[high3]]
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_2 [[lowRow0]] [[lowRow1]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_2 [[highRow0]] [[highRow1]]
// CHECK-NEXT: OpStore %lowbits2x2 [[low]]
// CHECK-NEXT: OpStore %highbits2x2 [[high]]
asuint(value2x2, lowbits2x2, highbits2x2);
}

0 comments on commit 7f4a6cd

Please sign in to comment.