mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-19 08:09:36 +08:00
Fix tensor casts for large packets and casts to/from std::complex
The original tensor casts were only defined for `SrcCoeffRatio`:`TgtCoeffRatio` 1:1, 1:2, 2:1, 4:1. Here we add the missing 1:N and 8:1. We also add casting `Eigen::half` to/from `std::complex<T>`, which was missing to make it consistent with `Eigen:bfloat16`, and generalize the overload to work for any complex type. Tests were added to `basicstuff`, `packetmath`, and `cxx11_tensor_casts` to test all cast configurations.
This commit is contained in:
parent
145e51516f
commit
9cb8771e9c
@ -65,13 +65,8 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base {
|
||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(f)) {}
|
||||
// Following the convention of numpy, converting between complex and
|
||||
// float will lead to loss of imag value.
|
||||
// Single precision complex.
|
||||
typedef std::complex<float> complex64;
|
||||
// Double precision complex.
|
||||
typedef std::complex<double> complex128;
|
||||
explicit EIGEN_DEVICE_FUNC bfloat16(const complex64& val)
|
||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(val.real())) {}
|
||||
explicit EIGEN_DEVICE_FUNC bfloat16(const complex128& val)
|
||||
template<typename RealScalar>
|
||||
explicit EIGEN_DEVICE_FUNC bfloat16(const std::complex<RealScalar>& val)
|
||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val.real()))) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const {
|
||||
@ -114,11 +109,9 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base {
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const {
|
||||
return static_cast<double>(bfloat16_impl::bfloat16_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex64) const {
|
||||
return complex64(bfloat16_impl::bfloat16_to_float(*this), float(0.0));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex128) const {
|
||||
return complex128(static_cast<double>(bfloat16_impl::bfloat16_to_float(*this)), double(0.0));
|
||||
template<typename RealScalar>
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(std::complex<RealScalar>) const {
|
||||
return std::complex<RealScalar>(static_cast<RealScalar>(bfloat16_impl::bfloat16_to_float(*this)), RealScalar(0));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(Eigen::half) const {
|
||||
return static_cast<Eigen::half>(bfloat16_impl::bfloat16_to_float(*this));
|
||||
|
@ -86,7 +86,7 @@ struct half_base : public __half_raw {
|
||||
#if (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000)
|
||||
EIGEN_DEVICE_FUNC half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {}
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -133,6 +133,11 @@ struct half : public half_impl::half_base {
|
||||
: half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(val))) {}
|
||||
explicit EIGEN_DEVICE_FUNC half(float f)
|
||||
: half_impl::half_base(half_impl::float_to_half_rtne(f)) {}
|
||||
// Following the convention of numpy, converting between complex and
|
||||
// float will lead to loss of imag value.
|
||||
template<typename RealScalar>
|
||||
explicit EIGEN_DEVICE_FUNC half(std::complex<RealScalar> c)
|
||||
: half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(c.real()))) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const {
|
||||
// +0.0 and -0.0 become false, everything else becomes true.
|
||||
@ -174,6 +179,11 @@ struct half : public half_impl::half_base {
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const {
|
||||
return static_cast<double>(half_impl::half_to_float(*this));
|
||||
}
|
||||
|
||||
template<typename RealScalar>
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(std::complex<RealScalar>) const {
|
||||
return std::complex<RealScalar>(static_cast<RealScalar>(*this), RealScalar(0));
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -10,6 +10,7 @@
|
||||
#define EIGEN_NO_STATIC_ASSERT
|
||||
|
||||
#include "main.h"
|
||||
#include "random_without_cast_overflow.h"
|
||||
|
||||
template<typename MatrixType> void basicStuff(const MatrixType& m)
|
||||
{
|
||||
@ -90,7 +91,7 @@ template<typename MatrixType> void basicStuff(const MatrixType& m)
|
||||
Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> cv(rows);
|
||||
rv = square.row(r);
|
||||
cv = square.col(r);
|
||||
|
||||
|
||||
VERIFY_IS_APPROX(rv, cv.transpose());
|
||||
|
||||
if(cols!=1 && rows!=1 && MatrixType::SizeAtCompileTime!=Dynamic)
|
||||
@ -120,28 +121,28 @@ template<typename MatrixType> void basicStuff(const MatrixType& m)
|
||||
m1 = m2;
|
||||
VERIFY(m1==m2);
|
||||
VERIFY(!(m1!=m2));
|
||||
|
||||
|
||||
// check automatic transposition
|
||||
sm2.setZero();
|
||||
for(Index i=0;i<rows;++i)
|
||||
sm2.col(i) = sm1.row(i);
|
||||
VERIFY_IS_APPROX(sm2,sm1.transpose());
|
||||
|
||||
|
||||
sm2.setZero();
|
||||
for(Index i=0;i<rows;++i)
|
||||
sm2.col(i).noalias() = sm1.row(i);
|
||||
VERIFY_IS_APPROX(sm2,sm1.transpose());
|
||||
|
||||
|
||||
sm2.setZero();
|
||||
for(Index i=0;i<rows;++i)
|
||||
sm2.col(i).noalias() += sm1.row(i);
|
||||
VERIFY_IS_APPROX(sm2,sm1.transpose());
|
||||
|
||||
|
||||
sm2.setZero();
|
||||
for(Index i=0;i<rows;++i)
|
||||
sm2.col(i).noalias() -= sm1.row(i);
|
||||
VERIFY_IS_APPROX(sm2,-sm1.transpose());
|
||||
|
||||
|
||||
// check ternary usage
|
||||
{
|
||||
bool b = internal::random<int>(0,10)>5;
|
||||
@ -194,14 +195,72 @@ template<typename MatrixType> void basicStuffComplex(const MatrixType& m)
|
||||
VERIFY(!static_cast<const MatrixType&>(cm).imag().isZero());
|
||||
}
|
||||
|
||||
template<int>
|
||||
void casting()
|
||||
template<typename SrcScalar, typename TgtScalar>
|
||||
void casting_test()
|
||||
{
|
||||
Matrix4f m = Matrix4f::Random(), m2;
|
||||
Matrix4d n = m.cast<double>();
|
||||
VERIFY(m.isApprox(n.cast<float>()));
|
||||
m2 = m.cast<float>(); // check the specialization when NewType == Type
|
||||
VERIFY(m.isApprox(m2));
|
||||
Matrix<SrcScalar,4,4> m;
|
||||
for (int i=0; i<m.rows(); ++i) {
|
||||
for (int j=0; j<m.cols(); ++j) {
|
||||
m(i, j) = internal::random_without_cast_overflow<SrcScalar,TgtScalar>::value();
|
||||
}
|
||||
}
|
||||
Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>();
|
||||
for (int i=0; i<m.rows(); ++i) {
|
||||
for (int j=0; j<m.cols(); ++j) {
|
||||
VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(m(i, j)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename SrcScalar, typename EnableIf = void>
|
||||
struct casting_test_runner {
|
||||
static void run() {
|
||||
casting_test<SrcScalar, bool>();
|
||||
casting_test<SrcScalar, int8_t>();
|
||||
casting_test<SrcScalar, uint8_t>();
|
||||
casting_test<SrcScalar, int16_t>();
|
||||
casting_test<SrcScalar, uint16_t>();
|
||||
casting_test<SrcScalar, int32_t>();
|
||||
casting_test<SrcScalar, uint32_t>();
|
||||
casting_test<SrcScalar, int64_t>();
|
||||
casting_test<SrcScalar, uint64_t>();
|
||||
casting_test<SrcScalar, half>();
|
||||
casting_test<SrcScalar, bfloat16>();
|
||||
casting_test<SrcScalar, float>();
|
||||
casting_test<SrcScalar, double>();
|
||||
casting_test<SrcScalar, std::complex<float>>();
|
||||
casting_test<SrcScalar, std::complex<double>>();
|
||||
}
|
||||
};
|
||||
|
||||
template<typename SrcScalar>
|
||||
struct casting_test_runner<SrcScalar, typename internal::enable_if<(NumTraits<SrcScalar>::IsComplex)>::type>
|
||||
{
|
||||
static void run() {
|
||||
// Only a few casts from std::complex<T> are defined.
|
||||
casting_test<SrcScalar, half>();
|
||||
casting_test<SrcScalar, bfloat16>();
|
||||
casting_test<SrcScalar, std::complex<float>>();
|
||||
casting_test<SrcScalar, std::complex<double>>();
|
||||
}
|
||||
};
|
||||
|
||||
void casting_all() {
|
||||
casting_test_runner<bool>::run();
|
||||
casting_test_runner<int8_t>::run();
|
||||
casting_test_runner<uint8_t>::run();
|
||||
casting_test_runner<int16_t>::run();
|
||||
casting_test_runner<uint16_t>::run();
|
||||
casting_test_runner<int32_t>::run();
|
||||
casting_test_runner<uint32_t>::run();
|
||||
casting_test_runner<int64_t>::run();
|
||||
casting_test_runner<uint64_t>::run();
|
||||
casting_test_runner<half>::run();
|
||||
casting_test_runner<bfloat16>::run();
|
||||
casting_test_runner<float>::run();
|
||||
casting_test_runner<double>::run();
|
||||
casting_test_runner<std::complex<float>>::run();
|
||||
casting_test_runner<std::complex<double>>::run();
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
@ -210,12 +269,12 @@ void fixedSizeMatrixConstruction()
|
||||
Scalar raw[4];
|
||||
for(int k=0; k<4; ++k)
|
||||
raw[k] = internal::random<Scalar>();
|
||||
|
||||
|
||||
{
|
||||
Matrix<Scalar,4,1> m(raw);
|
||||
Array<Scalar,4,1> a(raw);
|
||||
for(int k=0; k<4; ++k) VERIFY(m(k) == raw[k]);
|
||||
for(int k=0; k<4; ++k) VERIFY(a(k) == raw[k]);
|
||||
for(int k=0; k<4; ++k) VERIFY(a(k) == raw[k]);
|
||||
VERIFY_IS_EQUAL(m,(Matrix<Scalar,4,1>(raw[0],raw[1],raw[2],raw[3])));
|
||||
VERIFY((a==(Array<Scalar,4,1>(raw[0],raw[1],raw[2],raw[3]))).all());
|
||||
}
|
||||
@ -277,6 +336,7 @@ EIGEN_DECLARE_TEST(basicstuff)
|
||||
CALL_SUBTEST_5( basicStuff(MatrixXcd(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_6( basicStuff(Matrix<float, 100, 100>()) );
|
||||
CALL_SUBTEST_7( basicStuff(Matrix<long double,Dynamic,Dynamic>(internal::random<int>(1,EIGEN_TEST_MAX_SIZE),internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_8( casting_all() );
|
||||
|
||||
CALL_SUBTEST_3( basicStuffComplex(MatrixXcf(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_5( basicStuffComplex(MatrixXcd(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
@ -288,6 +348,4 @@ EIGEN_DECLARE_TEST(basicstuff)
|
||||
CALL_SUBTEST_1(fixedSizeMatrixConstruction<int>());
|
||||
CALL_SUBTEST_1(fixedSizeMatrixConstruction<long int>());
|
||||
CALL_SUBTEST_1(fixedSizeMatrixConstruction<std::ptrdiff_t>());
|
||||
|
||||
CALL_SUBTEST_2(casting<0>());
|
||||
}
|
||||
|
@ -8,8 +8,8 @@
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include <limits>
|
||||
#include "packetmath_test_shared.h"
|
||||
#include "random_without_cast_overflow.h"
|
||||
|
||||
template <typename T>
|
||||
inline T REF_ADD(const T& a, const T& b) {
|
||||
@ -126,129 +126,6 @@ struct test_cast_helper<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio, fals
|
||||
static void run() {}
|
||||
};
|
||||
|
||||
// Generates random values that fit in both SrcScalar and TgtScalar without
|
||||
// overflowing when cast.
|
||||
template <typename SrcScalar, typename TgtScalar, typename EnableIf = void>
|
||||
struct random_without_cast_overflow {
|
||||
static SrcScalar value() { return internal::random<SrcScalar>(); }
|
||||
};
|
||||
|
||||
// Widening integer cast signed to unsigned.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger &&
|
||||
!NumTraits<TgtScalar>::IsSigned &&
|
||||
(std::numeric_limits<SrcScalar>::digits < std::numeric_limits<TgtScalar>::digits ||
|
||||
(std::numeric_limits<SrcScalar>::digits == std::numeric_limits<TgtScalar>::digits &&
|
||||
NumTraits<SrcScalar>::IsSigned))>::type> {
|
||||
static SrcScalar value() {
|
||||
SrcScalar a = internal::random<SrcScalar>();
|
||||
return a < SrcScalar(0) ? -(a + 1) : a;
|
||||
}
|
||||
};
|
||||
|
||||
// Narrowing integer cast to unsigned.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<
|
||||
NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && !NumTraits<SrcScalar>::IsSigned &&
|
||||
(std::numeric_limits<SrcScalar>::digits > std::numeric_limits<TgtScalar>::digits)>::type> {
|
||||
static SrcScalar value() {
|
||||
TgtScalar b = internal::random<TgtScalar>();
|
||||
return static_cast<SrcScalar>(b < TgtScalar(0) ? -(b + 1) : b);
|
||||
}
|
||||
};
|
||||
|
||||
// Narrowing integer cast to signed.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<
|
||||
NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && NumTraits<SrcScalar>::IsSigned &&
|
||||
(std::numeric_limits<SrcScalar>::digits > std::numeric_limits<TgtScalar>::digits)>::type> {
|
||||
static SrcScalar value() {
|
||||
TgtScalar b = internal::random<TgtScalar>();
|
||||
return static_cast<SrcScalar>(b);
|
||||
}
|
||||
};
|
||||
|
||||
// Unsigned to signed narrowing cast.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger &&
|
||||
!NumTraits<SrcScalar>::IsSigned && NumTraits<TgtScalar>::IsSigned &&
|
||||
(std::numeric_limits<SrcScalar>::digits ==
|
||||
std::numeric_limits<TgtScalar>::digits)>::type> {
|
||||
static SrcScalar value() { return internal::random<SrcScalar>() / 2; }
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct is_floating_point {
|
||||
enum { value = 0 };
|
||||
};
|
||||
template <>
|
||||
struct is_floating_point<float> {
|
||||
enum { value = 1 };
|
||||
};
|
||||
template <>
|
||||
struct is_floating_point<double> {
|
||||
enum { value = 1 };
|
||||
};
|
||||
template <>
|
||||
struct is_floating_point<half> {
|
||||
enum { value = 1 };
|
||||
};
|
||||
template <>
|
||||
struct is_floating_point<bfloat16> {
|
||||
enum { value = 1 };
|
||||
};
|
||||
|
||||
// Floating-point to integer, full precision.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<is_floating_point<SrcScalar>::value && NumTraits<TgtScalar>::IsInteger &&
|
||||
(std::numeric_limits<TgtScalar>::digits <=
|
||||
std::numeric_limits<SrcScalar>::digits)>::type> {
|
||||
static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); }
|
||||
};
|
||||
|
||||
// Floating-point to integer, narrowing precision.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<is_floating_point<SrcScalar>::value && NumTraits<TgtScalar>::IsInteger &&
|
||||
(std::numeric_limits<TgtScalar>::digits >
|
||||
std::numeric_limits<SrcScalar>::digits)>::type> {
|
||||
static SrcScalar value() {
|
||||
static const int BitShift = std::numeric_limits<TgtScalar>::digits - std::numeric_limits<SrcScalar>::digits;
|
||||
return static_cast<SrcScalar>(internal::random<TgtScalar>() >> BitShift);
|
||||
}
|
||||
};
|
||||
|
||||
// Floating-point target from integer, re-use above logic.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && is_floating_point<TgtScalar>::value>::type> {
|
||||
static SrcScalar value() {
|
||||
return static_cast<SrcScalar>(random_without_cast_overflow<TgtScalar, SrcScalar>::value());
|
||||
}
|
||||
};
|
||||
|
||||
// Floating-point narrowing conversion.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<is_floating_point<SrcScalar>::value && is_floating_point<TgtScalar>::value &&
|
||||
(std::numeric_limits<SrcScalar>::digits >
|
||||
std::numeric_limits<TgtScalar>::digits)>::type> {
|
||||
static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); }
|
||||
};
|
||||
|
||||
template <typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio>
|
||||
struct test_cast_helper<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio, true> {
|
||||
static void run() {
|
||||
@ -266,10 +143,12 @@ struct test_cast_helper<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio, true
|
||||
|
||||
// Construct a packet of scalars that will not overflow when casting
|
||||
for (int i = 0; i < DataSize; ++i) {
|
||||
data1[i] = random_without_cast_overflow<SrcScalar, TgtScalar>::value();
|
||||
data1[i] = internal::random_without_cast_overflow<SrcScalar, TgtScalar>::value();
|
||||
}
|
||||
|
||||
for (int i = 0; i < DataSize; ++i) ref[i] = static_cast<const TgtScalar>(data1[i]);
|
||||
for (int i = 0; i < DataSize; ++i) {
|
||||
ref[i] = static_cast<const TgtScalar>(data1[i]);
|
||||
}
|
||||
|
||||
pcast_array<SrcPacket, TgtPacket, SrcCoeffRatio, TgtCoeffRatio>::cast(data1, DataSize, data2);
|
||||
|
||||
@ -318,21 +197,37 @@ struct test_cast_runner<SrcPacket, TgtScalar, TgtPacket, false, false> {
|
||||
static void run() {}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Packet, typename EnableIf = void>
|
||||
struct packetmath_pcast_ops_runner {
|
||||
static void run() {
|
||||
test_cast_runner<Packet, float>::run();
|
||||
test_cast_runner<Packet, double>::run();
|
||||
test_cast_runner<Packet, int8_t>::run();
|
||||
test_cast_runner<Packet, uint8_t>::run();
|
||||
test_cast_runner<Packet, int16_t>::run();
|
||||
test_cast_runner<Packet, uint16_t>::run();
|
||||
test_cast_runner<Packet, int32_t>::run();
|
||||
test_cast_runner<Packet, uint32_t>::run();
|
||||
test_cast_runner<Packet, int64_t>::run();
|
||||
test_cast_runner<Packet, uint64_t>::run();
|
||||
test_cast_runner<Packet, bool>::run();
|
||||
test_cast_runner<Packet, std::complex<float>>::run();
|
||||
test_cast_runner<Packet, std::complex<double>>::run();
|
||||
test_cast_runner<Packet, half>::run();
|
||||
test_cast_runner<Packet, bfloat16>::run();
|
||||
}
|
||||
};
|
||||
|
||||
// Only some types support cast from std::complex<>.
|
||||
template <typename Scalar, typename Packet>
|
||||
void packetmath_pcast_ops() {
|
||||
test_cast_runner<Packet, float>::run();
|
||||
test_cast_runner<Packet, double>::run();
|
||||
test_cast_runner<Packet, int8_t>::run();
|
||||
test_cast_runner<Packet, uint8_t>::run();
|
||||
test_cast_runner<Packet, int16_t>::run();
|
||||
test_cast_runner<Packet, uint16_t>::run();
|
||||
test_cast_runner<Packet, int32_t>::run();
|
||||
test_cast_runner<Packet, uint32_t>::run();
|
||||
test_cast_runner<Packet, int64_t>::run();
|
||||
test_cast_runner<Packet, uint64_t>::run();
|
||||
test_cast_runner<Packet, bool>::run();
|
||||
test_cast_runner<Packet, half>::run();
|
||||
}
|
||||
struct packetmath_pcast_ops_runner<Scalar, Packet, typename internal::enable_if<NumTraits<Scalar>::IsComplex>::type> {
|
||||
static void run() {
|
||||
test_cast_runner<Packet, std::complex<float>>::run();
|
||||
test_cast_runner<Packet, std::complex<double>>::run();
|
||||
test_cast_runner<Packet, half>::run();
|
||||
test_cast_runner<Packet, bfloat16>::run();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Packet>
|
||||
void packetmath_boolean_mask_ops() {
|
||||
@ -356,10 +251,8 @@ void packetmath_boolean_mask_ops() {
|
||||
|
||||
// Packet16b representing bool does not support ptrue, pandnot or pcmp_eq, since the scalar path
|
||||
// (for some compilers) compute the bitwise and with 0x1 of the results to keep the value in [0,1].
|
||||
#ifdef EIGEN_PACKET_MATH_SSE_H
|
||||
template <>
|
||||
void packetmath_boolean_mask_ops<bool, internal::Packet16b>() {}
|
||||
#endif
|
||||
template<>
|
||||
void packetmath_boolean_mask_ops<bool, typename internal::packet_traits<bool>::type>() {}
|
||||
|
||||
template <typename Scalar, typename Packet>
|
||||
void packetmath() {
|
||||
@ -560,7 +453,7 @@ void packetmath() {
|
||||
CHECK_CWISE2_IF(true, internal::pand, internal::pand);
|
||||
|
||||
packetmath_boolean_mask_ops<Scalar, Packet>();
|
||||
packetmath_pcast_ops<Scalar, Packet>();
|
||||
packetmath_pcast_ops_runner<Scalar, Packet>::run();
|
||||
}
|
||||
|
||||
template <typename Scalar, typename Packet>
|
||||
@ -975,9 +868,7 @@ EIGEN_DECLARE_TEST(packetmath) {
|
||||
CALL_SUBTEST_11(test::runner<std::complex<float> >::run());
|
||||
CALL_SUBTEST_12(test::runner<std::complex<double> >::run());
|
||||
CALL_SUBTEST_13((packetmath<half, internal::packet_traits<half>::type>()));
|
||||
#ifdef EIGEN_PACKET_MATH_SSE_H
|
||||
CALL_SUBTEST_14((packetmath<bool, internal::packet_traits<bool>::type>()));
|
||||
#endif
|
||||
CALL_SUBTEST_15((packetmath<bfloat16, internal::packet_traits<bfloat16>::type>()));
|
||||
g_first_pass = false;
|
||||
}
|
||||
|
152
test/random_without_cast_overflow.h
Normal file
152
test/random_without_cast_overflow.h
Normal file
@ -0,0 +1,152 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2020 C. Antonio Sanchez <cantonios@google.com>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// Utilities for generating random numbers without overflows, which might
|
||||
// otherwise result in undefined behavior.
|
||||
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
// Default implementation assuming SrcScalar fits into TgtScalar.
|
||||
template <typename SrcScalar, typename TgtScalar, typename EnableIf = void>
|
||||
struct random_without_cast_overflow {
|
||||
static SrcScalar value() { return internal::random<SrcScalar>(); }
|
||||
};
|
||||
|
||||
// Signed to unsigned integer widening cast.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger &&
|
||||
!NumTraits<TgtScalar>::IsSigned &&
|
||||
(std::numeric_limits<SrcScalar>::digits < std::numeric_limits<TgtScalar>::digits ||
|
||||
(std::numeric_limits<SrcScalar>::digits == std::numeric_limits<TgtScalar>::digits &&
|
||||
NumTraits<SrcScalar>::IsSigned))>::type> {
|
||||
static SrcScalar value() {
|
||||
SrcScalar a = internal::random<SrcScalar>();
|
||||
return a < SrcScalar(0) ? -(a + 1) : a;
|
||||
}
|
||||
};
|
||||
|
||||
// Integer to unsigned narrowing cast.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<
|
||||
NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && !NumTraits<SrcScalar>::IsSigned &&
|
||||
(std::numeric_limits<SrcScalar>::digits > std::numeric_limits<TgtScalar>::digits)>::type> {
|
||||
static SrcScalar value() {
|
||||
TgtScalar b = internal::random<TgtScalar>();
|
||||
return static_cast<SrcScalar>(b < TgtScalar(0) ? -(b + 1) : b);
|
||||
}
|
||||
};
|
||||
|
||||
// Integer to signed narrowing cast.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<
|
||||
NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger && NumTraits<SrcScalar>::IsSigned &&
|
||||
(std::numeric_limits<SrcScalar>::digits > std::numeric_limits<TgtScalar>::digits)>::type> {
|
||||
static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); }
|
||||
};
|
||||
|
||||
// Unsigned to signed integer narrowing cast.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && NumTraits<TgtScalar>::IsInteger &&
|
||||
!NumTraits<SrcScalar>::IsSigned && NumTraits<TgtScalar>::IsSigned &&
|
||||
(std::numeric_limits<SrcScalar>::digits ==
|
||||
std::numeric_limits<TgtScalar>::digits)>::type> {
|
||||
static SrcScalar value() { return internal::random<SrcScalar>() / 2; }
|
||||
};
|
||||
|
||||
// Floating-point to integer, full precision.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<
|
||||
!NumTraits<SrcScalar>::IsInteger && !NumTraits<SrcScalar>::IsComplex && NumTraits<TgtScalar>::IsInteger &&
|
||||
(std::numeric_limits<TgtScalar>::digits <= std::numeric_limits<SrcScalar>::digits)>::type> {
|
||||
static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); }
|
||||
};
|
||||
|
||||
// Floating-point to integer, narrowing precision.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<
|
||||
!NumTraits<SrcScalar>::IsInteger && !NumTraits<SrcScalar>::IsComplex && NumTraits<TgtScalar>::IsInteger &&
|
||||
(std::numeric_limits<TgtScalar>::digits > std::numeric_limits<SrcScalar>::digits)>::type> {
|
||||
static SrcScalar value() {
|
||||
// NOTE: internal::random<T>() is limited by RAND_MAX, so random<int64_t> is always within that range.
|
||||
// This prevents us from simply shifting bits, which would result in only 0 or -1.
|
||||
// Instead, keep least-significant K bits and sign.
|
||||
static const TgtScalar KeepMask = (static_cast<TgtScalar>(1) << std::numeric_limits<SrcScalar>::digits) - 1;
|
||||
const TgtScalar a = internal::random<TgtScalar>();
|
||||
return static_cast<SrcScalar>(a > TgtScalar(0) ? (a & KeepMask) : -(a & KeepMask));
|
||||
}
|
||||
};
|
||||
|
||||
// Integer to floating-point, re-use above logic.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<NumTraits<SrcScalar>::IsInteger && !NumTraits<TgtScalar>::IsInteger &&
|
||||
!NumTraits<TgtScalar>::IsComplex>::type> {
|
||||
static SrcScalar value() {
|
||||
return static_cast<SrcScalar>(random_without_cast_overflow<TgtScalar, SrcScalar>::value());
|
||||
}
|
||||
};
|
||||
|
||||
// Floating-point narrowing conversion.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<!NumTraits<SrcScalar>::IsInteger && !NumTraits<SrcScalar>::IsComplex &&
|
||||
!NumTraits<TgtScalar>::IsInteger && !NumTraits<TgtScalar>::IsComplex &&
|
||||
(std::numeric_limits<SrcScalar>::digits >
|
||||
std::numeric_limits<TgtScalar>::digits)>::type> {
|
||||
static SrcScalar value() { return static_cast<SrcScalar>(internal::random<TgtScalar>()); }
|
||||
};
|
||||
|
||||
// Complex to non-complex.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<NumTraits<SrcScalar>::IsComplex && !NumTraits<TgtScalar>::IsComplex>::type> {
|
||||
typedef typename NumTraits<SrcScalar>::Real SrcReal;
|
||||
static SrcScalar value() { return SrcScalar(random_without_cast_overflow<SrcReal, TgtScalar>::value(), 0); }
|
||||
};
|
||||
|
||||
// Non-complex to complex.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<!NumTraits<SrcScalar>::IsComplex && NumTraits<TgtScalar>::IsComplex>::type> {
|
||||
typedef typename NumTraits<TgtScalar>::Real TgtReal;
|
||||
static SrcScalar value() { return random_without_cast_overflow<SrcScalar, TgtReal>::value(); }
|
||||
};
|
||||
|
||||
// Complex to complex.
|
||||
template <typename SrcScalar, typename TgtScalar>
|
||||
struct random_without_cast_overflow<
|
||||
SrcScalar, TgtScalar,
|
||||
typename internal::enable_if<NumTraits<SrcScalar>::IsComplex && NumTraits<TgtScalar>::IsComplex>::type> {
|
||||
typedef typename NumTraits<SrcScalar>::Real SrcReal;
|
||||
typedef typename NumTraits<TgtScalar>::Real TgtReal;
|
||||
static SrcScalar value() {
|
||||
return SrcScalar(random_without_cast_overflow<SrcReal, TgtReal>::value(),
|
||||
random_without_cast_overflow<SrcReal, TgtReal>::value());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
@ -51,7 +51,10 @@ struct nested<TensorConversionOp<TargetType, XprType>, 1, typename eval<TensorCo
|
||||
|
||||
|
||||
template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio>
|
||||
struct PacketConverter {
|
||||
struct PacketConverter;
|
||||
|
||||
template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket>
|
||||
struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, 1> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
PacketConverter(const TensorEvaluator& impl)
|
||||
: m_impl(impl) {}
|
||||
@ -109,7 +112,33 @@ struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 4, 1> {
|
||||
};
|
||||
|
||||
template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket>
|
||||
struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, 2> {
|
||||
struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 8, 1> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
PacketConverter(const TensorEvaluator& impl)
|
||||
: m_impl(impl) {}
|
||||
|
||||
template<int LoadMode, typename Index>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket packet(Index index) const {
|
||||
const int SrcPacketSize = internal::unpacket_traits<SrcPacket>::size;
|
||||
|
||||
SrcPacket src1 = m_impl.template packet<LoadMode>(index);
|
||||
SrcPacket src2 = m_impl.template packet<LoadMode>(index + 1 * SrcPacketSize);
|
||||
SrcPacket src3 = m_impl.template packet<LoadMode>(index + 2 * SrcPacketSize);
|
||||
SrcPacket src4 = m_impl.template packet<LoadMode>(index + 3 * SrcPacketSize);
|
||||
SrcPacket src5 = m_impl.template packet<LoadMode>(index + 4 * SrcPacketSize);
|
||||
SrcPacket src6 = m_impl.template packet<LoadMode>(index + 5 * SrcPacketSize);
|
||||
SrcPacket src7 = m_impl.template packet<LoadMode>(index + 6 * SrcPacketSize);
|
||||
SrcPacket src8 = m_impl.template packet<LoadMode>(index + 7 * SrcPacketSize);
|
||||
TgtPacket result = internal::pcast<SrcPacket, TgtPacket>(src1, src2, src3, src4, src5, src6, src7, src8);
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
const TensorEvaluator& m_impl;
|
||||
};
|
||||
|
||||
template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket, int TgtCoeffRatio>
|
||||
struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, TgtCoeffRatio> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
PacketConverter(const TensorEvaluator& impl)
|
||||
: m_impl(impl), m_maxIndex(impl.dimensions().TotalSize()) {}
|
||||
|
@ -8,6 +8,7 @@
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include "main.h"
|
||||
#include "random_without_cast_overflow.h"
|
||||
|
||||
#include <Eigen/CXX11/Tensor>
|
||||
|
||||
@ -104,12 +105,82 @@ static void test_small_to_big_type_cast()
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FromType, typename ToType>
|
||||
static void test_type_cast() {
|
||||
Tensor<FromType, 2> ftensor(100, 200);
|
||||
// Generate random values for a valid cast.
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
for (int j = 0; j < 200; ++j) {
|
||||
ftensor(i, j) = internal::random_without_cast_overflow<FromType,ToType>::value();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor<ToType, 2> ttensor(100, 200);
|
||||
ttensor = ftensor.template cast<ToType>();
|
||||
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
for (int j = 0; j < 200; ++j) {
|
||||
const ToType ref = internal::cast<FromType,ToType>(ftensor(i, j));
|
||||
VERIFY_IS_APPROX(ttensor(i, j), ref);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Scalar, typename EnableIf = void>
|
||||
struct test_cast_runner {
|
||||
static void run() {
|
||||
test_type_cast<Scalar, bool>();
|
||||
test_type_cast<Scalar, int8_t>();
|
||||
test_type_cast<Scalar, int16_t>();
|
||||
test_type_cast<Scalar, int32_t>();
|
||||
test_type_cast<Scalar, int64_t>();
|
||||
test_type_cast<Scalar, uint8_t>();
|
||||
test_type_cast<Scalar, uint16_t>();
|
||||
test_type_cast<Scalar, uint32_t>();
|
||||
test_type_cast<Scalar, uint64_t>();
|
||||
test_type_cast<Scalar, half>();
|
||||
test_type_cast<Scalar, bfloat16>();
|
||||
test_type_cast<Scalar, float>();
|
||||
test_type_cast<Scalar, double>();
|
||||
test_type_cast<Scalar, std::complex<float>>();
|
||||
test_type_cast<Scalar, std::complex<double>>();
|
||||
}
|
||||
};
|
||||
|
||||
// Only certain types allow cast from std::complex<>.
|
||||
template<typename Scalar>
|
||||
struct test_cast_runner<Scalar, typename internal::enable_if<NumTraits<Scalar>::IsComplex>::type> {
|
||||
static void run() {
|
||||
test_type_cast<Scalar, half>();
|
||||
test_type_cast<Scalar, bfloat16>();
|
||||
test_type_cast<Scalar, std::complex<float>>();
|
||||
test_type_cast<Scalar, std::complex<double>>();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
EIGEN_DECLARE_TEST(cxx11_tensor_casts)
|
||||
{
|
||||
CALL_SUBTEST(test_simple_cast());
|
||||
CALL_SUBTEST(test_vectorized_cast());
|
||||
CALL_SUBTEST(test_float_to_int_cast());
|
||||
CALL_SUBTEST(test_big_to_small_type_cast());
|
||||
CALL_SUBTEST(test_small_to_big_type_cast());
|
||||
CALL_SUBTEST(test_simple_cast());
|
||||
CALL_SUBTEST(test_vectorized_cast());
|
||||
CALL_SUBTEST(test_float_to_int_cast());
|
||||
CALL_SUBTEST(test_big_to_small_type_cast());
|
||||
CALL_SUBTEST(test_small_to_big_type_cast());
|
||||
|
||||
CALL_SUBTEST(test_cast_runner<bool>::run());
|
||||
CALL_SUBTEST(test_cast_runner<int8_t>::run());
|
||||
CALL_SUBTEST(test_cast_runner<int16_t>::run());
|
||||
CALL_SUBTEST(test_cast_runner<int32_t>::run());
|
||||
CALL_SUBTEST(test_cast_runner<int64_t>::run());
|
||||
CALL_SUBTEST(test_cast_runner<uint8_t>::run());
|
||||
CALL_SUBTEST(test_cast_runner<uint16_t>::run());
|
||||
CALL_SUBTEST(test_cast_runner<uint32_t>::run());
|
||||
CALL_SUBTEST(test_cast_runner<uint64_t>::run());
|
||||
CALL_SUBTEST(test_cast_runner<half>::run());
|
||||
CALL_SUBTEST(test_cast_runner<bfloat16>::run());
|
||||
CALL_SUBTEST(test_cast_runner<float>::run());
|
||||
CALL_SUBTEST(test_cast_runner<double>::run());
|
||||
CALL_SUBTEST(test_cast_runner<std::complex<float>>::run());
|
||||
CALL_SUBTEST(test_cast_runner<std::complex<double>>::run());
|
||||
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user