Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use UMad for dot products on uint vectors #7059

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/dxc/HlslIntrinsicOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ enum class IntrinsicOp {
IOP_WavePrefixUSum,
IOP_uabs,
IOP_uclamp,
IOP_udot,
IOP_ufirstbithigh,
IOP_umad,
IOP_umax,
Expand All @@ -391,6 +392,7 @@ inline bool HasUnsignedIntrinsicOpcode(IntrinsicOp opcode) {
case IntrinsicOp::IOP_WavePrefixSum:
case IntrinsicOp::IOP_abs:
case IntrinsicOp::IOP_clamp:
case IntrinsicOp::IOP_dot:
case IntrinsicOp::IOP_firstbithigh:
case IntrinsicOp::IOP_mad:
case IntrinsicOp::IOP_max:
Expand Down Expand Up @@ -432,6 +434,8 @@ inline unsigned GetUnsignedIntrinsicOpcode(IntrinsicOp opcode) {
return static_cast<unsigned>(IntrinsicOp::IOP_uabs);
case IntrinsicOp::IOP_clamp:
return static_cast<unsigned>(IntrinsicOp::IOP_uclamp);
case IntrinsicOp::IOP_dot:
return static_cast<unsigned>(IntrinsicOp::IOP_udot);
case IntrinsicOp::IOP_firstbithigh:
return static_cast<unsigned>(IntrinsicOp::IOP_ufirstbithigh);
case IntrinsicOp::IOP_mad:
Expand Down
4 changes: 3 additions & 1 deletion lib/HLSL/HLOperationLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2480,7 +2480,8 @@ Value *TranslateDot(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
if (Ty->getScalarType()->isFloatingPointTy()) {
return TranslateFDot(arg0, arg1, vecSize, hlslOP, Builder);
} else {
return TranslateIDot(arg0, arg1, vecSize, hlslOP, Builder);
return TranslateIDot(arg0, arg1, vecSize, hlslOP, Builder,
IOP == IntrinsicOp::IOP_udot);
}
}

Expand Down Expand Up @@ -6789,6 +6790,7 @@ IntrinsicLower gLowerTable[] = {
DXIL::OpCode::WavePrefixOp},
{IntrinsicOp::IOP_uabs, TranslateUAbs, DXIL::OpCode::NumOpCodes},
{IntrinsicOp::IOP_uclamp, TranslateClamp, DXIL::OpCode::NumOpCodes},
{IntrinsicOp::IOP_udot, TranslateDot, DXIL::OpCode::NumOpCodes},
{IntrinsicOp::IOP_ufirstbithigh, TranslateFirstbitHi,
DXIL::OpCode::FirstbitHi},
{IntrinsicOp::IOP_umad, TranslateFUITrinary, DXIL::OpCode::UMad},
Expand Down
1 change: 1 addition & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8953,6 +8953,7 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
return nullptr;
}
case hlsl::IntrinsicOp::IOP_dot:
case hlsl::IntrinsicOp::IOP_udot:
retVal = processIntrinsicDot(callExpr);
break;
case hlsl::IntrinsicOp::IOP_GroupMemoryBarrier:
Expand Down
135 changes: 135 additions & 0 deletions tools/clang/test/CodeGenHLSL/dot.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// RUN: %dxc -T vs_6_0 -DFUNC=dot %s | FileCheck %s
// RUN: %dxc -T vs_6_0 -DFUNC=mul %s | FileCheck %s

// Verifies correct implementation of dot and mul with vectors for various sizes and types.

// Partially pilfered from SPIRV's intrinsic.dot.hlsl

float4 main(int1 i1[2] : IO, int2 i2[2] : IT, int3 i3[2] : IH, int4 i4[2] : IF,
float1 f1[2] : FO, float2 f2[2] : FT, float3 f3[2] : FH, float4 f4[2] : FF,
uint1 u1[2] : UO, uint2 u2[2] : UT, uint3 u3[2] : UH, uint4 u4[2] : UF) : SV_Position {
int i = 0;
// CHECK-DAG: [[I0:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 0, i32 0, i8 0, i32 undef)
// CHECK-DAG: [[I1:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 0, i32 1, i8 0, i32 undef)
// CHECK: mul i32 [[I0]], [[I1]]
i += FUNC(i1[0], i1[1]);

// CHECK-DAG: [[I00:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 1, i32 0, i8 0, i32 undef)
// CHECK-DAG: [[I01:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 1, i32 0, i8 1, i32 undef)
// CHECK-DAG: [[I10:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 1, i32 1, i8 0, i32 undef)
// CHECK-DAG: [[I11:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 1, i32 1, i8 1, i32 undef)

// CHECK: [[MUL:%.*]] = mul i32 [[I00]], [[I10]]
// CHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 [[I01]], i32 [[I11]], i32 [[MUL]]) ; IMad(a,b,c)
i += FUNC(i2[0], i2[1]);

// CHECK-DAG: [[I00:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 2, i32 0, i8 0, i32 undef)
// CHECK-DAG: [[I01:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 2, i32 0, i8 1, i32 undef)
// CHECK-DAG: [[I02:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 2, i32 0, i8 2, i32 undef)
// CHECK-DAG: [[I10:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 2, i32 1, i8 0, i32 undef)
// CHECK-DAG: [[I11:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 2, i32 1, i8 1, i32 undef)
// CHECK-DAG: [[I12:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 2, i32 1, i8 2, i32 undef)

// PING and PONG are just conveniences to track the result as it accumulates.
// Since we can't capture and match the source and result in the same line with the same variable.
// CHECK: [[PING:%.*]] = mul i32 [[I00]], [[I10]]
// CHECK: [[PONG:%.*]] = call i32 @dx.op.tertiary.i32(i32 48, i32 [[I01]], i32 [[I11]], i32 [[PING]]) ; IMad(a,b,c)
// CHECK: [[PING:%.*]] = call i32 @dx.op.tertiary.i32(i32 48, i32 [[I02]], i32 [[I12]], i32 [[PONG]]) ; IMad(a,b,c)
i += FUNC(i3[0], i3[1]);

// CHECK-DAG: [[I00:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 3, i32 0, i8 0, i32 undef)
// CHECK-DAG: [[I01:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 3, i32 0, i8 1, i32 undef)
// CHECK-DAG: [[I02:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 3, i32 0, i8 2, i32 undef)
// CHECK-DAG: [[I03:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 3, i32 0, i8 3, i32 undef)
// CHECK-DAG: [[I10:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 3, i32 1, i8 0, i32 undef)
// CHECK-DAG: [[I11:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 3, i32 1, i8 1, i32 undef)
// CHECK-DAG: [[I12:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 3, i32 1, i8 2, i32 undef)
// CHECK-DAG: [[I13:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 3, i32 1, i8 3, i32 undef)

// CHECK: [[PING:%.*]] = mul i32 [[I00]], [[I10]]
// CHECK: [[PONG:%.*]] = call i32 @dx.op.tertiary.i32(i32 48, i32 [[I01]], i32 [[I11]], i32 [[PING]]) ; IMad(a,b,c)
// CHECK: [[PING:%.*]] = call i32 @dx.op.tertiary.i32(i32 48, i32 [[I02]], i32 [[I12]], i32 [[PONG]]) ; IMad(a,b,c)
// CHECK: [[PONG:%.*]] = call i32 @dx.op.tertiary.i32(i32 48, i32 [[I03]], i32 [[I13]], i32 [[PING]]) ; IMad(a,b,c)
i += FUNC(i4[0], i4[1]);

float f = 0.0;

// CHECK-DAG: [[F0:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 4, i32 0, i8 0, i32 undef)
// CHECK-DAG: [[F1:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 4, i32 1, i8 0, i32 undef)
// CHECK: mul fast float [[F0]], [[F1]]
f += FUNC(f1[0], f1[1]);

// CHECK-DAG: [[F00:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 5, i32 0, i8 0, i32 undef)
// CHECK-DAG: [[F01:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 5, i32 0, i8 1, i32 undef)
// CHECK-DAG: [[F10:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 5, i32 1, i8 0, i32 undef)
// CHECK-DAG: [[F11:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 5, i32 1, i8 1, i32 undef)

// CHECK: call float @dx.op.dot2.f32(i32 54, float [[F00]], float [[F01]], float [[F10]], float [[F11]])
f += FUNC(f2[0], f2[1]);

// CHECK-DAG: [[F00:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 6, i32 0, i8 0, i32 undef)
// CHECK-DAG: [[F01:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 6, i32 0, i8 1, i32 undef)
// CHECK-DAG: [[F02:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 6, i32 0, i8 2, i32 undef)
// CHECK-DAG: [[F10:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 6, i32 1, i8 0, i32 undef)
// CHECK-DAG: [[F11:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 6, i32 1, i8 1, i32 undef)
// CHECK-DAG: [[F12:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 6, i32 1, i8 2, i32 undef)

// CHECK: call float @dx.op.dot3.f32(i32 55, float [[F00]], float [[F01]], float [[F02]], float [[F10]], float [[F11]], float [[F12]])
f += FUNC(f3[0], f3[1]);

// CHECK-DAG: [[F00:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 7, i32 0, i8 0, i32 undef)
// CHECK-DAG: [[F01:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 7, i32 0, i8 1, i32 undef)
// CHECK-DAG: [[F02:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 7, i32 0, i8 2, i32 undef)
// CHECK-DAG: [[F03:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 7, i32 0, i8 3, i32 undef)
// CHECK-DAG: [[F10:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 7, i32 1, i8 0, i32 undef)
// CHECK-DAG: [[F11:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 7, i32 1, i8 1, i32 undef)
// CHECK-DAG: [[F12:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 7, i32 1, i8 2, i32 undef)
// CHECK-DAG: [[F13:%.*]] = call float @dx.op.loadInput.f32(i32 4, i32 7, i32 1, i8 3, i32 undef)

// CHECK: call float @dx.op.dot4.f32(i32 56, float [[F00]], float [[F01]], float [[F02]], float [[F03]], float [[F10]], float [[F11]], float [[F12]], float [[F13]])
f += FUNC(f4[0], f4[1]);

int u = 0;
// CHECK-DAG: [[I0:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 8, i32 0, i8 0, i32 undef)
// CHECK-DAG: [[I1:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 8, i32 1, i8 0, i32 undef)
// CHECK: mul i32 [[I0]], [[I1]]
u += FUNC(u1[0], u1[1]);

// CHECK-DAG: [[I00:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 9, i32 0, i8 0, i32 undef)
// CHECK-DAG: [[I01:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 9, i32 0, i8 1, i32 undef)
// CHECK-DAG: [[I10:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 9, i32 1, i8 0, i32 undef)
// CHECK-DAG: [[I11:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 9, i32 1, i8 1, i32 undef)

// CHECK: [[MUL:%.*]] = mul i32 [[I00]], [[I10]]
// CHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 [[I01]], i32 [[I11]], i32 [[MUL]]) ; UMad(a,b,c)
u += FUNC(u2[0], u2[1]);

// CHECK-DAG: [[I00:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 10, i32 0, i8 0, i32 undef)
// CHECK-DAG: [[I01:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 10, i32 0, i8 1, i32 undef)
// CHECK-DAG: [[I02:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 10, i32 0, i8 2, i32 undef)
// CHECK-DAG: [[I10:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 10, i32 1, i8 0, i32 undef)
// CHECK-DAG: [[I11:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 10, i32 1, i8 1, i32 undef)
// CHECK-DAG: [[I12:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 10, i32 1, i8 2, i32 undef)

// CHECK: [[PING:%.*]] = mul i32 [[I00]], [[I10]]
// CHECK: [[PONG:%.*]] = call i32 @dx.op.tertiary.i32(i32 49, i32 [[I01]], i32 [[I11]], i32 [[PING]]) ; UMad(a,b,c)
// CHECK: [[PING:%.*]] = call i32 @dx.op.tertiary.i32(i32 49, i32 [[I02]], i32 [[I12]], i32 [[PONG]]) ; UMad(a,b,c)
u += FUNC(u3[0], u3[1]);

// CHECK-DAG: [[I00:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 11, i32 0, i8 0, i32 undef)
// CHECK-DAG: [[I01:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 11, i32 0, i8 1, i32 undef)
// CHECK-DAG: [[I02:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 11, i32 0, i8 2, i32 undef)
// CHECK-DAG: [[I03:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 11, i32 0, i8 3, i32 undef)
// CHECK-DAG: [[I10:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 11, i32 1, i8 0, i32 undef)
// CHECK-DAG: [[I11:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 11, i32 1, i8 1, i32 undef)
// CHECK-DAG: [[I12:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 11, i32 1, i8 2, i32 undef)
// CHECK-DAG: [[I13:%.*]] = call i32 @dx.op.loadInput.i32(i32 4, i32 11, i32 1, i8 3, i32 undef)

// CHECK: [[PING:%.*]] = mul i32 [[I00]], [[I10]]
// CHECK: [[PONG:%.*]] = call i32 @dx.op.tertiary.i32(i32 49, i32 [[I01]], i32 [[I11]], i32 [[PING]]) ; UMad(a,b,c)
// CHECK: [[PING:%.*]] = call i32 @dx.op.tertiary.i32(i32 49, i32 [[I02]], i32 [[I12]], i32 [[PONG]]) ; UMad(a,b,c)
// CHECK: [[PONG:%.*]] = call i32 @dx.op.tertiary.i32(i32 49, i32 [[I03]], i32 [[I13]], i32 [[PING]]) ; UMad(a,b,c)
u += FUNC(u4[0], u4[1]);

return float4(i, f, u, 0);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// CHECK: bufferLoad
// CHECK: FMax
// CHECK: FMin
// CHECK: IMad
// CHECK: UMad
// CHECK: bufferStore

//--------------------------------------------------------------------------------------
Expand Down
14 changes: 7 additions & 7 deletions utils/hct/gen_intrin_main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ $match<0, 1> float_like [[rn]] determinant(in float_like<r, r> x);
void [[]] DeviceMemoryBarrier() : syncdevicememory_ug;
void [[]] DeviceMemoryBarrierWithGroupSync() : syncgroupanddevicememory_ug;
$match<0, 1> float_like [[rn]] distance(in float_like<c> a, in $type1 b);
$match<0, 1> numeric [[rn]] dot(in numeric<c> a, in $type1 b);
$match<0, 1> numeric [[rn,unsigned_op=udot]] dot(in numeric<c> a, in $type1 b);
$type1 [[rn]] dst(in numeric<4> a, in $type1 b);
// void errorf(in string Format, ...);
$type1 [[rn]] EvaluateAttributeAtSample(in numeric<> value, in uint index);
Expand Down Expand Up @@ -198,13 +198,13 @@ $type1 [[rn,unsigned_op=umax]] max(in numeric<> a, in $type1 b);
$type1 [[rn,unsigned_op=umin]] min(in numeric<> a, in $type1 b);
$type1 [[]] modf(in float_like<> x, out $type1 ip);
uint<4> [[rn]] msad4(in uint reference, in uint<2> source, in uint<4> accum);
numeric [[rn]] mul(in $match<1, 0> numeric a, in $match<2, 0> numeric b) : mul_ss;
numeric<c2> [[rn]] mul(in $match<1, 0> numeric a, in $match<2, 0> numeric<c2> b) : mul_sv;
numeric<r2, c2> [[rn]] mul(in $match<1, 0> numeric a, in $match<2, 0> numeric<r2, c2> b) : mul_sm;
numeric<c> [[rn]] mul(in $match<1, 0> numeric<c> a, in $match<2, 0> numeric b) : mul_vs;
numeric [[rn]] mul(in $match<1, 0> numeric<c> a, in $match<2, 0> numeric<c> b) : mul_vv;
numeric [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric a, in $match<2, 0> numeric b) : mul_ss;
numeric<c2> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric a, in $match<2, 0> numeric<c2> b) : mul_sv;
numeric<r2, c2> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric a, in $match<2, 0> numeric<r2, c2> b) : mul_sm;
numeric<c> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric<c> a, in $match<2, 0> numeric b) : mul_vs;
numeric [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric<c> a, in $match<2, 0> numeric<c> b) : mul_vv;
numeric<c2> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric<c> a, in col_major $match<2, 0> numeric<c, c2> b) : mul_vm;
numeric<r, c> [[rn]] mul(in $match<1, 0> numeric<r, c> a, in $match<2, 0> numeric b) : mul_ms;
numeric<r, c> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric<r, c> a, in $match<2, 0> numeric b) : mul_ms;
numeric<r> [[rn,unsigned_op=umul]] mul(in row_major $match<1, 0> numeric<r, c> a, in $match<2, 0> numeric<c> b) : mul_mv;
numeric<r, c2> [[rn,unsigned_op=umul]] mul(in row_major $match<1, 0> numeric<r, c> a, in col_major $match<2, 0> numeric<c, c2> b) : mul_mm;
$type1 [[rn]] normalize(in float_like<c> x);
Expand Down
Loading