Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Handle vectors passed to asuint #6953

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 100 additions & 17 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11415,6 +11415,79 @@ SpirvEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
return nullptr;
}

void SpirvEmitter::splitDouble(SpirvInstruction *value,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range) {
cassiebeckley marked this conversation as resolved.
Show resolved Hide resolved
const QualType uintType = astContext.UnsignedIntTy;
const QualType uintVec2Type = astContext.getExtVectorType(uintType, 2);

SpirvInstruction *uints = spvBuilder.createUnaryOp(
spv::Op::OpBitcast, uintVec2Type, value, loc, range);

lowbits = spvBuilder.createCompositeExtract(uintType, uints, {0}, loc, range);
highbits =
spvBuilder.createCompositeExtract(uintType, uints, {1}, loc, range);
}

void SpirvEmitter::splitDoubleVector(QualType elemType, uint32_t count,
QualType outputType,
SpirvInstruction *value,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits,
SourceLocation loc, SourceRange range) {
llvm::SmallVector<SpirvInstruction *, 4> lowElems;
llvm::SmallVector<SpirvInstruction *, 4> highElems;

for (uint32_t i = 0; i < count; ++i) {
SpirvInstruction *elem =
spvBuilder.createCompositeExtract(elemType, value, {i}, loc, range);
SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;
splitDouble(elem, lowbitsResult, highbitsResult, loc, range);
lowElems.push_back(lowbitsResult);
highElems.push_back(highbitsResult);
}

lowbits =
spvBuilder.createCompositeConstruct(outputType, lowElems, loc, range);
highbits =
spvBuilder.createCompositeConstruct(outputType, highElems, loc, range);
}

void SpirvEmitter::splitDoubleMatrix(QualType elemType, uint32_t rowCount,
uint32_t colCount, QualType outputType,
SpirvInstruction *value,
SpirvInstruction *&lowbits,
SpirvInstruction *&highbits,
SourceLocation loc, SourceRange range) {

llvm::SmallVector<SpirvInstruction *, 4> lowElems;
llvm::SmallVector<SpirvInstruction *, 4> highElems;

QualType colType = astContext.getExtVectorType(elemType, colCount);

const QualType uintType = astContext.UnsignedIntTy;
const QualType outputColType =
astContext.getExtVectorType(uintType, colCount);

for (uint32_t i = 0; i < rowCount; ++i) {
SpirvInstruction *column =
spvBuilder.createCompositeExtract(colType, value, {i}, loc, range);
SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;
splitDoubleVector(elemType, colCount, outputColType, column, lowbitsResult,
highbitsResult, loc, range);
lowElems.push_back(lowbitsResult);
highElems.push_back(highbitsResult);
}

lowbits =
spvBuilder.createCompositeConstruct(outputType, lowElems, loc, range);
highbits =
spvBuilder.createCompositeConstruct(outputType, highElems, loc, range);
}

SpirvInstruction *
SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
// This function handles the following intrinsics:
Expand Down Expand Up @@ -11519,23 +11592,33 @@ SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
}
case 3: {
// Handling Method 6.
auto *value = doExpr(arg0);
auto *lowbits = doExpr(callExpr->getArg(1));
auto *highbits = doExpr(callExpr->getArg(2));
const auto uintType = astContext.UnsignedIntTy;
const auto uintVec2Type = astContext.getExtVectorType(uintType, 2);
auto *vecResult = spvBuilder.createUnaryOp(spv::Op::OpBitcast, uintVec2Type,
value, loc, range);
spvBuilder.createStore(
lowbits,
spvBuilder.createCompositeExtract(uintType, vecResult, {0},
arg0->getLocStart(), range),
loc, range);
spvBuilder.createStore(
highbits,
spvBuilder.createCompositeExtract(uintType, vecResult, {1},
arg0->getLocStart(), range),
loc, range);
const Expr *arg1 = callExpr->getArg(1);
const Expr *arg2 = callExpr->getArg(2);

SpirvInstruction *value = doExpr(arg0);
SpirvInstruction *lowbits = doExpr(arg1);
SpirvInstruction *highbits = doExpr(arg2);

QualType elemType = QualType();
uint32_t rowCount = 0;
uint32_t colCount = 0;

SpirvInstruction *lowbitsResult = nullptr;
SpirvInstruction *highbitsResult = nullptr;

if (isScalarType(argType)) {
splitDouble(value, lowbitsResult, highbitsResult, loc, range);
} else if (isVectorType(argType, &elemType, &rowCount)) {
splitDoubleVector(elemType, rowCount, arg1->getType(), value,
lowbitsResult, highbitsResult, loc, range);
} else if (isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
splitDoubleMatrix(elemType, rowCount, colCount, arg1->getType(), value,
lowbitsResult, highbitsResult, loc, range);
}
cassiebeckley marked this conversation as resolved.
Show resolved Hide resolved

spvBuilder.createStore(lowbits, lowbitsResult, loc, range);
spvBuilder.createStore(highbits, highbitsResult, loc, range);

return nullptr;
}
default:
Expand Down
39 changes: 39 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,45 @@ class SpirvEmitter : public ASTConsumer {
/// the Vulkan memory model capability has been added to the module.
bool UpgradeToVulkanMemoryModelIfNeeded(std::vector<uint32_t> *module);

// Splits the `value`, which must be a 64-bit scalar, into two 32-bit wide
// uints, stored in `lowbits` and `highbits`.
void splitDouble(SpirvInstruction *value, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range);

// Splits the value, which must be a vector with element type `elemType` and
// size `count`, into two composite values of size `count` and type
// `outputType`. The elements are split component-wise: the vector
// {0x0123456789abcdef, 0x0123456789abcdef} is split into `lowbits`
// {0x89abcdef, 0x89abcdef} and and `highbits` {0x01234567, 0x01234567}.
void splitDoubleVector(QualType elemType, uint32_t count, QualType outputType,
SpirvInstruction *value, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range);

// Splits the value, which must be a matrix with element type `elemType` and
// dimensions `rowCount` and `colCount`, into two composite values of
// dimensions `rowCount` and `colCount`. The elements are split
// component-wise: the matrix
//
// { 0x0123456789abcdef, 0x0123456789abcdef,
// 0x0123456789abcdef, 0x0123456789abcdef }
//
// is split into `lowbits`
//
// { 0x89abcdef, 0x89abcdef,
// 0x89abcdef, 0x89abcdef }
//
// and `highbits`
//
// { 0x012345678, 0x012345678,
// 0x012345678, 0x012345678 }.
void splitDoubleMatrix(QualType elemType, uint32_t rowCount,
uint32_t colCount, QualType outputType,
SpirvInstruction *value, SpirvInstruction *&lowbits,
SpirvInstruction *&highbits, SourceLocation loc,
SourceRange range);

public:
/// \brief Wrapper method to create a fatal error message and report it
/// in the diagnostic engine associated with this consumer.
Expand Down
135 changes: 134 additions & 1 deletion tools/clang/test/CodeGenSPIRV/intrinsics.asuint.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,141 @@ void main() {
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %double %value
// CHECK-NEXT: [[resultVec:%[0-9]+]] = OpBitcast %v2uint [[value]]
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 0
// CHECK-NEXT: OpStore %lowbits [[resultVec0]]
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec]] 1
// CHECK-NEXT: OpStore %lowbits [[resultVec0]]
// CHECK-NEXT: OpStore %highbits [[resultVec1]]
asuint(value, lowbits, highbits);

double3 value3;
uint3 lowbits3;
uint3 highbits3;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %v3double %value3
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[value]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[value]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[value]] 2
// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]]
// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0
// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %v3uint [[low0]] [[low1]] [[low2]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %v3uint [[high0]] [[high1]] [[high2]]
// CHECK-NEXT: OpStore %lowbits3 [[low]]
// CHECK-NEXT: OpStore %highbits3 [[high]]
asuint(value3, lowbits3, highbits3);

double2x2 value2x2;
uint2x2 lowbits2x2;
uint2x2 highbits2x2;
cassiebeckley marked this conversation as resolved.
Show resolved Hide resolved
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %mat2v2double %value2x2
// CHECK-NEXT: [[row0:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 0
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[row0]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[row0]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[lowRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]]
// CHECK-NEXT: [[highRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]]
// CHECK-NEXT: [[row1:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 1
// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[row1]] 0
// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]]
// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0
// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1
// CHECK-NEXT: [[value3:%[0-9]+]] = OpCompositeExtract %double [[row1]] 1
// CHECK-NEXT: [[resultVec3:%[0-9]+]] = OpBitcast %v2uint [[value3]]
// CHECK-NEXT: [[low3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 0
// CHECK-NEXT: [[high3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 1
// CHECK-NEXT: [[lowRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[low2]] [[low3]]
// CHECK-NEXT: [[highRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[high2]] [[high3]]
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_2 [[lowRow0]] [[lowRow1]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_2 [[highRow0]] [[highRow1]]
// CHECK-NEXT: OpStore %lowbits2x2 [[low]]
// CHECK-NEXT: OpStore %highbits2x2 [[high]]
asuint(value2x2, lowbits2x2, highbits2x2);

double3x2 value3x2;
uint3x2 lowbits3x2;
uint3x2 highbits3x2;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %mat3v2double %value3x2
// CHECK-NEXT: [[row0:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 0
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[row0]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[row0]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[lowRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]]
// CHECK-NEXT: [[highRow0:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]]
// CHECK-NEXT: [[row1:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 1
// CHECK-NEXT: [[value2:%[0-9]+]] = OpCompositeExtract %double [[row1]] 0
// CHECK-NEXT: [[resultVec2:%[0-9]+]] = OpBitcast %v2uint [[value2]]
// CHECK-NEXT: [[low2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 0
// CHECK-NEXT: [[high2:%[0-9]+]] = OpCompositeExtract %uint [[resultVec2]] 1
// CHECK-NEXT: [[value3:%[0-9]+]] = OpCompositeExtract %double [[row1]] 1
// CHECK-NEXT: [[resultVec3:%[0-9]+]] = OpBitcast %v2uint [[value3]]
// CHECK-NEXT: [[low3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 0
// CHECK-NEXT: [[high3:%[0-9]+]] = OpCompositeExtract %uint [[resultVec3]] 1
// CHECK-NEXT: [[lowRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[low2]] [[low3]]
// CHECK-NEXT: [[highRow1:%[0-9]+]] = OpCompositeConstruct %v2uint [[high2]] [[high3]]
// CHECK-NEXT: [[row2:%[0-9]+]] = OpCompositeExtract %v2double [[value]] 2
// CHECK-NEXT: [[value4:%[0-9]+]] = OpCompositeExtract %double [[row2]] 0
// CHECK-NEXT: [[resultVec4:%[0-9]+]] = OpBitcast %v2uint [[value4]]
// CHECK-NEXT: [[low4:%[0-9]+]] = OpCompositeExtract %uint [[resultVec4]] 0
// CHECK-NEXT: [[high4:%[0-9]+]] = OpCompositeExtract %uint [[resultVec4]] 1
// CHECK-NEXT: [[value5:%[0-9]+]] = OpCompositeExtract %double [[row2]] 1
// CHECK-NEXT: [[resultVec5:%[0-9]+]] = OpBitcast %v2uint [[value5]]
// CHECK-NEXT: [[low5:%[0-9]+]] = OpCompositeExtract %uint [[resultVec5]] 0
// CHECK-NEXT: [[high5:%[0-9]+]] = OpCompositeExtract %uint [[resultVec5]] 1
// CHECK-NEXT: [[lowRow2:%[0-9]+]] = OpCompositeConstruct %v2uint [[low4]] [[low5]]
// CHECK-NEXT: [[highRow2:%[0-9]+]] = OpCompositeConstruct %v2uint [[high4]] [[high5]]
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_3 [[lowRow0]] [[lowRow1]] [[lowRow2]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %_arr_v2uint_uint_3 [[highRow0]] [[highRow1]] [[highRow2]]
// CHECK-NEXT: OpStore %lowbits3x2 [[low]]
// CHECK-NEXT: OpStore %highbits3x2 [[high]]
asuint(value3x2, lowbits3x2, highbits3x2);

double2x1 value2x1;
uint2x1 lowbits2x1;
uint2x1 highbits2x1;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %v2double %value2x1
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[value]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[value]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]]
// CHECK-NEXT: OpStore %lowbits2x1 [[low]]
// CHECK-NEXT: OpStore %highbits2x1 [[high]]
asuint(value2x1, lowbits2x1, highbits2x1);

double1x2 value1x2;
uint1x2 lowbits1x2;
uint1x2 highbits1x2;
// CHECK-NEXT: [[value:%[0-9]+]] = OpLoad %v2double %value1x2
// CHECK-NEXT: [[value0:%[0-9]+]] = OpCompositeExtract %double [[value]] 0
// CHECK-NEXT: [[resultVec0:%[0-9]+]] = OpBitcast %v2uint [[value0]]
// CHECK-NEXT: [[low0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 0
// CHECK-NEXT: [[high0:%[0-9]+]] = OpCompositeExtract %uint [[resultVec0]] 1
// CHECK-NEXT: [[value1:%[0-9]+]] = OpCompositeExtract %double [[value]] 1
// CHECK-NEXT: [[resultVec1:%[0-9]+]] = OpBitcast %v2uint [[value1]]
// CHECK-NEXT: [[low1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 0
// CHECK-NEXT: [[high1:%[0-9]+]] = OpCompositeExtract %uint [[resultVec1]] 1
// CHECK-NEXT: [[low:%[0-9]+]] = OpCompositeConstruct %v2uint [[low0]] [[low1]]
// CHECK-NEXT: [[high:%[0-9]+]] = OpCompositeConstruct %v2uint [[high0]] [[high1]]
// CHECK-NEXT: OpStore %lowbits1x2 [[low]]
// CHECK-NEXT: OpStore %highbits1x2 [[high]]
asuint(value1x2, lowbits1x2, highbits1x2);
}
Loading