diff --git a/Eigen/Core b/Eigen/Core index bf0b9c736..857cffa15 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -183,7 +183,6 @@ using std::ptrdiff_t; // Generic half float support #include "src/Core/arch/Default/Half.h" #include "src/Core/arch/Default/BFloat16.h" -#include "src/Core/arch/Default/TypeCasting.h" #include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h" #if defined EIGEN_VECTORIZE_AVX512 diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index e233efba1..02c0600b0 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -621,6 +621,110 @@ protected: Data m_d; }; +// ----------------------- Casting --------------------- +template +struct unary_evaluator, ArgType>, IndexBased> { + using CastOp = scalar_cast_op; + using XprType = CwiseUnaryOp; + using SrcPacketType = typename packet_traits::type; + + static constexpr int SrcPacketSize = packet_traits::size; + static constexpr int SrcPacketSizeBytes = SrcPacketSize * sizeof(SrcType); + + enum { + CoeffReadCost = int(evaluator::CoeffReadCost) + int(functor_traits::Cost), + Flags = evaluator::Flags & + (HereditaryBits | LinearAccessBit | (functor_traits::PacketAccess ? PacketAccessBit : 0)), + Alignment = evaluator::Alignment + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit unary_evaluator(const XprType& xpr) + : m_argImpl(xpr.nestedExpression()) { + EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits::Cost); + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DstType coeff(Index row, Index col) const { + return CastOp()(m_argImpl.coeff(row, col)); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DstType coeff(Index index) const { + return CastOp()(m_argImpl.coeff(index)); + } + + template + EIGEN_ALWAYS_INLINE SrcPacketType srcPacket(Index row, Index col, Index offset) const { + EIGEN_STATIC_ASSERT((LoadMode & (LoadMode - 1)) == 0, LoadMode must be a power of two) + constexpr bool ArgIsRowMajor = evaluator::Flags & RowMajorBit; + return m_argImpl.template packet(ArgIsRowMajor ? row : row + (offset * SrcPacketSize), + ArgIsRowMajor ? col + (offset * SrcPacketSize) : col); + } + template + EIGEN_ALWAYS_INLINE SrcPacketType srcPacket(Index index, Index offset) const { + EIGEN_STATIC_ASSERT((LoadMode & (LoadMode - 1)) == 0, LoadMode must be a power of two) + return m_argImpl.template packet(index + (offset * SrcPacketSize)); + } + + template + using SrcPacketArgs1 = std::enable_if_t::size <= (1 * SrcPacketSize), bool>; + template + using SrcPacketArgs2 = std::enable_if_t::size == (2 * SrcPacketSize), bool>; + template + using SrcPacketArgs4 = std::enable_if_t::size == (4 * SrcPacketSize), bool>; + + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const { + constexpr int DstPacketSize = unpacket_traits::size; + constexpr int SrcIncrementBytes = DstPacketSize * sizeof(SrcType); + constexpr int SrcLoadMode = plain_enum_min(SrcIncrementBytes, LoadMode); + return CastOp().template packetOp(srcPacket(row, col, 0)); + } + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const { + constexpr int SrcLoadMode0 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode); + constexpr int SrcLoadMode1 = plain_enum_min(1 * SrcPacketSizeBytes, LoadMode); + return CastOp().template packetOp(srcPacket(row, col, 0), + srcPacket(row, col, 1)); + } + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const { + constexpr int SrcLoadMode0 = plain_enum_min(4 * SrcPacketSizeBytes, LoadMode); + constexpr int SrcLoadMode1 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode); + constexpr int SrcLoadMode2 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode); + constexpr int SrcLoadMode3 = plain_enum_min(1 * SrcPacketSizeBytes, LoadMode); + return CastOp().template packetOp( + srcPacket(row, col, 0), srcPacket(row, col, 1), + srcPacket(row, col, 2), srcPacket(row, col, 3)); + } + + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index index) const { + constexpr int DstPacketSize = unpacket_traits::size; + constexpr int SrcIncrementBytes = DstPacketSize * sizeof(SrcType); + constexpr int SrcLoadMode = plain_enum_min(SrcIncrementBytes, LoadMode); + return CastOp().template packetOp(srcPacket(index, 0)); + } + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index index) const { + constexpr int SrcLoadMode0 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode); + constexpr int SrcLoadMode1 = plain_enum_min(1 * SrcPacketSizeBytes, LoadMode); + return CastOp().template packetOp(srcPacket(index, 0), + srcPacket(index, 1)); + } + template = true> + EIGEN_STRONG_INLINE DstPacketType packet(Index index) const { + constexpr int SrcLoadMode0 = plain_enum_min(4 * SrcPacketSizeBytes, LoadMode); + constexpr int SrcLoadMode1 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode); + constexpr int SrcLoadMode2 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode); + constexpr int SrcLoadMode3 = plain_enum_min(1 * SrcPacketSizeBytes, LoadMode); + return CastOp().template packetOp( + srcPacket(index, 0), srcPacket(index, 1), + srcPacket(index, 2), srcPacket(index, 3)); + } + + protected: + const evaluator m_argImpl; +}; + // -------------------- CwiseTernaryOp -------------------- // this is a ternary expression diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 40ee3f551..5fa114f52 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -430,6 +430,12 @@ struct cast_impl } }; +template +struct cast_impl { + EIGEN_DEVICE_FUNC + static inline bool run(const OldType& x) { return x != OldType(0); } +}; + // Casting from S -> Complex leads to an implicit conversion from S to T, // generating warnings on clang. Here we explicitly cast the real component. template diff --git a/Eigen/src/Core/arch/AVX/TypeCasting.h b/Eigen/src/Core/arch/AVX/TypeCasting.h index 386543e66..41db0355d 100644 --- a/Eigen/src/Core/arch/AVX/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX/TypeCasting.h @@ -62,6 +62,15 @@ struct type_casting_traits { TgtCoeffRatio = 1 }; }; + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 2 + }; +}; #endif // EIGEN_VECTORIZE_AVX512 template<> EIGEN_STRONG_INLINE Packet8i pcast(const Packet8f& a) { @@ -80,6 +89,10 @@ template<> EIGEN_STRONG_INLINE Packet8i pcast(const Packet4d return _mm256_set_m128i(_mm256_cvttpd_epi32(b), _mm256_cvttpd_epi32(a)); } +template<> EIGEN_STRONG_INLINE Packet4d pcast(const Packet8f& a) { + return _mm256_cvtps_pd(_mm256_castps256_ps128(a)); +} + template <> EIGEN_STRONG_INLINE Packet16b pcast(const Packet8f& a, const Packet8f& b) { diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h index c8ca33a5b..c08b7c5da 100644 --- a/Eigen/src/Core/arch/Default/Half.h +++ b/Eigen/src/Core/arch/Default/Half.h @@ -1014,4 +1014,49 @@ struct hash { } // end namespace std #endif +namespace Eigen { +namespace internal { + +template <> +struct cast_impl { + EIGEN_DEVICE_FUNC + static inline half run(const float& a) { +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return __float2half(a); +#else + return Eigen::half(a); +#endif + } +}; + +template <> +struct cast_impl { + EIGEN_DEVICE_FUNC + static inline half run(const int& a) { +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return __float2half(static_cast(a)); +#else + return half(static_cast(a)); +#endif + } +}; + +template <> +struct cast_impl { + EIGEN_DEVICE_FUNC + static inline float run(const half& a) { +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return __half2float(a); +#else + return static_cast(a); +#endif + } +}; + +} // namespace internal +} // namespace Eigen + #endif // EIGEN_HALF_H diff --git a/Eigen/src/Core/arch/Default/TypeCasting.h b/Eigen/src/Core/arch/Default/TypeCasting.h deleted file mode 100644 index dc779a725..000000000 --- a/Eigen/src/Core/arch/Default/TypeCasting.h +++ /dev/null @@ -1,116 +0,0 @@ -// This file is part of Eigen, a lightweight C++ template library -// for linear algebra. -// -// Copyright (C) 2016 Benoit Steiner -// Copyright (C) 2019 Rasmus Munk Larsen -// -// This Source Code Form is subject to the terms of the Mozilla -// Public License v. 2.0. If a copy of the MPL was not distributed -// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. - -#ifndef EIGEN_GENERIC_TYPE_CASTING_H -#define EIGEN_GENERIC_TYPE_CASTING_H - -#include "../../InternalHeaderCheck.h" - -namespace Eigen { - -namespace internal { - -template<> -struct scalar_cast_op { - typedef Eigen::half result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const float& a) const { - #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ - (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) - return __float2half(a); - #else - return Eigen::half(a); - #endif - } -}; - -template<> -struct functor_traits > -{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; - - -template<> -struct scalar_cast_op { - typedef Eigen::half result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const int& a) const { - #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ - (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) - return __float2half(static_cast(a)); - #else - return Eigen::half(static_cast(a)); - #endif - } -}; - -template<> -struct functor_traits > -{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; - - -template<> -struct scalar_cast_op { - typedef float result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::half& a) const { - #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ - (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) - return __half2float(a); - #else - return static_cast(a); - #endif - } -}; - -template<> -struct functor_traits > -{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; - - -template<> -struct scalar_cast_op { - typedef Eigen::bfloat16 result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const { - return Eigen::bfloat16(a); - } -}; - -template<> -struct functor_traits > -{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; - - -template<> -struct scalar_cast_op { - typedef Eigen::bfloat16 result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const { - return Eigen::bfloat16(static_cast(a)); - } -}; - -template<> -struct functor_traits > -{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; - - -template<> -struct scalar_cast_op { - typedef float result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const { - return static_cast(a); - } -}; - -template<> -struct functor_traits > -{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; - - -} -} - -#endif // EIGEN_GENERIC_TYPE_CASTING_H diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 8354c0a76..8d7f59b02 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -173,22 +173,40 @@ struct functor_traits> { * * \sa class CwiseUnaryOp, MatrixBase::cast() */ -template +template struct scalar_cast_op { - typedef NewType result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const NewType operator() (const Scalar& a) const { return cast(a); } + + using result_type = DstType; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DstType operator()(const SrcType& a) const { + return cast(a); + } + + using SrcPacket = typename packet_traits::type; + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DstPacket packetOp(const SrcPacket& a) const { + return pcast(a); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DstPacket packetOp(const SrcPacket& a, const SrcPacket& b) const { + return pcast(a, b); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DstPacket packetOp(const SrcPacket& a, const SrcPacket& b, + const SrcPacket& c, const SrcPacket& d) const { + return pcast(a, b, c, d); + } }; -template -struct scalar_cast_op { - typedef bool result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a) const { return a != Scalar(0); } +template +struct functor_traits> { + enum { + Cost = is_same::value ? 0 : NumTraits::AddCost, + PacketAccess = (type_casting_traits::VectorizedCast != 0) && + (type_casting_traits::SrcCoeffRatio <= 4) + }; }; -template -struct functor_traits > -{ enum { Cost = is_same::value ? 0 : NumTraits::AddCost, PacketAccess = false }; }; - /** \internal * \brief Template functor to arithmetically shift a scalar right by a number of bits * diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 4f8111bc6..1ddd22b32 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -9,6 +9,7 @@ #include #include "main.h" +#include "random_without_cast_overflow.h" template ::IsInteger,int> = 0> std::vector special_values() { @@ -1183,6 +1184,59 @@ void typed_logicals_test(const ArrayType& m) { typed_logicals_test_impl::run(m); } +template +struct cast_test_impl { + using SrcArray = ArrayX; + using DstArray = ArrayX; + + static constexpr int SrcPacketSize = internal::packet_traits::size; + static constexpr int DstPacketSize = internal::packet_traits::size; + static constexpr int MaxPacketSize = internal::plain_enum_max(SrcPacketSize, DstPacketSize); + + static void run() { + const Index testSize = 100 * MaxPacketSize; + SrcArray src(testSize); + for (Index i = 0; i < testSize; i++) src(i) = internal::random_without_cast_overflow::value(); + DstArray dst = src.template cast(); + for (Index i = 0; i < testSize; i++) { + DstType ref = static_cast(src(i)); + bool all_nan = ((numext::isnan)(src(i)) && (numext::isnan)(ref) && (numext::isnan)(dst(i))); + bool is_equal = ref == dst(i); + bool pass = all_nan || is_equal; + if (!pass) { + std::cout << typeid(SrcType).name() << ": [" << +src(i) << "] to " << typeid(DstType).name() << ": [" << +dst(i) + << "] != [" << +ref << "]\n"; + } + VERIFY(pass); + } + } +}; + +template +struct cast_tests_impl { + using ScalarTuple = std::tuple; + static constexpr size_t ScalarTupleSize = std::tuple_size::value; + + template = ScalarTupleSize - 1) || (j >= ScalarTupleSize)> + static std::enable_if_t run() {} + + template = ScalarTupleSize - 1) || (j >= ScalarTupleSize)> + static std::enable_if_t run() { + using Type1 = typename std::tuple_element::type; + using Type2 = typename std::tuple_element::type; + cast_test_impl::run(); + cast_test_impl::run(); + static constexpr size_t next_i = (j == ScalarTupleSize - 1) ? (i + 1) : (i + 0); + static constexpr size_t next_j = (j == ScalarTupleSize - 1) ? (i + 2) : (j + 1); + run(); + } +}; + +void cast_test() { + cast_tests_impl::run(); +} + EIGEN_DECLARE_TEST(array_cwise) { for(int i = 0; i < g_repeat; i++) { @@ -1238,6 +1292,9 @@ EIGEN_DECLARE_TEST(array_cwise) CALL_SUBTEST_3( typed_logicals_test(ArrayX>(internal::random(1, EIGEN_TEST_MAX_SIZE)))); CALL_SUBTEST_3( typed_logicals_test(ArrayX>(internal::random(1, EIGEN_TEST_MAX_SIZE)))); } + for (int i = 0; i < g_repeat; i++) { + cast_test(); + } VERIFY((internal::is_same< internal::global_math_functions_filtering_base::type, int >::value)); VERIFY((internal::is_same< internal::global_math_functions_filtering_base::type, float >::value));