mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-20 08:39:37 +08:00
Revert "Vectorize cast"
This reverts commit eb5ff1861a4783876564a1a79573c3b9ff566863
This commit is contained in:
parent
8999525c29
commit
2d0c6ad873
@ -183,6 +183,7 @@ using std::ptrdiff_t;
|
|||||||
// Generic half float support
|
// Generic half float support
|
||||||
#include "src/Core/arch/Default/Half.h"
|
#include "src/Core/arch/Default/Half.h"
|
||||||
#include "src/Core/arch/Default/BFloat16.h"
|
#include "src/Core/arch/Default/BFloat16.h"
|
||||||
|
#include "src/Core/arch/Default/TypeCasting.h"
|
||||||
#include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h"
|
#include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h"
|
||||||
|
|
||||||
#if defined EIGEN_VECTORIZE_AVX512
|
#if defined EIGEN_VECTORIZE_AVX512
|
||||||
|
@ -621,110 +621,6 @@ protected:
|
|||||||
Data m_d;
|
Data m_d;
|
||||||
};
|
};
|
||||||
|
|
||||||
// ----------------------- Casting ---------------------
|
|
||||||
template <typename SrcType, typename DstType, typename ArgType>
|
|
||||||
struct unary_evaluator<CwiseUnaryOp<scalar_cast_op<SrcType, DstType>, ArgType>, IndexBased> {
|
|
||||||
using CastOp = scalar_cast_op<SrcType, DstType>;
|
|
||||||
using XprType = CwiseUnaryOp<CastOp, ArgType>;
|
|
||||||
using SrcPacketType = typename packet_traits<SrcType>::type;
|
|
||||||
|
|
||||||
static constexpr int SrcPacketSize = packet_traits<SrcType>::size;
|
|
||||||
static constexpr int SrcPacketSizeBytes = SrcPacketSize * sizeof(SrcType);
|
|
||||||
|
|
||||||
enum {
|
|
||||||
CoeffReadCost = int(evaluator<ArgType>::CoeffReadCost) + int(functor_traits<CastOp>::Cost),
|
|
||||||
Flags = evaluator<ArgType>::Flags &
|
|
||||||
(HereditaryBits | LinearAccessBit | (functor_traits<CastOp>::PacketAccess ? PacketAccessBit : 0)),
|
|
||||||
Alignment = evaluator<ArgType>::Alignment
|
|
||||||
};
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit unary_evaluator(const XprType& xpr)
|
|
||||||
: m_argImpl(xpr.nestedExpression()) {
|
|
||||||
EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<CastOp>::Cost);
|
|
||||||
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
|
|
||||||
}
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DstType coeff(Index row, Index col) const {
|
|
||||||
return CastOp()(m_argImpl.coeff(row, col));
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DstType coeff(Index index) const {
|
|
||||||
return CastOp()(m_argImpl.coeff(index));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int LoadMode>
|
|
||||||
EIGEN_ALWAYS_INLINE SrcPacketType srcPacket(Index row, Index col, Index offset) const {
|
|
||||||
EIGEN_STATIC_ASSERT((LoadMode & (LoadMode - 1)) == 0, LoadMode must be a power of two)
|
|
||||||
constexpr bool ArgIsRowMajor = evaluator<ArgType>::Flags & RowMajorBit;
|
|
||||||
return m_argImpl.template packet<LoadMode, SrcPacketType>(ArgIsRowMajor ? row : row + (offset * SrcPacketSize),
|
|
||||||
ArgIsRowMajor ? col + (offset * SrcPacketSize) : col);
|
|
||||||
}
|
|
||||||
template <int LoadMode>
|
|
||||||
EIGEN_ALWAYS_INLINE SrcPacketType srcPacket(Index index, Index offset) const {
|
|
||||||
EIGEN_STATIC_ASSERT((LoadMode & (LoadMode - 1)) == 0, LoadMode must be a power of two)
|
|
||||||
return m_argImpl.template packet<LoadMode, SrcPacketType>(index + (offset * SrcPacketSize));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename DstPacketType>
|
|
||||||
using SrcPacketArgs1 = std::enable_if_t<unpacket_traits<DstPacketType>::size <= (1 * SrcPacketSize), bool>;
|
|
||||||
template <typename DstPacketType>
|
|
||||||
using SrcPacketArgs2 = std::enable_if_t<unpacket_traits<DstPacketType>::size == (2 * SrcPacketSize), bool>;
|
|
||||||
template <typename DstPacketType>
|
|
||||||
using SrcPacketArgs4 = std::enable_if_t<unpacket_traits<DstPacketType>::size == (4 * SrcPacketSize), bool>;
|
|
||||||
|
|
||||||
template <int LoadMode, typename DstPacketType, SrcPacketArgs1<DstPacketType> = true>
|
|
||||||
EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const {
|
|
||||||
constexpr int DstPacketSize = unpacket_traits<DstPacketType>::size;
|
|
||||||
constexpr int SrcIncrementBytes = DstPacketSize * sizeof(SrcType);
|
|
||||||
constexpr int SrcLoadMode = plain_enum_min(SrcIncrementBytes, LoadMode);
|
|
||||||
return CastOp().template packetOp<DstPacketType>(srcPacket<SrcLoadMode>(row, col, 0));
|
|
||||||
}
|
|
||||||
template <int LoadMode, typename DstPacketType, SrcPacketArgs2<DstPacketType> = true>
|
|
||||||
EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const {
|
|
||||||
constexpr int SrcLoadMode0 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode);
|
|
||||||
constexpr int SrcLoadMode1 = plain_enum_min(1 * SrcPacketSizeBytes, LoadMode);
|
|
||||||
return CastOp().template packetOp<DstPacketType>(srcPacket<SrcLoadMode0>(row, col, 0),
|
|
||||||
srcPacket<SrcLoadMode1>(row, col, 1));
|
|
||||||
}
|
|
||||||
template <int LoadMode, typename DstPacketType, SrcPacketArgs4<DstPacketType> = true>
|
|
||||||
EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const {
|
|
||||||
constexpr int SrcLoadMode0 = plain_enum_min(4 * SrcPacketSizeBytes, LoadMode);
|
|
||||||
constexpr int SrcLoadMode1 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode);
|
|
||||||
constexpr int SrcLoadMode2 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode);
|
|
||||||
constexpr int SrcLoadMode3 = plain_enum_min(1 * SrcPacketSizeBytes, LoadMode);
|
|
||||||
return CastOp().template packetOp<DstPacketType>(
|
|
||||||
srcPacket<SrcLoadMode0>(row, col, 0), srcPacket<SrcLoadMode1>(row, col, 1),
|
|
||||||
srcPacket<SrcLoadMode2>(row, col, 2), srcPacket<SrcLoadMode3>(row, col, 3));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int LoadMode, typename DstPacketType, SrcPacketArgs1<DstPacketType> = true>
|
|
||||||
EIGEN_STRONG_INLINE DstPacketType packet(Index index) const {
|
|
||||||
constexpr int DstPacketSize = unpacket_traits<DstPacketType>::size;
|
|
||||||
constexpr int SrcIncrementBytes = DstPacketSize * sizeof(SrcType);
|
|
||||||
constexpr int SrcLoadMode = plain_enum_min(SrcIncrementBytes, LoadMode);
|
|
||||||
return CastOp().template packetOp<DstPacketType>(srcPacket<SrcLoadMode>(index, 0));
|
|
||||||
}
|
|
||||||
template <int LoadMode, typename DstPacketType, SrcPacketArgs2<DstPacketType> = true>
|
|
||||||
EIGEN_STRONG_INLINE DstPacketType packet(Index index) const {
|
|
||||||
constexpr int SrcLoadMode0 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode);
|
|
||||||
constexpr int SrcLoadMode1 = plain_enum_min(1 * SrcPacketSizeBytes, LoadMode);
|
|
||||||
return CastOp().template packetOp<DstPacketType>(srcPacket<SrcLoadMode0>(index, 0),
|
|
||||||
srcPacket<SrcLoadMode1>(index, 1));
|
|
||||||
}
|
|
||||||
template <int LoadMode, typename DstPacketType, SrcPacketArgs4<DstPacketType> = true>
|
|
||||||
EIGEN_STRONG_INLINE DstPacketType packet(Index index) const {
|
|
||||||
constexpr int SrcLoadMode0 = plain_enum_min(4 * SrcPacketSizeBytes, LoadMode);
|
|
||||||
constexpr int SrcLoadMode1 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode);
|
|
||||||
constexpr int SrcLoadMode2 = plain_enum_min(2 * SrcPacketSizeBytes, LoadMode);
|
|
||||||
constexpr int SrcLoadMode3 = plain_enum_min(1 * SrcPacketSizeBytes, LoadMode);
|
|
||||||
return CastOp().template packetOp<DstPacketType>(
|
|
||||||
srcPacket<SrcLoadMode0>(index, 0), srcPacket<SrcLoadMode1>(index, 1),
|
|
||||||
srcPacket<SrcLoadMode2>(index, 2), srcPacket<SrcLoadMode3>(index, 3));
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
const evaluator<ArgType> m_argImpl;
|
|
||||||
};
|
|
||||||
|
|
||||||
// -------------------- CwiseTernaryOp --------------------
|
// -------------------- CwiseTernaryOp --------------------
|
||||||
|
|
||||||
// this is a ternary expression
|
// this is a ternary expression
|
||||||
|
@ -430,12 +430,6 @@ struct cast_impl
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename OldType>
|
|
||||||
struct cast_impl<OldType, bool> {
|
|
||||||
EIGEN_DEVICE_FUNC
|
|
||||||
static inline bool run(const OldType& x) { return x != OldType(0); }
|
|
||||||
};
|
|
||||||
|
|
||||||
// Casting from S -> Complex<T> leads to an implicit conversion from S to T,
|
// Casting from S -> Complex<T> leads to an implicit conversion from S to T,
|
||||||
// generating warnings on clang. Here we explicitly cast the real component.
|
// generating warnings on clang. Here we explicitly cast the real component.
|
||||||
template<typename OldType, typename NewType>
|
template<typename OldType, typename NewType>
|
||||||
|
@ -62,15 +62,6 @@ struct type_casting_traits<float, bool> {
|
|||||||
TgtCoeffRatio = 1
|
TgtCoeffRatio = 1
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
struct type_casting_traits<float, double> {
|
|
||||||
enum {
|
|
||||||
VectorizedCast = 1,
|
|
||||||
SrcCoeffRatio = 1,
|
|
||||||
TgtCoeffRatio = 2
|
|
||||||
};
|
|
||||||
};
|
|
||||||
#endif // EIGEN_VECTORIZE_AVX512
|
#endif // EIGEN_VECTORIZE_AVX512
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet8i pcast<Packet8f, Packet8i>(const Packet8f& a) {
|
template<> EIGEN_STRONG_INLINE Packet8i pcast<Packet8f, Packet8i>(const Packet8f& a) {
|
||||||
@ -89,10 +80,6 @@ template<> EIGEN_STRONG_INLINE Packet8i pcast<Packet4d, Packet8i>(const Packet4d
|
|||||||
return _mm256_set_m128i(_mm256_cvttpd_epi32(b), _mm256_cvttpd_epi32(a));
|
return _mm256_set_m128i(_mm256_cvttpd_epi32(b), _mm256_cvttpd_epi32(a));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4d pcast<Packet8f, Packet4d>(const Packet8f& a) {
|
|
||||||
return _mm256_cvtps_pd(_mm256_castps256_ps128(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE Packet16b pcast<Packet8f, Packet16b>(const Packet8f& a,
|
EIGEN_STRONG_INLINE Packet16b pcast<Packet8f, Packet16b>(const Packet8f& a,
|
||||||
const Packet8f& b) {
|
const Packet8f& b) {
|
||||||
|
@ -1014,49 +1014,4 @@ struct hash<Eigen::half> {
|
|||||||
} // end namespace std
|
} // end namespace std
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace Eigen {
|
|
||||||
namespace internal {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct cast_impl<float, half> {
|
|
||||||
EIGEN_DEVICE_FUNC
|
|
||||||
static inline half run(const float& a) {
|
|
||||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
|
||||||
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
|
||||||
return __float2half(a);
|
|
||||||
#else
|
|
||||||
return Eigen::half(a);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct cast_impl<int, half> {
|
|
||||||
EIGEN_DEVICE_FUNC
|
|
||||||
static inline half run(const int& a) {
|
|
||||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
|
||||||
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
|
||||||
return __float2half(static_cast<float>(a));
|
|
||||||
#else
|
|
||||||
return half(static_cast<float>(a));
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct cast_impl<half, float> {
|
|
||||||
EIGEN_DEVICE_FUNC
|
|
||||||
static inline float run(const half& a) {
|
|
||||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
|
||||||
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
|
||||||
return __half2float(a);
|
|
||||||
#else
|
|
||||||
return static_cast<float>(a);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace internal
|
|
||||||
} // namespace Eigen
|
|
||||||
|
|
||||||
#endif // EIGEN_HALF_H
|
#endif // EIGEN_HALF_H
|
||||||
|
116
Eigen/src/Core/arch/Default/TypeCasting.h
Normal file
116
Eigen/src/Core/arch/Default/TypeCasting.h
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
|
// for linear algebra.
|
||||||
|
//
|
||||||
|
// Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||||
|
// Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
#ifndef EIGEN_GENERIC_TYPE_CASTING_H
|
||||||
|
#define EIGEN_GENERIC_TYPE_CASTING_H
|
||||||
|
|
||||||
|
#include "../../InternalHeaderCheck.h"
|
||||||
|
|
||||||
|
namespace Eigen {
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct scalar_cast_op<float, Eigen::half> {
|
||||||
|
typedef Eigen::half result_type;
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const float& a) const {
|
||||||
|
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
||||||
|
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
||||||
|
return __float2half(a);
|
||||||
|
#else
|
||||||
|
return Eigen::half(a);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct functor_traits<scalar_cast_op<float, Eigen::half> >
|
||||||
|
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
||||||
|
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct scalar_cast_op<int, Eigen::half> {
|
||||||
|
typedef Eigen::half result_type;
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const int& a) const {
|
||||||
|
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
||||||
|
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
||||||
|
return __float2half(static_cast<float>(a));
|
||||||
|
#else
|
||||||
|
return Eigen::half(static_cast<float>(a));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct functor_traits<scalar_cast_op<int, Eigen::half> >
|
||||||
|
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
||||||
|
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct scalar_cast_op<Eigen::half, float> {
|
||||||
|
typedef float result_type;
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::half& a) const {
|
||||||
|
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
||||||
|
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
||||||
|
return __half2float(a);
|
||||||
|
#else
|
||||||
|
return static_cast<float>(a);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct functor_traits<scalar_cast_op<Eigen::half, float> >
|
||||||
|
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
||||||
|
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct scalar_cast_op<float, Eigen::bfloat16> {
|
||||||
|
typedef Eigen::bfloat16 result_type;
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const {
|
||||||
|
return Eigen::bfloat16(a);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct functor_traits<scalar_cast_op<float, Eigen::bfloat16> >
|
||||||
|
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
||||||
|
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct scalar_cast_op<int, Eigen::bfloat16> {
|
||||||
|
typedef Eigen::bfloat16 result_type;
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const {
|
||||||
|
return Eigen::bfloat16(static_cast<float>(a));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct functor_traits<scalar_cast_op<int, Eigen::bfloat16> >
|
||||||
|
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
||||||
|
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct scalar_cast_op<Eigen::bfloat16, float> {
|
||||||
|
typedef float result_type;
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const {
|
||||||
|
return static_cast<float>(a);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct functor_traits<scalar_cast_op<Eigen::bfloat16, float> >
|
||||||
|
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // EIGEN_GENERIC_TYPE_CASTING_H
|
@ -173,40 +173,22 @@ struct functor_traits<scalar_carg_op<Scalar>> {
|
|||||||
*
|
*
|
||||||
* \sa class CwiseUnaryOp, MatrixBase::cast()
|
* \sa class CwiseUnaryOp, MatrixBase::cast()
|
||||||
*/
|
*/
|
||||||
template <typename SrcType, typename DstType>
|
template<typename Scalar, typename NewType>
|
||||||
struct scalar_cast_op {
|
struct scalar_cast_op {
|
||||||
|
typedef NewType result_type;
|
||||||
using result_type = DstType;
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const NewType operator() (const Scalar& a) const { return cast<Scalar, NewType>(a); }
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DstType operator()(const SrcType& a) const {
|
|
||||||
return cast<SrcType, DstType>(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
using SrcPacket = typename packet_traits<SrcType>::type;
|
|
||||||
|
|
||||||
template <typename DstPacket>
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DstPacket packetOp(const SrcPacket& a) const {
|
|
||||||
return pcast<SrcPacket, DstPacket>(a);
|
|
||||||
}
|
|
||||||
template <typename DstPacket>
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DstPacket packetOp(const SrcPacket& a, const SrcPacket& b) const {
|
|
||||||
return pcast<SrcPacket, DstPacket>(a, b);
|
|
||||||
}
|
|
||||||
template <typename DstPacket>
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DstPacket packetOp(const SrcPacket& a, const SrcPacket& b,
|
|
||||||
const SrcPacket& c, const SrcPacket& d) const {
|
|
||||||
return pcast<SrcPacket, DstPacket>(a, b, c, d);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename SrcType, typename DstType>
|
template <typename Scalar>
|
||||||
struct functor_traits<scalar_cast_op<SrcType, DstType>> {
|
struct scalar_cast_op<Scalar, bool> {
|
||||||
enum {
|
typedef bool result_type;
|
||||||
Cost = is_same<SrcType, DstType>::value ? 0 : NumTraits<DstType>::AddCost,
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a) const { return a != Scalar(0); }
|
||||||
PacketAccess = (type_casting_traits<SrcType, DstType>::VectorizedCast != 0) &&
|
|
||||||
(type_casting_traits<SrcType, DstType>::SrcCoeffRatio <= 4)
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename Scalar, typename NewType>
|
||||||
|
struct functor_traits<scalar_cast_op<Scalar,NewType> >
|
||||||
|
{ enum { Cost = is_same<Scalar, NewType>::value ? 0 : NumTraits<NewType>::AddCost, PacketAccess = false }; };
|
||||||
|
|
||||||
/** \internal
|
/** \internal
|
||||||
* \brief Template functor to arithmetically shift a scalar right by a number of bits
|
* \brief Template functor to arithmetically shift a scalar right by a number of bits
|
||||||
*
|
*
|
||||||
|
@ -9,7 +9,6 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "main.h"
|
#include "main.h"
|
||||||
#include "random_without_cast_overflow.h"
|
|
||||||
|
|
||||||
template <typename Scalar, std::enable_if_t<NumTraits<Scalar>::IsInteger,int> = 0>
|
template <typename Scalar, std::enable_if_t<NumTraits<Scalar>::IsInteger,int> = 0>
|
||||||
std::vector<Scalar> special_values() {
|
std::vector<Scalar> special_values() {
|
||||||
@ -1184,59 +1183,6 @@ void typed_logicals_test(const ArrayType& m) {
|
|||||||
typed_logicals_test_impl<ArrayType>::run(m);
|
typed_logicals_test_impl<ArrayType>::run(m);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcType, typename DstType>
|
|
||||||
struct cast_test_impl {
|
|
||||||
using SrcArray = ArrayX<SrcType>;
|
|
||||||
using DstArray = ArrayX<DstType>;
|
|
||||||
|
|
||||||
static constexpr int SrcPacketSize = internal::packet_traits<SrcType>::size;
|
|
||||||
static constexpr int DstPacketSize = internal::packet_traits<DstType>::size;
|
|
||||||
static constexpr int MaxPacketSize = internal::plain_enum_max(SrcPacketSize, DstPacketSize);
|
|
||||||
|
|
||||||
static void run() {
|
|
||||||
const Index testSize = 100 * MaxPacketSize;
|
|
||||||
SrcArray src(testSize);
|
|
||||||
for (Index i = 0; i < testSize; i++) src(i) = internal::random_without_cast_overflow<SrcType, DstType>::value();
|
|
||||||
DstArray dst = src.template cast<DstType>();
|
|
||||||
for (Index i = 0; i < testSize; i++) {
|
|
||||||
DstType ref = static_cast<DstType>(src(i));
|
|
||||||
bool all_nan = ((numext::isnan)(src(i)) && (numext::isnan)(ref) && (numext::isnan)(dst(i)));
|
|
||||||
bool is_equal = ref == dst(i);
|
|
||||||
bool pass = all_nan || is_equal;
|
|
||||||
if (!pass) {
|
|
||||||
std::cout << typeid(SrcType).name() << ": [" << +src(i) << "] to " << typeid(DstType).name() << ": [" << +dst(i)
|
|
||||||
<< "] != [" << +ref << "]\n";
|
|
||||||
}
|
|
||||||
VERIFY(pass);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename... ScalarTypes>
|
|
||||||
struct cast_tests_impl {
|
|
||||||
using ScalarTuple = std::tuple<ScalarTypes...>;
|
|
||||||
static constexpr size_t ScalarTupleSize = std::tuple_size<ScalarTuple>::value;
|
|
||||||
|
|
||||||
template <size_t i = 0, size_t j = i + 1, bool Done = (i >= ScalarTupleSize - 1) || (j >= ScalarTupleSize)>
|
|
||||||
static std::enable_if_t<Done> run() {}
|
|
||||||
|
|
||||||
template <size_t i = 0, size_t j = i + 1, bool Done = (i >= ScalarTupleSize - 1) || (j >= ScalarTupleSize)>
|
|
||||||
static std::enable_if_t<!Done> run() {
|
|
||||||
using Type1 = typename std::tuple_element<i, ScalarTuple>::type;
|
|
||||||
using Type2 = typename std::tuple_element<j, ScalarTuple>::type;
|
|
||||||
cast_test_impl<Type1, Type2>::run();
|
|
||||||
cast_test_impl<Type2, Type1>::run();
|
|
||||||
static constexpr size_t next_i = (j == ScalarTupleSize - 1) ? (i + 1) : (i + 0);
|
|
||||||
static constexpr size_t next_j = (j == ScalarTupleSize - 1) ? (i + 2) : (j + 1);
|
|
||||||
run<next_i, next_j>();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void cast_test() {
|
|
||||||
cast_tests_impl<bool, int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t, float, double,
|
|
||||||
long double, half, bfloat16>::run();
|
|
||||||
}
|
|
||||||
|
|
||||||
EIGEN_DECLARE_TEST(array_cwise)
|
EIGEN_DECLARE_TEST(array_cwise)
|
||||||
{
|
{
|
||||||
for(int i = 0; i < g_repeat; i++) {
|
for(int i = 0; i < g_repeat; i++) {
|
||||||
@ -1293,9 +1239,6 @@ EIGEN_DECLARE_TEST(array_cwise)
|
|||||||
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<float>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
|
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<float>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
|
||||||
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<double>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
|
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<double>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
|
||||||
}
|
}
|
||||||
for (int i = 0; i < g_repeat; i++) {
|
|
||||||
cast_test();
|
|
||||||
}
|
|
||||||
|
|
||||||
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));
|
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));
|
||||||
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<float>::type, float >::value));
|
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<float>::type, float >::value));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user