Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate vector constructions more efficiently when sizes match #3628

Merged
merged 2 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions SPIRV/SpvBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ Op Builder::getMostBasicTypeClass(Id typeId) const
}
}

int Builder::getNumTypeConstituents(Id typeId) const
unsigned int Builder::getNumTypeConstituents(Id typeId) const
{
Instruction* instr = module.getInstruction(typeId);

Expand Down Expand Up @@ -2924,7 +2924,7 @@ Id Builder::createLvalueSwizzle(Id typeId, Id target, Id source, const std::vect
swizzle->reserveOperands(2);
swizzle->addIdOperand(target);

assert(getNumComponents(source) == (int)channels.size());
assert(getNumComponents(source) == channels.size());
assert(isVector(source));
swizzle->addIdOperand(source);

Expand Down Expand Up @@ -3371,7 +3371,7 @@ Id Builder::createCompositeCompare(Decoration precision, Id value1, Id value2, b
Id Builder::createCompositeConstruct(Id typeId, const std::vector<Id>& constituents)
{
assert(isAggregateType(typeId) || (getNumTypeConstituents(typeId) > 1 &&
getNumTypeConstituents(typeId) == (int)constituents.size()));
getNumTypeConstituents(typeId) == constituents.size()));

if (generatingOpCodeForSpecConst) {
// Sometime, even in spec-constant-op mode, the constant composite to be
Expand Down Expand Up @@ -3424,6 +3424,12 @@ Id Builder::createConstructor(Decoration precision, const std::vector<Id>& sourc
if (sources.size() == 1 && isScalar(sources[0]) && numTargetComponents > 1)
return smearScalar(precision, sources[0], resultTypeId);

// Special case: 2 vectors of equal size
if (sources.size() == 1 && isVector(sources[0]) && numTargetComponents == getNumComponents(sources[0])) {
assert(resultTypeId == getTypeId(sources[0]));
return sources[0];
}

// accumulate the arguments for OpCompositeConstruct
std::vector<Id> constituents;
Id scalarTypeId = getScalarTypeId(resultTypeId);
Expand Down Expand Up @@ -3458,8 +3464,8 @@ Id Builder::createConstructor(Decoration precision, const std::vector<Id>& sourc
if (sourcesToUse + targetComponent > numTargetComponents)
sourcesToUse = numTargetComponents - targetComponent;

int col = 0;
int row = 0;
unsigned int col = 0;
unsigned int row = 0;
for (unsigned int s = 0; s < sourcesToUse; ++s) {
if (row >= getNumRows(sourceArg)) {
row = 0;
Expand Down Expand Up @@ -3504,8 +3510,8 @@ Id Builder::createConstructor(Decoration precision, const std::vector<Id>& sourc
Id Builder::createMatrixConstructor(Decoration precision, const std::vector<Id>& sources, Id resultTypeId)
{
Id componentTypeId = getScalarTypeId(resultTypeId);
int numCols = getTypeNumColumns(resultTypeId);
int numRows = getTypeNumRows(resultTypeId);
unsigned int numCols = getTypeNumColumns(resultTypeId);
unsigned int numRows = getTypeNumRows(resultTypeId);

Instruction* instr = module.getInstruction(componentTypeId);
const unsigned bitCount = instr->getImmediateOperand(0);
Expand All @@ -3520,11 +3526,11 @@ Id Builder::createMatrixConstructor(Decoration precision, const std::vector<Id>&
Id sourceColumnTypeId = getContainedTypeId(getTypeId(matrix));

std::vector<unsigned> channels;
for (int row = 0; row < numRows; ++row)
for (unsigned int row = 0; row < numRows; ++row)
channels.push_back(row);

std::vector<Id> matrixColumns;
for (int col = 0; col < numCols; ++col) {
for (unsigned int col = 0; col < numCols; ++col) {
std::vector<unsigned> indexes;
indexes.push_back(col);
Id colv = createCompositeExtract(matrix, sourceColumnTypeId, indexes);
Expand All @@ -3542,7 +3548,7 @@ Id Builder::createMatrixConstructor(Decoration precision, const std::vector<Id>&

// Detect a matrix being constructed from a repeated vector of the correct size.
// Create the composite directly from it.
if ((int)sources.size() == numCols && isVector(sources[0]) && getNumComponents(sources[0]) == numRows &&
if (sources.size() == numCols && isVector(sources[0]) && getNumComponents(sources[0]) == numRows &&
std::equal(sources.begin() + 1, sources.end(), sources.begin())) {
return setPrecision(createCompositeConstruct(resultTypeId, sources), precision);
}
Expand Down Expand Up @@ -3574,12 +3580,12 @@ Id Builder::createMatrixConstructor(Decoration precision, const std::vector<Id>&
} else if (isMatrix(sources[0])) {
// constructing from another matrix; copy over the parts that exist in both the argument and constructee
Id matrix = sources[0];
int minCols = std::min(numCols, getNumColumns(matrix));
int minRows = std::min(numRows, getNumRows(matrix));
for (int col = 0; col < minCols; ++col) {
unsigned int minCols = std::min(numCols, getNumColumns(matrix));
unsigned int minRows = std::min(numRows, getNumRows(matrix));
for (unsigned int col = 0; col < minCols; ++col) {
std::vector<unsigned> indexes;
indexes.push_back(col);
for (int row = 0; row < minRows; ++row) {
for (unsigned int row = 0; row < minRows; ++row) {
indexes.push_back(row);
ids[col][row] = createCompositeExtract(matrix, componentTypeId, indexes);
indexes.pop_back();
Expand All @@ -3588,12 +3594,12 @@ Id Builder::createMatrixConstructor(Decoration precision, const std::vector<Id>&
}
} else {
// fill in the matrix in column-major order with whatever argument components are available
int row = 0;
int col = 0;
unsigned int row = 0;
unsigned int col = 0;

for (int arg = 0; arg < (int)sources.size() && col < numCols; ++arg) {
for (unsigned int arg = 0; arg < sources.size() && col < numCols; ++arg) {
Id argComp = sources[arg];
for (int comp = 0; comp < getNumComponents(sources[arg]); ++comp) {
for (unsigned int comp = 0; comp < getNumComponents(sources[arg]); ++comp) {
if (getNumComponents(sources[arg]) > 1) {
argComp = createCompositeExtract(sources[arg], componentTypeId, comp);
setPrecision(argComp, precision);
Expand All @@ -3617,9 +3623,9 @@ Id Builder::createMatrixConstructor(Decoration precision, const std::vector<Id>&
// make the column vectors
Id columnTypeId = getContainedTypeId(resultTypeId);
std::vector<Id> matrixColumns;
for (int col = 0; col < numCols; ++col) {
for (unsigned int col = 0; col < numCols; ++col) {
std::vector<Id> vectorComponents;
for (int row = 0; row < numRows; ++row)
for (unsigned int row = 0; row < numRows; ++row)
vectorComponents.push_back(ids[col][row]);
Id column = createCompositeConstruct(columnTypeId, vectorComponents);
setPrecision(column, precision);
Expand Down Expand Up @@ -3846,7 +3852,7 @@ void Builder::accessChainStore(Id rvalue, Decoration nonUniform, spv::MemoryAcce

// If a swizzle exists and is not full and is not dynamic, then the swizzle will be broken into individual stores.
if (accessChain.swizzle.size() > 0 &&
getNumTypeComponents(getResultingAccessChainType()) != (int)accessChain.swizzle.size() &&
getNumTypeComponents(getResultingAccessChainType()) != accessChain.swizzle.size() &&
accessChain.component == NoResult) {
for (unsigned int i = 0; i < accessChain.swizzle.size(); ++i) {
accessChain.indexChain.push_back(makeUintConstant(accessChain.swizzle[i]));
Expand Down Expand Up @@ -4166,7 +4172,7 @@ void Builder::simplifyAccessChainSwizzle()
{
// If the swizzle has fewer components than the vector, it is subsetting, and must stay
// to preserve that fact.
if (getNumTypeComponents(accessChain.preSwizzleBaseType) > (int)accessChain.swizzle.size())
if (getNumTypeComponents(accessChain.preSwizzleBaseType) > accessChain.swizzle.size())
return;

// if components are out of order, it is a swizzle
Expand Down
14 changes: 7 additions & 7 deletions SPIRV/SpvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ class Builder {
Op getOpCode(Id id) const { return module.getInstruction(id)->getOpCode(); }
Op getTypeClass(Id typeId) const { return getOpCode(typeId); }
Op getMostBasicTypeClass(Id typeId) const;
int getNumComponents(Id resultId) const { return getNumTypeComponents(getTypeId(resultId)); }
int getNumTypeConstituents(Id typeId) const;
int getNumTypeComponents(Id typeId) const { return getNumTypeConstituents(typeId); }
unsigned int getNumComponents(Id resultId) const { return getNumTypeComponents(getTypeId(resultId)); }
unsigned int getNumTypeConstituents(Id typeId) const;
unsigned int getNumTypeComponents(Id typeId) const { return getNumTypeConstituents(typeId); }
Id getScalarTypeId(Id typeId) const;
Id getContainedTypeId(Id typeId) const;
Id getContainedTypeId(Id typeId, int) const;
Expand Down Expand Up @@ -334,18 +334,18 @@ class Builder {
return module.getInstruction(scalarTypeId)->getImmediateOperand(0);
}

int getTypeNumColumns(Id typeId) const
unsigned int getTypeNumColumns(Id typeId) const
{
assert(isMatrixType(typeId));
return getNumTypeConstituents(typeId);
}
int getNumColumns(Id resultId) const { return getTypeNumColumns(getTypeId(resultId)); }
int getTypeNumRows(Id typeId) const
unsigned int getNumColumns(Id resultId) const { return getTypeNumColumns(getTypeId(resultId)); }
unsigned int getTypeNumRows(Id typeId) const
{
assert(isMatrixType(typeId));
return getNumTypeComponents(getContainedTypeId(typeId));
}
int getNumRows(Id resultId) const { return getTypeNumRows(getTypeId(resultId)); }
unsigned int getNumRows(Id resultId) const { return getTypeNumRows(getTypeId(resultId)); }

Dim getTypeDimensionality(Id typeId) const
{
Expand Down
Loading