diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index c3725d473..99ce99a27 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -65,13 +65,8 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base { : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(f)) {} // Following the convention of numpy, converting between complex and // float will lead to loss of imag value. - // Single precision complex. - typedef std::complex complex64; - // Double precision complex. - typedef std::complex complex128; - explicit EIGEN_DEVICE_FUNC bfloat16(const complex64& val) - : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(val.real())) {} - explicit EIGEN_DEVICE_FUNC bfloat16(const complex128& val) + template + explicit EIGEN_DEVICE_FUNC bfloat16(const std::complex& val) : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast(val.real()))) {} EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const { @@ -114,11 +109,9 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base { EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const { return static_cast(bfloat16_impl::bfloat16_to_float(*this)); } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex64) const { - return complex64(bfloat16_impl::bfloat16_to_float(*this), float(0.0)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex128) const { - return complex128(static_cast(bfloat16_impl::bfloat16_to_float(*this)), double(0.0)); + template + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(std::complex) const { + return std::complex(static_cast(bfloat16_impl::bfloat16_to_float(*this)), RealScalar(0)); } EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(Eigen::half) const { return static_cast(bfloat16_impl::bfloat16_to_float(*this)); diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h index cfd0bdc06..b84cfc7db 100644 --- a/Eigen/src/Core/arch/Default/Half.h +++ b/Eigen/src/Core/arch/Default/Half.h @@ -86,7 +86,7 @@ struct half_base : public __half_raw { #if (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000) EIGEN_DEVICE_FUNC half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {} #endif - #endif + #endif #endif }; @@ -133,6 +133,11 @@ struct half : public half_impl::half_base { : half_impl::half_base(half_impl::float_to_half_rtne(static_cast(val))) {} explicit EIGEN_DEVICE_FUNC half(float f) : half_impl::half_base(half_impl::float_to_half_rtne(f)) {} + // Following the convention of numpy, converting between complex and + // float will lead to loss of imag value. + template + explicit EIGEN_DEVICE_FUNC half(std::complex c) + : half_impl::half_base(half_impl::float_to_half_rtne(static_cast(c.real()))) {} EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const { // +0.0 and -0.0 become false, everything else becomes true. @@ -174,6 +179,11 @@ struct half : public half_impl::half_base { EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const { return static_cast(half_impl::half_to_float(*this)); } + + template + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(std::complex) const { + return std::complex(static_cast(*this), RealScalar(0)); + } }; } // end namespace Eigen diff --git a/test/basicstuff.cpp b/test/basicstuff.cpp index 85af603d8..80fc8a07f 100644 --- a/test/basicstuff.cpp +++ b/test/basicstuff.cpp @@ -10,6 +10,7 @@ #define EIGEN_NO_STATIC_ASSERT #include "main.h" +#include "random_without_cast_overflow.h" template void basicStuff(const MatrixType& m) { @@ -90,7 +91,7 @@ template void basicStuff(const MatrixType& m) Matrix cv(rows); rv = square.row(r); cv = square.col(r); - + VERIFY_IS_APPROX(rv, cv.transpose()); if(cols!=1 && rows!=1 && MatrixType::SizeAtCompileTime!=Dynamic) @@ -120,28 +121,28 @@ template void basicStuff(const MatrixType& m) m1 = m2; VERIFY(m1==m2); VERIFY(!(m1!=m2)); - + // check automatic transposition sm2.setZero(); for(Index i=0;i(0,10)>5; @@ -194,14 +195,72 @@ template void basicStuffComplex(const MatrixType& m) VERIFY(!static_cast(cm).imag().isZero()); } -template -void casting() +template +void casting_test() { - Matrix4f m = Matrix4f::Random(), m2; - Matrix4d n = m.cast(); - VERIFY(m.isApprox(n.cast())); - m2 = m.cast(); // check the specialization when NewType == Type - VERIFY(m.isApprox(m2)); + Matrix m; + for (int i=0; i::value(); + } + } + Matrix n = m.template cast(); + for (int i=0; i(m(i, j))); + } + } +} + +template +struct casting_test_runner { + static void run() { + casting_test(); + casting_test(); + casting_test(); + casting_test(); + casting_test(); + casting_test(); + casting_test(); + casting_test(); + casting_test(); + casting_test(); + casting_test(); + casting_test(); + casting_test(); + casting_test>(); + casting_test>(); + } +}; + +template +struct casting_test_runner::IsComplex)>::type> +{ + static void run() { + // Only a few casts from std::complex are defined. + casting_test(); + casting_test(); + casting_test>(); + casting_test>(); + } +}; + +void casting_all() { + casting_test_runner::run(); + casting_test_runner::run(); + casting_test_runner::run(); + casting_test_runner::run(); + casting_test_runner::run(); + casting_test_runner::run(); + casting_test_runner::run(); + casting_test_runner::run(); + casting_test_runner::run(); + casting_test_runner::run(); + casting_test_runner::run(); + casting_test_runner::run(); + casting_test_runner::run(); + casting_test_runner>::run(); + casting_test_runner>::run(); } template @@ -210,12 +269,12 @@ void fixedSizeMatrixConstruction() Scalar raw[4]; for(int k=0; k<4; ++k) raw[k] = internal::random(); - + { Matrix m(raw); Array a(raw); for(int k=0; k<4; ++k) VERIFY(m(k) == raw[k]); - for(int k=0; k<4; ++k) VERIFY(a(k) == raw[k]); + for(int k=0; k<4; ++k) VERIFY(a(k) == raw[k]); VERIFY_IS_EQUAL(m,(Matrix(raw[0],raw[1],raw[2],raw[3]))); VERIFY((a==(Array(raw[0],raw[1],raw[2],raw[3]))).all()); } @@ -277,6 +336,7 @@ EIGEN_DECLARE_TEST(basicstuff) CALL_SUBTEST_5( basicStuff(MatrixXcd(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE))) ); CALL_SUBTEST_6( basicStuff(Matrix()) ); CALL_SUBTEST_7( basicStuff(Matrix(internal::random(1,EIGEN_TEST_MAX_SIZE),internal::random(1,EIGEN_TEST_MAX_SIZE))) ); + CALL_SUBTEST_8( casting_all() ); CALL_SUBTEST_3( basicStuffComplex(MatrixXcf(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE))) ); CALL_SUBTEST_5( basicStuffComplex(MatrixXcd(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE))) ); @@ -288,6 +348,4 @@ EIGEN_DECLARE_TEST(basicstuff) CALL_SUBTEST_1(fixedSizeMatrixConstruction()); CALL_SUBTEST_1(fixedSizeMatrixConstruction()); CALL_SUBTEST_1(fixedSizeMatrixConstruction()); - - CALL_SUBTEST_2(casting<0>()); } diff --git a/test/packetmath.cpp b/test/packetmath.cpp index dbc1d3f5a..7821877db 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -8,8 +8,8 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. -#include #include "packetmath_test_shared.h" +#include "random_without_cast_overflow.h" template inline T REF_ADD(const T& a, const T& b) { @@ -126,129 +126,6 @@ struct test_cast_helper -struct random_without_cast_overflow { - static SrcScalar value() { return internal::random(); } -}; - -// Widening integer cast signed to unsigned. -template -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if::IsInteger && NumTraits::IsInteger && - !NumTraits::IsSigned && - (std::numeric_limits::digits < std::numeric_limits::digits || - (std::numeric_limits::digits == std::numeric_limits::digits && - NumTraits::IsSigned))>::type> { - static SrcScalar value() { - SrcScalar a = internal::random(); - return a < SrcScalar(0) ? -(a + 1) : a; - } -}; - -// Narrowing integer cast to unsigned. -template -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if< - NumTraits::IsInteger && NumTraits::IsInteger && !NumTraits::IsSigned && - (std::numeric_limits::digits > std::numeric_limits::digits)>::type> { - static SrcScalar value() { - TgtScalar b = internal::random(); - return static_cast(b < TgtScalar(0) ? -(b + 1) : b); - } -}; - -// Narrowing integer cast to signed. -template -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if< - NumTraits::IsInteger && NumTraits::IsInteger && NumTraits::IsSigned && - (std::numeric_limits::digits > std::numeric_limits::digits)>::type> { - static SrcScalar value() { - TgtScalar b = internal::random(); - return static_cast(b); - } -}; - -// Unsigned to signed narrowing cast. -template -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if::IsInteger && NumTraits::IsInteger && - !NumTraits::IsSigned && NumTraits::IsSigned && - (std::numeric_limits::digits == - std::numeric_limits::digits)>::type> { - static SrcScalar value() { return internal::random() / 2; } -}; - -template -struct is_floating_point { - enum { value = 0 }; -}; -template <> -struct is_floating_point { - enum { value = 1 }; -}; -template <> -struct is_floating_point { - enum { value = 1 }; -}; -template <> -struct is_floating_point { - enum { value = 1 }; -}; -template <> -struct is_floating_point { - enum { value = 1 }; -}; - -// Floating-point to integer, full precision. -template -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if::value && NumTraits::IsInteger && - (std::numeric_limits::digits <= - std::numeric_limits::digits)>::type> { - static SrcScalar value() { return static_cast(internal::random()); } -}; - -// Floating-point to integer, narrowing precision. -template -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if::value && NumTraits::IsInteger && - (std::numeric_limits::digits > - std::numeric_limits::digits)>::type> { - static SrcScalar value() { - static const int BitShift = std::numeric_limits::digits - std::numeric_limits::digits; - return static_cast(internal::random() >> BitShift); - } -}; - -// Floating-point target from integer, re-use above logic. -template -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if::IsInteger && is_floating_point::value>::type> { - static SrcScalar value() { - return static_cast(random_without_cast_overflow::value()); - } -}; - -// Floating-point narrowing conversion. -template -struct random_without_cast_overflow< - SrcScalar, TgtScalar, - typename internal::enable_if::value && is_floating_point::value && - (std::numeric_limits::digits > - std::numeric_limits::digits)>::type> { - static SrcScalar value() { return static_cast(internal::random()); } -}; - template struct test_cast_helper { static void run() { @@ -266,10 +143,12 @@ struct test_cast_helper::value(); + data1[i] = internal::random_without_cast_overflow::value(); } - for (int i = 0; i < DataSize; ++i) ref[i] = static_cast(data1[i]); + for (int i = 0; i < DataSize; ++i) { + ref[i] = static_cast(data1[i]); + } pcast_array::cast(data1, DataSize, data2); @@ -318,21 +197,37 @@ struct test_cast_runner { static void run() {} }; +template +struct packetmath_pcast_ops_runner { + static void run() { + test_cast_runner::run(); + test_cast_runner::run(); + test_cast_runner::run(); + test_cast_runner::run(); + test_cast_runner::run(); + test_cast_runner::run(); + test_cast_runner::run(); + test_cast_runner::run(); + test_cast_runner::run(); + test_cast_runner::run(); + test_cast_runner::run(); + test_cast_runner>::run(); + test_cast_runner>::run(); + test_cast_runner::run(); + test_cast_runner::run(); + } +}; + +// Only some types support cast from std::complex<>. template -void packetmath_pcast_ops() { - test_cast_runner::run(); - test_cast_runner::run(); - test_cast_runner::run(); - test_cast_runner::run(); - test_cast_runner::run(); - test_cast_runner::run(); - test_cast_runner::run(); - test_cast_runner::run(); - test_cast_runner::run(); - test_cast_runner::run(); - test_cast_runner::run(); - test_cast_runner::run(); -} +struct packetmath_pcast_ops_runner::IsComplex>::type> { + static void run() { + test_cast_runner>::run(); + test_cast_runner>::run(); + test_cast_runner::run(); + test_cast_runner::run(); + } +}; template void packetmath_boolean_mask_ops() { @@ -356,10 +251,8 @@ void packetmath_boolean_mask_ops() { // Packet16b representing bool does not support ptrue, pandnot or pcmp_eq, since the scalar path // (for some compilers) compute the bitwise and with 0x1 of the results to keep the value in [0,1]. -#ifdef EIGEN_PACKET_MATH_SSE_H -template <> -void packetmath_boolean_mask_ops() {} -#endif +template<> +void packetmath_boolean_mask_ops::type>() {} template void packetmath() { @@ -560,7 +453,7 @@ void packetmath() { CHECK_CWISE2_IF(true, internal::pand, internal::pand); packetmath_boolean_mask_ops(); - packetmath_pcast_ops(); + packetmath_pcast_ops_runner::run(); } template @@ -975,9 +868,7 @@ EIGEN_DECLARE_TEST(packetmath) { CALL_SUBTEST_11(test::runner >::run()); CALL_SUBTEST_12(test::runner >::run()); CALL_SUBTEST_13((packetmath::type>())); -#ifdef EIGEN_PACKET_MATH_SSE_H CALL_SUBTEST_14((packetmath::type>())); -#endif CALL_SUBTEST_15((packetmath::type>())); g_first_pass = false; } diff --git a/test/random_without_cast_overflow.h b/test/random_without_cast_overflow.h new file mode 100644 index 000000000..000345110 --- /dev/null +++ b/test/random_without_cast_overflow.h @@ -0,0 +1,152 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2020 C. Antonio Sanchez +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Utilities for generating random numbers without overflows, which might +// otherwise result in undefined behavior. + +namespace Eigen { +namespace internal { + +// Default implementation assuming SrcScalar fits into TgtScalar. +template +struct random_without_cast_overflow { + static SrcScalar value() { return internal::random(); } +}; + +// Signed to unsigned integer widening cast. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsInteger && NumTraits::IsInteger && + !NumTraits::IsSigned && + (std::numeric_limits::digits < std::numeric_limits::digits || + (std::numeric_limits::digits == std::numeric_limits::digits && + NumTraits::IsSigned))>::type> { + static SrcScalar value() { + SrcScalar a = internal::random(); + return a < SrcScalar(0) ? -(a + 1) : a; + } +}; + +// Integer to unsigned narrowing cast. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if< + NumTraits::IsInteger && NumTraits::IsInteger && !NumTraits::IsSigned && + (std::numeric_limits::digits > std::numeric_limits::digits)>::type> { + static SrcScalar value() { + TgtScalar b = internal::random(); + return static_cast(b < TgtScalar(0) ? -(b + 1) : b); + } +}; + +// Integer to signed narrowing cast. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if< + NumTraits::IsInteger && NumTraits::IsInteger && NumTraits::IsSigned && + (std::numeric_limits::digits > std::numeric_limits::digits)>::type> { + static SrcScalar value() { return static_cast(internal::random()); } +}; + +// Unsigned to signed integer narrowing cast. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsInteger && NumTraits::IsInteger && + !NumTraits::IsSigned && NumTraits::IsSigned && + (std::numeric_limits::digits == + std::numeric_limits::digits)>::type> { + static SrcScalar value() { return internal::random() / 2; } +}; + +// Floating-point to integer, full precision. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if< + !NumTraits::IsInteger && !NumTraits::IsComplex && NumTraits::IsInteger && + (std::numeric_limits::digits <= std::numeric_limits::digits)>::type> { + static SrcScalar value() { return static_cast(internal::random()); } +}; + +// Floating-point to integer, narrowing precision. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if< + !NumTraits::IsInteger && !NumTraits::IsComplex && NumTraits::IsInteger && + (std::numeric_limits::digits > std::numeric_limits::digits)>::type> { + static SrcScalar value() { + // NOTE: internal::random() is limited by RAND_MAX, so random is always within that range. + // This prevents us from simply shifting bits, which would result in only 0 or -1. + // Instead, keep least-significant K bits and sign. + static const TgtScalar KeepMask = (static_cast(1) << std::numeric_limits::digits) - 1; + const TgtScalar a = internal::random(); + return static_cast(a > TgtScalar(0) ? (a & KeepMask) : -(a & KeepMask)); + } +}; + +// Integer to floating-point, re-use above logic. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsInteger && !NumTraits::IsInteger && + !NumTraits::IsComplex>::type> { + static SrcScalar value() { + return static_cast(random_without_cast_overflow::value()); + } +}; + +// Floating-point narrowing conversion. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsInteger && !NumTraits::IsComplex && + !NumTraits::IsInteger && !NumTraits::IsComplex && + (std::numeric_limits::digits > + std::numeric_limits::digits)>::type> { + static SrcScalar value() { return static_cast(internal::random()); } +}; + +// Complex to non-complex. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsComplex && !NumTraits::IsComplex>::type> { + typedef typename NumTraits::Real SrcReal; + static SrcScalar value() { return SrcScalar(random_without_cast_overflow::value(), 0); } +}; + +// Non-complex to complex. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsComplex && NumTraits::IsComplex>::type> { + typedef typename NumTraits::Real TgtReal; + static SrcScalar value() { return random_without_cast_overflow::value(); } +}; + +// Complex to complex. +template +struct random_without_cast_overflow< + SrcScalar, TgtScalar, + typename internal::enable_if::IsComplex && NumTraits::IsComplex>::type> { + typedef typename NumTraits::Real SrcReal; + typedef typename NumTraits::Real TgtReal; + static SrcScalar value() { + return SrcScalar(random_without_cast_overflow::value(), + random_without_cast_overflow::value()); + } +}; + +} // namespace internal +} // namespace Eigen diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h index cdbafbbb1..44493906d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h @@ -51,7 +51,10 @@ struct nested, 1, typename eval -struct PacketConverter { +struct PacketConverter; + +template +struct PacketConverter { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketConverter(const TensorEvaluator& impl) : m_impl(impl) {} @@ -109,7 +112,33 @@ struct PacketConverter { }; template -struct PacketConverter { +struct PacketConverter { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + PacketConverter(const TensorEvaluator& impl) + : m_impl(impl) {} + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket packet(Index index) const { + const int SrcPacketSize = internal::unpacket_traits::size; + + SrcPacket src1 = m_impl.template packet(index); + SrcPacket src2 = m_impl.template packet(index + 1 * SrcPacketSize); + SrcPacket src3 = m_impl.template packet(index + 2 * SrcPacketSize); + SrcPacket src4 = m_impl.template packet(index + 3 * SrcPacketSize); + SrcPacket src5 = m_impl.template packet(index + 4 * SrcPacketSize); + SrcPacket src6 = m_impl.template packet(index + 5 * SrcPacketSize); + SrcPacket src7 = m_impl.template packet(index + 6 * SrcPacketSize); + SrcPacket src8 = m_impl.template packet(index + 7 * SrcPacketSize); + TgtPacket result = internal::pcast(src1, src2, src3, src4, src5, src6, src7, src8); + return result; + } + + private: + const TensorEvaluator& m_impl; +}; + +template +struct PacketConverter { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketConverter(const TensorEvaluator& impl) : m_impl(impl), m_maxIndex(impl.dimensions().TotalSize()) {} diff --git a/unsupported/test/cxx11_tensor_casts.cpp b/unsupported/test/cxx11_tensor_casts.cpp index c4fe9a798..45456f3ef 100644 --- a/unsupported/test/cxx11_tensor_casts.cpp +++ b/unsupported/test/cxx11_tensor_casts.cpp @@ -8,6 +8,7 @@ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #include "main.h" +#include "random_without_cast_overflow.h" #include @@ -104,12 +105,82 @@ static void test_small_to_big_type_cast() } } +template +static void test_type_cast() { + Tensor ftensor(100, 200); + // Generate random values for a valid cast. + for (int i = 0; i < 100; ++i) { + for (int j = 0; j < 200; ++j) { + ftensor(i, j) = internal::random_without_cast_overflow::value(); + } + } + + Tensor ttensor(100, 200); + ttensor = ftensor.template cast(); + + for (int i = 0; i < 100; ++i) { + for (int j = 0; j < 200; ++j) { + const ToType ref = internal::cast(ftensor(i, j)); + VERIFY_IS_APPROX(ttensor(i, j), ref); + } + } +} + +template +struct test_cast_runner { + static void run() { + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast(); + test_type_cast>(); + test_type_cast>(); + } +}; + +// Only certain types allow cast from std::complex<>. +template +struct test_cast_runner::IsComplex>::type> { + static void run() { + test_type_cast(); + test_type_cast(); + test_type_cast>(); + test_type_cast>(); + } +}; + EIGEN_DECLARE_TEST(cxx11_tensor_casts) { - CALL_SUBTEST(test_simple_cast()); - CALL_SUBTEST(test_vectorized_cast()); - CALL_SUBTEST(test_float_to_int_cast()); - CALL_SUBTEST(test_big_to_small_type_cast()); - CALL_SUBTEST(test_small_to_big_type_cast()); + CALL_SUBTEST(test_simple_cast()); + CALL_SUBTEST(test_vectorized_cast()); + CALL_SUBTEST(test_float_to_int_cast()); + CALL_SUBTEST(test_big_to_small_type_cast()); + CALL_SUBTEST(test_small_to_big_type_cast()); + + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner::run()); + CALL_SUBTEST(test_cast_runner>::run()); + CALL_SUBTEST(test_cast_runner>::run()); + }