Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Handle vectors passed to asuint #6953

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 107 additions & 17 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
@@ -11415,6 +11415,79 @@ SpirvEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
return nullptr;
}

void SpirvEmitter::splitDouble(SpirvInstruction *value,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range) {
cassiebeckley marked this conversation as resolved.
Show resolved Hide resolved
const QualType uintType = astContext.UnsignedIntTy;
const QualType uintVec2Type = astContext.getExtVectorType(uintType, 2);

SpirvInstruction *uints = spvBuilder.createUnaryOp(
spv::Op::OpBitcast, uintVec2Type, value, loc, range);

lowbits = spvBuilder.createCompositeExtract(uintType, uints, {0}, loc, range);
highbits =
spvBuilder.createCompositeExtract(uintType, uints, {1}, loc, range);
}

void SpirvEmitter::splitDoubleVector(QualType elemType, uint32_t count,
QualType outputType,
SpirvInstruction *value,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits,
SourceLocation loc, SourceRange range) {
llvm::SmallVector<SpirvInstruction *, 4> lowElems;
llvm::SmallVector<SpirvInstruction *, 4> highElems;

for (uint32_t i = 0; i < count; ++i) {
SpirvInstruction *elem =
spvBuilder.createCompositeExtract(elemType, value, {i}, loc, range);
SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;
splitDouble(elem, lowbitsResult, highbitsResult, loc, range);
lowElems.push_back(lowbitsResult);
highElems.push_back(highbitsResult);
}

lowbits =
spvBuilder.createCompositeConstruct(outputType, lowElems, loc, range);
highbits =
spvBuilder.createCompositeConstruct(outputType, highElems, loc, range);
}

void SpirvEmitter::splitDoubleMatrix(QualType elemType, uint32_t rowCount,
uint32_t colCount, QualType outputType,
SpirvInstruction *value,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits,
SourceLocation loc, SourceRange range) {

llvm::SmallVector<SpirvInstruction *, 4> lowElems;
llvm::SmallVector<SpirvInstruction *, 4> highElems;

QualType colType = astContext.getExtVectorType(elemType, colCount);

const QualType uintType = astContext.UnsignedIntTy;
const QualType outputColType =
astContext.getExtVectorType(uintType, colCount);

for (uint32_t i = 0; i < rowCount; ++i) {
SpirvInstruction *column =
spvBuilder.createCompositeExtract(colType, value, {i}, loc, range);
SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;
splitDoubleVector(elemType, colCount, outputColType, column, lowbitsResult,
highbitsResult, loc, range);
lowElems.push_back(lowbitsResult);
highElems.push_back(highbitsResult);
}

lowbits =
spvBuilder.createCompositeConstruct(outputType, lowElems, loc, range);
highbits =
spvBuilder.createCompositeConstruct(outputType, highElems, loc, range);
}

SpirvInstruction *
SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
// This function handles the following intrinsics:
@@ -11519,23 +11592,40 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
}
case 3: {
// Handling Method 6.
auto *value = doExpr(arg0);
auto *lowbits = doExpr(callExpr->getArg(1));
auto *highbits = doExpr(callExpr->getArg(2));
const auto uintType = astContext.UnsignedIntTy;
const auto uintVec2Type = astContext.getExtVectorType(uintType, 2);
auto *vecResult = spvBuilder.createUnaryOp(spv::Op::OpBitcast, uintVec2Type,
value, loc, range);
spvBuilder.createStore(
lowbits,
spvBuilder.createCompositeExtract(uintType, vecResult, {0},
arg0->getLocStart(), range),
loc, range);
spvBuilder.createStore(
highbits,
spvBuilder.createCompositeExtract(uintType, vecResult, {1},
arg0->getLocStart(), range),
loc, range);
const Expr *arg1 = callExpr->getArg(1);
SourceLocation arg1loc = arg1->getExprLoc();
SourceRange arg1range = arg1->getSourceRange();

const Expr *arg2 = callExpr->getArg(2);
SourceLocation arg2loc = arg2->getExprLoc();
SourceRange arg2range = arg2->getSourceRange();

SpirvInstruction *value = doExpr(arg0);
SpirvInstruction *lowbits = doExpr(arg1);
SpirvInstruction *highbits = doExpr(arg2);

QualType elemType = QualType();
uint32_t rowCount = 0;
uint32_t colCount = 0;

SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;

if (isScalarType(argType)) {
splitDouble(value, lowbitsResult, highbitsResult, loc, range);
} else if (isVectorType(argType, &elemType, &rowCount)) {
splitDoubleVector(elemType, rowCount, arg1->getType(), value,
lowbitsResult, highbitsResult, loc, range);
} else if (isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
splitDoubleMatrix(elemType, rowCount, colCount, arg1->getType(), value,
lowbitsResult, highbitsResult, loc, range);
}
cassiebeckley marked this conversation as resolved.
Show resolved Hide resolved

spvBuilder.createStore(lowbits, lowbitsResult, arg1loc, arg1range);
spvBuilder.createStore(highbits, highbitsResult, arg2loc, arg2range);
cassiebeckley marked this conversation as resolved.
Show resolved Hide resolved

// TODO: handle matrices
cassiebeckley marked this conversation as resolved.
Show resolved Hide resolved

return nullptr;
}
default:
16 changes: 16 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
@@ -1312,6 +1312,22 @@ class SpirvEmitter : public ASTConsumer {
/// the Vulkan memory model capability has been added to the module.
bool UpgradeToVulkanMemoryModelIfNeeded(std::vector<uint32_t> *module);

// TODO: docs
cassiebeckley marked this conversation as resolved.
Show resolved Hide resolved
void splitDouble(SpirvInstruction *value, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range);

void splitDoubleVector(QualType elemType, uint32_t count, QualType outputType,
SpirvInstruction *value, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range);

void splitDoubleMatrix(QualType elemType, uint32_t rowCount,
uint32_t colCount, QualType outputType,
SpirvInstruction *value, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range);

public:
/// \brief Wrapper method to create a fatal error message and report it
/// in the diagnostic engine associated with this consumer.
99 changes: 98 additions & 1 deletion tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl
Original file line number Diff line number Diff line change
@@ -79,8 +79,105 @@ void main() {
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %double %value
// CHECK-NEXT: [[resultVec:%[0-9]+]] = OpBitcast %v2uint [[value]]
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 0
// CHECK-NEXT: OpStore %lowbits [[resultVec0]]
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 1
// CHECK-NEXT: OpStore %lowbits [[resultVec0]]
// CHECK-NEXT: OpStore %highbits [[resultVec1]]
asuint(value, lowbits, highbits);

double3 value3;
uint3 lowbits3;
uint3 highbits3;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %v3double %value3
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[value]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[value]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[value]] 2
// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]]
// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0
// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %v3uint [[low0]] [[low1]] [[low2]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %v3uint [[high0]] [[high1]] [[high2]]
// CHECK-NEXT: OpStore %lowbits3 [[low]]
// CHECK-NEXT: OpStore %highbits3 [[high]]
asuint(value3, lowbits3, highbits3);

double2x2 value2x2;
uint2x2 lowbits2x2;
uint2x2 highbits2x2;
cassiebeckley marked this conversation as resolved.
Show resolved Hide resolved
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %mat2v2double %value2x2
// CHECK-NEXT: [[row0:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 0
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[row0]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[row0]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[lowRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]]
// CHECK-NEXT: [[highRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]]
// CHECK-NEXT: [[row1:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 1
// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[row1]] 0
// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]]
// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0
// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1
// CHECK-NEXT: [[value3:%[0-9]+]] = OpCompositeExtract %double [[row1]] 1
// CHECK-NEXT: [[resultVec3:%[0-9]+]] = OpBitcast %v2uint [[value3]]
// CHECK-NEXT: [[low3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 0
// CHECK-NEXT: [[high3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 1
// CHECK-NEXT: [[lowRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[low2]] [[low3]]
// CHECK-NEXT: [[highRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[high2]] [[high3]]
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_2 [[lowRow0]] [[lowRow1]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_2 [[highRow0]] [[highRow1]]
// CHECK-NEXT: OpStore %lowbits2x2 [[low]]
// CHECK-NEXT: OpStore %highbits2x2 [[high]]
asuint(value2x2, lowbits2x2, highbits2x2);

double3x2 value3x2;
uint3x2 lowbits3x2;
uint3x2 highbits3x2;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %mat3v2double %value3x2
// CHECK-NEXT: [[row0:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 0
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[row0]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[row0]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[lowRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]]
// CHECK-NEXT: [[highRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]]
// CHECK-NEXT: [[row1:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 1
// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[row1]] 0
// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]]
// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0
// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1
// CHECK-NEXT: [[value3:%[0-9]+]] = OpCompositeExtract %double [[row1]] 1
// CHECK-NEXT: [[resultVec3:%[0-9]+]] = OpBitcast %v2uint [[value3]]
// CHECK-NEXT: [[low3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 0
// CHECK-NEXT: [[high3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 1
// CHECK-NEXT: [[lowRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[low2]] [[low3]]
// CHECK-NEXT: [[highRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[high2]] [[high3]]
// CHECK-NEXT: [[row2:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 2
// CHECK-NEXT: [[value4:%[0-9]+]] = OpCompositeExtract %double [[row2]] 0
// CHECK-NEXT: [[resultVec4:%[0-9]+]] = OpBitcast %v2uint [[value4]]
// CHECK-NEXT: [[low4:%[0-9]+]] = OpCompositeExtract %uint [[resultVec4]] 0
// CHECK-NEXT: [[high4:%[0-9]+]] = OpCompositeExtract %uint [[resultVec4]] 1
// CHECK-NEXT: [[value5:%[0-9]+]] = OpCompositeExtract %double [[row2]] 1
// CHECK-NEXT: [[resultVec5:%[0-9]+]] = OpBitcast %v2uint [[value5]]
// CHECK-NEXT: [[low5:%[0-9]+]] = OpCompositeExtract %uint [[resultVec5]] 0
// CHECK-NEXT: [[high5:%[0-9]+]] = OpCompositeExtract %uint [[resultVec5]] 1
// CHECK-NEXT: [[lowRow2:%[0-9]+]] = OpCompositeConstruct %v2uint [[low4]] [[low5]]
// CHECK-NEXT: [[highRow2:%[0-9]+]] = OpCompositeConstruct %v2uint [[high4]] [[high5]]
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_3 [[lowRow0]] [[lowRow1]] [[lowRow2]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_3 [[highRow0]] [[highRow1]] [[highRow2]]
// CHECK-NEXT: OpStore %lowbits3x2 [[low]]
// CHECK-NEXT: OpStore %highbits3x2 [[high]]
asuint(value3x2, lowbits3x2, highbits3x2);
}
Loading