Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved swap_pair into its callable #2044

Merged
merged 6 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/eve/module/core/detail/simd/x86/basic_shuffle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ basic_shuffle_(EVE_SUPPORTS(avx_),
[](auto i, auto c)
{
Pattern r;
return (i < c / 2 ? r(i, c) : r(i - c / 2, c)) << 1;
return (r(i,c) % 2) << 1;
});

auto const m = as_indexes<wide<T, N>>(fixed_pattern);
Expand Down
6 changes: 3 additions & 3 deletions include/eve/module/core/regular/bit_shr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ namespace eve
template<typename Options>
struct bit_shr_t : strict_elementwise_callable<bit_shr_t, Options>
{
template<eve::value T, integral_value S>
constexpr EVE_FORCEINLINE as_wide_as_t<T, S> operator()(T v, S s) const
template<integral_value T, integral_value S>
constexpr EVE_FORCEINLINE as_wide_as_t<T, S> operator()(T v, S s) const
{
return EVE_DISPATCH_CALL(v, s);
}

template<eve::integral_value T, std::ptrdiff_t S>
template<integral_value T, std::ptrdiff_t S>
constexpr EVE_FORCEINLINE T operator()(T v, index_t<S> s) const
{
constexpr std::ptrdiff_t l = sizeof(element_type_t<T>) * 8;
Expand Down
45 changes: 38 additions & 7 deletions include/eve/module/core/regular/bit_swap_pairs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,34 @@
#include <eve/module/core/regular/bit_cast.hpp>
#include <eve/module/core/regular/bit_xor.hpp>
#include <eve/module/core/regular/bit_shl.hpp>
#include <eve/module/core/regular/convert.hpp>
#include <eve/module/core/constant/one.hpp>
#include <eve/traits/max_lanes.hpp>

namespace eve
{

template<typename Options>
struct bit_swap_pairs_t : strict_elementwise_callable<bit_swap_pairs_t, Options>
{
template<eve::integral_value T, integral_value I0, integral_value I1>
constexpr EVE_FORCEINLINE T operator()(T v, I0 i0, I1 i1) const noexcept
template<typename T, typename I0, typename I1>
struct result
{
using type = std::conditional_t<scalar_value<T> && scalar_value<I0> && scalar_value<I1>, T, as_wide_t<T, max_lanes_t<T, I0, I1>>>;
};

template<integral_value T, integral_value I0, integral_value I1>
EVE_FORCEINLINE constexpr typename result<T, I0, I1>::type operator()(T v, I0 i0, I1 i1) const noexcept
requires same_lanes_or_scalar<T, I0, I1>
{
return EVE_DISPATCH_CALL(v, i0, i1);
}

template<integral_value T, std::ptrdiff_t I0, std::ptrdiff_t I1>
EVE_FORCEINLINE constexpr T operator()(T a, index_t<I0> i0, index_t<I1> i1) const noexcept
{
return EVE_DISPATCH_CALL(a, i0, i1);
}

EVE_CALLABLE_OBJECT(bit_swap_pairs_t, bit_swap_pairs_);
};
Expand Down Expand Up @@ -83,16 +98,24 @@ namespace eve
namespace detail
{
template<callable_options O, conditional_expr C, value T, integral_value I0, integral_value I1>
constexpr T bit_swap_pairs_(EVE_REQUIRES(cpu_), C const& cx, O const&, T a, I0 i0, I1 i1) noexcept
constexpr auto bit_swap_pairs_(EVE_REQUIRES(cpu_), C const& cx, O const&, T a, I0 i0, I1 i1) noexcept
{
auto i0m = if_else(cx, i0, zero);
auto i1m = if_else(cx, i1, zero);
if constexpr (scalar_value<T> && scalar_value<I0> && scalar_value<I1>)
{
return bit_swap_pairs(a, cx ? i0 : 0, cx ? i1 : 0);
}
else
{
using MC = max_lanes_t<T, I0, I1>;

return bit_swap_pairs(a, i0m, i1m);
auto i0m = if_else(cx, as_wide_t<I0, MC>{i0}, zero);
auto i1m = if_else(cx, as_wide_t<I1, MC>{i1}, zero);
return bit_swap_pairs(as_wide_t<T, MC>{a}, i0m, i1m);
}
}

template<callable_options O, value T, integral_value I0, integral_value I1>
constexpr T bit_swap_pairs_(EVE_REQUIRES(cpu_), O const&, T a, I0 i0, I1 i1) noexcept
constexpr auto bit_swap_pairs_(EVE_REQUIRES(cpu_), O const&, T a, I0 i0, I1 i1) noexcept
{
// 1 if the bits of a at i0 and i1 are different, 0 otherwise
auto x = bit_and(
Expand All @@ -106,5 +129,13 @@ namespace eve
// if the bits are different, swap them by toggling both
return bit_xor(a, bit_shl(x, i1), bit_shl(x, i0));
}

template<callable_options O, typename T, std::ptrdiff_t I0, std::ptrdiff_t I1>
EVE_FORCEINLINE T bit_swap_pairs_(EVE_REQUIRES(cpu_), O const& o, T x, index_t<I0>, index_t<I1>) noexcept
{
constexpr std::ptrdiff_t C = sizeof(element_type_t<T>) * 8;
static_assert((I0 >= 0) && (I1 >= 0) && (I0 < C) && (I1 < C), "some index(es) are out or range");
return bit_swap_pairs[o](x, I0, I1);
}
}
}
6 changes: 4 additions & 2 deletions include/eve/module/core/regular/byte_swap_pairs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ namespace eve
struct byte_swap_pairs_t : strict_elementwise_callable<byte_swap_pairs_t, Options>
{
template<integral_value T, std::ptrdiff_t I0, std::ptrdiff_t I1>
EVE_FORCEINLINE T operator()(T a, index_t<I0> const & i0, index_t<I1> const & i1) const noexcept
{ return EVE_DISPATCH_CALL(a, i0, i1); }
EVE_FORCEINLINE T operator()(T a, index_t<I0> i0, index_t<I1> i1) const noexcept
{
return EVE_DISPATCH_CALL(a, i0, i1);
}

EVE_CALLABLE_OBJECT(byte_swap_pairs_t, byte_swap_pairs_);
};
Expand Down
13 changes: 2 additions & 11 deletions include/eve/module/core/regular/impl/byte_swap_pairs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
namespace eve::detail
{
template<typename T, std::ptrdiff_t I0, std::ptrdiff_t I1, callable_options O>
EVE_FORCEINLINE T byte_swap_pairs_(EVE_REQUIRES(cpu_),
O const &,
T x ,
index_t<I0> const & ,
index_t<I1> const &) noexcept
EVE_FORCEINLINE T byte_swap_pairs_(EVE_REQUIRES(cpu_), O const &, T x, index_t<I0>, index_t<I1>) noexcept
{
if constexpr(simd_value<T>)
{
Expand Down Expand Up @@ -54,12 +50,7 @@ namespace eve::detail

// Masked case
template<conditional_expr C, typename T, std::ptrdiff_t I0, std::ptrdiff_t I1, callable_options O>
EVE_FORCEINLINE T byte_swap_pairs_(EVE_REQUIRES(cpu_),
C const& cond,
O const &,
T t,
index_t<I0> const & i0,
index_t<I1> const & i1) noexcept
EVE_FORCEINLINE T byte_swap_pairs_(EVE_REQUIRES(cpu_), C const& cond, O const&, T t, index_t<I0> i0, index_t<I1> i1) noexcept
{
return mask_op(cond, eve::byte_swap_pairs, t, i0, i1);
}
Expand Down
28 changes: 10 additions & 18 deletions include/eve/module/core/regular/impl/swap_pairs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,16 @@

namespace eve::detail
{
template<value T, std::ptrdiff_t I0, std::ptrdiff_t I1>
EVE_FORCEINLINE T
swap_pairs_(EVE_SUPPORTS(cpu_), T x
, index_t<I0> const &
, index_t<I1> const & ) noexcept
template<callable_options O, simd_value T, std::ptrdiff_t I0, std::ptrdiff_t I1>
EVE_FORCEINLINE constexpr T swap_pairs_(EVE_REQUIRES(cpu_), O const&, T x, index_t<I0>, index_t<I1>) noexcept
{
[[maybe_unused]] constexpr std::ptrdiff_t C = scalar_value<T> ? 1 : cardinal_v<T>;
EVE_ASSERT((I0 >= 0) && (I1 >= 0) && (I0 < C) && (I1 < C), "some index(es) are out or range");
if constexpr(simd_value<T>)
{
auto p = [](auto i, auto){
return (i == I0) ? I1 :(i == I1 ? I0 : i) ;
};
return eve::shuffle(x, eve::as_pattern(p));
}
else if constexpr(scalar_value<T>)
{
return x;
}
constexpr std::ptrdiff_t C = cardinal_v<T>;
static_assert((I0 >= 0) && (I1 >= 0) && (I0 < C) && (I1 < C), "some index(es) are out or range");

auto p = [](auto i, auto){
return (i == I0) ? I1 :(i == I1 ? I0 : i) ;
};

return eve::shuffle(x, eve::as_pattern(p));
}
}
92 changes: 51 additions & 41 deletions include/eve/module/core/regular/swap_pairs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,57 @@

namespace eve
{
//TODO DOC
//================================================================================================
//! @addtogroup core_bitops
//! @{
//! @var swap_pairs
//! @brief swap chosen pair of elements.
//!
//! @groupheader{Header file}
//!
//! @code
//! #include <eve/module/core.hpp>
//! @endcode
//!
//! @groupheader{Callable Signatures}
//!
//! @code
//! namespace eve
//! {
//! template<value T, std::ptrdiff_t I0, std::ptrdiff_t I1 >
//! T swap_pairs(T x, index_t<I0> const & i0, index_t<I1> const & i1);
//! @endcode
//!
//! **Parameters**
//!
//! * `x` : [argument](@ref eve::value).
//! * `i0` : first index
//! * `i1` : second index
//!
//! **Return value**
//!
//! Return x with element i0 and i1 swapped. Action on scalar is identity.
//! Assert if i0 or i1 are out of range.
//!
//! @groupheader{Example}
//!
//! @godbolt{doc/core/swap_pairs.cpp}
//================================================================================================
EVE_MAKE_CALLABLE(swap_pairs_, swap_pairs);
//================================================================================================
//! @}
//================================================================================================
template<typename Options>
struct swap_pairs_t : callable<swap_pairs_t, Options>
{
template<simd_value T, std::ptrdiff_t I0, std::ptrdiff_t I1>
EVE_FORCEINLINE T operator()(T x, index_t<I0> i0, index_t<I1> i1) const noexcept
{
return EVE_DISPATCH_CALL(x, i0, i1);
}

EVE_CALLABLE_OBJECT(swap_pairs_t, swap_pairs_);
};

//================================================================================================
//! @addtogroup core_bitops
//! @{
//! @var swap_pairs
//! @brief swap chosen pair of elements.
//!
//! @groupheader{Header file}
//!
//! @code
//! #include <eve/module/core.hpp>
//! @endcode
//!
//! @groupheader{Callable Signatures}
//!
//! @code
//! namespace eve
//! {
//! template<simd_value T, std::ptrdiff_t I0, std::ptrdiff_t I1 >
//! T swap_pairs(T x, index_t<I0> i0, index_t<I1> i1);
//! @endcode
//!
//! **Parameters**
//!
//! * `x` : [argument](@ref eve::simd_value).
//! * `i0` : first index
//! * `i1` : second index
//!
//! **Return value**
//!
//! Return x with element i0 and i1 swapped.
//!
//! @groupheader{Example}
//!
//! @godbolt{doc/core/swap_pairs.cpp}
//================================================================================================
inline constexpr auto swap_pairs = functor<swap_pairs_t>;
//================================================================================================
//! @}
//================================================================================================
}

#include <eve/module/core/regular/impl/swap_pairs.hpp>
51 changes: 51 additions & 0 deletions include/eve/traits/max_lanes.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//==================================================================================================
/*
EVE - Expressive Vector Engine
Copyright : EVE Project Contributors
SPDX-License-Identifier: BSL-1.0
*/
//==================================================================================================
#pragma once

namespace eve
{
namespace detail
{
template<typename... Ts>
consteval auto compute_max_lanes()
{
std::ptrdiff_t cards[] = { cardinal_v<Ts>... };

auto max_card = cards[0];
for(auto c : cards) max_card = max_card < c ? c : max_card;

return max_card;
}
}

//================================================================================================
//! @addtogroup traits
//! @{
//! @var max_lanes
//!
//! @tparam Ts Types to process
//!
//! @brief A meta function for getting a maximum lane count of given wide or scalar types.
//! @}
//================================================================================================
template <typename... Ts>
inline constexpr auto max_lanes_v = detail::compute_max_lanes<Ts...>();

//================================================================================================
//! @addtogroup traits
//! @{
//! @var max_lanes
//!
//! @tparam Ts Types to process
//!
//! @brief The cardinal type of the maximum lane count of given wide or scalar types.
//! @}
//================================================================================================
template <typename... Ts>
using max_lanes_t = fixed<max_lanes_v<Ts...>>;
}
14 changes: 12 additions & 2 deletions test/doc/core/bit_swap_pairs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,19 @@ int main()
std::cout << std::showbase << std::hex;
std::cout << "<- wi0 = " << wi0 << "\n";
std::cout << "<- wi1 = " << wi1 << "\n";
std::cout << "<- wi2 = " << wi2 << "\n";
std::cout << "<- wi2 = " << wi2 << "\n\n";

std::cout << "-> bit_swap_pairs(wi0, wi1, wi2) = " << eve::bit_swap_pairs(wi0, wi1, wi2) << "\n";
std::cout << "-> bit_swap_pairs[ignore_last(2)](wi0, wi1, wi2) = " << eve::bit_swap_pairs[eve::ignore_last(2)](wi0, wi1, wi2) << "\n";
std::cout << "-> bit_swap_pairs[wi3 > 0](wi0, wi1, wi3) = " << eve::bit_swap_pairs[wi3 >= 0](wi0, wi1, wi3) << "\n";
std::cout << "-> bit_swap_pairs[wi3 > 0](wi0, wi1, wi3) = " << eve::bit_swap_pairs[wi3 >= 0](wi0, wi1, wi3) << "\n\n";

std::cout << "-> bit_swap_pairs(wi0, 3, 2) = " << eve::bit_swap_pairs(wi0, 3, 2) << "\n";
std::cout << "-> bit_swap_pairs[ignore_last(2)](wi0, 3, 2) = " << eve::bit_swap_pairs[eve::ignore_last(2)](wi0, 3, 2) << "\n";
std::cout << "-> bit_swap_pairs[wi3 > 0](wi0, 3, 2) = " << eve::bit_swap_pairs[wi3 >= 0](wi0, 3, 2) << "\n\n";

auto i3 = eve::index_t<3>{};
auto i2 = eve::index_t<2>{};
std::cout << "-> bit_swap_pairs(wi0, i3, i2) = " << eve::bit_swap_pairs(wi0, i3, i2) << "\n";
std::cout << "-> bit_swap_pairs[ignore_last(2)](wi0, i3, i2) = " << eve::bit_swap_pairs[eve::ignore_last(2)](wi0, i3, i2) << "\n";
std::cout << "-> bit_swap_pairs[wi3 > 0](wi0, i3, i2) = " << eve::bit_swap_pairs[wi3 >= 0](wi0, i3, i2) << "\n\n";
}
17 changes: 15 additions & 2 deletions test/unit/module/core/bit_swap_pairs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,21 @@ TTS_CASE_WITH("Check behavior of bit_swap_pairs(simd) on integral types",
using v_t = eve::element_type_t<T>;
using eve::bit_swap_pairs;

TTS_EQUAL(bit_swap_pairs(a0, 0u, 7u), tts::map([](auto e) -> v_t { return eve::bit_swap_pairs(e, 0u, 7u); }, a0)) << a0 << '\n';
TTS_EQUAL(eve::bit_swap_pairs[t](a0, 0u, 7u), eve::if_else(t, eve::bit_swap_pairs(a0, 0u, 7u), a0));
// full scalar
TTS_EQUAL(eve::bit_swap_pairs(0b01010101, 0, 7), 0b11010100);
TTS_EQUAL(eve::bit_swap_pairs(0b01010101, eve::index<0>, eve::index<7>), 0b11010100);

// scalar into wide
using wt = eve::wide<int, eve::fixed<4>>;
TTS_EQUAL(eve::bit_swap_pairs(0b01010101, wt{0}, wt{7}), wt{0b11010100});

// wide
TTS_EQUAL(bit_swap_pairs(a0, 0u, 7), tts::map([](auto e) -> v_t { return eve::bit_swap_pairs(e, 0, 7u); }, a0)) << a0 << '\n';
TTS_EQUAL(bit_swap_pairs(a0, eve::index<0>, eve::index<7>), tts::map([](auto e) -> v_t { return eve::bit_swap_pairs(e, eve::index<0>, eve::index<7>); }, a0)) << a0 << '\n';

// wide masked
TTS_EQUAL(eve::bit_swap_pairs[t](a0, 0u, 7), eve::if_else(t, eve::bit_swap_pairs(a0, 0, 7u), a0));
TTS_EQUAL(eve::bit_swap_pairs[t](a0, eve::index<0>, eve::index<7>), eve::if_else(t, eve::bit_swap_pairs(a0, eve::index<0>, eve::index<7>), a0));

eve::wide<int, typename T::cardinal_type> wn{[](auto i, auto) { return -i; }};
TTS_EQUAL(eve::bit_swap_pairs[wn > 0](a0, wn, 7), a0);
Expand Down
Loading
Loading