Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support EXT_integer_dot_product #3832

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion SPIRV/GLSL.ext.KHR.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ static const char* const E_SPV_KHR_subgroup_rotate = "SPV_KHR_subgr
static const char* const E_SPV_KHR_expect_assume = "SPV_KHR_expect_assume";
static const char* const E_SPV_EXT_replicated_composites = "SPV_EXT_replicated_composites";
static const char* const E_SPV_KHR_relaxed_extended_instruction = "SPV_KHR_relaxed_extended_instruction";

static const char* const E_SPV_KHR_integer_dot_product = "SPV_KHR_integer_dot_product";
#endif // #ifndef GLSLextKHR_H
44 changes: 43 additions & 1 deletion SPIRV/GlslangToSpv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8760,7 +8760,49 @@ spv::Id TGlslangToSpvTraverser::createMiscOperation(glslang::TOperator op, spv::
libCall = spv::GLSLstd450Pow;
break;
case glslang::EOpDot:
opCode = spv::OpDot;
case glslang::EOpDotPackedEXT:
case glslang::EOpDotAccSatEXT:
case glslang::EOpDotPackedAccSatEXT:
{
if (builder.isFloatType(builder.getScalarTypeId(typeId0)) ||
// HLSL supports dot(int,int) which is just a multiply
glslangIntermediate->getSource() == glslang::EShSourceHlsl) {
opCode = spv::OpDot;
} else {
builder.addExtension(spv::E_SPV_KHR_integer_dot_product);
builder.addCapability(spv::CapabilityDotProductKHR);
const unsigned int vectorSize = builder.getNumComponents(operands[0]);
if (op == glslang::EOpDotPackedEXT || op == glslang::EOpDotPackedAccSatEXT) {
builder.addCapability(spv::CapabilityDotProductInput4x8BitPackedKHR);
} else if (vectorSize == 4 && builder.getScalarTypeWidth(typeId0) == 8) {
builder.addCapability(spv::CapabilityDotProductInput4x8BitKHR);
} else {
builder.addCapability(spv::CapabilityDotProductInputAllKHR);
}
const bool type0isSigned = builder.isIntType(builder.getScalarTypeId(typeId0));
const bool type1isSigned = builder.isIntType(builder.getScalarTypeId(typeId1));
const bool accSat = (op == glslang::EOpDotAccSatEXT || op == glslang::EOpDotPackedAccSatEXT);
if (!type0isSigned && !type1isSigned) {
opCode = accSat ? spv::OpUDotAccSatKHR : spv::OpUDotKHR;
} else if (type0isSigned && type1isSigned) {
opCode = accSat ? spv::OpSDotAccSatKHR : spv::OpSDotKHR;
} else {
opCode = accSat ? spv::OpSUDotAccSatKHR : spv::OpSUDotKHR;
// the spir-v opcode assumes the operands to be "signed, unsigned" in that order, so swap if needed
if (type1isSigned) {
std::swap(operands[0], operands[1]);
}
}
std::vector<spv::IdImmediate> operands2;
for (auto &o : operands) {
operands2.push_back({true, o});
}
if (op == glslang::EOpDotPackedEXT || op == glslang::EOpDotPackedAccSatEXT) {
operands2.push_back({false, 0});
}
return builder.createOp(opCode, typeId, operands2);
}
}
break;
case glslang::EOpAtan:
libCall = spv::GLSLstd450Atan2;
Expand Down
39 changes: 39 additions & 0 deletions SPIRV/doc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,11 @@ const char* CapabilityString(int info)

case CapabilityReplicatedCompositesEXT: return "CapabilityReplicatedCompositesEXT";

case CapabilityDotProductKHR: return "DotProductKHR";
case CapabilityDotProductInputAllKHR: return "DotProductInputAllKHR";
case CapabilityDotProductInput4x8BitKHR: return "DotProductInput4x8BitKHR";
case CapabilityDotProductInput4x8BitPackedKHR: return "DotProductInput4x8BitPackedKHR";

default: return "Bad";
}
}
Expand Down Expand Up @@ -1631,6 +1636,13 @@ const char* OpcodeString(int op)
case OpSpecConstantCompositeReplicateEXT: return "OpSpecConstantCompositeReplicateEXT";
case OpCompositeConstructReplicateEXT: return "OpCompositeConstructReplicateEXT";

case OpSDotKHR: return "OpSDotKHR";
case OpUDotKHR: return "OpUDotKHR";
case OpSUDotKHR: return "OpSUDotKHR";
case OpSDotAccSatKHR: return "OpSDotAccSatKHR";
case OpUDotAccSatKHR: return "OpUDotAccSatKHR";
case OpSUDotAccSatKHR: return "OpSUDotAccSatKHR";

default:
return "Bad";
}
Expand Down Expand Up @@ -3592,6 +3604,33 @@ void Parameterize()
InstructionDesc[OpTensorViewSetClipNV].operands.push(OperandId, "'ClipRowSpan'");
InstructionDesc[OpTensorViewSetClipNV].operands.push(OperandId, "'ClipColOffset'");
InstructionDesc[OpTensorViewSetClipNV].operands.push(OperandId, "'ClipColSpan'");

InstructionDesc[OpSDotKHR].operands.push(OperandId, "'Vector1'");
InstructionDesc[OpSDotKHR].operands.push(OperandId, "'Vector2'");
InstructionDesc[OpSDotKHR].operands.push(OperandLiteralNumber, "'PackedVectorFormat'");

InstructionDesc[OpUDotKHR].operands.push(OperandId, "'Vector1'");
InstructionDesc[OpUDotKHR].operands.push(OperandId, "'Vector2'");
InstructionDesc[OpUDotKHR].operands.push(OperandLiteralNumber, "'PackedVectorFormat'");

InstructionDesc[OpSUDotKHR].operands.push(OperandId, "'Vector1'");
InstructionDesc[OpSUDotKHR].operands.push(OperandId, "'Vector2'");
InstructionDesc[OpSUDotKHR].operands.push(OperandLiteralNumber, "'PackedVectorFormat'");

InstructionDesc[OpSDotAccSatKHR].operands.push(OperandId, "'Vector1'");
InstructionDesc[OpSDotAccSatKHR].operands.push(OperandId, "'Vector2'");
InstructionDesc[OpSDotAccSatKHR].operands.push(OperandId, "'Accumulator'");
InstructionDesc[OpSDotAccSatKHR].operands.push(OperandLiteralNumber, "'PackedVectorFormat'");

InstructionDesc[OpUDotAccSatKHR].operands.push(OperandId, "'Vector1'");
InstructionDesc[OpUDotAccSatKHR].operands.push(OperandId, "'Vector2'");
InstructionDesc[OpUDotAccSatKHR].operands.push(OperandId, "'Accumulator'");
InstructionDesc[OpUDotAccSatKHR].operands.push(OperandLiteralNumber, "'PackedVectorFormat'");

InstructionDesc[OpSUDotAccSatKHR].operands.push(OperandId, "'Vector1'");
InstructionDesc[OpSUDotAccSatKHR].operands.push(OperandId, "'Vector2'");
InstructionDesc[OpSUDotAccSatKHR].operands.push(OperandId, "'Accumulator'");
InstructionDesc[OpSUDotAccSatKHR].operands.push(OperandLiteralNumber, "'PackedVectorFormat'");
});
}

Expand Down
Loading