Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiShift, Masked shift and masked interleave #2431

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,49 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
(a[i] << ((sizeof(T)*8 - b[i]) & shift_amt_mask))`, where `shift_amt_mask` is
equal to `sizeof(T)*8 - 1`.

A compound shift on 64-bit values:

* `V`: `{u,i}64`, `VI`: `{u,i}8` \
<code>V **MultiShift**(V vals, VI indices)</code>: returns a
vector with `(vals[i] >> indices[i*8+j]) & 0xff` in byte `j` of `r[i]` for each
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's define r, for example "vector r".

`j` between 0 and 7.

If `indices[i*8+j]` is less than 0 or greater than 63, byte `j` of `r[i]` is
implementation-defined.

`VI` must be either `Vec<Repartition<int8_t, DFromV<V>>>` or
`Vec<Repartition<uint8_t, DFromV<V>>>`.

`MultiShift(V vals, VI indices)` is equivalent to the following loop (where `N` is
equal to `Lanes(DFromV<V>())`):
```
for(size_t i = 0; i < N; i++) {
uint64_t shift_result = 0;
for(int j = 0; j < 8; j++) {
uint64_t rot_result = (v[i] >> indices[i*8+j]) | (v[i] << (64 - indices[i*8+j]));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we're only using the lower 8 bits, we can just talk about the right-shift and not a rotate, right?

shift_result |= (rot_result & 0xff) << (j * 8);
}
r[i] = shift_result;
}
```

#### Masked Shifts
* `V`: `{u,i}` \
<code>V **MaskedShiftLeftOrZero**&lt;int&gt;(M mask, V a)</code> returns `a[i] << int` or `0` if
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As before, let's drop the OrZero suffix?

`mask[i]` is false.

* `V`: `{u,i}` \
<code>V **MaskedShiftRightOrZero**&lt;int&gt;(M mask, V a)</code> returns `a[i] >> int` or `0` if
`mask[i]` is false.

* `V`: `{u,i}` \
<code>V **MaskedShiftRightOr**&lt;int&gt;(V no, M mask, V a)</code> returns `a[i] >> int` or `no` if
`mask[i]` is false.

* `V`: `{u,i}` \
<code>V **MaskedShrOr**(V no, M mask, V a, V shifts)</code> returns `a[i] >> shifts[i]` or `no` if
`mask[i]` is false.

#### Floating-point rounding

* `V`: `{f}` \
Expand Down Expand Up @@ -2081,6 +2124,12 @@ Ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:
`InterleaveOdd(d, a, b)` is usually more efficient than `OddEven(b,
DupOdd(a))`.

* <code>V **InterleaveEvenOrZero**(M m, V a, V b)</code>: Performs the same
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Masked prefix and remove OrZero suffix? I notice this doesn't have an SVE implementation yet?

operation as InterleaveEven, but returns zero in lanes where `m[i]` is false.

* <code>V **InterleaveOddOrZero**(M m, V a, V b)</code>: Performs the same
operation as InterleaveOdd, but returns zero in lanes where `m[i]` is false.

#### Zip

* `Ret`: `MakeWide<T>`; `V`: `{u,i}{8,16,32}` \
Expand Down
29 changes: 29 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,35 @@ HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n)

#undef HWY_SVE_SHIFT_N

// ------------------------------ MaskedShift[Left/Right]SameOrZero

#define HWY_SVE_SHIFT_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
template <int kBits> \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \
auto shifts = static_cast<HWY_SVE_T(uint, BITS)>(kBits); \
return sv##OP##_##CHAR##BITS##_z(m, v, shifts); \
}
HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_Z, MaskedShiftLeftOrZero, lsl_n)
HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_Z, MaskedShiftRightOrZero, asr_n)
HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_Z, MaskedShiftRightOrZero, lsr_n)

#undef HWY_SVE_SHIFT_Z

// ------------------------------ MaskedShiftRightSameOr

#define HWY_SVE_SHIFT_OR(BASE, CHAR, BITS, HALF, NAME, OP) \
template <int kBits> \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_V(BASE, BITS) v) { \
auto shifts = static_cast<HWY_SVE_T(uint, BITS)>(kBits); \
return svsel##_##CHAR##BITS(m, sv##OP##_##CHAR##BITS##_z(m, v, shifts), \
no); \
}
HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_OR, MaskedShiftRightOr, asr_n)
HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_OR, MaskedShiftRightOr, lsr_n)

#undef HWY_SVE_SHIFT_OR

// ------------------------------ RotateRight

#if HWY_SVE_HAVE_2
Expand Down
95 changes: 95 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,24 @@ HWY_API V InterleaveEven(V a, V b) {
}
#endif

// ------------------------------ InterleaveEvenOrZero

#if HWY_TARGET != HWY_SCALAR || HWY_IDE
template <class V, class M>
HWY_API V InterleaveEvenOrZero(M m, V a, V b) {
return IfThenElseZero(m, InterleaveEven(DFromV<V>(), a, b));
}
#endif

// ------------------------------ InterleaveOddOrZero

#if HWY_TARGET != HWY_SCALAR || HWY_IDE
template <class V, class M>
HWY_API V InterleaveOddOrZero(M m, V a, V b) {
return IfThenElseZero(m, InterleaveOdd(DFromV<V>(), a, b));
}
#endif

// ------------------------------ AddSub

template <class V, HWY_IF_LANES_D(DFromV<V>, 1)>
Expand Down Expand Up @@ -574,6 +592,27 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
}
#endif // HWY_NATIVE_MASKED_ARITH

// ------------------------------ MaskedShift
template <int kshift, class V, class M>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: please rename kshift -> kShift.

HWY_API V MaskedShiftLeftOrZero(M m, V a) {
return IfThenElseZero(m, ShiftLeft<kshift>(a));
}

template <int kshift, class V, class M>
HWY_API V MaskedShiftRightOrZero(M m, V a) {
return IfThenElseZero(m, ShiftRight<kshift>(a));
}

template <int kshift, class V, class M>
HWY_API V MaskedShiftRightOr(V no, M m, V a) {
return IfThenElse(m, ShiftRight<kshift>(a), no);
}

template <class V, class M>
HWY_API V MaskedShrOr(V no, M m, V a, V shifts) {
return IfThenElse(m, Shr(a, shifts), no);
}

// ------------------------------ IfNegativeThenNegOrUndefIfZero

#if (defined(HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG) == \
Expand Down Expand Up @@ -7299,6 +7338,62 @@ HWY_API V BitShuffle(V v, VI idx) {

#endif // HWY_NATIVE_BITSHUFFLE

// ------------------------------ MultiShift (Rol)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ops in parentheses are ops that the implementation of the one in this section uses, for purposes of sorting the ops in the source file. I think this is a copy-paste remnant? We can remove it because this implementation does not seem to use any ops defined in generic_ops-inl.h.

#if (defined(HWY_NATIVE_MULTISHIFT) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MULTISHIFT
#undef HWY_NATIVE_MULTISHIFT
#else
#define HWY_NATIVE_MULTISHIFT
#endif

template <class V, class VI, HWY_IF_UI64(TFromV<V>), HWY_IF_UI8(TFromV<VI>)>
HWY_API V MultiShift(V v, VI idx) {
const DFromV<V> d64;
const Repartition<uint8_t, decltype(d64)> du8;
const Repartition<uint16_t, decltype(d64)> du16;
const auto k7 = Set(du8, uint8_t{0x07});
const auto k63 = Set(du8, uint8_t{0x3F});

const auto masked_idx = And(k63, BitCast(du8, idx));
const auto byte_idx = ShiftRight<3>(masked_idx);
const auto idx_shift = And(k7, masked_idx);

// Calculate even lanes
const auto even_src = DupEven(v);
// Expand indexes to pull out 16 bit segments of idx and idx + 1
const auto even_idx =
InterleaveLower(byte_idx, Add(byte_idx, Set(du8, uint8_t{1})));
// TableLookupBytes indexes select from within a 16 byte block
const auto even_segments = TableLookupBytes(even_src, even_idx);
// Extract unaligned bytes from 16 bit segments
const auto even_idx_shift = ZipLower(idx_shift, Zero(du8));
const auto extracted_even_bytes =
Shr(BitCast(du16, even_segments), even_idx_shift);

// Calculate odd lanes
const auto odd_src = DupOdd(v);
// Expand indexes to pull out 16 bit segments of idx and idx + 1
const auto odd_idx =
InterleaveUpper(du8, byte_idx, Add(byte_idx, Set(du8, uint8_t{1})));
// TableLookupBytes indexes select from within a 16 byte block
const auto odd_segments = TableLookupBytes(odd_src, odd_idx);
// Extract unaligned bytes from 16 bit segments
const auto odd_idx_shift = ZipUpper(du16, idx_shift, Zero(du8));
const auto extracted_odd_bytes =
Shr(BitCast(du16, odd_segments), odd_idx_shift);

// Extract the even bytes of each 128 bit block and pack into lower 64 bits
const auto extract_mask = Dup128VecFromValues(du8, 0, 2, 4, 6, 8, 10, 12, 14,
0, 0, 0, 0, 0, 0, 0, 0);
const auto even_lanes =
BitCast(d64, TableLookupBytes(extracted_even_bytes, extract_mask));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could use ConcatEven here?

const auto odd_lanes =
BitCast(d64, TableLookupBytes(extracted_odd_bytes, extract_mask));
// Interleave at 64 bit level
return InterleaveLower(even_lanes, odd_lanes);
}

#endif
// ================================================== Operator wrapper

// SVE* and RVV currently cannot define operators and have already defined
Expand Down
60 changes: 60 additions & 0 deletions hwy/tests/blockwise_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,72 @@ struct TestInterleaveOdd {
}
};

struct TestMaskedInterleaveEven {
template <class T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const size_t N = Lanes(d);
const MFromD<D> first_3 = FirstN(d, 3);
auto even_lanes = AllocateAligned<T>(N);
auto odd_lanes = AllocateAligned<T>(N);
auto expected = AllocateAligned<T>(N);
HWY_ASSERT(even_lanes && odd_lanes && expected);
for (size_t i = 0; i < N; ++i) {
even_lanes[i] = ConvertScalarTo<T>(2 * i + 0);
odd_lanes[i] = ConvertScalarTo<T>(2 * i + 1);
}
const auto even = Load(d, even_lanes.get());
const auto odd = Load(d, odd_lanes.get());

for (size_t i = 0; i < N; ++i) {
if (i < 3) {
expected[i] = ConvertScalarTo<T>(2 * i - (i & 1));
} else {
expected[i] = ConvertScalarTo<T>(0);
}
}

HWY_ASSERT_VEC_EQ(d, expected.get(),
InterleaveEvenOrZero(first_3, even, odd));
}
};

struct TestMaskedInterleaveOdd {
template <class T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const size_t N = Lanes(d);
const MFromD<D> first_3 = FirstN(d, 3);
auto even_lanes = AllocateAligned<T>(N);
auto odd_lanes = AllocateAligned<T>(N);
auto expected = AllocateAligned<T>(N);
HWY_ASSERT(even_lanes && odd_lanes && expected);
for (size_t i = 0; i < N; ++i) {
even_lanes[i] = ConvertScalarTo<T>(2 * i + 0);
odd_lanes[i] = ConvertScalarTo<T>(2 * i + 1);
}
const auto even = Load(d, even_lanes.get());
const auto odd = Load(d, odd_lanes.get());

for (size_t i = 0; i < N; ++i) {
if (i < 3) {
expected[i] = ConvertScalarTo<T>((2 * i) - (i & 1) + 2);
} else {
expected[i] = ConvertScalarTo<T>(0);
}
}

HWY_ASSERT_VEC_EQ(d, expected.get(),
InterleaveOddOrZero(first_3, even, odd));
}
};

HWY_NOINLINE void TestAllInterleave() {
// Not DemoteVectors because this cannot be supported by HWY_SCALAR.
ForAllTypes(ForShrinkableVectors<TestInterleaveLower>());
ForAllTypes(ForShrinkableVectors<TestInterleaveUpper>());
ForAllTypes(ForShrinkableVectors<TestInterleaveEven>());
ForAllTypes(ForShrinkableVectors<TestInterleaveOdd>());
ForAllTypes(ForShrinkableVectors<TestMaskedInterleaveEven>());
ForAllTypes(ForShrinkableVectors<TestMaskedInterleaveOdd>());
}

struct TestZipLower {
Expand Down
Loading