From de50084e919947018b840c190ac73328bfc05155 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 10 Jan 2025 16:01:59 -0600 Subject: [PATCH] Support EXT_integer_dot_product --- SPIRV/GLSL.ext.KHR.h | 2 +- SPIRV/GlslangToSpv.cpp | 44 ++- SPIRV/doc.cpp | 39 +++ Test/baseResults/spv.int_dot.frag.out | 288 ++++++++++++++++++ Test/baseResults/spv.int_dot_Error.frag.out | 12 + Test/spv.int_dot.frag | 143 +++++++++ Test/spv.int_dot_Error.frag | 28 ++ glslang/Include/intermediate.h | 5 + glslang/MachineIndependent/Constant.cpp | 3 + glslang/MachineIndependent/Initialize.cpp | 155 ++++++++++ glslang/MachineIndependent/Versions.cpp | 4 + glslang/MachineIndependent/Versions.h | 2 + glslang/MachineIndependent/intermOut.cpp | 3 + .../propagateNoContraction.cpp | 3 + gtests/Spv.FromFile.cpp | 2 + 15 files changed, 731 insertions(+), 2 deletions(-) create mode 100644 Test/baseResults/spv.int_dot.frag.out create mode 100644 Test/baseResults/spv.int_dot_Error.frag.out create mode 100644 Test/spv.int_dot.frag create mode 100644 Test/spv.int_dot_Error.frag diff --git a/SPIRV/GLSL.ext.KHR.h b/SPIRV/GLSL.ext.KHR.h index 38d3b974b0..8a44b17a6c 100644 --- a/SPIRV/GLSL.ext.KHR.h +++ b/SPIRV/GLSL.ext.KHR.h @@ -63,5 +63,5 @@ static const char* const E_SPV_KHR_subgroup_rotate = "SPV_KHR_subgr static const char* const E_SPV_KHR_expect_assume = "SPV_KHR_expect_assume"; static const char* const E_SPV_EXT_replicated_composites = "SPV_EXT_replicated_composites"; static const char* const E_SPV_KHR_relaxed_extended_instruction = "SPV_KHR_relaxed_extended_instruction"; - +static const char* const E_SPV_KHR_integer_dot_product = "SPV_KHR_integer_dot_product"; #endif // #ifndef GLSLextKHR_H diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp index 4dc754ffd5..0fdec041f2 100644 --- a/SPIRV/GlslangToSpv.cpp +++ b/SPIRV/GlslangToSpv.cpp @@ -8760,7 +8760,49 @@ spv::Id TGlslangToSpvTraverser::createMiscOperation(glslang::TOperator op, spv:: libCall = spv::GLSLstd450Pow; break; case glslang::EOpDot: - opCode = spv::OpDot; + case glslang::EOpDotPackedEXT: + case glslang::EOpDotAccSatEXT: + case glslang::EOpDotPackedAccSatEXT: + { + if (builder.isFloatType(builder.getScalarTypeId(typeId0)) || + // HLSL supports dot(int,int) which is just a multiply + glslangIntermediate->getSource() == glslang::EShSourceHlsl) { + opCode = spv::OpDot; + } else { + builder.addExtension(spv::E_SPV_KHR_integer_dot_product); + builder.addCapability(spv::CapabilityDotProductKHR); + const unsigned int vectorSize = builder.getNumComponents(operands[0]); + if (op == glslang::EOpDotPackedEXT || op == glslang::EOpDotPackedAccSatEXT) { + builder.addCapability(spv::CapabilityDotProductInput4x8BitPackedKHR); + } else if (vectorSize == 4 && builder.getScalarTypeWidth(typeId0) == 8) { + builder.addCapability(spv::CapabilityDotProductInput4x8BitKHR); + } else { + builder.addCapability(spv::CapabilityDotProductInputAllKHR); + } + const bool type0isSigned = builder.isIntType(builder.getScalarTypeId(typeId0)); + const bool type1isSigned = builder.isIntType(builder.getScalarTypeId(typeId1)); + const bool accSat = (op == glslang::EOpDotAccSatEXT || op == glslang::EOpDotPackedAccSatEXT); + if (!type0isSigned && !type1isSigned) { + opCode = accSat ? spv::OpUDotAccSatKHR : spv::OpUDotKHR; + } else if (type0isSigned && type1isSigned) { + opCode = accSat ? spv::OpSDotAccSatKHR : spv::OpSDotKHR; + } else { + opCode = accSat ? spv::OpSUDotAccSatKHR : spv::OpSUDotKHR; + // the spir-v opcode assumes the operands to be "signed, unsigned" in that order, so swap if needed + if (type1isSigned) { + std::swap(operands[0], operands[1]); + } + } + std::vector operands2; + for (auto &o : operands) { + operands2.push_back({true, o}); + } + if (op == glslang::EOpDotPackedEXT || op == glslang::EOpDotPackedAccSatEXT) { + operands2.push_back({false, 0}); + } + return builder.createOp(opCode, typeId, operands2); + } + } break; case glslang::EOpAtan: libCall = spv::GLSLstd450Atan2; diff --git a/SPIRV/doc.cpp b/SPIRV/doc.cpp index dce8a71a97..346a09f93e 100644 --- a/SPIRV/doc.cpp +++ b/SPIRV/doc.cpp @@ -1089,6 +1089,11 @@ const char* CapabilityString(int info) case CapabilityReplicatedCompositesEXT: return "CapabilityReplicatedCompositesEXT"; + case CapabilityDotProductKHR: return "DotProductKHR"; + case CapabilityDotProductInputAllKHR: return "DotProductInputAllKHR"; + case CapabilityDotProductInput4x8BitKHR: return "DotProductInput4x8BitKHR"; + case CapabilityDotProductInput4x8BitPackedKHR: return "DotProductInput4x8BitPackedKHR"; + default: return "Bad"; } } @@ -1631,6 +1636,13 @@ const char* OpcodeString(int op) case OpSpecConstantCompositeReplicateEXT: return "OpSpecConstantCompositeReplicateEXT"; case OpCompositeConstructReplicateEXT: return "OpCompositeConstructReplicateEXT"; + case OpSDotKHR: return "OpSDotKHR"; + case OpUDotKHR: return "OpUDotKHR"; + case OpSUDotKHR: return "OpSUDotKHR"; + case OpSDotAccSatKHR: return "OpSDotAccSatKHR"; + case OpUDotAccSatKHR: return "OpUDotAccSatKHR"; + case OpSUDotAccSatKHR: return "OpSUDotAccSatKHR"; + default: return "Bad"; } @@ -3592,6 +3604,33 @@ void Parameterize() InstructionDesc[OpTensorViewSetClipNV].operands.push(OperandId, "'ClipRowSpan'"); InstructionDesc[OpTensorViewSetClipNV].operands.push(OperandId, "'ClipColOffset'"); InstructionDesc[OpTensorViewSetClipNV].operands.push(OperandId, "'ClipColSpan'"); + + InstructionDesc[OpSDotKHR].operands.push(OperandId, "'Vector1'"); + InstructionDesc[OpSDotKHR].operands.push(OperandId, "'Vector2'"); + InstructionDesc[OpSDotKHR].operands.push(OperandLiteralNumber, "'PackedVectorFormat'"); + + InstructionDesc[OpUDotKHR].operands.push(OperandId, "'Vector1'"); + InstructionDesc[OpUDotKHR].operands.push(OperandId, "'Vector2'"); + InstructionDesc[OpUDotKHR].operands.push(OperandLiteralNumber, "'PackedVectorFormat'"); + + InstructionDesc[OpSUDotKHR].operands.push(OperandId, "'Vector1'"); + InstructionDesc[OpSUDotKHR].operands.push(OperandId, "'Vector2'"); + InstructionDesc[OpSUDotKHR].operands.push(OperandLiteralNumber, "'PackedVectorFormat'"); + + InstructionDesc[OpSDotAccSatKHR].operands.push(OperandId, "'Vector1'"); + InstructionDesc[OpSDotAccSatKHR].operands.push(OperandId, "'Vector2'"); + InstructionDesc[OpSDotAccSatKHR].operands.push(OperandId, "'Accumulator'"); + InstructionDesc[OpSDotAccSatKHR].operands.push(OperandLiteralNumber, "'PackedVectorFormat'"); + + InstructionDesc[OpUDotAccSatKHR].operands.push(OperandId, "'Vector1'"); + InstructionDesc[OpUDotAccSatKHR].operands.push(OperandId, "'Vector2'"); + InstructionDesc[OpUDotAccSatKHR].operands.push(OperandId, "'Accumulator'"); + InstructionDesc[OpUDotAccSatKHR].operands.push(OperandLiteralNumber, "'PackedVectorFormat'"); + + InstructionDesc[OpSUDotAccSatKHR].operands.push(OperandId, "'Vector1'"); + InstructionDesc[OpSUDotAccSatKHR].operands.push(OperandId, "'Vector2'"); + InstructionDesc[OpSUDotAccSatKHR].operands.push(OperandId, "'Accumulator'"); + InstructionDesc[OpSUDotAccSatKHR].operands.push(OperandLiteralNumber, "'PackedVectorFormat'"); }); } diff --git a/Test/baseResults/spv.int_dot.frag.out b/Test/baseResults/spv.int_dot.frag.out new file mode 100644 index 0000000000..ae83da920d --- /dev/null +++ b/Test/baseResults/spv.int_dot.frag.out @@ -0,0 +1,288 @@ +spv.int_dot.frag +// Module Version 10000 +// Generated by (magic number): 8000b +// Id's are bound by 254 + + Capability Shader + Capability Int64 + Capability Int16 + Capability Int8 + Capability DotProductInputAllKHR + Capability DotProductInput4x8BitKHR + Capability DotProductInput4x8BitPackedKHR + Capability DotProductKHR + Extension "SPV_KHR_integer_dot_product" + 1: ExtInstImport "GLSL.std.450" + MemoryModel Logical GLSL450 + EntryPoint Fragment 4 "main" + ExecutionMode 4 OriginUpperLeft + Source GLSL 450 + SourceExtension "GL_EXT_integer_dot_product" + SourceExtension "GL_EXT_shader_explicit_arithmetic_types" + Name 4 "main" + Name 8 "i32" + Name 12 "i16" + Name 16 "ui32" + Name 20 "ui16" + Name 24 "i64" + Name 28 "ui64" + 2: TypeVoid + 3: TypeFunction 2 + 6: TypeInt 32 1 + 7: TypePointer Private 6(int) + 8(i32): 7(ptr) Variable Private + 9: 6(int) Constant 0 + 10: TypeInt 16 1 + 11: TypePointer Private 10(int16_t) + 12(i16): 11(ptr) Variable Private + 13: 10(int16_t) Constant 0 + 14: TypeInt 32 0 + 15: TypePointer Private 14(int) + 16(ui32): 15(ptr) Variable Private + 17: 14(int) Constant 0 + 18: TypeInt 16 0 + 19: TypePointer Private 18(int16_t) + 20(ui16): 19(ptr) Variable Private + 21: 18(int16_t) Constant 0 + 22: TypeInt 64 1 + 23: TypePointer Private 22(int64_t) + 24(i64): 23(ptr) Variable Private + 25: 22(int64_t) Constant 0 0 + 26: TypeInt 64 0 + 27: TypePointer Private 26(int64_t) + 28(ui64): 27(ptr) Variable Private + 29: 26(int64_t) Constant 0 0 + 58: TypeInt 8 1 + 59: TypeVector 58(int8_t) 4 + 60: 58(int8_t) Constant 0 + 61: 59(i8vec4) ConstantComposite 60 60 60 60 + 63: TypeInt 8 0 + 64: TypeVector 63(int8_t) 4 + 65: 63(int8_t) Constant 0 + 66: 64(i8vec4) ConstantComposite 65 65 65 65 + 78: TypeVector 58(int8_t) 3 + 79: 78(i8vec3) ConstantComposite 60 60 60 + 81: TypeVector 63(int8_t) 3 + 82: 81(i8vec3) ConstantComposite 65 65 65 + 94: TypeVector 58(int8_t) 2 + 95: 94(i8vec2) ConstantComposite 60 60 + 97: TypeVector 63(int8_t) 2 + 98: 97(i8vec2) ConstantComposite 65 65 + 110: TypeVector 10(int16_t) 4 + 111:110(i16vec4) ConstantComposite 13 13 13 13 + 113: TypeVector 18(int16_t) 4 + 114:113(i16vec4) ConstantComposite 21 21 21 21 + 126: TypeVector 10(int16_t) 3 + 127:126(i16vec3) ConstantComposite 13 13 13 + 129: TypeVector 18(int16_t) 3 + 130:129(i16vec3) ConstantComposite 21 21 21 + 142: TypeVector 10(int16_t) 2 + 143:142(i16vec2) ConstantComposite 13 13 + 145: TypeVector 18(int16_t) 2 + 146:145(i16vec2) ConstantComposite 21 21 + 158: TypeVector 6(int) 4 + 159: 158(ivec4) ConstantComposite 9 9 9 9 + 161: TypeVector 14(int) 4 + 162: 161(ivec4) ConstantComposite 17 17 17 17 + 174: TypeVector 6(int) 3 + 175: 174(ivec3) ConstantComposite 9 9 9 + 177: TypeVector 14(int) 3 + 178: 177(ivec3) ConstantComposite 17 17 17 + 190: TypeVector 6(int) 2 + 191: 190(ivec2) ConstantComposite 9 9 + 193: TypeVector 14(int) 2 + 194: 193(ivec2) ConstantComposite 17 17 + 206: TypeVector 22(int64_t) 4 + 207:206(i64vec4) ConstantComposite 25 25 25 25 + 209: TypeVector 26(int64_t) 4 + 210:209(i64vec4) ConstantComposite 29 29 29 29 + 222: TypeVector 22(int64_t) 3 + 223:222(i64vec3) ConstantComposite 25 25 25 + 225: TypeVector 26(int64_t) 3 + 226:225(i64vec3) ConstantComposite 29 29 29 + 238: TypeVector 22(int64_t) 2 + 239:238(i64vec2) ConstantComposite 25 25 + 241: TypeVector 26(int64_t) 2 + 242:241(i64vec2) ConstantComposite 29 29 + 4(main): 2 Function None 3 + 5: Label + Store 8(i32) 9 + Store 12(i16) 13 + Store 16(ui32) 17 + Store 20(ui16) 21 + Store 24(i64) 25 + Store 28(ui64) 29 + 30: 6(int) Load 8(i32) + 31: 6(int) Load 8(i32) + 32: 6(int) SDotKHR 30 31 0 + 33: 14(int) Load 16(ui32) + 34: 14(int) Load 16(ui32) + 35: 14(int) UDotKHR 33 34 0 + 36: 14(int) Load 16(ui32) + 37: 6(int) Load 8(i32) + 38: 6(int) SUDotKHR 37 36 0 + 39: 6(int) Load 8(i32) + 40: 14(int) Load 16(ui32) + 41: 6(int) SUDotKHR 39 40 0 + 42: 6(int) Load 8(i32) + 43: 6(int) Load 8(i32) + 44: 6(int) Load 8(i32) + 45: 6(int) SDotAccSatKHR 42 43 44 0 + 46: 14(int) Load 16(ui32) + 47: 14(int) Load 16(ui32) + 48: 14(int) Load 16(ui32) + 49: 14(int) UDotAccSatKHR 46 47 48 0 + 50: 14(int) Load 16(ui32) + 51: 6(int) Load 8(i32) + 52: 6(int) Load 8(i32) + 53: 6(int) SUDotAccSatKHR 51 50 52 0 + 54: 6(int) Load 8(i32) + 55: 14(int) Load 16(ui32) + 56: 6(int) Load 8(i32) + 57: 6(int) SUDotAccSatKHR 54 55 56 0 + 62: 6(int) SDotKHR 61 61 + 67: 14(int) UDotKHR 66 66 + 68: 6(int) SUDotKHR 61 66 + 69: 6(int) SUDotKHR 61 66 + 70: 6(int) Load 8(i32) + 71: 6(int) SDotAccSatKHR 61 61 70 + 72: 14(int) Load 16(ui32) + 73: 14(int) UDotAccSatKHR 66 66 72 + 74: 6(int) Load 8(i32) + 75: 6(int) SUDotAccSatKHR 61 66 74 + 76: 6(int) Load 8(i32) + 77: 6(int) SUDotAccSatKHR 61 66 76 + 80: 6(int) SDotKHR 79 79 + 83: 14(int) UDotKHR 82 82 + 84: 6(int) SUDotKHR 79 82 + 85: 6(int) SUDotKHR 79 82 + 86: 6(int) Load 8(i32) + 87: 6(int) SDotAccSatKHR 79 79 86 + 88: 14(int) Load 16(ui32) + 89: 14(int) UDotAccSatKHR 82 82 88 + 90: 6(int) Load 8(i32) + 91: 6(int) SUDotAccSatKHR 79 82 90 + 92: 6(int) Load 8(i32) + 93: 6(int) SUDotAccSatKHR 79 82 92 + 96: 6(int) SDotKHR 95 95 + 99: 14(int) UDotKHR 98 98 + 100: 6(int) SUDotKHR 95 98 + 101: 6(int) SUDotKHR 95 98 + 102: 6(int) Load 8(i32) + 103: 6(int) SDotAccSatKHR 95 95 102 + 104: 14(int) Load 16(ui32) + 105: 14(int) UDotAccSatKHR 98 98 104 + 106: 6(int) Load 8(i32) + 107: 6(int) SUDotAccSatKHR 95 98 106 + 108: 6(int) Load 8(i32) + 109: 6(int) SUDotAccSatKHR 95 98 108 + 112: 6(int) SDotKHR 111 111 + 115: 14(int) UDotKHR 114 114 + 116: 6(int) SUDotKHR 111 114 + 117: 6(int) SUDotKHR 111 114 + 118: 6(int) Load 8(i32) + 119: 6(int) SDotAccSatKHR 111 111 118 + 120: 14(int) Load 16(ui32) + 121: 14(int) UDotAccSatKHR 114 114 120 + 122: 6(int) Load 8(i32) + 123: 6(int) SUDotAccSatKHR 111 114 122 + 124: 6(int) Load 8(i32) + 125: 6(int) SUDotAccSatKHR 111 114 124 + 128: 6(int) SDotKHR 127 127 + 131: 14(int) UDotKHR 130 130 + 132: 6(int) SUDotKHR 127 130 + 133: 6(int) SUDotKHR 127 130 + 134: 6(int) Load 8(i32) + 135: 6(int) SDotAccSatKHR 127 127 134 + 136: 14(int) Load 16(ui32) + 137: 14(int) UDotAccSatKHR 130 130 136 + 138: 6(int) Load 8(i32) + 139: 6(int) SUDotAccSatKHR 127 130 138 + 140: 6(int) Load 8(i32) + 141: 6(int) SUDotAccSatKHR 127 130 140 + 144: 6(int) SDotKHR 143 143 + 147: 14(int) UDotKHR 146 146 + 148: 6(int) SUDotKHR 143 146 + 149: 6(int) SUDotKHR 143 146 + 150: 6(int) Load 8(i32) + 151: 6(int) SDotAccSatKHR 143 143 150 + 152: 14(int) Load 16(ui32) + 153: 14(int) UDotAccSatKHR 146 146 152 + 154: 6(int) Load 8(i32) + 155: 6(int) SUDotAccSatKHR 143 146 154 + 156: 6(int) Load 8(i32) + 157: 6(int) SUDotAccSatKHR 143 146 156 + 160: 6(int) SDotKHR 159 159 + 163: 14(int) UDotKHR 162 162 + 164: 6(int) SUDotKHR 159 162 + 165: 6(int) SUDotKHR 159 162 + 166: 6(int) Load 8(i32) + 167: 6(int) SDotAccSatKHR 159 159 166 + 168: 14(int) Load 16(ui32) + 169: 14(int) UDotAccSatKHR 162 162 168 + 170: 6(int) Load 8(i32) + 171: 6(int) SUDotAccSatKHR 159 162 170 + 172: 6(int) Load 8(i32) + 173: 6(int) SUDotAccSatKHR 159 162 172 + 176: 6(int) SDotKHR 175 175 + 179: 14(int) UDotKHR 178 178 + 180: 6(int) SUDotKHR 175 178 + 181: 6(int) SUDotKHR 175 178 + 182: 6(int) Load 8(i32) + 183: 6(int) SDotAccSatKHR 175 175 182 + 184: 14(int) Load 16(ui32) + 185: 14(int) UDotAccSatKHR 178 178 184 + 186: 6(int) Load 8(i32) + 187: 6(int) SUDotAccSatKHR 175 178 186 + 188: 6(int) Load 8(i32) + 189: 6(int) SUDotAccSatKHR 175 178 188 + 192: 6(int) SDotKHR 191 191 + 195: 14(int) UDotKHR 194 194 + 196: 6(int) SUDotKHR 191 194 + 197: 6(int) SUDotKHR 191 194 + 198: 6(int) Load 8(i32) + 199: 6(int) SDotAccSatKHR 191 191 198 + 200: 14(int) Load 16(ui32) + 201: 14(int) UDotAccSatKHR 194 194 200 + 202: 6(int) Load 8(i32) + 203: 6(int) SUDotAccSatKHR 191 194 202 + 204: 6(int) Load 8(i32) + 205: 6(int) SUDotAccSatKHR 191 194 204 + 208: 22(int64_t) SDotKHR 207 207 + 211: 26(int64_t) UDotKHR 210 210 + 212: 22(int64_t) SUDotKHR 207 210 + 213: 22(int64_t) SUDotKHR 207 210 + 214: 22(int64_t) Load 24(i64) + 215: 22(int64_t) SDotAccSatKHR 207 207 214 + 216: 26(int64_t) Load 28(ui64) + 217: 26(int64_t) UDotAccSatKHR 210 210 216 + 218: 22(int64_t) Load 24(i64) + 219: 22(int64_t) SUDotAccSatKHR 207 210 218 + 220: 22(int64_t) Load 24(i64) + 221: 22(int64_t) SUDotAccSatKHR 207 210 220 + 224: 22(int64_t) SDotKHR 223 223 + 227: 26(int64_t) UDotKHR 226 226 + 228: 22(int64_t) SUDotKHR 223 226 + 229: 22(int64_t) SUDotKHR 223 226 + 230: 22(int64_t) Load 24(i64) + 231: 22(int64_t) SDotAccSatKHR 223 223 230 + 232: 26(int64_t) Load 28(ui64) + 233: 26(int64_t) UDotAccSatKHR 226 226 232 + 234: 22(int64_t) Load 24(i64) + 235: 22(int64_t) SUDotAccSatKHR 223 226 234 + 236: 22(int64_t) Load 24(i64) + 237: 22(int64_t) SUDotAccSatKHR 223 226 236 + 240: 22(int64_t) SDotKHR 239 239 + 243: 26(int64_t) UDotKHR 242 242 + 244: 22(int64_t) SUDotKHR 239 242 + 245: 22(int64_t) SUDotKHR 239 242 + 246: 22(int64_t) Load 24(i64) + 247: 22(int64_t) SDotAccSatKHR 239 239 246 + 248: 26(int64_t) Load 28(ui64) + 249: 26(int64_t) UDotAccSatKHR 242 242 248 + 250: 22(int64_t) Load 24(i64) + 251: 22(int64_t) SUDotAccSatKHR 239 242 250 + 252: 22(int64_t) Load 24(i64) + 253: 22(int64_t) SUDotAccSatKHR 239 242 252 + Return + FunctionEnd diff --git a/Test/baseResults/spv.int_dot_Error.frag.out b/Test/baseResults/spv.int_dot_Error.frag.out new file mode 100644 index 0000000000..18624f2966 --- /dev/null +++ b/Test/baseResults/spv.int_dot_Error.frag.out @@ -0,0 +1,12 @@ +spv.int_dot_Error.frag +ERROR: 0:16: 'dotPacked4x8EXT' : no matching overloaded function found +ERROR: 0:17: 'dotPacked4x8EXT' : no matching overloaded function found +ERROR: 0:18: 'dotPacked4x8EXT' : no matching overloaded function found +ERROR: 0:19: 'dotPacked4x8AccSatEXT' : no matching overloaded function found +ERROR: 0:20: 'dotPacked4x8AccSatEXT' : no matching overloaded function found +ERROR: 0:21: 'dotEXT' : no matching overloaded function found +ERROR: 0:22: 'dotAccSatEXT' : no matching overloaded function found +ERROR: 7 compilation errors. No code generated. + + +SPIR-V is not generated for failed compile or link diff --git a/Test/spv.int_dot.frag b/Test/spv.int_dot.frag new file mode 100644 index 0000000000..3bf7920636 --- /dev/null +++ b/Test/spv.int_dot.frag @@ -0,0 +1,143 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types: enable +#extension GL_EXT_integer_dot_product: enable + +int32_t i32 = 0; +int16_t i16 = int16_t(0); +uint32_t ui32 = 0; +uint16_t ui16 = uint16_t(0); +int64_t i64 = int64_t(0); +uint64_t ui64 = uint64_t(0); + +void main (void) +{ + // DotProductInput4x8BitPackedKHR + dotPacked4x8EXT(i32,i32); + dotPacked4x8EXT(ui32,ui32); + dotPacked4x8EXT(ui32,i32); + dotPacked4x8EXT(i32,ui32); + dotPacked4x8AccSatEXT(i32,i32,i32); + dotPacked4x8AccSatEXT(ui32,ui32,ui32); + dotPacked4x8AccSatEXT(ui32,i32,i32); + dotPacked4x8AccSatEXT(i32,ui32,i32); + + // 8bit vec4 + dotEXT(i8vec4(0),i8vec4(0)); + dotEXT(u8vec4(0),u8vec4(0)); + dotEXT(u8vec4(0),i8vec4(0)); + dotEXT(i8vec4(0),u8vec4(0)); + dotAccSatEXT(i8vec4(0),i8vec4(0),i32); + dotAccSatEXT(u8vec4(0),u8vec4(0),ui32); + dotAccSatEXT(u8vec4(0),i8vec4(0),i32); + dotAccSatEXT(i8vec4(0),u8vec4(0),i32); + + // 8bit vec3 + dotEXT(i8vec3(0),i8vec3(0)); + dotEXT(u8vec3(0),u8vec3(0)); + dotEXT(u8vec3(0),i8vec3(0)); + dotEXT(i8vec3(0),u8vec3(0)); + dotAccSatEXT(i8vec3(0),i8vec3(0),i32); + dotAccSatEXT(u8vec3(0),u8vec3(0),ui32); + dotAccSatEXT(u8vec3(0),i8vec3(0),i32); + dotAccSatEXT(i8vec3(0),u8vec3(0),i32); + + // 8bit vec2 + dotEXT(i8vec2(0),i8vec2(0)); + dotEXT(u8vec2(0),u8vec2(0)); + dotEXT(u8vec2(0),i8vec2(0)); + dotEXT(i8vec2(0),u8vec2(0)); + dotAccSatEXT(i8vec2(0),i8vec2(0),i32); + dotAccSatEXT(u8vec2(0),u8vec2(0),ui32); + dotAccSatEXT(u8vec2(0),i8vec2(0),i32); + dotAccSatEXT(i8vec2(0),u8vec2(0),i32); + + // 16bit vec4 + dotEXT(i16vec4(0),i16vec4(0)); + dotEXT(u16vec4(0),u16vec4(0)); + dotEXT(i16vec4(0),u16vec4(0)); + dotEXT(u16vec4(0),i16vec4(0)); + dotAccSatEXT(i16vec4(0),i16vec4(0),i32); + dotAccSatEXT(u16vec4(0),u16vec4(0),ui32); + dotAccSatEXT(i16vec4(0),u16vec4(0),i32); + dotAccSatEXT(u16vec4(0),i16vec4(0),i32); + + // 16bit vec3 + dotEXT(i16vec3(0),i16vec3(0)); + dotEXT(u16vec3(0),u16vec3(0)); + dotEXT(i16vec3(0),u16vec3(0)); + dotEXT(u16vec3(0),i16vec3(0)); + dotAccSatEXT(i16vec3(0),i16vec3(0),i32); + dotAccSatEXT(u16vec3(0),u16vec3(0),ui32); + dotAccSatEXT(i16vec3(0),u16vec3(0),i32); + dotAccSatEXT(u16vec3(0),i16vec3(0),i32); + + // 16bit vec2 + dotEXT(i16vec2(0),i16vec2(0)); + dotEXT(u16vec2(0),u16vec2(0)); + dotEXT(i16vec2(0),u16vec2(0)); + dotEXT(u16vec2(0),i16vec2(0)); + dotAccSatEXT(i16vec2(0),i16vec2(0),i32); + dotAccSatEXT(u16vec2(0),u16vec2(0),ui32); + dotAccSatEXT(i16vec2(0),u16vec2(0),i32); + dotAccSatEXT(u16vec2(0),i16vec2(0),i32); + + // 32bit vec4 + dotEXT(i32vec4(0),i32vec4(0)); + dotEXT(u32vec4(0),u32vec4(0)); + dotEXT(i32vec4(0),u32vec4(0)); + dotEXT(u32vec4(0),i32vec4(0)); + dotAccSatEXT(i32vec4(0),i32vec4(0),i32); + dotAccSatEXT(u32vec4(0),u32vec4(0),ui32); + dotAccSatEXT(i32vec4(0),u32vec4(0),i32); + dotAccSatEXT(u32vec4(0),i32vec4(0),i32); + + // 32bit vec3 + dotEXT(i32vec3(0),i32vec3(0)); + dotEXT(u32vec3(0),u32vec3(0)); + dotEXT(i32vec3(0),u32vec3(0)); + dotEXT(u32vec3(0),i32vec3(0)); + dotAccSatEXT(i32vec3(0),i32vec3(0),i32); + dotAccSatEXT(u32vec3(0),u32vec3(0),ui32); + dotAccSatEXT(i32vec3(0),u32vec3(0),i32); + dotAccSatEXT(u32vec3(0),i32vec3(0),i32); + + // 32bit vec2 + dotEXT(i32vec2(0),i32vec2(0)); + dotEXT(u32vec2(0),u32vec2(0)); + dotEXT(i32vec2(0),u32vec2(0)); + dotEXT(u32vec2(0),i32vec2(0)); + dotAccSatEXT(i32vec2(0),i32vec2(0),i32); + dotAccSatEXT(u32vec2(0),u32vec2(0),ui32); + dotAccSatEXT(i32vec2(0),u32vec2(0),i32); + dotAccSatEXT(u32vec2(0),i32vec2(0),i32); + + // 64bit vec4 + dotEXT(i64vec4(0),i64vec4(0)); + dotEXT(u64vec4(0),u64vec4(0)); + dotEXT(i64vec4(0),u64vec4(0)); + dotEXT(u64vec4(0),i64vec4(0)); + dotAccSatEXT(i64vec4(0),i64vec4(0),i64); + dotAccSatEXT(u64vec4(0),u64vec4(0),ui64); + dotAccSatEXT(i64vec4(0),u64vec4(0),i64); + dotAccSatEXT(u64vec4(0),i64vec4(0),i64); + + // 64bit vec3 + dotEXT(i64vec3(0),i64vec3(0)); + dotEXT(u64vec3(0),u64vec3(0)); + dotEXT(i64vec3(0),u64vec3(0)); + dotEXT(u64vec3(0),i64vec3(0)); + dotAccSatEXT(i64vec3(0),i64vec3(0),i64); + dotAccSatEXT(u64vec3(0),u64vec3(0),ui64); + dotAccSatEXT(i64vec3(0),u64vec3(0),i64); + dotAccSatEXT(u64vec3(0),i64vec3(0),i64); + + // 64bit vec2 + dotEXT(i64vec2(0),i64vec2(0)); + dotEXT(u64vec2(0),u64vec2(0)); + dotEXT(i64vec2(0),u64vec2(0)); + dotEXT(u64vec2(0),i64vec2(0)); + dotAccSatEXT(i64vec2(0),i64vec2(0),i64); + dotAccSatEXT(u64vec2(0),u64vec2(0),ui64); + dotAccSatEXT(i64vec2(0),u64vec2(0),i64); + dotAccSatEXT(u64vec2(0),i64vec2(0),i64); +} diff --git a/Test/spv.int_dot_Error.frag b/Test/spv.int_dot_Error.frag new file mode 100644 index 0000000000..52f82c853f --- /dev/null +++ b/Test/spv.int_dot_Error.frag @@ -0,0 +1,28 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types: enable +#extension GL_EXT_integer_dot_product: enable + +int32_t i32 = 0; +int16_t i16 = int16_t(0); +uint32_t ui32 = 0; +uint16_t ui16 = uint16_t(0); +int64_t i64 = int64_t(0); +uint64_t ui64 = uint64_t(0); + + +void overload_errors() +{ + dotPacked4x8EXT(u8vec4(0),i32); + dotPacked4x8EXT(i32,i64); + dotPacked4x8EXT(i64,i32); + dotPacked4x8AccSatEXT(u8vec4(0),i32,i32); + dotPacked4x8AccSatEXT(i32,i32,i64); + dotEXT(i32,i32); + dotAccSatEXT(i32,i32,i32); + +} + +void main (void) +{ +} diff --git a/glslang/Include/intermediate.h b/glslang/Include/intermediate.h index e574b80d07..1294a6d26a 100644 --- a/glslang/Include/intermediate.h +++ b/glslang/Include/intermediate.h @@ -387,6 +387,11 @@ enum TOperator { EOpSubgroupPartitionedExclusiveXor, EOpSubgroupGuardStop, + + // Integer dot product + EOpDotPackedEXT, + EOpDotAccSatEXT, + EOpDotPackedAccSatEXT, EOpMinInvocations, EOpMaxInvocations, diff --git a/glslang/MachineIndependent/Constant.cpp b/glslang/MachineIndependent/Constant.cpp index 2878be8d1f..488ac81616 100644 --- a/glslang/MachineIndependent/Constant.cpp +++ b/glslang/MachineIndependent/Constant.cpp @@ -1119,6 +1119,9 @@ TIntermTyped* TIntermediate::fold(TIntermAggregate* aggrNode) break; } case EOpDot: + if (!children[0]->getAsTyped()->isFloatingDomain()) { + return aggrNode; + } newConstArray[0].setDConst(childConstUnions[0].dot(childConstUnions[1])); break; case EOpCross: diff --git a/glslang/MachineIndependent/Initialize.cpp b/glslang/MachineIndependent/Initialize.cpp index abab968f21..0fc75d885d 100644 --- a/glslang/MachineIndependent/Initialize.cpp +++ b/glslang/MachineIndependent/Initialize.cpp @@ -2082,6 +2082,143 @@ void TBuiltIns::initialize(int version, EProfile profile, const SpvVersion& spvV "\n"); } + // GL_EXT_integer_dot_product + if ((profile == EEsProfile && version >= 300) || + (profile != EEsProfile && version >= 450)) { + commonBuiltins.append( + + "uint dotEXT(uvec2 a, uvec2 b);" + "int dotEXT(ivec2 a, ivec2 b);" + "int dotEXT(ivec2 a, uvec2 b);" + "int dotEXT(uvec2 a, ivec2 b);" + + "uint dotEXT(uvec3 a, uvec3 b);" + "int dotEXT(ivec3 a, ivec3 b);" + "int dotEXT(ivec3 a, uvec3 b);" + "int dotEXT(uvec3 a, ivec3 b);" + + "uint dotEXT(uvec4 a, uvec4 b);" + "int dotEXT(ivec4 a, ivec4 b);" + "int dotEXT(ivec4 a, uvec4 b);" + "int dotEXT(uvec4 a, ivec4 b);" + + "uint dotPacked4x8EXT(uint a, uint b);" + "int dotPacked4x8EXT(int a, uint b);" + "int dotPacked4x8EXT(uint a, int b);" + "int dotPacked4x8EXT(int a, int b);" + + "uint dotEXT(u8vec2 a, u8vec2 b);" + "int dotEXT(i8vec2 a, u8vec2 b);" + "int dotEXT(u8vec2 a, i8vec2 b);" + "int dotEXT(i8vec2 a, i8vec2 b);" + + "uint dotEXT(u8vec3 a, u8vec3 b);" + "int dotEXT(i8vec3 a, u8vec3 b);" + "int dotEXT(u8vec3 a, i8vec3 b);" + "int dotEXT(i8vec3 a, i8vec3 b);" + + "uint dotEXT(u8vec4 a, u8vec4 b);" + "int dotEXT(i8vec4 a, u8vec4 b);" + "int dotEXT(u8vec4 a, i8vec4 b);" + "int dotEXT(i8vec4 a, i8vec4 b);" + + "uint dotEXT(u16vec2 a, u16vec2 b);" + "int dotEXT(i16vec2 a, u16vec2 b);" + "int dotEXT(u16vec2 a, i16vec2 b);" + "int dotEXT(i16vec2 a, i16vec2 b);" + + "uint dotEXT(u16vec3 a, u16vec3 b);" + "int dotEXT(i16vec3 a, u16vec3 b);" + "int dotEXT(u16vec3 a, i16vec3 b);" + "int dotEXT(i16vec3 a, i16vec3 b);" + + "uint dotEXT(u16vec4 a, u16vec4 b);" + "int dotEXT(i16vec4 a, u16vec4 b);" + "int dotEXT(u16vec4 a, i16vec4 b);" + "int dotEXT(i16vec4 a, i16vec4 b);" + + "uint64_t dotEXT(u64vec2 a, u64vec2 b);" + "int64_t dotEXT(i64vec2 a, u64vec2 b);" + "int64_t dotEXT(u64vec2 a, i64vec2 b);" + "int64_t dotEXT(i64vec2 a, i64vec2 b);" + + "uint64_t dotEXT(u64vec3 a, u64vec3 b);" + "int64_t dotEXT(i64vec3 a, u64vec3 b);" + "int64_t dotEXT(u64vec3 a, i64vec3 b);" + "int64_t dotEXT(i64vec3 a, i64vec3 b);" + + "uint64_t dotEXT(u64vec4 a, u64vec4 b);" + "int64_t dotEXT(i64vec4 a, u64vec4 b);" + "int64_t dotEXT(u64vec4 a, i64vec4 b);" + "int64_t dotEXT(i64vec4 a, i64vec4 b);" + + "uint dotAccSatEXT(uvec2 a, uvec2 b, uint c);" + "int dotAccSatEXT(ivec2 a, uvec2 b, int c);" + "int dotAccSatEXT(uvec2 a, ivec2 b, int c);" + "int dotAccSatEXT(ivec2 a, ivec2 b, int c);" + + "uint dotAccSatEXT(uvec3 a, uvec3 b, uint c);" + "int dotAccSatEXT(ivec3 a, uvec3 b, int c);" + "int dotAccSatEXT(uvec3 a, ivec3 b, int c);" + "int dotAccSatEXT(ivec3 a, ivec3 b, int c);" + + "uint dotAccSatEXT(uvec4 a, uvec4 b, uint c);" + "int dotAccSatEXT(ivec4 a, uvec4 b, int c);" + "int dotAccSatEXT(uvec4 a, ivec4 b, int c);" + "int dotAccSatEXT(ivec4 a, ivec4 b, int c);" + + "uint dotPacked4x8AccSatEXT(uint a, uint b, uint c);" + "int dotPacked4x8AccSatEXT(int a, uint b, int c);" + "int dotPacked4x8AccSatEXT(uint a, int b, int c);" + "int dotPacked4x8AccSatEXT(int a, int b, int c);" + + "uint dotAccSatEXT(u8vec2 a, u8vec2 b, uint c);" + "int dotAccSatEXT(i8vec2 a, u8vec2 b, int c);" + "int dotAccSatEXT(u8vec2 a, i8vec2 b, int c);" + "int dotAccSatEXT(i8vec2 a, i8vec2 b, int c);" + + "uint dotAccSatEXT(u8vec3 a, u8vec3 b, uint c);" + "int dotAccSatEXT(i8vec3 a, u8vec3 b, int c);" + "int dotAccSatEXT(u8vec3 a, i8vec3 b, int c);" + "int dotAccSatEXT(i8vec3 a, i8vec3 b, int c);" + + "uint dotAccSatEXT(u8vec4 a, u8vec4 b, uint c);" + "int dotAccSatEXT(i8vec4 a, u8vec4 b, int c);" + "int dotAccSatEXT(u8vec4 a, i8vec4 b, int c);" + "int dotAccSatEXT(i8vec4 a, i8vec4 b, int c);" + + "uint dotAccSatEXT(u16vec2 a, u16vec2 b, uint c);" + "int dotAccSatEXT(i16vec2 a, u16vec2 b, int c);" + "int dotAccSatEXT(u16vec2 a, i16vec2 b, int c);" + "int dotAccSatEXT(i16vec2 a, i16vec2 b, int c);" + + "uint dotAccSatEXT(u16vec3 a, u16vec3 b, uint c);" + "int dotAccSatEXT(i16vec3 a, u16vec3 b, int c);" + "int dotAccSatEXT(u16vec3 a, i16vec3 b, int c);" + "int dotAccSatEXT(i16vec3 a, i16vec3 b, int c);" + + "uint dotAccSatEXT(u16vec4 a, u16vec4 b, uint c);" + "int dotAccSatEXT(i16vec4 a, u16vec4 b, int c);" + "int dotAccSatEXT(u16vec4 a, i16vec4 b, int c);" + "int dotAccSatEXT(i16vec4 a, i16vec4 b, int c);" + + "uint64_t dotAccSatEXT(u64vec2 a, u64vec2 b, uint64_t c);" + "int64_t dotAccSatEXT(i64vec2 a, u64vec2 b, int64_t c);" + "int64_t dotAccSatEXT(u64vec2 a, i64vec2 b, int64_t c);" + "int64_t dotAccSatEXT(i64vec2 a, i64vec2 b, int64_t c);" + + "uint64_t dotAccSatEXT(u64vec3 a, u64vec3 b, uint64_t c);" + "int64_t dotAccSatEXT(i64vec3 a, u64vec3 b, int64_t c);" + "int64_t dotAccSatEXT(u64vec3 a, i64vec3 b, int64_t c);" + "int64_t dotAccSatEXT(i64vec3 a, i64vec3 b, int64_t c);" + + "uint64_t dotAccSatEXT(u64vec4 a, u64vec4 b, uint64_t c);" + "int64_t dotAccSatEXT(i64vec4 a, u64vec4 b, int64_t c);" + "int64_t dotAccSatEXT(u64vec4 a, i64vec4 b, int64_t c);" + "int64_t dotAccSatEXT(i64vec4 a, i64vec4 b, int64_t c);" + "\n"); + } + // GL_KHR_shader_subgroup if ((profile == EEsProfile && version >= 310) || (profile != EEsProfile && version >= 140)) { @@ -9144,6 +9281,15 @@ void TBuiltIns::identifyBuiltIns(int version, EProfile profile, const SpvVersion symbolTable.setFunctionExtensions("fetchMicroTriangleVertexPositionNV", 1, &E_GL_NV_displacement_micromap); symbolTable.setFunctionExtensions("fetchMicroTriangleVertexBarycentricNV", 1, &E_GL_NV_displacement_micromap); } + + // GL_EXT_integer_dot_product + if ((profile == EEsProfile && version >= 300) || + (profile != EEsProfile && version >= 450)) { + symbolTable.setFunctionExtensions("dotEXT", 1, &E_GL_EXT_integer_dot_product); + symbolTable.setFunctionExtensions("dotPacked4x8EXT", 1, &E_GL_EXT_integer_dot_product); + symbolTable.setFunctionExtensions("dotAccSatEXT", 1, &E_GL_EXT_integer_dot_product); + symbolTable.setFunctionExtensions("dotPacked4x8AccSatEXT", 1, &E_GL_EXT_integer_dot_product); + } break; case EShLangRayGen: @@ -10082,6 +10228,15 @@ void TBuiltIns::identifyBuiltIns(int version, EProfile profile, const SpvVersion symbolTable.relateToOperator("fragmentFetchAMD", EOpFragmentFetch); } + // GL_EXT_integer_dot_product + if ((profile == EEsProfile && version >= 300) || + (profile != EEsProfile && version >= 450)) { + symbolTable.relateToOperator("dotEXT", EOpDot); + symbolTable.relateToOperator("dotPacked4x8EXT", EOpDotPackedEXT); + symbolTable.relateToOperator("dotAccSatEXT", EOpDotAccSatEXT); + symbolTable.relateToOperator("dotPacked4x8AccSatEXT", EOpDotPackedAccSatEXT); + } + // GL_KHR_shader_subgroup if ((profile == EEsProfile && version >= 310) || (profile != EEsProfile && version >= 140)) { diff --git a/glslang/MachineIndependent/Versions.cpp b/glslang/MachineIndependent/Versions.cpp index cbfae6c72e..88b7b1c076 100644 --- a/glslang/MachineIndependent/Versions.cpp +++ b/glslang/MachineIndependent/Versions.cpp @@ -396,6 +396,8 @@ void TParseVersions::initializeExtensionBehavior() extensionBehavior[E_GL_EXT_shader_atomic_float] = EBhDisable; extensionBehavior[E_GL_EXT_shader_atomic_float2] = EBhDisable; + extensionBehavior[E_GL_EXT_integer_dot_product] = EBhDisable; + // Record extensions not for spv. spvUnsupportedExt.push_back(E_GL_ARB_bindless_texture); } @@ -603,6 +605,8 @@ void TParseVersions::getPreamble(std::string& preamble) "#define GL_EXT_texture_array 1\n" "#define GL_EXT_control_flow_attributes2 1\n" + + "#define GL_EXT_integer_dot_product 1\n" ; if (spvVersion.spv == 0) { diff --git a/glslang/MachineIndependent/Versions.h b/glslang/MachineIndependent/Versions.h index 4541381ab2..7c3e0f828b 100644 --- a/glslang/MachineIndependent/Versions.h +++ b/glslang/MachineIndependent/Versions.h @@ -348,6 +348,8 @@ const char* const E_GL_EXT_shader_tile_image = "GL_EXT_shader_tile_image"; const char* const E_GL_EXT_texture_shadow_lod = "GL_EXT_texture_shadow_lod"; +const char* const E_GL_EXT_integer_dot_product = "GL_EXT_integer_dot_product"; + // Arrays of extensions for the above AEP duplications const char* const AEP_geometry_shader[] = { E_GL_EXT_geometry_shader, E_GL_OES_geometry_shader }; diff --git a/glslang/MachineIndependent/intermOut.cpp b/glslang/MachineIndependent/intermOut.cpp index ccfff38c4b..750ba7095f 100644 --- a/glslang/MachineIndependent/intermOut.cpp +++ b/glslang/MachineIndependent/intermOut.cpp @@ -667,6 +667,9 @@ bool TOutputTraverser::visitAggregate(TVisit /* visit */, TIntermAggregate* node case EOpDistance: out.debug << "distance"; break; case EOpDot: out.debug << "dot-product"; break; + case EOpDotPackedEXT: out.debug << "dot-product-packed";break; + case EOpDotAccSatEXT: out.debug << "dot-product-accumulate-saturate";break; + case EOpDotPackedAccSatEXT: out.debug << "dot-product-packed-accumulate-saturate";break; case EOpCross: out.debug << "cross-product"; break; case EOpFaceForward: out.debug << "face-forward"; break; case EOpReflect: out.debug << "reflect"; break; diff --git a/glslang/MachineIndependent/propagateNoContraction.cpp b/glslang/MachineIndependent/propagateNoContraction.cpp index 7b5cd03fa6..600541f613 100644 --- a/glslang/MachineIndependent/propagateNoContraction.cpp +++ b/glslang/MachineIndependent/propagateNoContraction.cpp @@ -174,6 +174,9 @@ bool isArithmeticOperation(glslang::TOperator op) case glslang::EOpMatrixTimesMatrix: case glslang::EOpDot: + case glslang::EOpDotPackedEXT: + case glslang::EOpDotAccSatEXT: + case glslang::EOpDotPackedAccSatEXT: case glslang::EOpPostIncrement: case glslang::EOpPostDecrement: diff --git a/gtests/Spv.FromFile.cpp b/gtests/Spv.FromFile.cpp index b75b6d8a52..cf52bd8993 100644 --- a/gtests/Spv.FromFile.cpp +++ b/gtests/Spv.FromFile.cpp @@ -418,6 +418,8 @@ INSTANTIATE_TEST_SUITE_P( "spv.GeometryShaderPassthrough.geom", "spv.funcall.array.frag", "spv.load.bool.array.interface.block.frag", + "spv.int_dot.frag", + "spv.int_dot_Error.frag", "spv.interpOps.frag", "spv.int64.frag", "spv.intcoopmat.comp",