diff --git a/Eigen/Core b/Eigen/Core index 1a9b4700b..1e7e38cb1 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -186,7 +186,6 @@ using std::ptrdiff_t; // Generic half float support #include "src/Core/arch/Default/Half.h" #include "src/Core/arch/Default/BFloat16.h" -#include "src/Core/arch/Default/TypeCasting.h" #include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h" #if defined EIGEN_VECTORIZE_AVX512 diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index e233efba1..32371c55b 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -621,6 +621,207 @@ protected: Data m_d; }; +// ----------------------- Casting --------------------- + +template +struct unary_evaluator, ArgType>, IndexBased> { + using CastOp = core_cast_op; + using XprType = CwiseUnaryOp; + + // Use the largest packet type by default + using SrcPacketType = typename packet_traits::type; + static constexpr int SrcPacketSize = unpacket_traits::size; + static constexpr int SrcPacketBytes = SrcPacketSize * sizeof(SrcType); + + enum { + CoeffReadCost = int(evaluator::CoeffReadCost) + int(functor_traits::Cost), + PacketAccess = functor_traits::PacketAccess, + ActualPacketAccessBit = PacketAccess ? PacketAccessBit : 0, + Flags = evaluator::Flags & (HereditaryBits | LinearAccessBit | ActualPacketAccessBit), + IsRowMajor = (evaluator::Flags & RowMajorBit), + Alignment = evaluator::Alignment + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit unary_evaluator(const XprType& xpr) + : m_argImpl(xpr.nestedExpression()), m_rows(xpr.rows()), m_cols(xpr.cols()) { + EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits::Cost); + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + template + using AltSrcScalarOp = std::enable_if_t<(unpacket_traits::size < SrcPacketSize && !find_packet_by_size::size>::value), bool>; + template + using SrcPacketArgs1 = std::enable_if_t<(find_packet_by_size::size>::value), bool>; + template + using SrcPacketArgs2 = std::enable_if_t<(unpacket_traits::size) == (2 * SrcPacketSize), bool>; + template + using SrcPacketArgs4 = std::enable_if_t<(unpacket_traits::size) == (4 * SrcPacketSize), bool>; + template + using SrcPacketArgs8 = std::enable_if_t<(unpacket_traits::size) == (8 * SrcPacketSize), bool>; + + template = true> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_array_bounds(Index, Index col, Index packetSize) const { + return col + packetSize <= cols(); + } + template = true> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_array_bounds(Index row, Index, Index packetSize) const { + return row + packetSize <= rows(); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_array_bounds(Index index, Index packetSize) const { + return index + packetSize <= size(); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE SrcType srcCoeff(Index row, Index col, Index offset) const { + Index actualRow = IsRowMajor ? row : row + offset; + Index actualCol = IsRowMajor ? col + offset : col; + return m_argImpl.coeff(actualRow, actualCol); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE SrcType srcCoeff(Index index, Index offset) const { + Index actualIndex = index + offset; + return m_argImpl.coeff(actualIndex); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DstType coeff(Index row, Index col) const { + return cast(srcCoeff(row, col, 0)); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DstType coeff(Index index) const { return cast(srcCoeff(index, 0)); } + + template + EIGEN_STRONG_INLINE PacketType srcPacket(Index row, Index col, Index offset) const { + constexpr int PacketSize = unpacket_traits::size; + Index actualRow = IsRowMajor ? row : row + (offset * PacketSize); + Index actualCol = IsRowMajor ? col + (offset * PacketSize) : col; + eigen_assert(check_array_bounds(actualRow, actualCol, PacketSize) && "Array index out of bounds"); + return m_argImpl.template packet(actualRow, actualCol); + } + template + EIGEN_STRONG_INLINE PacketType srcPacket(Index index, Index offset) const { + constexpr int PacketSize = unpacket_traits::size; + Index actualIndex = index + (offset * PacketSize); + eigen_assert(check_array_bounds(actualIndex, PacketSize) && "Array index out of bounds"); + return m_argImpl.template packet(actualIndex); + } + + // There is no source packet type with equal or fewer elements than DstPacketType. + // This is problematic as the evaluation loop may attempt to access data outside the bounds of the array. + // For example, consider the cast utilizing pcast with an array of size 4: {0.0f,1.0f,2.0f,3.0f}. + // The first iteration of the evaulation loop will load 16 bytes: {0.0f,1.0f,2.0f,3.0f} and cast to {0.0,1.0}, which is acceptable. + // The second iteration will load 16 bytes: {2.0f,3.0f,?,?}, which is outside the bounds of the array. + + // Instead, perform runtime check to determine if the load would access data outside the bounds of the array. + // If not, perform full load. Otherwise, revert to a scalar loop to perform a partial load. + // In either case, perform a vectorized cast of the source packet. + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const { + constexpr int DstPacketSize = unpacket_traits::size; + constexpr int SrcBytesIncrement = DstPacketSize * sizeof(SrcType); + constexpr int SrcLoadMode = plain_enum_min(SrcBytesIncrement, LoadMode); + SrcPacketType src; + if (EIGEN_PREDICT_TRUE(check_array_bounds(row, col, SrcPacketSize))) { + src = srcPacket(row, col, 0); + } else { + Array srcArray; + for (size_t k = 0; k < DstPacketSize; k++) srcArray[k] = srcCoeff(row, col, k); + for (size_t k = DstPacketSize; k < SrcPacketSize; k++) srcArray[k] = SrcType(0); + src = pload(srcArray.data()); + } + return pcast(src); + } + // Use the source packet type with the same size as DstPacketType, if it exists + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const { + constexpr int DstPacketSize = unpacket_traits::size; + using SizedSrcPacketType = typename find_packet_by_size::type; + constexpr int SrcBytesIncrement = DstPacketSize * sizeof(SrcType); + constexpr int SrcLoadMode = plain_enum_min(SrcBytesIncrement, LoadMode); + return pcast( + srcPacket(row, col, 0)); + } + // unpacket_traits::size == 2 * SrcPacketSize + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const { + constexpr int SrcLoadMode = plain_enum_min(SrcPacketBytes, LoadMode); + return pcast( + srcPacket(row, col, 0), srcPacket(row, col, 1)); + } + // unpacket_traits::size == 4 * SrcPacketSize + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const { + constexpr int SrcLoadMode = plain_enum_min(SrcPacketBytes, LoadMode); + return pcast( + srcPacket(row, col, 0), srcPacket(row, col, 1), + srcPacket(row, col, 2), srcPacket(row, col, 3)); + } + // unpacket_traits::size == 8 * SrcPacketSize + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const { + constexpr int SrcLoadMode = plain_enum_min(SrcPacketBytes, LoadMode); + return pcast( + srcPacket(row, col, 0), srcPacket(row, col, 1), + srcPacket(row, col, 2), srcPacket(row, col, 3), + srcPacket(row, col, 4), srcPacket(row, col, 5), + srcPacket(row, col, 6), srcPacket(row, col, 7)); + } + + // Analagous routines for linear access. + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index index) const { + constexpr int DstPacketSize = unpacket_traits::size; + constexpr int SrcBytesIncrement = DstPacketSize * sizeof(SrcType); + constexpr int SrcLoadMode = plain_enum_min(SrcBytesIncrement, LoadMode); + SrcPacketType src; + if (EIGEN_PREDICT_TRUE(check_array_bounds(index, SrcPacketSize))) { + src = srcPacket(index, 0); + } else { + Array srcArray; + for (size_t k = 0; k < DstPacketSize; k++) srcArray[k] = srcCoeff(index, k); + for (size_t k = DstPacketSize; k < SrcPacketSize; k++) srcArray[k] = SrcType(0); + src = pload(srcArray.data()); + } + return pcast(src); + } + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index index) const { + constexpr int DstPacketSize = unpacket_traits::size; + using SizedSrcPacketType = typename find_packet_by_size::type; + constexpr int SrcBytesIncrement = DstPacketSize * sizeof(SrcType); + constexpr int SrcLoadMode = plain_enum_min(SrcBytesIncrement, LoadMode); + return pcast( + srcPacket(index, 0)); + } + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index index) const { + constexpr int SrcLoadMode = plain_enum_min(SrcPacketBytes, LoadMode); + return pcast( + srcPacket(index, 0), srcPacket(index, 1)); + } + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index index) const { + constexpr int SrcLoadMode = plain_enum_min(SrcPacketBytes, LoadMode); + return pcast( + srcPacket(index, 0), srcPacket(index, 1), + srcPacket(index, 2), srcPacket(index, 3)); + } + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index index) const { + constexpr int SrcLoadMode = plain_enum_min(SrcPacketBytes, LoadMode); + return pcast( + srcPacket(index, 0), srcPacket(index, 1), + srcPacket(index, 2), srcPacket(index, 3), + srcPacket(index, 4), srcPacket(index, 5), + srcPacket(index, 6), srcPacket(index, 7)); + } + + constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rows() const { return m_rows; } + constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index cols() const { return m_cols; } + constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return m_rows * m_cols; } + + protected: + const evaluator m_argImpl; + const variable_if_dynamic m_rows; + const variable_if_dynamic m_cols; +}; + // -------------------- CwiseTernaryOp -------------------- // this is a ternary expression diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 8cb80bb81..bfc7ae68a 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -146,14 +146,67 @@ template struct unpacket_traits template struct unpacket_traits : unpacket_traits { }; -template struct type_casting_traits { +/** \internal A convenience utility for determining if the type is a scalar. + * This is used to enable some generic packet implementations. + */ +template +struct is_scalar { + using Scalar = typename unpacket_traits::type; + enum { value = internal::is_same::value }; +}; + +// automatically and succinctly define combinations of pcast when +// 1) the packets are the same type, or +// 2) the packets differ only in sign. +// In both of these cases, preinterpret (bit_cast) is equivalent to pcast (static_cast) +template ::value && is_scalar::value> +struct is_degenerate_helper : is_same {}; +template <> +struct is_degenerate_helper : std::true_type {}; +template <> +struct is_degenerate_helper : std::true_type {}; +template <> +struct is_degenerate_helper : std::true_type {}; +template <> +struct is_degenerate_helper : std::true_type {}; + +template +struct is_degenerate_helper { + using SrcScalar = typename unpacket_traits::type; + static constexpr int SrcSize = unpacket_traits::size; + using TgtScalar = typename unpacket_traits::type; + static constexpr int TgtSize = unpacket_traits::size; + static constexpr bool value = is_degenerate_helper::value && (SrcSize == TgtSize); +}; + +// is_degenerate::value == is_degenerate::value +template +struct is_degenerate { + static constexpr bool value = + is_degenerate_helper::value || is_degenerate_helper::value; +}; + +template +struct is_half { + using Scalar = typename unpacket_traits::type; + static constexpr int Size = unpacket_traits::size; + using DefaultPacket = typename packet_traits::type; + static constexpr int DefaultSize = unpacket_traits::size; + static constexpr bool value = Size < DefaultSize; +}; + +template +struct type_casting_traits { enum { - VectorizedCast = 0, + VectorizedCast = + is_degenerate::value && packet_traits::Vectorizable && packet_traits::Vectorizable, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; }; + /** \internal Wrapper to ensure that multiple packet types can map to the same same underlying vector type. */ template @@ -171,45 +224,84 @@ struct eigen_packet_wrapper T m_val; }; +template ::value> +struct preinterpret_generic; -/** \internal A convenience utility for determining if the type is a scalar. - * This is used to enable some generic packet implementations. - */ -template -struct is_scalar { - using Scalar = typename unpacket_traits::type; - enum { - value = internal::is_same::value - }; +template +struct preinterpret_generic { + // the packets are not the same, attempt scalar bit_cast + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Target run(const Packet& a) { + return numext::bit_cast(a); + } }; -/** \internal \returns static_cast(a) (coeff-wise) */ -template -EIGEN_DEVICE_FUNC inline TgtPacket -pcast(const SrcPacket& a) { - return static_cast(a); -} -template -EIGEN_DEVICE_FUNC inline TgtPacket -pcast(const SrcPacket& a, const SrcPacket& /*b*/) { - return static_cast(a); -} -template -EIGEN_DEVICE_FUNC inline TgtPacket -pcast(const SrcPacket& a, const SrcPacket& /*b*/, const SrcPacket& /*c*/, const SrcPacket& /*d*/) { - return static_cast(a); -} -template -EIGEN_DEVICE_FUNC inline TgtPacket -pcast(const SrcPacket& a, const SrcPacket& /*b*/, const SrcPacket& /*c*/, const SrcPacket& /*d*/, - const SrcPacket& /*e*/, const SrcPacket& /*f*/, const SrcPacket& /*g*/, const SrcPacket& /*h*/) { - return static_cast(a); -} +template +struct preinterpret_generic { + // the packets are the same type: do nothing + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& a) { return a; } +}; /** \internal \returns reinterpret_cast(a) */ template -EIGEN_DEVICE_FUNC inline Target -preinterpret(const Packet& a); /* { return reinterpret_cast(a); } */ +EIGEN_DEVICE_FUNC inline Target preinterpret(const Packet& a) { + return preinterpret_generic::run(a); +} + +template ::value, bool TgtIsHalf = is_half::value> +struct pcast_generic; + +template +struct pcast_generic { + // the packets are not degenerate: attempt scalar static_cast + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket run(const SrcPacket& a) { + return cast_impl::run(a); + } +}; + +template +struct pcast_generic { + // the packets are the same: do nothing + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& a) { return a; } +}; + +template +struct pcast_generic { + // the packets are degenerate: preinterpret is equivalent to pcast + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket run(const SrcPacket& a) { return preinterpret(a); } +}; + + + +/** \internal \returns static_cast(a) (coeff-wise) */ +template +EIGEN_DEVICE_FUNC inline TgtPacket pcast(const SrcPacket& a) { + return pcast_generic::run(a); +} +template +EIGEN_DEVICE_FUNC inline TgtPacket pcast(const SrcPacket& a, const SrcPacket& b) { + return pcast_generic::run(a, b); +} +template +EIGEN_DEVICE_FUNC inline TgtPacket pcast(const SrcPacket& a, const SrcPacket& b, const SrcPacket& c, + const SrcPacket& d) { + return pcast_generic::run(a, b, c, d); +} +template +EIGEN_DEVICE_FUNC inline TgtPacket pcast(const SrcPacket& a, const SrcPacket& b, const SrcPacket& c, const SrcPacket& d, + const SrcPacket& e, const SrcPacket& f, const SrcPacket& g, + const SrcPacket& h) { + return pcast_generic::run(a, b, c, d, e, f, g, h); +} + +template +struct pcast_generic { + // TgtPacket is a half packet of some other type + // perform cast and truncate result + using DefaultTgtPacket = typename is_half::DefaultPacket; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket run(const SrcPacket& a) { + return preinterpret(pcast(a)); + } +}; /** \internal \returns a + b (coeff-wise) */ template EIGEN_DEVICE_FUNC inline Packet diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 40ee3f551..d7851e319 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -430,6 +430,13 @@ struct cast_impl } }; +template +struct cast_impl { + EIGEN_DEVICE_FUNC + static inline bool run(const OldType& x) { return x != OldType(0); } +}; + + // Casting from S -> Complex leads to an implicit conversion from S to T, // generating warnings on clang. Here we explicitly cast the real component. template diff --git a/Eigen/src/Core/arch/AVX/TypeCasting.h b/Eigen/src/Core/arch/AVX/TypeCasting.h index 386543e66..461f3a637 100644 --- a/Eigen/src/Core/arch/AVX/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX/TypeCasting.h @@ -80,6 +80,14 @@ template<> EIGEN_STRONG_INLINE Packet8i pcast(const Packet4d return _mm256_set_m128i(_mm256_cvttpd_epi32(b), _mm256_cvttpd_epi32(a)); } +template <> EIGEN_STRONG_INLINE Packet4f pcast(const Packet4d& a) { + return _mm256_cvtpd_ps(a); +} + +template <> EIGEN_STRONG_INLINE Packet4i pcast(const Packet4d& a) { + return _mm256_cvttpd_epi32(a); +} + template <> EIGEN_STRONG_INLINE Packet16b pcast(const Packet8f& a, const Packet8f& b) { @@ -118,6 +126,44 @@ template<> EIGEN_STRONG_INLINE Packet8f preinterpret(const Pa return _mm256_castsi256_ps(a); } +template<> EIGEN_STRONG_INLINE Packet8ui preinterpret(const Packet8i& a) { + return Packet8ui(a); +} + +template<> EIGEN_STRONG_INLINE Packet8i preinterpret(const Packet8ui& a) { + return Packet8i(a); +} + +// truncation operations + +template<> EIGEN_STRONG_INLINE Packet4f preinterpret(const Packet8f& a) { + return _mm256_castps256_ps128(a); +} + +template<> EIGEN_STRONG_INLINE Packet2d preinterpret(const Packet4d& a) { + return _mm256_castpd256_pd128(a); +} + +template<> EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet8i& a) { + return _mm256_castsi256_si128(a); +} + +template<> EIGEN_STRONG_INLINE Packet4ui preinterpret(const Packet8ui& a) { + return _mm256_castsi256_si128(a); +} + + +#ifdef EIGEN_VECTORIZE_AVX2 +template<> EIGEN_STRONG_INLINE Packet4ul preinterpret(const Packet4l& a) { + return Packet4ul(a); +} + +template<> EIGEN_STRONG_INLINE Packet4l preinterpret(const Packet4ul& a) { + return Packet4l(a); +} + +#endif + template<> EIGEN_STRONG_INLINE Packet8f pcast(const Packet8h& a) { return half2float(a); } diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h index 02e633552..2f38d7f80 100644 --- a/Eigen/src/Core/arch/AVX512/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h @@ -59,6 +59,13 @@ template<> EIGEN_STRONG_INLINE Packet16i pcast(const Packet return cat256i(_mm512_cvttpd_epi32(a), _mm512_cvttpd_epi32(b)); } +template<> EIGEN_STRONG_INLINE Packet8i pcast(const Packet8d& a) { + return _mm512_cvtpd_epi32(a); +} +template<> EIGEN_STRONG_INLINE Packet8f pcast(const Packet8d& a) { + return _mm512_cvtpd_ps(a); +} + template<> EIGEN_STRONG_INLINE Packet16i preinterpret(const Packet16f& a) { return _mm512_castps_si512(a); } @@ -107,12 +114,19 @@ template<> EIGEN_STRONG_INLINE Packet8d preinterpret(const P return _mm512_castpd128_pd512(a); } -template<> EIGEN_STRONG_INLINE Packet16f preinterpret(const Packet16f& a) { - return a; +template<> EIGEN_STRONG_INLINE Packet8i preinterpret(const Packet16i& a) { + return _mm512_castsi512_si256(a); +} +template<> EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet16i& a) { + return _mm512_castsi512_si128(a); } -template<> EIGEN_STRONG_INLINE Packet8d preinterpret(const Packet8d& a) { - return a; +template<> EIGEN_STRONG_INLINE Packet8h preinterpret(const Packet16h& a) { + return _mm256_castsi256_si128(a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf preinterpret(const Packet16bf& a) { + return _mm256_castsi256_si128(a); } #ifndef EIGEN_VECTORIZE_AVX512FP16 @@ -191,6 +205,13 @@ struct type_casting_traits { }; }; +template<> EIGEN_STRONG_INLINE Packet16h preinterpret(const Packet32h& a) { + return _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0)); +} +template<> EIGEN_STRONG_INLINE Packet8h preinterpret(const Packet32h& a) { + return _mm256_castsi256_si128(preinterpret(a)); +} + template <> EIGEN_STRONG_INLINE Packet16f pcast(const Packet32h& a) { // Discard second-half of input. diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h index c8ca33a5b..17ce13581 100644 --- a/Eigen/src/Core/arch/Default/Half.h +++ b/Eigen/src/Core/arch/Default/Half.h @@ -1014,4 +1014,49 @@ struct hash { } // end namespace std #endif +namespace Eigen { +namespace internal { + +template <> +struct cast_impl { + EIGEN_DEVICE_FUNC + static inline half run(const float& a) { +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return __float2half(a); +#else + return half(a); +#endif + } +}; + +template <> +struct cast_impl { + EIGEN_DEVICE_FUNC + static inline half run(const int& a) { +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return __float2half(static_cast(a)); +#else + return half(static_cast(a)); +#endif + } +}; + +template <> +struct cast_impl { + EIGEN_DEVICE_FUNC + static inline float run(const half& a) { +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return __half2float(a); +#else + return static_cast(a); +#endif + } +}; + +} // namespace internal +} // namespace Eigen + #endif // EIGEN_HALF_H diff --git a/Eigen/src/Core/arch/Default/TypeCasting.h b/Eigen/src/Core/arch/Default/TypeCasting.h deleted file mode 100644 index dc779a725..000000000 --- a/Eigen/src/Core/arch/Default/TypeCasting.h +++ /dev/null @@ -1,116 +0,0 @@ -// This file is part of Eigen, a lightweight C++ template library -// for linear algebra. -// -// Copyright (C) 2016 Benoit Steiner -// Copyright (C) 2019 Rasmus Munk Larsen -// -// This Source Code Form is subject to the terms of the Mozilla -// Public License v. 2.0. If a copy of the MPL was not distributed -// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. - -#ifndef EIGEN_GENERIC_TYPE_CASTING_H -#define EIGEN_GENERIC_TYPE_CASTING_H - -#include "../../InternalHeaderCheck.h" - -namespace Eigen { - -namespace internal { - -template<> -struct scalar_cast_op { - typedef Eigen::half result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const float& a) const { - #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ - (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) - return __float2half(a); - #else - return Eigen::half(a); - #endif - } -}; - -template<> -struct functor_traits > -{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; - - -template<> -struct scalar_cast_op { - typedef Eigen::half result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const int& a) const { - #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ - (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) - return __float2half(static_cast(a)); - #else - return Eigen::half(static_cast(a)); - #endif - } -}; - -template<> -struct functor_traits > -{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; - - -template<> -struct scalar_cast_op { - typedef float result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::half& a) const { - #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ - (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) - return __half2float(a); - #else - return static_cast(a); - #endif - } -}; - -template<> -struct functor_traits > -{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; - - -template<> -struct scalar_cast_op { - typedef Eigen::bfloat16 result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const { - return Eigen::bfloat16(a); - } -}; - -template<> -struct functor_traits > -{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; - - -template<> -struct scalar_cast_op { - typedef Eigen::bfloat16 result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const { - return Eigen::bfloat16(static_cast(a)); - } -}; - -template<> -struct functor_traits > -{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; - - -template<> -struct scalar_cast_op { - typedef float result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const { - return static_cast(a); - } -}; - -template<> -struct functor_traits > -{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; - - -} -} - -#endif // EIGEN_GENERIC_TYPE_CASTING_H diff --git a/Eigen/src/Core/arch/NEON/TypeCasting.h b/Eigen/src/Core/arch/NEON/TypeCasting.h index add31b917..834fcf53e 100644 --- a/Eigen/src/Core/arch/NEON/TypeCasting.h +++ b/Eigen/src/Core/arch/NEON/TypeCasting.h @@ -17,6 +17,61 @@ namespace Eigen { namespace internal { + +//============================================================================== +// preinterpret (truncation operations) +//============================================================================== + +template <> +EIGEN_STRONG_INLINE Packet8c preinterpret(const Packet16c& a) { + return Packet8c(vget_low_s8(a)); +} +template <> +EIGEN_STRONG_INLINE Packet4c preinterpret(const Packet8c& a) { + return Packet4c(vget_lane_s32(vreinterpret_s32_s8(a), 0)); +} +template <> +EIGEN_STRONG_INLINE Packet4c preinterpret(const Packet16c& a) { + return preinterpret(preinterpret(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet8uc preinterpret(const Packet16uc& a) { + return Packet8uc(vget_low_u8(a)); +} +template <> +EIGEN_STRONG_INLINE Packet4uc preinterpret(const Packet8uc& a) { + return Packet4uc(vget_lane_u32(vreinterpret_u32_u8(a), 0)); +} +template <> +EIGEN_STRONG_INLINE Packet4uc preinterpret(const Packet16uc& a) { + return preinterpret(preinterpret(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet4s preinterpret(const Packet8s& a) { + return Packet4s(vget_low_s16(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet4us preinterpret(const Packet8us& a) { + return Packet4us(vget_low_u16(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet2i preinterpret(const Packet4i& a) { + return Packet2i(vget_low_s32(a)); +} +template <> +EIGEN_STRONG_INLINE Packet2ui preinterpret(const Packet4ui& a) { + return Packet2ui(vget_low_u32(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet2f preinterpret(const Packet4f& a) { + return Packet2f(vget_low_f32(a)); +} + //============================================================================== // preinterpret //============================================================================== @@ -37,6 +92,7 @@ EIGEN_STRONG_INLINE Packet4f preinterpret(const Packet4ui& return Packet4f(vreinterpretq_f32_u32(a)); } + template <> EIGEN_STRONG_INLINE Packet4c preinterpret(const Packet4uc& a) { return static_cast(a); @@ -50,6 +106,7 @@ EIGEN_STRONG_INLINE Packet16c preinterpret(const Packet16 return Packet16c(vreinterpretq_s8_u8(a)); } + template <> EIGEN_STRONG_INLINE Packet4uc preinterpret(const Packet4c& a) { return static_cast(a); @@ -71,7 +128,6 @@ template <> EIGEN_STRONG_INLINE Packet8s preinterpret(const Packet8us& a) { return Packet8s(vreinterpretq_s16_u16(a)); } - template <> EIGEN_STRONG_INLINE Packet4us preinterpret(const Packet4s& a) { return Packet4us(vreinterpret_u16_s16(a)); @@ -127,18 +183,7 @@ EIGEN_STRONG_INLINE Packet2ul preinterpret(const Packet2l& //============================================================================== // pcast, SrcType = float //============================================================================== -template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet4f pcast(const Packet4f& a) { - return a; -} -template <> -EIGEN_STRONG_INLINE Packet2f pcast(const Packet2f& a) { - return a; -} + template <> struct type_casting_traits { @@ -156,10 +201,18 @@ EIGEN_STRONG_INLINE Packet2l pcast(const Packet4f& a) { return vcvtq_s64_f64(vcvt_f64_f32(vget_low_f32(a))); } template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet2f& a) { + return vcvtq_s64_f64(vcvt_f64_f32(a)); +} +template <> EIGEN_STRONG_INLINE Packet2ul pcast(const Packet4f& a) { // Discard second half of input. return vcvtq_u64_f64(vcvt_f64_f32(vget_low_f32(a))); } +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet2f& a) { + return vcvtq_u64_f64(vcvt_f64_f32(a)); +} #else template <> EIGEN_STRONG_INLINE Packet2l pcast(const Packet4f& a) { @@ -167,10 +220,19 @@ EIGEN_STRONG_INLINE Packet2l pcast(const Packet4f& a) { return vmovl_s32(vget_low_s32(vcvtq_s32_f32(a))); } template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet2f& a) { + return vmovl_s32(vcvt_s32_f32(a)); +} +template <> EIGEN_STRONG_INLINE Packet2ul pcast(const Packet4f& a) { // Discard second half of input. return vmovl_u32(vget_low_u32(vcvtq_u32_f32(a))); } +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet2f& a) { + // Discard second half of input. + return vmovl_u32(vcvt_u32_f32(a)); +} #endif // EIGEN_ARCH_ARM64 template <> @@ -208,6 +270,10 @@ EIGEN_STRONG_INLINE Packet8s pcast(const Packet4f& a, const return vcombine_s16(vmovn_s32(vcvtq_s32_f32(a)), vmovn_s32(vcvtq_s32_f32(b))); } template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet4f& a) { + return vmovn_s32(vcvtq_s32_f32(a)); +} +template <> EIGEN_STRONG_INLINE Packet4s pcast(const Packet2f& a, const Packet2f& b) { return vmovn_s32(vcombine_s32(vcvt_s32_f32(a), vcvt_s32_f32(b))); } @@ -221,6 +287,10 @@ EIGEN_STRONG_INLINE Packet8us pcast(const Packet4f& a, cons return vcombine_u16(vmovn_u32(vcvtq_u32_f32(a)), vmovn_u32(vcvtq_u32_f32(b))); } template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet4f& a) { + return vmovn_u32(vcvtq_u32_f32(a)); +} +template <> EIGEN_STRONG_INLINE Packet4us pcast(const Packet2f& a, const Packet2f& b) { return vmovn_u32(vcombine_u32(vcvt_u32_f32(a), vcvt_u32_f32(b))); } @@ -237,12 +307,25 @@ EIGEN_STRONG_INLINE Packet16c pcast(const Packet4f& a, cons return vcombine_s8(vmovn_s16(ab_s16), vmovn_s16(cd_s16)); } template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet4f& a, const Packet4f& b) { + const int16x8_t ab_s16 = pcast(a, b); + return vmovn_s16(ab_s16); +} +template <> EIGEN_STRONG_INLINE Packet8c pcast(const Packet2f& a, const Packet2f& b, const Packet2f& c, const Packet2f& d) { const int16x4_t ab_s16 = pcast(a, b); const int16x4_t cd_s16 = pcast(c, d); return vmovn_s16(vcombine_s16(ab_s16, cd_s16)); } +template <> +EIGEN_STRONG_INLINE Packet4c pcast(const Packet4f& a) { + const int32x4_t a_s32x4 = vcvtq_s32_f32(a); + const int16x4_t a_s16x4 = vmovn_s32(a_s32x4); + const int16x8_t aa_s16x8 = vcombine_s16(a_s16x4, a_s16x4); + const int8x8_t aa_s8x8 = vmovn_s16(aa_s16x8); + return vget_lane_s32(vreinterpret_s32_s8(aa_s8x8), 0); +} template <> struct type_casting_traits { @@ -251,16 +334,20 @@ struct type_casting_traits { template <> EIGEN_STRONG_INLINE Packet16uc pcast(const Packet4f& a, const Packet4f& b, const Packet4f& c, const Packet4f& d) { - const uint16x8_t ab_u16 = pcast(a, b); - const uint16x8_t cd_u16 = pcast(c, d); - return vcombine_u8(vmovn_u16(ab_u16), vmovn_u16(cd_u16)); + return preinterpret(pcast(a, b, c, d)); +} +template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet4f& a, const Packet4f& b) { + return preinterpret(pcast(a, b)); } template <> EIGEN_STRONG_INLINE Packet8uc pcast(const Packet2f& a, const Packet2f& b, const Packet2f& c, const Packet2f& d) { - const uint16x4_t ab_u16 = pcast(a, b); - const uint16x4_t cd_u16 = pcast(c, d); - return vmovn_u16(vcombine_u16(ab_u16, cd_u16)); + return preinterpret(pcast(a, b, c, d)); +} +template <> +EIGEN_STRONG_INLINE Packet4uc pcast(const Packet4f& a) { + return static_cast(pcast(a)); } //============================================================================== @@ -276,6 +363,10 @@ EIGEN_STRONG_INLINE Packet4f pcast(const Packet16c& a) { return vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(a))))); } template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet4c& a) { + return vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vreinterpret_s8_s32(vdup_n_s32(a)))))); +} +template <> EIGEN_STRONG_INLINE Packet2f pcast(const Packet8c& a) { // Discard all but first 2 bytes. return vcvt_f32_s32(vget_low_s32(vmovl_s16(vget_low_s16(vmovl_s8(a))))); @@ -310,11 +401,20 @@ EIGEN_STRONG_INLINE Packet4i pcast(const Packet16c& a) { return vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(a)))); } template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet8c& a) { + return vmovl_s16(vget_low_s16(vmovl_s8(a))); +} +template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet4c& a) { + return pcast(vreinterpret_s8_s32(vdup_n_s32(a))); +} +template <> EIGEN_STRONG_INLINE Packet2i pcast(const Packet8c& a) { // Discard all but first 2 bytes. return vget_low_s32(vmovl_s16(vget_low_s16(vmovl_s8(a)))); } + template <> struct type_casting_traits { enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 }; @@ -327,6 +427,10 @@ template <> EIGEN_STRONG_INLINE Packet2ui pcast(const Packet8c& a) { return preinterpret(pcast(a)); } +template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet4c& a) { + return preinterpret(pcast(a)); +} template <> struct type_casting_traits { @@ -338,10 +442,18 @@ EIGEN_STRONG_INLINE Packet8s pcast(const Packet16c& a) { return vmovl_s8(vget_low_s8(a)); } template <> +EIGEN_STRONG_INLINE Packet8s pcast(const Packet8c& a) { + return vmovl_s8(a); +} +template <> EIGEN_STRONG_INLINE Packet4s pcast(const Packet8c& a) { // Discard second half of input. return vget_low_s16(vmovl_s8(a)); } +template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet4c& a) { + return pcast(vreinterpret_s8_s32(vdup_n_s32(a))); +} template <> struct type_casting_traits { @@ -352,43 +464,18 @@ EIGEN_STRONG_INLINE Packet8us pcast(const Packet16c& a) { return preinterpret(pcast(a)); } template <> +EIGEN_STRONG_INLINE Packet8us pcast(const Packet8c& a) { + return preinterpret(pcast(a)); +} +template <> EIGEN_STRONG_INLINE Packet4us pcast(const Packet8c& a) { return preinterpret(pcast(a)); } - template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet16c pcast(const Packet16c& a) { - return a; -} -template <> -EIGEN_STRONG_INLINE Packet8c pcast(const Packet8c& a) { - return a; -} -template <> -EIGEN_STRONG_INLINE Packet4c pcast(const Packet4c& a) { - return a; +EIGEN_STRONG_INLINE Packet4us pcast(const Packet4c& a) { + return preinterpret(pcast(a)); } -template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet16uc pcast(const Packet16c& a) { - return preinterpret(a); -} -template <> -EIGEN_STRONG_INLINE Packet8uc pcast(const Packet8c& a) { - return preinterpret(a); -} -template <> -EIGEN_STRONG_INLINE Packet4uc pcast(const Packet4c& a) { - return static_cast(a); -} //============================================================================== // pcast, SrcType = uint8_t @@ -403,6 +490,10 @@ EIGEN_STRONG_INLINE Packet4f pcast(const Packet16uc& a) { return vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(a))))); } template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet4uc& a) { + return vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vreinterpret_u8_u32(vdup_n_u32(a)))))); +} +template <> EIGEN_STRONG_INLINE Packet2f pcast(const Packet8uc& a) { // Discard all but first 2 bytes. return vcvt_f32_u32(vget_low_u32(vmovl_u16(vget_low_u16(vmovl_u8(a))))); @@ -437,10 +528,18 @@ EIGEN_STRONG_INLINE Packet4ui pcast(const Packet16uc& a) return vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(a)))); } template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet8uc& a) { + return vmovl_u16(vget_low_u16(vmovl_u8(a))); +} +template <> EIGEN_STRONG_INLINE Packet2ui pcast(const Packet8uc& a) { // Discard all but first 2 bytes. return vget_low_u32(vmovl_u16(vget_low_u16(vmovl_u8(a)))); } +template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet4uc& a) { + return pcast(vreinterpret_u8_u32(vdup_n_u32(a))); +} template <> struct type_casting_traits { @@ -454,6 +553,10 @@ template <> EIGEN_STRONG_INLINE Packet2i pcast(const Packet8uc& a) { return preinterpret(pcast(a)); } +template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet4uc& a) { + return preinterpret(pcast(a)); +} template <> struct type_casting_traits { @@ -465,10 +568,14 @@ EIGEN_STRONG_INLINE Packet8us pcast(const Packet16uc& a) return vmovl_u8(vget_low_u8(a)); } template <> -EIGEN_STRONG_INLINE Packet4us pcast(const Packet8uc& a) { - // Discard second half of input. - return vget_low_u16(vmovl_u8(a)); +EIGEN_STRONG_INLINE Packet8us pcast(const Packet8uc& a) { + return vmovl_u8(a); } +template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet4uc& a) { + return vget_low_u16(vmovl_u8(vreinterpret_u8_u32(vdup_n_u32(a)))); +} + template <> struct type_casting_traits { @@ -479,43 +586,14 @@ EIGEN_STRONG_INLINE Packet8s pcast(const Packet16uc& a) { return preinterpret(pcast(a)); } template <> -EIGEN_STRONG_INLINE Packet4s pcast(const Packet8uc& a) { - return preinterpret(pcast(a)); +EIGEN_STRONG_INLINE Packet8s pcast(const Packet8uc& a) { + return preinterpret(pcast(a)); +} +template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet4uc& a) { + return preinterpret(pcast(a)); } -template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet16uc pcast(const Packet16uc& a) { - return a; -} -template <> -EIGEN_STRONG_INLINE Packet8uc pcast(const Packet8uc& a) { - return a; -} -template <> -EIGEN_STRONG_INLINE Packet4uc pcast(const Packet4uc& a) { - return a; -} - -template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet16c pcast(const Packet16uc& a) { - return preinterpret(a); -} -template <> -EIGEN_STRONG_INLINE Packet8c pcast(const Packet8uc& a) { - return preinterpret(a); -} -template <> -EIGEN_STRONG_INLINE Packet4c pcast(const Packet4uc& a) { - return static_cast(a); -} //============================================================================== // pcast, SrcType = int16_t @@ -530,6 +608,10 @@ EIGEN_STRONG_INLINE Packet4f pcast(const Packet8s& a) { return vcvtq_f32_s32(vmovl_s16(vget_low_s16(a))); } template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet4s& a) { + return vcvtq_f32_s32(vmovl_s16(a)); +} +template <> EIGEN_STRONG_INLINE Packet2f pcast(const Packet4s& a) { // Discard second half of input. return vcvt_f32_s32(vget_low_s32(vmovl_s16(a))); @@ -564,6 +646,10 @@ EIGEN_STRONG_INLINE Packet4i pcast(const Packet8s& a) { return vmovl_s16(vget_low_s16(a)); } template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet4s& a) { + return vmovl_s16(a); +} +template <> EIGEN_STRONG_INLINE Packet2i pcast(const Packet4s& a) { // Discard second half of input. return vget_low_s32(vmovl_s16(a)); @@ -578,35 +664,14 @@ EIGEN_STRONG_INLINE Packet4ui pcast(const Packet8s& a) { return preinterpret(pcast(a)); } template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet4s& a) { + return preinterpret(pcast(a)); +} +template <> EIGEN_STRONG_INLINE Packet2ui pcast(const Packet4s& a) { return preinterpret(pcast(a)); } -template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet8s pcast(const Packet8s& a) { - return a; -} -template <> -EIGEN_STRONG_INLINE Packet4s pcast(const Packet4s& a) { - return a; -} - -template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet8us pcast(const Packet8s& a) { - return preinterpret(a); -} -template <> -EIGEN_STRONG_INLINE Packet4us pcast(const Packet4s& a) { - return preinterpret(a); -} template <> struct type_casting_traits { @@ -617,9 +682,18 @@ EIGEN_STRONG_INLINE Packet16c pcast(const Packet8s& a, cons return vcombine_s8(vmovn_s16(a), vmovn_s16(b)); } template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet8s& a) { + return vmovn_s16(a); +} +template <> EIGEN_STRONG_INLINE Packet8c pcast(const Packet4s& a, const Packet4s& b) { return vmovn_s16(vcombine_s16(a, b)); } +template <> +EIGEN_STRONG_INLINE Packet4c pcast(const Packet4s& a) { + const int8x8_t aa_s8x8 = pcast(a, a); + return vget_lane_s32(vreinterpret_s32_s8(aa_s8x8), 0); +} template <> struct type_casting_traits { @@ -627,11 +701,19 @@ struct type_casting_traits { }; template <> EIGEN_STRONG_INLINE Packet16uc pcast(const Packet8s& a, const Packet8s& b) { - return vcombine_u8(vmovn_u16(vreinterpretq_u16_s16(a)), vmovn_u16(vreinterpretq_u16_s16(b))); + return preinterpret(pcast(a, b)); +} +template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet8s& a) { + return preinterpret(pcast(a)); } template <> EIGEN_STRONG_INLINE Packet8uc pcast(const Packet4s& a, const Packet4s& b) { - return vmovn_u16(vcombine_u16(vreinterpret_u16_s16(a), vreinterpret_u16_s16(b))); + return preinterpret(pcast(a, b)); +} +template <> +EIGEN_STRONG_INLINE Packet4uc pcast(const Packet4s& a) { + return static_cast(pcast(a)); } //============================================================================== @@ -647,6 +729,10 @@ EIGEN_STRONG_INLINE Packet4f pcast(const Packet8us& a) { return vcvtq_f32_u32(vmovl_u16(vget_low_u16(a))); } template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet4us& a) { + return vcvtq_f32_u32(vmovl_u16(a)); +} +template <> EIGEN_STRONG_INLINE Packet2f pcast(const Packet4us& a) { // Discard second half of input. return vcvt_f32_u32(vget_low_u32(vmovl_u16(a))); @@ -681,6 +767,10 @@ EIGEN_STRONG_INLINE Packet4ui pcast(const Packet8us& a) { return vmovl_u16(vget_low_u16(a)); } template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet4us& a) { + return vmovl_u16(a); +} +template <> EIGEN_STRONG_INLINE Packet2ui pcast(const Packet4us& a) { // Discard second half of input. return vget_low_u32(vmovl_u16(a)); @@ -695,35 +785,14 @@ EIGEN_STRONG_INLINE Packet4i pcast(const Packet8us& a) { return preinterpret(pcast(a)); } template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet4us& a) { + return preinterpret(pcast(a)); +} +template <> EIGEN_STRONG_INLINE Packet2i pcast(const Packet4us& a) { return preinterpret(pcast(a)); } -template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet8us pcast(const Packet8us& a) { - return a; -} -template <> -EIGEN_STRONG_INLINE Packet4us pcast(const Packet4us& a) { - return a; -} - -template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet8s pcast(const Packet8us& a) { - return preinterpret(a); -} -template <> -EIGEN_STRONG_INLINE Packet4s pcast(const Packet4us& a) { - return preinterpret(a); -} template <> struct type_casting_traits { @@ -734,9 +803,18 @@ EIGEN_STRONG_INLINE Packet16uc pcast(const Packet8us& a, return vcombine_u8(vmovn_u16(a), vmovn_u16(b)); } template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet8us& a) { + return vmovn_u16(a); +} +template <> EIGEN_STRONG_INLINE Packet8uc pcast(const Packet4us& a, const Packet4us& b) { return vmovn_u16(vcombine_u16(a, b)); } +template <> +EIGEN_STRONG_INLINE Packet4uc pcast(const Packet4us& a) { + uint8x8_t aa_u8x8 = pcast(a, a); + return vget_lane_u32(vreinterpret_u32_u8(aa_u8x8), 0); +} template <> struct type_casting_traits { @@ -747,9 +825,17 @@ EIGEN_STRONG_INLINE Packet16c pcast(const Packet8us& a, co return preinterpret(pcast(a, b)); } template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet8us& a) { + return preinterpret(pcast(a)); +} +template <> EIGEN_STRONG_INLINE Packet8c pcast(const Packet4us& a, const Packet4us& b) { return preinterpret(pcast(a, b)); } +template <> +EIGEN_STRONG_INLINE Packet4c pcast(const Packet4us& a) { + return static_cast(pcast(a)); +} //============================================================================== // pcast, SrcType = int32_t @@ -776,6 +862,10 @@ EIGEN_STRONG_INLINE Packet2l pcast(const Packet4i& a) { // Discard second half of input. return vmovl_s32(vget_low_s32(a)); } +template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet2i& a) { + return vmovl_s32(a); +} template <> struct type_casting_traits { @@ -785,32 +875,11 @@ template <> EIGEN_STRONG_INLINE Packet2ul pcast(const Packet4i& a) { return preinterpret(pcast(a)); } - template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet4i pcast(const Packet4i& a) { - return a; -} -template <> -EIGEN_STRONG_INLINE Packet2i pcast(const Packet2i& a) { - return a; +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet2i& a) { + return preinterpret(pcast(a)); } -template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet4ui pcast(const Packet4i& a) { - return preinterpret(a); -} -template <> -EIGEN_STRONG_INLINE Packet2ui pcast(const Packet2i& a) { - return preinterpret(a); -} template <> struct type_casting_traits { @@ -821,6 +890,10 @@ EIGEN_STRONG_INLINE Packet8s pcast(const Packet4i& a, const return vcombine_s16(vmovn_s32(a), vmovn_s32(b)); } template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet4i& a) { + return vmovn_s32(a); +} +template <> EIGEN_STRONG_INLINE Packet4s pcast(const Packet2i& a, const Packet2i& b) { return vmovn_s32(vcombine_s32(a, b)); } @@ -834,6 +907,10 @@ EIGEN_STRONG_INLINE Packet8us pcast(const Packet4i& a, cons return vcombine_u16(vmovn_u32(vreinterpretq_u32_s32(a)), vmovn_u32(vreinterpretq_u32_s32(b))); } template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet4i& a) { + return vmovn_u32(vreinterpretq_u32_s32(a)); +} +template <> EIGEN_STRONG_INLINE Packet4us pcast(const Packet2i& a, const Packet2i& b) { return vmovn_u32(vreinterpretq_u32_s32(vcombine_s32(a, b))); } @@ -850,12 +927,24 @@ EIGEN_STRONG_INLINE Packet16c pcast(const Packet4i& a, cons return vcombine_s8(vmovn_s16(ab_s16), vmovn_s16(cd_s16)); } template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet4i& a, const Packet4i& b) { + const int16x8_t ab_s16 = pcast(a, b); + return vmovn_s16(ab_s16); +} +template <> EIGEN_STRONG_INLINE Packet8c pcast(const Packet2i& a, const Packet2i& b, const Packet2i& c, const Packet2i& d) { const int16x4_t ab_s16 = vmovn_s32(vcombine_s32(a, b)); const int16x4_t cd_s16 = vmovn_s32(vcombine_s32(c, d)); return vmovn_s16(vcombine_s16(ab_s16, cd_s16)); } +template <> +EIGEN_STRONG_INLINE Packet4c pcast(const Packet4i& a) { + const int16x4_t a_s16x4 = vmovn_s32(a); + const int16x8_t aa_s16x8 = vcombine_s16(a_s16x4, a_s16x4); + const int8x8_t aa_s8x8 = vmovn_s16(aa_s16x8); + return vget_lane_s32(vreinterpret_s32_s8(aa_s8x8), 0); +} template <> struct type_casting_traits { @@ -864,16 +953,20 @@ struct type_casting_traits { template <> EIGEN_STRONG_INLINE Packet16uc pcast(const Packet4i& a, const Packet4i& b, const Packet4i& c, const Packet4i& d) { - const uint16x8_t ab_u16 = pcast(a, b); - const uint16x8_t cd_u16 = pcast(c, d); - return vcombine_u8(vmovn_u16(ab_u16), vmovn_u16(cd_u16)); + return preinterpret(pcast(a, b, c, d)); +} +template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet4i& a, const Packet4i& b) { + return preinterpret(pcast(a, b)); } template <> EIGEN_STRONG_INLINE Packet8uc pcast(const Packet2i& a, const Packet2i& b, const Packet2i& c, const Packet2i& d) { - const uint16x4_t ab_u16 = pcast(a, b); - const uint16x4_t cd_u16 = pcast(c, d); - return vmovn_u16(vcombine_u16(ab_u16, cd_u16)); + return preinterpret(pcast(a, b, c, d)); +} +template <> +EIGEN_STRONG_INLINE Packet4uc pcast(const Packet4i& a) { + return static_cast(pcast(a)); } //============================================================================== @@ -901,6 +994,10 @@ EIGEN_STRONG_INLINE Packet2ul pcast(const Packet4ui& a) { // Discard second half of input. return vmovl_u32(vget_low_u32(a)); } +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet2ui& a) { + return vmovl_u32(a); +} template <> struct type_casting_traits { @@ -910,32 +1007,11 @@ template <> EIGEN_STRONG_INLINE Packet2l pcast(const Packet4ui& a) { return preinterpret(pcast(a)); } - template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet4ui pcast(const Packet4ui& a) { - return a; -} -template <> -EIGEN_STRONG_INLINE Packet2ui pcast(const Packet2ui& a) { - return a; +EIGEN_STRONG_INLINE Packet2l pcast(const Packet2ui& a) { + return preinterpret(pcast(a)); } -template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet4i pcast(const Packet4ui& a) { - return preinterpret(a); -} -template <> -EIGEN_STRONG_INLINE Packet2i pcast(const Packet2ui& a) { - return preinterpret(a); -} template <> struct type_casting_traits { @@ -949,6 +1025,10 @@ template <> EIGEN_STRONG_INLINE Packet4us pcast(const Packet2ui& a, const Packet2ui& b) { return vmovn_u32(vcombine_u32(a, b)); } +template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet4ui& a) { + return vmovn_u32(a); +} template <> struct type_casting_traits { @@ -962,6 +1042,10 @@ template <> EIGEN_STRONG_INLINE Packet4s pcast(const Packet2ui& a, const Packet2ui& b) { return preinterpret(pcast(a, b)); } +template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet4ui& a) { + return preinterpret(pcast(a)); +} template <> struct type_casting_traits { @@ -975,12 +1059,24 @@ EIGEN_STRONG_INLINE Packet16uc pcast(const Packet4ui& a, return vcombine_u8(vmovn_u16(ab_u16), vmovn_u16(cd_u16)); } template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet4ui& a, const Packet4ui& b) { + const uint16x8_t ab_u16 = vcombine_u16(vmovn_u32(a), vmovn_u32(b)); + return vmovn_u16(ab_u16); +} +template <> EIGEN_STRONG_INLINE Packet8uc pcast(const Packet2ui& a, const Packet2ui& b, const Packet2ui& c, const Packet2ui& d) { const uint16x4_t ab_u16 = vmovn_u32(vcombine_u32(a, b)); const uint16x4_t cd_u16 = vmovn_u32(vcombine_u32(c, d)); return vmovn_u16(vcombine_u16(ab_u16, cd_u16)); } +template <> +EIGEN_STRONG_INLINE Packet4uc pcast(const Packet4ui& a) { + const uint16x4_t a_u16x4 = vmovn_u32(a); + const uint16x8_t aa_u16x8 = vcombine_u16(a_u16x4, a_u16x4); + const uint8x8_t aa_u8x8 = vmovn_u16(aa_u16x8); + return vget_lane_u32(vreinterpret_u32_u8(aa_u8x8), 0); +} template <> struct type_casting_traits { @@ -992,10 +1088,18 @@ EIGEN_STRONG_INLINE Packet16c pcast(const Packet4ui& a, co return preinterpret(pcast(a, b, c, d)); } template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet4ui& a, const Packet4ui& b) { + return preinterpret(pcast(a, b)); +} +template <> EIGEN_STRONG_INLINE Packet8c pcast(const Packet2ui& a, const Packet2ui& b, const Packet2ui& c, const Packet2ui& d) { return preinterpret(pcast(a, b, c, d)); } +template <> +EIGEN_STRONG_INLINE Packet4c pcast(const Packet4ui& a) { + return static_cast(pcast(a)); +} //============================================================================== // pcast, SrcType = int64_t @@ -1008,24 +1112,11 @@ template <> EIGEN_STRONG_INLINE Packet4f pcast(const Packet2l& a, const Packet2l& b) { return vcvtq_f32_s32(vcombine_s32(vmovn_s64(a), vmovn_s64(b))); } - template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet2l pcast(const Packet2l& a) { - return a; +EIGEN_STRONG_INLINE Packet2f pcast(const Packet2l& a) { + return vcvt_f32_s32(vmovn_s64(a)); } -template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet2ul pcast(const Packet2l& a) { - return preinterpret(a); -} template <> struct type_casting_traits { @@ -1035,6 +1126,10 @@ template <> EIGEN_STRONG_INLINE Packet4i pcast(const Packet2l& a, const Packet2l& b) { return vcombine_s32(vmovn_s64(a), vmovn_s64(b)); } +template <> +EIGEN_STRONG_INLINE Packet2i pcast(const Packet2l& a) { + return vmovn_s64(a); +} template <> struct type_casting_traits { @@ -1044,6 +1139,10 @@ template <> EIGEN_STRONG_INLINE Packet4ui pcast(const Packet2l& a, const Packet2l& b) { return vcombine_u32(vmovn_u64(vreinterpretq_u64_s64(a)), vmovn_u64(vreinterpretq_u64_s64(b))); } +template <> +EIGEN_STRONG_INLINE Packet2ui pcast(const Packet2l& a) { + return vmovn_u64(vreinterpretq_u64_s64(a)); +} template <> struct type_casting_traits { @@ -1056,6 +1155,11 @@ EIGEN_STRONG_INLINE Packet8s pcast(const Packet2l& a, const const int32x4_t cd_s32 = pcast(c, d); return vcombine_s16(vmovn_s32(ab_s32), vmovn_s32(cd_s32)); } +template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet2l& a, const Packet2l& b) { + const int32x4_t ab_s32 = pcast(a, b); + return vmovn_s32(ab_s32); +} template <> struct type_casting_traits { @@ -1064,9 +1168,11 @@ struct type_casting_traits { template <> EIGEN_STRONG_INLINE Packet8us pcast(const Packet2l& a, const Packet2l& b, const Packet2l& c, const Packet2l& d) { - const uint32x4_t ab_u32 = pcast(a, b); - const uint32x4_t cd_u32 = pcast(c, d); - return vcombine_u16(vmovn_u32(ab_u32), vmovn_u32(cd_u32)); + return preinterpret(pcast(a, b, c, d)); +} +template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet2l& a, const Packet2l& b) { + return preinterpret(pcast(a, b)); } template <> @@ -1081,6 +1187,19 @@ EIGEN_STRONG_INLINE Packet16c pcast(const Packet2l& a, cons const int16x8_t efgh_s16 = pcast(e, f, g, h); return vcombine_s8(vmovn_s16(abcd_s16), vmovn_s16(efgh_s16)); } +template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet2l& a, const Packet2l& b, const Packet2l& c, + const Packet2l& d) { + const int16x8_t abcd_s16 = pcast(a, b, c, d); + return vmovn_s16(abcd_s16); +} +template <> +EIGEN_STRONG_INLINE Packet4c pcast(const Packet2l& a, const Packet2l& b) { + const int16x4_t ab_s16 = pcast(a, b); + const int16x8_t abab_s16 = vcombine_s16(ab_s16, ab_s16); + const int8x8_t abab_s8 = vmovn_s16(abab_s16); + return vget_lane_s32(vreinterpret_s32_s8(abab_s8), 0); +} template <> struct type_casting_traits { @@ -1094,6 +1213,15 @@ EIGEN_STRONG_INLINE Packet16uc pcast(const Packet2l& a, co const uint16x8_t efgh_u16 = pcast(e, f, g, h); return vcombine_u8(vmovn_u16(abcd_u16), vmovn_u16(efgh_u16)); } +template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet2l& a, const Packet2l& b, const Packet2l& c, + const Packet2l& d) { + return preinterpret(pcast(a, b, c, d)); +} +template <> +EIGEN_STRONG_INLINE Packet4uc pcast(const Packet2l& a, const Packet2l& b) { + return static_cast(pcast(a, b)); +} //============================================================================== // pcast, SrcType = uint64_t @@ -1106,24 +1234,11 @@ template <> EIGEN_STRONG_INLINE Packet4f pcast(const Packet2ul& a, const Packet2ul& b) { return vcvtq_f32_u32(vcombine_u32(vmovn_u64(a), vmovn_u64(b))); } - template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet2ul pcast(const Packet2ul& a) { - return a; +EIGEN_STRONG_INLINE Packet2f pcast(const Packet2ul& a) { + return vcvt_f32_u32(vmovn_u64(a)); } -template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet2l pcast(const Packet2ul& a) { - return preinterpret(a); -} template <> struct type_casting_traits { @@ -1133,6 +1248,10 @@ template <> EIGEN_STRONG_INLINE Packet4ui pcast(const Packet2ul& a, const Packet2ul& b) { return vcombine_u32(vmovn_u64(a), vmovn_u64(b)); } +template <> +EIGEN_STRONG_INLINE Packet2ui pcast(const Packet2ul& a) { + return vmovn_u64(a); +} template <> struct type_casting_traits { @@ -1142,6 +1261,10 @@ template <> EIGEN_STRONG_INLINE Packet4i pcast(const Packet2ul& a, const Packet2ul& b) { return preinterpret(pcast(a, b)); } +template <> +EIGEN_STRONG_INLINE Packet2i pcast(const Packet2ul& a) { + return preinterpret(pcast(a)); +} template <> struct type_casting_traits { @@ -1154,6 +1277,10 @@ EIGEN_STRONG_INLINE Packet8us pcast(const Packet2ul& a, co const uint16x4_t cd_u16 = vmovn_u32(vcombine_u32(vmovn_u64(c), vmovn_u64(d))); return vcombine_u16(ab_u16, cd_u16); } +template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet2ul& a, const Packet2ul& b) { + return vmovn_u32(vcombine_u32(vmovn_u64(a), vmovn_u64(b))); +} template <> struct type_casting_traits { @@ -1164,6 +1291,10 @@ EIGEN_STRONG_INLINE Packet8s pcast(const Packet2ul& a, cons const Packet2ul& d) { return preinterpret(pcast(a, b, c, d)); } +template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet2ul& a, const Packet2ul& b) { + return preinterpret(pcast(a, b)); +} template <> struct type_casting_traits { @@ -1177,6 +1308,19 @@ EIGEN_STRONG_INLINE Packet16uc pcast(const Packet2ul& a, const uint16x8_t efgh_u16 = pcast(e, f, g, h); return vcombine_u8(vmovn_u16(abcd_u16), vmovn_u16(efgh_u16)); } +template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c, + const Packet2ul& d) { + const uint16x8_t abcd_u16 = pcast(a, b, c, d); + return vmovn_u16(abcd_u16); +} +template <> +EIGEN_STRONG_INLINE Packet4uc pcast(const Packet2ul& a, const Packet2ul& b) { + const uint16x4_t ab_u16 = pcast(a, b); + const uint16x8_t abab_u16 = vcombine_u16(ab_u16, ab_u16); + const uint8x8_t abab_u8 = vmovn_u16(abab_u16); + return vget_lane_u32(vreinterpret_u32_u8(abab_u8), 0); +} template <> struct type_casting_traits { @@ -1188,6 +1332,15 @@ EIGEN_STRONG_INLINE Packet16c pcast(const Packet2ul& a, co const Packet2ul& g, const Packet2ul& h) { return preinterpret(pcast(a, b, c, d, e, f, g, h)); } +template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c, + const Packet2ul& d) { + return preinterpret(pcast(a, b, c, d)); +} +template <> +EIGEN_STRONG_INLINE Packet4c pcast(const Packet2ul& a, const Packet2ul& b) { + return static_cast(pcast(a, b)); +} #if EIGEN_ARCH_ARM64 @@ -1220,14 +1373,6 @@ EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet2d& a) return Packet4i(vreinterpretq_s32_f64(a)); } -template <> -struct type_casting_traits { - enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; -}; -template <> -EIGEN_STRONG_INLINE Packet2d pcast(const Packet2d& a) { - return a; -} template <> struct type_casting_traits { @@ -1237,6 +1382,10 @@ template <> EIGEN_STRONG_INLINE Packet4f pcast(const Packet2d& a, const Packet2d& b) { return vcombine_f32(vcvt_f32_f64(a), vcvt_f32_f64(b)); } +template <> +EIGEN_STRONG_INLINE Packet2f pcast(const Packet2d& a) { + return vcvt_f32_f64(a); +} template <> struct type_casting_traits { @@ -1264,6 +1413,10 @@ template <> EIGEN_STRONG_INLINE Packet4i pcast(const Packet2d& a, const Packet2d& b) { return vcombine_s32(vmovn_s64(vcvtq_s64_f64(a)), vmovn_s64(vcvtq_s64_f64(b))); } +template <> +EIGEN_STRONG_INLINE Packet2i pcast(const Packet2d& a) { + return vmovn_s64(vcvtq_s64_f64(a)); +} template <> struct type_casting_traits { @@ -1273,6 +1426,10 @@ template <> EIGEN_STRONG_INLINE Packet4ui pcast(const Packet2d& a, const Packet2d& b) { return vcombine_u32(vmovn_u64(vcvtq_u64_f64(a)), vmovn_u64(vcvtq_u64_f64(b))); } +template <> +EIGEN_STRONG_INLINE Packet2ui pcast(const Packet2d& a) { + return vmovn_u64(vcvtq_u64_f64(a)); +} template <> struct type_casting_traits { @@ -1285,6 +1442,11 @@ EIGEN_STRONG_INLINE Packet8s pcast(const Packet2d& a, const const int32x4_t cd_s32 = pcast(c, d); return vcombine_s16(vmovn_s32(ab_s32), vmovn_s32(cd_s32)); } +template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet2d& a, const Packet2d& b) { + const int32x4_t ab_s32 = pcast(a, b); + return vmovn_s32(ab_s32); +} template <> struct type_casting_traits { @@ -1293,9 +1455,11 @@ struct type_casting_traits { template <> EIGEN_STRONG_INLINE Packet8us pcast(const Packet2d& a, const Packet2d& b, const Packet2d& c, const Packet2d& d) { - const uint32x4_t ab_u32 = pcast(a, b); - const uint32x4_t cd_u32 = pcast(c, d); - return vcombine_u16(vmovn_u32(ab_u32), vmovn_u32(cd_u32)); + return preinterpret(pcast(a, b, c, d)); +} +template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet2d& a, const Packet2d& b) { + return preinterpret(pcast(a, b)); } template <> @@ -1310,6 +1474,17 @@ EIGEN_STRONG_INLINE Packet16c pcast(const Packet2d& a, cons const int16x8_t efgh_s16 = pcast(e, f, g, h); return vcombine_s8(vmovn_s16(abcd_s16), vmovn_s16(efgh_s16)); } +template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet2d& a, const Packet2d& b, const Packet2d& c, + const Packet2d& d) { + const int16x8_t abcd_s16 = pcast(a, b, c, d); + return vmovn_s16(abcd_s16); +} +template <> +EIGEN_STRONG_INLINE Packet4c pcast(const Packet2d& a, const Packet2d& b) { + const int32x4_t ab_s32 = pcast(a, b); + return pcast(ab_s32); +} template <> struct type_casting_traits { @@ -1323,6 +1498,15 @@ EIGEN_STRONG_INLINE Packet16uc pcast(const Packet2d& a, co const uint16x8_t efgh_u16 = pcast(e, f, g, h); return vcombine_u8(vmovn_u16(abcd_u16), vmovn_u16(efgh_u16)); } +template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet2d& a, const Packet2d& b, const Packet2d& c, + const Packet2d& d) { + return preinterpret(pcast(a, b, c, d)); +} +template <> +EIGEN_STRONG_INLINE Packet4uc pcast(const Packet2d& a, const Packet2d& b) { + return static_cast(pcast(a, b)); +} template <> struct type_casting_traits { @@ -1333,6 +1517,10 @@ EIGEN_STRONG_INLINE Packet2d pcast(const Packet4f& a) { // Discard second-half of input. return vcvt_f64_f32(vget_low_f32(a)); } +template <> +EIGEN_STRONG_INLINE Packet2d pcast(const Packet2f& a) { + return vcvt_f64_f32(a); +} template <> struct type_casting_traits { @@ -1388,6 +1576,10 @@ EIGEN_STRONG_INLINE Packet2d pcast(const Packet4i& a) { // Discard second half of input. return vcvtq_f64_s64(vmovl_s32(vget_low_s32(a))); } +template <> +EIGEN_STRONG_INLINE Packet2d pcast(const Packet2i& a) { + return vcvtq_f64_s64(vmovl_s32(a)); +} template <> struct type_casting_traits { @@ -1398,6 +1590,10 @@ EIGEN_STRONG_INLINE Packet2d pcast(const Packet4ui& a) { // Discard second half of input. return vcvtq_f64_u64(vmovl_u32(vget_low_u32(a))); } +template <> +EIGEN_STRONG_INLINE Packet2d pcast(const Packet2ui& a) { + return vcvtq_f64_u64(vmovl_u32(a)); +} template <> struct type_casting_traits { diff --git a/Eigen/src/Core/arch/SSE/TypeCasting.h b/Eigen/src/Core/arch/SSE/TypeCasting.h index df5c72c42..0b5aa1c78 100644 --- a/Eigen/src/Core/arch/SSE/TypeCasting.h +++ b/Eigen/src/Core/arch/SSE/TypeCasting.h @@ -135,6 +135,13 @@ template<> EIGEN_STRONG_INLINE Packet4i preinterpret(const Pa return _mm_castpd_si128(a); } +template<> EIGEN_STRONG_INLINE Packet4ui preinterpret(const Packet4i& a) { + return Packet4ui(a); +} + +template<> EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet4ui& a) { + return Packet4i(a); +} // Disable the following code since it's broken on too many platforms / compilers. //#elif defined(EIGEN_VECTORIZE_SSE) && (!EIGEN_ARCH_x86_64) && (!EIGEN_COMP_MSVC) #if 0 diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 8354c0a76..4760d9b59 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -179,16 +179,28 @@ struct scalar_cast_op { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const NewType operator() (const Scalar& a) const { return cast(a); } }; -template -struct scalar_cast_op { - typedef bool result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a) const { return a != Scalar(0); } -}; - template struct functor_traits > { enum { Cost = is_same::value ? 0 : NumTraits::AddCost, PacketAccess = false }; }; +/** \internal + * `core_cast_op` serves to distinguish the vectorized implementation from that of the legacy `scalar_cast_op` for backwards + * compatibility. The manner in which packet ops are handled is defined by the specialized unary_evaluator: + * `unary_evaluator, ArgType>, IndexBased>` in CoreEvaluators.h + * Otherwise, the non-vectorized behavior is identical to that of `scalar_cast_op` + */ +template +struct core_cast_op : scalar_cast_op {}; + +template +struct functor_traits> { + using CastingTraits = type_casting_traits; + enum { + Cost = is_same::value ? 0 : NumTraits::AddCost, + PacketAccess = CastingTraits::VectorizedCast && (CastingTraits::SrcCoeffRatio <= 8) + }; +}; + /** \internal * \brief Template functor to arithmetically shift a scalar right by a number of bits * diff --git a/Eigen/src/Core/util/XprHelper.h b/Eigen/src/Core/util/XprHelper.h index b5f91bf75..49fa37c0c 100644 --- a/Eigen/src/Core/util/XprHelper.h +++ b/Eigen/src/Core/util/XprHelper.h @@ -190,6 +190,30 @@ struct find_best_packet typedef typename find_best_packet_helper::type>::type type; }; +template ::size) || + is_same::half>::value> +struct find_packet_by_size_helper; +template +struct find_packet_by_size_helper { + using type = PacketType; +}; +template +struct find_packet_by_size_helper { + using type = typename find_packet_by_size_helper::half>::type; +}; + +template +struct find_packet_by_size { + using type = typename find_packet_by_size_helper::type>::type; + static constexpr bool value = (Size == unpacket_traits::size); +}; +template +struct find_packet_by_size { + using type = typename unpacket_traits::type; + static constexpr bool value = (unpacket_traits::size == 1); +}; + #if EIGEN_MAX_STATIC_ALIGN_BYTES>0 constexpr inline int compute_default_alignment_helper(int ArrayBytes, int AlignmentBytes) { if((ArrayBytes % AlignmentBytes) == 0) { diff --git a/Eigen/src/plugins/CommonCwiseUnaryOps.h b/Eigen/src/plugins/CommonCwiseUnaryOps.h index 390759cd0..1c6b28451 100644 --- a/Eigen/src/plugins/CommonCwiseUnaryOps.h +++ b/Eigen/src/plugins/CommonCwiseUnaryOps.h @@ -45,7 +45,7 @@ inline const NegativeReturnType operator-() const { return NegativeReturnType(derived()); } -template struct CastXpr { typedef typename internal::cast_return_type, const Derived> >::type Type; }; +template struct CastXpr { typedef typename internal::cast_return_type, const Derived> >::type Type; }; /// \returns an expression of \c *this with the \a Scalar type casted to /// \a NewScalar. diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 989359fe1..1648c7cb4 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -9,6 +9,7 @@ #include #include "main.h" +#include "random_without_cast_overflow.h" // suppress annoying unsigned integer warnings template ::IsSigned> @@ -1213,6 +1214,109 @@ void typed_logicals_test(const ArrayType& m) { typed_logicals_test_impl::run(m); } +template +struct cast_test_impl { + using SrcArray = Array; + using DstArray = Array; + struct RandomOp { + inline SrcType operator()(const SrcType&) const { + return internal::random_without_cast_overflow::value(); + } + }; + + static constexpr int SrcPacketSize = internal::packet_traits::size; + static constexpr int DstPacketSize = internal::packet_traits::size; + static constexpr int MaxPacketSize = internal::plain_enum_max(SrcPacketSize, DstPacketSize); + + // print non-mangled typenames + template + static std::string printTypeInfo(const T&) { + if (internal::is_same::value) + return "bool"; + else if (internal::is_same::value) + return "int8_t"; + else if (internal::is_same::value) + return "int16_t"; + else if (internal::is_same::value) + return "int32_t"; + else if (internal::is_same::value) + return "int64_t"; + else if (internal::is_same::value) + return "uint8_t"; + else if (internal::is_same::value) + return "uint16_t"; + else if (internal::is_same::value) + return "uint32_t"; + else if (internal::is_same::value) + return "uint64_t"; + else if (internal::is_same::value) + return "float"; + else if (internal::is_same::value) + return "double"; + //else if (internal::is_same::value) + // return "long double"; + else if (internal::is_same::value) + return "half"; + else if (internal::is_same::value) + return "bfloat16"; + else + return typeid(T).name(); + } + + static void run() { + const Index testRows = RowsAtCompileTime == Dynamic ? ((10 * MaxPacketSize) + 1) : RowsAtCompileTime; + const Index testCols = ColsAtCompileTime == Dynamic ? ((10 * MaxPacketSize) + 1) : ColsAtCompileTime; + const Index testSize = testRows * testCols; + const Index minTestSize = 100; + const Index repeats = numext::div_ceil(minTestSize, testSize); + SrcArray src(testRows, testCols); + DstArray dst(testRows, testCols); + for (Index repeat = 0; repeat < repeats; repeat++) { + src = src.unaryExpr(RandomOp()); + dst = src.template cast(); + for (Index i = 0; i < testRows; i++) + for (Index j = 0; j < testCols; j++) { + DstType ref = internal::cast_impl::run(src(i, j)); + bool all_nan = ((numext::isnan)(src(i, j)) && (numext::isnan)(ref) && (numext::isnan)(dst(i, j))); + bool is_equal = ref == dst(i, j); + bool pass = all_nan || is_equal; + if (!pass) { + std::cout << printTypeInfo(SrcType()) << ": [" << +src(i, j) << "] to " << printTypeInfo(DstType()) << ": [" + << +dst(i, j) << "] != [" << +ref << "]\n"; + } + VERIFY(pass); + } + } + } +}; + +template +struct cast_tests_impl { + using ScalarTuple = std::tuple; + static constexpr size_t ScalarTupleSize = std::tuple_size::value; + + template = ScalarTupleSize - 1) || (j >= ScalarTupleSize)> + static std::enable_if_t run() {} + + template = ScalarTupleSize - 1) || (j >= ScalarTupleSize)> + static std::enable_if_t run() { + using Type1 = typename std::tuple_element::type; + using Type2 = typename std::tuple_element::type; + cast_test_impl::run(); + cast_test_impl::run(); + static constexpr size_t next_i = (j == ScalarTupleSize - 1) ? (i + 1) : (i + 0); + static constexpr size_t next_j = (j == ScalarTupleSize - 1) ? (i + 2) : (j + 1); + run(); + } +}; + +// for now, remove all references to 'long double' until test passes on all platforms +template +void cast_test() { + cast_tests_impl::run(); +} + EIGEN_DECLARE_TEST(array_cwise) { for(int i = 0; i < g_repeat; i++) { @@ -1269,6 +1373,20 @@ EIGEN_DECLARE_TEST(array_cwise) CALL_SUBTEST_3( typed_logicals_test(ArrayX>(internal::random(1, EIGEN_TEST_MAX_SIZE)))); } + for (int i = 0; i < g_repeat; i++) { + CALL_SUBTEST_1((cast_test<1, 1>())); + CALL_SUBTEST_2((cast_test<3, 1>())); + CALL_SUBTEST_2((cast_test<3, 3>())); + CALL_SUBTEST_3((cast_test<5, 1>())); + CALL_SUBTEST_3((cast_test<5, 5>())); + CALL_SUBTEST_4((cast_test<9, 1>())); + CALL_SUBTEST_4((cast_test<9, 9>())); + CALL_SUBTEST_5((cast_test<17, 1>())); + CALL_SUBTEST_5((cast_test<17, 17>())); + CALL_SUBTEST_6((cast_test())); + CALL_SUBTEST_6((cast_test())); + } + VERIFY((internal::is_same< internal::global_math_functions_filtering_base::type, int >::value)); VERIFY((internal::is_same< internal::global_math_functions_filtering_base::type, float >::value)); VERIFY((internal::is_same< internal::global_math_functions_filtering_base::type, ArrayBase >::value));