mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-13 16:41:50 +08:00
Partially Vectorize Cast
This commit is contained in:
parent
7d7576f326
commit
59b3ef5409
@ -186,7 +186,6 @@ using std::ptrdiff_t;
|
|||||||
// Generic half float support
|
// Generic half float support
|
||||||
#include "src/Core/arch/Default/Half.h"
|
#include "src/Core/arch/Default/Half.h"
|
||||||
#include "src/Core/arch/Default/BFloat16.h"
|
#include "src/Core/arch/Default/BFloat16.h"
|
||||||
#include "src/Core/arch/Default/TypeCasting.h"
|
|
||||||
#include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h"
|
#include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h"
|
||||||
|
|
||||||
#if defined EIGEN_VECTORIZE_AVX512
|
#if defined EIGEN_VECTORIZE_AVX512
|
||||||
|
@ -621,6 +621,207 @@ protected:
|
|||||||
Data m_d;
|
Data m_d;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// ----------------------- Casting ---------------------
|
||||||
|
|
||||||
|
template <typename SrcType, typename DstType, typename ArgType>
|
||||||
|
struct unary_evaluator<CwiseUnaryOp<core_cast_op<SrcType, DstType>, ArgType>, IndexBased> {
|
||||||
|
using CastOp = core_cast_op<SrcType, DstType>;
|
||||||
|
using XprType = CwiseUnaryOp<CastOp, ArgType>;
|
||||||
|
|
||||||
|
// Use the largest packet type by default
|
||||||
|
using SrcPacketType = typename packet_traits<SrcType>::type;
|
||||||
|
static constexpr int SrcPacketSize = unpacket_traits<SrcPacketType>::size;
|
||||||
|
static constexpr int SrcPacketBytes = SrcPacketSize * sizeof(SrcType);
|
||||||
|
|
||||||
|
enum {
|
||||||
|
CoeffReadCost = int(evaluator<ArgType>::CoeffReadCost) + int(functor_traits<CastOp>::Cost),
|
||||||
|
PacketAccess = functor_traits<CastOp>::PacketAccess,
|
||||||
|
ActualPacketAccessBit = PacketAccess ? PacketAccessBit : 0,
|
||||||
|
Flags = evaluator<ArgType>::Flags & (HereditaryBits | LinearAccessBit | ActualPacketAccessBit),
|
||||||
|
IsRowMajor = (evaluator<ArgType>::Flags & RowMajorBit),
|
||||||
|
Alignment = evaluator<ArgType>::Alignment
|
||||||
|
};
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit unary_evaluator(const XprType& xpr)
|
||||||
|
: m_argImpl(xpr.nestedExpression()), m_rows(xpr.rows()), m_cols(xpr.cols()) {
|
||||||
|
EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<CastOp>::Cost);
|
||||||
|
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename DstPacketType>
|
||||||
|
using AltSrcScalarOp = std::enable_if_t<(unpacket_traits<DstPacketType>::size < SrcPacketSize && !find_packet_by_size<SrcType, unpacket_traits<DstPacketType>::size>::value), bool>;
|
||||||
|
template <typename DstPacketType>
|
||||||
|
using SrcPacketArgs1 = std::enable_if_t<(find_packet_by_size<SrcType, unpacket_traits<DstPacketType>::size>::value), bool>;
|
||||||
|
template <typename DstPacketType>
|
||||||
|
using SrcPacketArgs2 = std::enable_if_t<(unpacket_traits<DstPacketType>::size) == (2 * SrcPacketSize), bool>;
|
||||||
|
template <typename DstPacketType>
|
||||||
|
using SrcPacketArgs4 = std::enable_if_t<(unpacket_traits<DstPacketType>::size) == (4 * SrcPacketSize), bool>;
|
||||||
|
template <typename DstPacketType>
|
||||||
|
using SrcPacketArgs8 = std::enable_if_t<(unpacket_traits<DstPacketType>::size) == (8 * SrcPacketSize), bool>;
|
||||||
|
|
||||||
|
template <bool UseRowMajor = IsRowMajor, std::enable_if_t<UseRowMajor, bool> = true>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_array_bounds(Index, Index col, Index packetSize) const {
|
||||||
|
return col + packetSize <= cols();
|
||||||
|
}
|
||||||
|
template <bool UseRowMajor = IsRowMajor, std::enable_if_t<!UseRowMajor, bool> = true>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_array_bounds(Index row, Index, Index packetSize) const {
|
||||||
|
return row + packetSize <= rows();
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_array_bounds(Index index, Index packetSize) const {
|
||||||
|
return index + packetSize <= size();
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE SrcType srcCoeff(Index row, Index col, Index offset) const {
|
||||||
|
Index actualRow = IsRowMajor ? row : row + offset;
|
||||||
|
Index actualCol = IsRowMajor ? col + offset : col;
|
||||||
|
return m_argImpl.coeff(actualRow, actualCol);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE SrcType srcCoeff(Index index, Index offset) const {
|
||||||
|
Index actualIndex = index + offset;
|
||||||
|
return m_argImpl.coeff(actualIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DstType coeff(Index row, Index col) const {
|
||||||
|
return cast<SrcType, DstType>(srcCoeff(row, col, 0));
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DstType coeff(Index index) const { return cast<SrcType, DstType>(srcCoeff(index, 0)); }
|
||||||
|
|
||||||
|
template <int LoadMode, typename PacketType = SrcPacketType>
|
||||||
|
EIGEN_STRONG_INLINE PacketType srcPacket(Index row, Index col, Index offset) const {
|
||||||
|
constexpr int PacketSize = unpacket_traits<PacketType>::size;
|
||||||
|
Index actualRow = IsRowMajor ? row : row + (offset * PacketSize);
|
||||||
|
Index actualCol = IsRowMajor ? col + (offset * PacketSize) : col;
|
||||||
|
eigen_assert(check_array_bounds(actualRow, actualCol, PacketSize) && "Array index out of bounds");
|
||||||
|
return m_argImpl.template packet<LoadMode, PacketType>(actualRow, actualCol);
|
||||||
|
}
|
||||||
|
template <int LoadMode, typename PacketType = SrcPacketType>
|
||||||
|
EIGEN_STRONG_INLINE PacketType srcPacket(Index index, Index offset) const {
|
||||||
|
constexpr int PacketSize = unpacket_traits<PacketType>::size;
|
||||||
|
Index actualIndex = index + (offset * PacketSize);
|
||||||
|
eigen_assert(check_array_bounds(actualIndex, PacketSize) && "Array index out of bounds");
|
||||||
|
return m_argImpl.template packet<LoadMode, PacketType>(actualIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
// There is no source packet type with equal or fewer elements than DstPacketType.
|
||||||
|
// This is problematic as the evaluation loop may attempt to access data outside the bounds of the array.
|
||||||
|
// For example, consider the cast utilizing pcast<Packet4f,Packet2d> with an array of size 4: {0.0f,1.0f,2.0f,3.0f}.
|
||||||
|
// The first iteration of the evaulation loop will load 16 bytes: {0.0f,1.0f,2.0f,3.0f} and cast to {0.0,1.0}, which is acceptable.
|
||||||
|
// The second iteration will load 16 bytes: {2.0f,3.0f,?,?}, which is outside the bounds of the array.
|
||||||
|
|
||||||
|
// Instead, perform runtime check to determine if the load would access data outside the bounds of the array.
|
||||||
|
// If not, perform full load. Otherwise, revert to a scalar loop to perform a partial load.
|
||||||
|
// In either case, perform a vectorized cast of the source packet.
|
||||||
|
template <int LoadMode, typename DstPacketType, AltSrcScalarOp<DstPacketType> = true>
|
||||||
|
EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const {
|
||||||
|
constexpr int DstPacketSize = unpacket_traits<DstPacketType>::size;
|
||||||
|
constexpr int SrcBytesIncrement = DstPacketSize * sizeof(SrcType);
|
||||||
|
constexpr int SrcLoadMode = plain_enum_min(SrcBytesIncrement, LoadMode);
|
||||||
|
SrcPacketType src;
|
||||||
|
if (EIGEN_PREDICT_TRUE(check_array_bounds(row, col, SrcPacketSize))) {
|
||||||
|
src = srcPacket<SrcLoadMode>(row, col, 0);
|
||||||
|
} else {
|
||||||
|
Array<SrcType, SrcPacketSize, 1> srcArray;
|
||||||
|
for (size_t k = 0; k < DstPacketSize; k++) srcArray[k] = srcCoeff(row, col, k);
|
||||||
|
for (size_t k = DstPacketSize; k < SrcPacketSize; k++) srcArray[k] = SrcType(0);
|
||||||
|
src = pload<SrcPacketType>(srcArray.data());
|
||||||
|
}
|
||||||
|
return pcast<SrcPacketType, DstPacketType>(src);
|
||||||
|
}
|
||||||
|
// Use the source packet type with the same size as DstPacketType, if it exists
|
||||||
|
template <int LoadMode, typename DstPacketType, SrcPacketArgs1<DstPacketType> = true>
|
||||||
|
EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const {
|
||||||
|
constexpr int DstPacketSize = unpacket_traits<DstPacketType>::size;
|
||||||
|
using SizedSrcPacketType = typename find_packet_by_size<SrcType, DstPacketSize>::type;
|
||||||
|
constexpr int SrcBytesIncrement = DstPacketSize * sizeof(SrcType);
|
||||||
|
constexpr int SrcLoadMode = plain_enum_min(SrcBytesIncrement, LoadMode);
|
||||||
|
return pcast<SizedSrcPacketType, DstPacketType>(
|
||||||
|
srcPacket<SrcLoadMode, SizedSrcPacketType>(row, col, 0));
|
||||||
|
}
|
||||||
|
// unpacket_traits<DstPacketType>::size == 2 * SrcPacketSize
|
||||||
|
template <int LoadMode, typename DstPacketType, SrcPacketArgs2<DstPacketType> = true>
|
||||||
|
EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const {
|
||||||
|
constexpr int SrcLoadMode = plain_enum_min(SrcPacketBytes, LoadMode);
|
||||||
|
return pcast<SrcPacketType, DstPacketType>(
|
||||||
|
srcPacket<SrcLoadMode>(row, col, 0), srcPacket<SrcLoadMode>(row, col, 1));
|
||||||
|
}
|
||||||
|
// unpacket_traits<DstPacketType>::size == 4 * SrcPacketSize
|
||||||
|
template <int LoadMode, typename DstPacketType, SrcPacketArgs4<DstPacketType> = true>
|
||||||
|
EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const {
|
||||||
|
constexpr int SrcLoadMode = plain_enum_min(SrcPacketBytes, LoadMode);
|
||||||
|
return pcast<SrcPacketType, DstPacketType>(
|
||||||
|
srcPacket<SrcLoadMode>(row, col, 0), srcPacket<SrcLoadMode>(row, col, 1),
|
||||||
|
srcPacket<SrcLoadMode>(row, col, 2), srcPacket<SrcLoadMode>(row, col, 3));
|
||||||
|
}
|
||||||
|
// unpacket_traits<DstPacketType>::size == 8 * SrcPacketSize
|
||||||
|
template <int LoadMode, typename DstPacketType, SrcPacketArgs8<DstPacketType> = true>
|
||||||
|
EIGEN_STRONG_INLINE DstPacketType packet(Index row, Index col) const {
|
||||||
|
constexpr int SrcLoadMode = plain_enum_min(SrcPacketBytes, LoadMode);
|
||||||
|
return pcast<SrcPacketType, DstPacketType>(
|
||||||
|
srcPacket<SrcLoadMode>(row, col, 0), srcPacket<SrcLoadMode>(row, col, 1),
|
||||||
|
srcPacket<SrcLoadMode>(row, col, 2), srcPacket<SrcLoadMode>(row, col, 3),
|
||||||
|
srcPacket<SrcLoadMode>(row, col, 4), srcPacket<SrcLoadMode>(row, col, 5),
|
||||||
|
srcPacket<SrcLoadMode>(row, col, 6), srcPacket<SrcLoadMode>(row, col, 7));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Analagous routines for linear access.
|
||||||
|
template <int LoadMode, typename DstPacketType, AltSrcScalarOp<DstPacketType> = true>
|
||||||
|
EIGEN_STRONG_INLINE DstPacketType packet(Index index) const {
|
||||||
|
constexpr int DstPacketSize = unpacket_traits<DstPacketType>::size;
|
||||||
|
constexpr int SrcBytesIncrement = DstPacketSize * sizeof(SrcType);
|
||||||
|
constexpr int SrcLoadMode = plain_enum_min(SrcBytesIncrement, LoadMode);
|
||||||
|
SrcPacketType src;
|
||||||
|
if (EIGEN_PREDICT_TRUE(check_array_bounds(index, SrcPacketSize))) {
|
||||||
|
src = srcPacket<SrcLoadMode>(index, 0);
|
||||||
|
} else {
|
||||||
|
Array<SrcType, SrcPacketSize, 1> srcArray;
|
||||||
|
for (size_t k = 0; k < DstPacketSize; k++) srcArray[k] = srcCoeff(index, k);
|
||||||
|
for (size_t k = DstPacketSize; k < SrcPacketSize; k++) srcArray[k] = SrcType(0);
|
||||||
|
src = pload<SrcPacketType>(srcArray.data());
|
||||||
|
}
|
||||||
|
return pcast<SrcPacketType, DstPacketType>(src);
|
||||||
|
}
|
||||||
|
template <int LoadMode, typename DstPacketType, SrcPacketArgs1<DstPacketType> = true>
|
||||||
|
EIGEN_STRONG_INLINE DstPacketType packet(Index index) const {
|
||||||
|
constexpr int DstPacketSize = unpacket_traits<DstPacketType>::size;
|
||||||
|
using SizedSrcPacketType = typename find_packet_by_size<SrcType, DstPacketSize>::type;
|
||||||
|
constexpr int SrcBytesIncrement = DstPacketSize * sizeof(SrcType);
|
||||||
|
constexpr int SrcLoadMode = plain_enum_min(SrcBytesIncrement, LoadMode);
|
||||||
|
return pcast<SizedSrcPacketType, DstPacketType>(
|
||||||
|
srcPacket<SrcLoadMode, SizedSrcPacketType>(index, 0));
|
||||||
|
}
|
||||||
|
template <int LoadMode, typename DstPacketType, SrcPacketArgs2<DstPacketType> = true>
|
||||||
|
EIGEN_STRONG_INLINE DstPacketType packet(Index index) const {
|
||||||
|
constexpr int SrcLoadMode = plain_enum_min(SrcPacketBytes, LoadMode);
|
||||||
|
return pcast<SrcPacketType, DstPacketType>(
|
||||||
|
srcPacket<SrcLoadMode>(index, 0), srcPacket<SrcLoadMode>(index, 1));
|
||||||
|
}
|
||||||
|
template <int LoadMode, typename DstPacketType, SrcPacketArgs4<DstPacketType> = true>
|
||||||
|
EIGEN_STRONG_INLINE DstPacketType packet(Index index) const {
|
||||||
|
constexpr int SrcLoadMode = plain_enum_min(SrcPacketBytes, LoadMode);
|
||||||
|
return pcast<SrcPacketType, DstPacketType>(
|
||||||
|
srcPacket<SrcLoadMode>(index, 0), srcPacket<SrcLoadMode>(index, 1),
|
||||||
|
srcPacket<SrcLoadMode>(index, 2), srcPacket<SrcLoadMode>(index, 3));
|
||||||
|
}
|
||||||
|
template <int LoadMode, typename DstPacketType, SrcPacketArgs8<DstPacketType> = true>
|
||||||
|
EIGEN_STRONG_INLINE DstPacketType packet(Index index) const {
|
||||||
|
constexpr int SrcLoadMode = plain_enum_min(SrcPacketBytes, LoadMode);
|
||||||
|
return pcast<SrcPacketType, DstPacketType>(
|
||||||
|
srcPacket<SrcLoadMode>(index, 0), srcPacket<SrcLoadMode>(index, 1),
|
||||||
|
srcPacket<SrcLoadMode>(index, 2), srcPacket<SrcLoadMode>(index, 3),
|
||||||
|
srcPacket<SrcLoadMode>(index, 4), srcPacket<SrcLoadMode>(index, 5),
|
||||||
|
srcPacket<SrcLoadMode>(index, 6), srcPacket<SrcLoadMode>(index, 7));
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rows() const { return m_rows; }
|
||||||
|
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index cols() const { return m_cols; }
|
||||||
|
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return m_rows * m_cols; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const evaluator<ArgType> m_argImpl;
|
||||||
|
const variable_if_dynamic<Index, XprType::RowsAtCompileTime> m_rows;
|
||||||
|
const variable_if_dynamic<Index, XprType::ColsAtCompileTime> m_cols;
|
||||||
|
};
|
||||||
|
|
||||||
// -------------------- CwiseTernaryOp --------------------
|
// -------------------- CwiseTernaryOp --------------------
|
||||||
|
|
||||||
// this is a ternary expression
|
// this is a ternary expression
|
||||||
|
@ -146,14 +146,67 @@ template<typename T> struct unpacket_traits
|
|||||||
|
|
||||||
template<typename T> struct unpacket_traits<const T> : unpacket_traits<T> { };
|
template<typename T> struct unpacket_traits<const T> : unpacket_traits<T> { };
|
||||||
|
|
||||||
template <typename Src, typename Tgt> struct type_casting_traits {
|
/** \internal A convenience utility for determining if the type is a scalar.
|
||||||
|
* This is used to enable some generic packet implementations.
|
||||||
|
*/
|
||||||
|
template <typename Packet>
|
||||||
|
struct is_scalar {
|
||||||
|
using Scalar = typename unpacket_traits<Packet>::type;
|
||||||
|
enum { value = internal::is_same<Packet, Scalar>::value };
|
||||||
|
};
|
||||||
|
|
||||||
|
// automatically and succinctly define combinations of pcast<SrcPacket,TgtPacket> when
|
||||||
|
// 1) the packets are the same type, or
|
||||||
|
// 2) the packets differ only in sign.
|
||||||
|
// In both of these cases, preinterpret (bit_cast) is equivalent to pcast (static_cast)
|
||||||
|
template <typename SrcPacket, typename TgtPacket,
|
||||||
|
bool Scalar = is_scalar<SrcPacket>::value && is_scalar<TgtPacket>::value>
|
||||||
|
struct is_degenerate_helper : is_same<SrcPacket, TgtPacket> {};
|
||||||
|
template <>
|
||||||
|
struct is_degenerate_helper<int8_t, uint8_t, true> : std::true_type {};
|
||||||
|
template <>
|
||||||
|
struct is_degenerate_helper<int16_t, uint16_t, true> : std::true_type {};
|
||||||
|
template <>
|
||||||
|
struct is_degenerate_helper<int32_t, uint32_t, true> : std::true_type {};
|
||||||
|
template <>
|
||||||
|
struct is_degenerate_helper<int64_t, uint64_t, true> : std::true_type {};
|
||||||
|
|
||||||
|
template <typename SrcPacket, typename TgtPacket>
|
||||||
|
struct is_degenerate_helper<SrcPacket, TgtPacket, false> {
|
||||||
|
using SrcScalar = typename unpacket_traits<SrcPacket>::type;
|
||||||
|
static constexpr int SrcSize = unpacket_traits<SrcPacket>::size;
|
||||||
|
using TgtScalar = typename unpacket_traits<TgtPacket>::type;
|
||||||
|
static constexpr int TgtSize = unpacket_traits<TgtPacket>::size;
|
||||||
|
static constexpr bool value = is_degenerate_helper<SrcScalar, TgtScalar, true>::value && (SrcSize == TgtSize);
|
||||||
|
};
|
||||||
|
|
||||||
|
// is_degenerate<T1,T2>::value == is_degenerate<T2,T1>::value
|
||||||
|
template <typename SrcPacket, typename TgtPacket>
|
||||||
|
struct is_degenerate {
|
||||||
|
static constexpr bool value =
|
||||||
|
is_degenerate_helper<SrcPacket, TgtPacket>::value || is_degenerate_helper<TgtPacket, SrcPacket>::value;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Packet>
|
||||||
|
struct is_half {
|
||||||
|
using Scalar = typename unpacket_traits<Packet>::type;
|
||||||
|
static constexpr int Size = unpacket_traits<Packet>::size;
|
||||||
|
using DefaultPacket = typename packet_traits<Scalar>::type;
|
||||||
|
static constexpr int DefaultSize = unpacket_traits<DefaultPacket>::size;
|
||||||
|
static constexpr bool value = Size < DefaultSize;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Src, typename Tgt>
|
||||||
|
struct type_casting_traits {
|
||||||
enum {
|
enum {
|
||||||
VectorizedCast = 0,
|
VectorizedCast =
|
||||||
|
is_degenerate<Src, Tgt>::value && packet_traits<Src>::Vectorizable && packet_traits<Tgt>::Vectorizable,
|
||||||
SrcCoeffRatio = 1,
|
SrcCoeffRatio = 1,
|
||||||
TgtCoeffRatio = 1
|
TgtCoeffRatio = 1
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
/** \internal Wrapper to ensure that multiple packet types can map to the same
|
/** \internal Wrapper to ensure that multiple packet types can map to the same
|
||||||
same underlying vector type. */
|
same underlying vector type. */
|
||||||
template<typename T, int unique_id = 0>
|
template<typename T, int unique_id = 0>
|
||||||
@ -171,45 +224,84 @@ struct eigen_packet_wrapper
|
|||||||
T m_val;
|
T m_val;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Target, typename Packet, bool IsSame = is_same<Target, Packet>::value>
|
||||||
|
struct preinterpret_generic;
|
||||||
|
|
||||||
/** \internal A convenience utility for determining if the type is a scalar.
|
template <typename Target, typename Packet>
|
||||||
* This is used to enable some generic packet implementations.
|
struct preinterpret_generic<Target, Packet, false> {
|
||||||
*/
|
// the packets are not the same, attempt scalar bit_cast
|
||||||
template<typename Packet>
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Target run(const Packet& a) {
|
||||||
struct is_scalar {
|
return numext::bit_cast<Target, Packet>(a);
|
||||||
using Scalar = typename unpacket_traits<Packet>::type;
|
}
|
||||||
enum {
|
|
||||||
value = internal::is_same<Packet, Scalar>::value
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/** \internal \returns static_cast<TgtType>(a) (coeff-wise) */
|
template <typename Packet>
|
||||||
template <typename SrcPacket, typename TgtPacket>
|
struct preinterpret_generic<Packet, Packet, true> {
|
||||||
EIGEN_DEVICE_FUNC inline TgtPacket
|
// the packets are the same type: do nothing
|
||||||
pcast(const SrcPacket& a) {
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& a) { return a; }
|
||||||
return static_cast<TgtPacket>(a);
|
};
|
||||||
}
|
|
||||||
template <typename SrcPacket, typename TgtPacket>
|
|
||||||
EIGEN_DEVICE_FUNC inline TgtPacket
|
|
||||||
pcast(const SrcPacket& a, const SrcPacket& /*b*/) {
|
|
||||||
return static_cast<TgtPacket>(a);
|
|
||||||
}
|
|
||||||
template <typename SrcPacket, typename TgtPacket>
|
|
||||||
EIGEN_DEVICE_FUNC inline TgtPacket
|
|
||||||
pcast(const SrcPacket& a, const SrcPacket& /*b*/, const SrcPacket& /*c*/, const SrcPacket& /*d*/) {
|
|
||||||
return static_cast<TgtPacket>(a);
|
|
||||||
}
|
|
||||||
template <typename SrcPacket, typename TgtPacket>
|
|
||||||
EIGEN_DEVICE_FUNC inline TgtPacket
|
|
||||||
pcast(const SrcPacket& a, const SrcPacket& /*b*/, const SrcPacket& /*c*/, const SrcPacket& /*d*/,
|
|
||||||
const SrcPacket& /*e*/, const SrcPacket& /*f*/, const SrcPacket& /*g*/, const SrcPacket& /*h*/) {
|
|
||||||
return static_cast<TgtPacket>(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** \internal \returns reinterpret_cast<Target>(a) */
|
/** \internal \returns reinterpret_cast<Target>(a) */
|
||||||
template <typename Target, typename Packet>
|
template <typename Target, typename Packet>
|
||||||
EIGEN_DEVICE_FUNC inline Target
|
EIGEN_DEVICE_FUNC inline Target preinterpret(const Packet& a) {
|
||||||
preinterpret(const Packet& a); /* { return reinterpret_cast<const Target&>(a); } */
|
return preinterpret_generic<Target, Packet>::run(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcPacket, typename TgtPacket, bool Degenerate = is_degenerate<SrcPacket, TgtPacket>::value, bool TgtIsHalf = is_half<TgtPacket>::value>
|
||||||
|
struct pcast_generic;
|
||||||
|
|
||||||
|
template <typename SrcPacket, typename TgtPacket>
|
||||||
|
struct pcast_generic<SrcPacket, TgtPacket, false, false> {
|
||||||
|
// the packets are not degenerate: attempt scalar static_cast
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket run(const SrcPacket& a) {
|
||||||
|
return cast_impl<SrcPacket, TgtPacket>::run(a);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Packet>
|
||||||
|
struct pcast_generic<Packet, Packet, true, false> {
|
||||||
|
// the packets are the same: do nothing
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& a) { return a; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SrcPacket, typename TgtPacket, bool TgtIsHalf>
|
||||||
|
struct pcast_generic<SrcPacket, TgtPacket, true, TgtIsHalf> {
|
||||||
|
// the packets are degenerate: preinterpret is equivalent to pcast
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket run(const SrcPacket& a) { return preinterpret<TgtPacket>(a); }
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/** \internal \returns static_cast<TgtType>(a) (coeff-wise) */
|
||||||
|
template <typename SrcPacket, typename TgtPacket>
|
||||||
|
EIGEN_DEVICE_FUNC inline TgtPacket pcast(const SrcPacket& a) {
|
||||||
|
return pcast_generic<SrcPacket, TgtPacket>::run(a);
|
||||||
|
}
|
||||||
|
template <typename SrcPacket, typename TgtPacket>
|
||||||
|
EIGEN_DEVICE_FUNC inline TgtPacket pcast(const SrcPacket& a, const SrcPacket& b) {
|
||||||
|
return pcast_generic<SrcPacket, TgtPacket>::run(a, b);
|
||||||
|
}
|
||||||
|
template <typename SrcPacket, typename TgtPacket>
|
||||||
|
EIGEN_DEVICE_FUNC inline TgtPacket pcast(const SrcPacket& a, const SrcPacket& b, const SrcPacket& c,
|
||||||
|
const SrcPacket& d) {
|
||||||
|
return pcast_generic<SrcPacket, TgtPacket>::run(a, b, c, d);
|
||||||
|
}
|
||||||
|
template <typename SrcPacket, typename TgtPacket>
|
||||||
|
EIGEN_DEVICE_FUNC inline TgtPacket pcast(const SrcPacket& a, const SrcPacket& b, const SrcPacket& c, const SrcPacket& d,
|
||||||
|
const SrcPacket& e, const SrcPacket& f, const SrcPacket& g,
|
||||||
|
const SrcPacket& h) {
|
||||||
|
return pcast_generic<SrcPacket, TgtPacket>::run(a, b, c, d, e, f, g, h);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcPacket, typename TgtPacket>
|
||||||
|
struct pcast_generic<SrcPacket, TgtPacket, false, true> {
|
||||||
|
// TgtPacket is a half packet of some other type
|
||||||
|
// perform cast and truncate result
|
||||||
|
using DefaultTgtPacket = typename is_half<TgtPacket>::DefaultPacket;
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket run(const SrcPacket& a) {
|
||||||
|
return preinterpret<TgtPacket>(pcast<SrcPacket, DefaultTgtPacket>(a));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/** \internal \returns a + b (coeff-wise) */
|
/** \internal \returns a + b (coeff-wise) */
|
||||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||||
|
@ -430,6 +430,13 @@ struct cast_impl
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename OldType>
|
||||||
|
struct cast_impl<OldType, bool> {
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
static inline bool run(const OldType& x) { return x != OldType(0); }
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
// Casting from S -> Complex<T> leads to an implicit conversion from S to T,
|
// Casting from S -> Complex<T> leads to an implicit conversion from S to T,
|
||||||
// generating warnings on clang. Here we explicitly cast the real component.
|
// generating warnings on clang. Here we explicitly cast the real component.
|
||||||
template<typename OldType, typename NewType>
|
template<typename OldType, typename NewType>
|
||||||
|
@ -80,6 +80,14 @@ template<> EIGEN_STRONG_INLINE Packet8i pcast<Packet4d, Packet8i>(const Packet4d
|
|||||||
return _mm256_set_m128i(_mm256_cvttpd_epi32(b), _mm256_cvttpd_epi32(a));
|
return _mm256_set_m128i(_mm256_cvttpd_epi32(b), _mm256_cvttpd_epi32(a));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <> EIGEN_STRONG_INLINE Packet4f pcast<Packet4d, Packet4f>(const Packet4d& a) {
|
||||||
|
return _mm256_cvtpd_ps(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> EIGEN_STRONG_INLINE Packet4i pcast<Packet4d, Packet4i>(const Packet4d& a) {
|
||||||
|
return _mm256_cvttpd_epi32(a);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE Packet16b pcast<Packet8f, Packet16b>(const Packet8f& a,
|
EIGEN_STRONG_INLINE Packet16b pcast<Packet8f, Packet16b>(const Packet8f& a,
|
||||||
const Packet8f& b) {
|
const Packet8f& b) {
|
||||||
@ -118,6 +126,44 @@ template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f,Packet8i>(const Pa
|
|||||||
return _mm256_castsi256_ps(a);
|
return _mm256_castsi256_ps(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8ui preinterpret<Packet8ui, Packet8i>(const Packet8i& a) {
|
||||||
|
return Packet8ui(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8i preinterpret<Packet8i, Packet8ui>(const Packet8ui& a) {
|
||||||
|
return Packet8i(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncation operations
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet8f>(const Packet8f& a) {
|
||||||
|
return _mm256_castps256_ps128(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet4d>(const Packet4d& a) {
|
||||||
|
return _mm256_castpd256_pd128(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet8i>(const Packet8i& a) {
|
||||||
|
return _mm256_castsi256_si128(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet4ui preinterpret<Packet4ui, Packet8ui>(const Packet8ui& a) {
|
||||||
|
return _mm256_castsi256_si128(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef EIGEN_VECTORIZE_AVX2
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet4ul preinterpret<Packet4ul, Packet4l>(const Packet4l& a) {
|
||||||
|
return Packet4ul(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet4l preinterpret<Packet4l, Packet4ul>(const Packet4ul& a) {
|
||||||
|
return Packet4l(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
|
template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
|
||||||
return half2float(a);
|
return half2float(a);
|
||||||
}
|
}
|
||||||
|
@ -59,6 +59,13 @@ template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet8d, Packet16i>(const Packet
|
|||||||
return cat256i(_mm512_cvttpd_epi32(a), _mm512_cvttpd_epi32(b));
|
return cat256i(_mm512_cvttpd_epi32(a), _mm512_cvttpd_epi32(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8i pcast<Packet8d, Packet8i>(const Packet8d& a) {
|
||||||
|
return _mm512_cvtpd_epi32(a);
|
||||||
|
}
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8d, Packet8f>(const Packet8d& a) {
|
||||||
|
return _mm512_cvtpd_ps(a);
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i, Packet16f>(const Packet16f& a) {
|
template<> EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i, Packet16f>(const Packet16f& a) {
|
||||||
return _mm512_castps_si512(a);
|
return _mm512_castps_si512(a);
|
||||||
}
|
}
|
||||||
@ -107,12 +114,19 @@ template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet2d>(const P
|
|||||||
return _mm512_castpd128_pd512(a);
|
return _mm512_castpd128_pd512(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16f>(const Packet16f& a) {
|
template<> EIGEN_STRONG_INLINE Packet8i preinterpret<Packet8i, Packet16i>(const Packet16i& a) {
|
||||||
return a;
|
return _mm512_castsi512_si256(a);
|
||||||
|
}
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet16i>(const Packet16i& a) {
|
||||||
|
return _mm512_castsi512_si128(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet8d>(const Packet8d& a) {
|
template<> EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet16h>(const Packet16h& a) {
|
||||||
return a;
|
return _mm256_castsi256_si128(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8bf preinterpret<Packet8bf, Packet16bf>(const Packet16bf& a) {
|
||||||
|
return _mm256_castsi256_si128(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||||
@ -191,6 +205,13 @@ struct type_casting_traits<float, half> {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet16h preinterpret<Packet16h, Packet32h>(const Packet32h& a) {
|
||||||
|
return _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
|
||||||
|
}
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet32h>(const Packet32h& a) {
|
||||||
|
return _mm256_castsi256_si128(preinterpret<Packet16h>(a));
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE Packet16f pcast<Packet32h, Packet16f>(const Packet32h& a) {
|
EIGEN_STRONG_INLINE Packet16f pcast<Packet32h, Packet16f>(const Packet32h& a) {
|
||||||
// Discard second-half of input.
|
// Discard second-half of input.
|
||||||
|
@ -1014,4 +1014,49 @@ struct hash<Eigen::half> {
|
|||||||
} // end namespace std
|
} // end namespace std
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
namespace Eigen {
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct cast_impl<float, half> {
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
static inline half run(const float& a) {
|
||||||
|
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
||||||
|
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
||||||
|
return __float2half(a);
|
||||||
|
#else
|
||||||
|
return half(a);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct cast_impl<int, half> {
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
static inline half run(const int& a) {
|
||||||
|
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
||||||
|
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
||||||
|
return __float2half(static_cast<float>(a));
|
||||||
|
#else
|
||||||
|
return half(static_cast<float>(a));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct cast_impl<half, float> {
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
static inline float run(const half& a) {
|
||||||
|
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
||||||
|
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
||||||
|
return __half2float(a);
|
||||||
|
#else
|
||||||
|
return static_cast<float>(a);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace internal
|
||||||
|
} // namespace Eigen
|
||||||
|
|
||||||
#endif // EIGEN_HALF_H
|
#endif // EIGEN_HALF_H
|
||||||
|
@ -1,116 +0,0 @@
|
|||||||
// This file is part of Eigen, a lightweight C++ template library
|
|
||||||
// for linear algebra.
|
|
||||||
//
|
|
||||||
// Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
|
||||||
// Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
|
|
||||||
//
|
|
||||||
// This Source Code Form is subject to the terms of the Mozilla
|
|
||||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
|
||||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
||||||
|
|
||||||
#ifndef EIGEN_GENERIC_TYPE_CASTING_H
|
|
||||||
#define EIGEN_GENERIC_TYPE_CASTING_H
|
|
||||||
|
|
||||||
#include "../../InternalHeaderCheck.h"
|
|
||||||
|
|
||||||
namespace Eigen {
|
|
||||||
|
|
||||||
namespace internal {
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct scalar_cast_op<float, Eigen::half> {
|
|
||||||
typedef Eigen::half result_type;
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const float& a) const {
|
|
||||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
|
||||||
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
|
||||||
return __float2half(a);
|
|
||||||
#else
|
|
||||||
return Eigen::half(a);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct functor_traits<scalar_cast_op<float, Eigen::half> >
|
|
||||||
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
|
||||||
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct scalar_cast_op<int, Eigen::half> {
|
|
||||||
typedef Eigen::half result_type;
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const int& a) const {
|
|
||||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
|
||||||
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
|
||||||
return __float2half(static_cast<float>(a));
|
|
||||||
#else
|
|
||||||
return Eigen::half(static_cast<float>(a));
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct functor_traits<scalar_cast_op<int, Eigen::half> >
|
|
||||||
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
|
||||||
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct scalar_cast_op<Eigen::half, float> {
|
|
||||||
typedef float result_type;
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::half& a) const {
|
|
||||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
|
||||||
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
|
||||||
return __half2float(a);
|
|
||||||
#else
|
|
||||||
return static_cast<float>(a);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct functor_traits<scalar_cast_op<Eigen::half, float> >
|
|
||||||
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
|
||||||
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct scalar_cast_op<float, Eigen::bfloat16> {
|
|
||||||
typedef Eigen::bfloat16 result_type;
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const {
|
|
||||||
return Eigen::bfloat16(a);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct functor_traits<scalar_cast_op<float, Eigen::bfloat16> >
|
|
||||||
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
|
||||||
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct scalar_cast_op<int, Eigen::bfloat16> {
|
|
||||||
typedef Eigen::bfloat16 result_type;
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const {
|
|
||||||
return Eigen::bfloat16(static_cast<float>(a));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct functor_traits<scalar_cast_op<int, Eigen::bfloat16> >
|
|
||||||
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
|
||||||
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct scalar_cast_op<Eigen::bfloat16, float> {
|
|
||||||
typedef float result_type;
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const {
|
|
||||||
return static_cast<float>(a);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct functor_traits<scalar_cast_op<Eigen::bfloat16, float> >
|
|
||||||
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // EIGEN_GENERIC_TYPE_CASTING_H
|
|
File diff suppressed because it is too large
Load Diff
@ -135,6 +135,13 @@ template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet2d>(const Pa
|
|||||||
return _mm_castpd_si128(a);
|
return _mm_castpd_si128(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet4ui preinterpret<Packet4ui, Packet4i>(const Packet4i& a) {
|
||||||
|
return Packet4ui(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet4ui>(const Packet4ui& a) {
|
||||||
|
return Packet4i(a);
|
||||||
|
}
|
||||||
// Disable the following code since it's broken on too many platforms / compilers.
|
// Disable the following code since it's broken on too many platforms / compilers.
|
||||||
//#elif defined(EIGEN_VECTORIZE_SSE) && (!EIGEN_ARCH_x86_64) && (!EIGEN_COMP_MSVC)
|
//#elif defined(EIGEN_VECTORIZE_SSE) && (!EIGEN_ARCH_x86_64) && (!EIGEN_COMP_MSVC)
|
||||||
#if 0
|
#if 0
|
||||||
|
@ -179,16 +179,28 @@ struct scalar_cast_op {
|
|||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const NewType operator() (const Scalar& a) const { return cast<Scalar, NewType>(a); }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const NewType operator() (const Scalar& a) const { return cast<Scalar, NewType>(a); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Scalar>
|
|
||||||
struct scalar_cast_op<Scalar, bool> {
|
|
||||||
typedef bool result_type;
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a) const { return a != Scalar(0); }
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename Scalar, typename NewType>
|
template<typename Scalar, typename NewType>
|
||||||
struct functor_traits<scalar_cast_op<Scalar,NewType> >
|
struct functor_traits<scalar_cast_op<Scalar,NewType> >
|
||||||
{ enum { Cost = is_same<Scalar, NewType>::value ? 0 : NumTraits<NewType>::AddCost, PacketAccess = false }; };
|
{ enum { Cost = is_same<Scalar, NewType>::value ? 0 : NumTraits<NewType>::AddCost, PacketAccess = false }; };
|
||||||
|
|
||||||
|
/** \internal
|
||||||
|
* `core_cast_op` serves to distinguish the vectorized implementation from that of the legacy `scalar_cast_op` for backwards
|
||||||
|
* compatibility. The manner in which packet ops are handled is defined by the specialized unary_evaluator:
|
||||||
|
* `unary_evaluator<CwiseUnaryOp<core_cast_op<SrcType, DstType>, ArgType>, IndexBased>` in CoreEvaluators.h
|
||||||
|
* Otherwise, the non-vectorized behavior is identical to that of `scalar_cast_op`
|
||||||
|
*/
|
||||||
|
template <typename SrcType, typename DstType>
|
||||||
|
struct core_cast_op : scalar_cast_op<SrcType, DstType> {};
|
||||||
|
|
||||||
|
template <typename SrcType, typename DstType>
|
||||||
|
struct functor_traits<core_cast_op<SrcType, DstType>> {
|
||||||
|
using CastingTraits = type_casting_traits<SrcType, DstType>;
|
||||||
|
enum {
|
||||||
|
Cost = is_same<SrcType, DstType>::value ? 0 : NumTraits<DstType>::AddCost,
|
||||||
|
PacketAccess = CastingTraits::VectorizedCast && (CastingTraits::SrcCoeffRatio <= 8)
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
/** \internal
|
/** \internal
|
||||||
* \brief Template functor to arithmetically shift a scalar right by a number of bits
|
* \brief Template functor to arithmetically shift a scalar right by a number of bits
|
||||||
*
|
*
|
||||||
|
@ -190,6 +190,30 @@ struct find_best_packet
|
|||||||
typedef typename find_best_packet_helper<Size,typename packet_traits<T>::type>::type type;
|
typedef typename find_best_packet_helper<Size,typename packet_traits<T>::type>::type type;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <int Size, typename PacketType,
|
||||||
|
bool Stop = (Size == unpacket_traits<PacketType>::size) ||
|
||||||
|
is_same<PacketType, typename unpacket_traits<PacketType>::half>::value>
|
||||||
|
struct find_packet_by_size_helper;
|
||||||
|
template <int Size, typename PacketType>
|
||||||
|
struct find_packet_by_size_helper<Size, PacketType, true> {
|
||||||
|
using type = PacketType;
|
||||||
|
};
|
||||||
|
template <int Size, typename PacketType>
|
||||||
|
struct find_packet_by_size_helper<Size, PacketType, false> {
|
||||||
|
using type = typename find_packet_by_size_helper<Size, typename unpacket_traits<PacketType>::half>::type;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int Size>
|
||||||
|
struct find_packet_by_size {
|
||||||
|
using type = typename find_packet_by_size_helper<Size, typename packet_traits<T>::type>::type;
|
||||||
|
static constexpr bool value = (Size == unpacket_traits<type>::size);
|
||||||
|
};
|
||||||
|
template <typename T>
|
||||||
|
struct find_packet_by_size<T, 1> {
|
||||||
|
using type = typename unpacket_traits<T>::type;
|
||||||
|
static constexpr bool value = (unpacket_traits<type>::size == 1);
|
||||||
|
};
|
||||||
|
|
||||||
#if EIGEN_MAX_STATIC_ALIGN_BYTES>0
|
#if EIGEN_MAX_STATIC_ALIGN_BYTES>0
|
||||||
constexpr inline int compute_default_alignment_helper(int ArrayBytes, int AlignmentBytes) {
|
constexpr inline int compute_default_alignment_helper(int ArrayBytes, int AlignmentBytes) {
|
||||||
if((ArrayBytes % AlignmentBytes) == 0) {
|
if((ArrayBytes % AlignmentBytes) == 0) {
|
||||||
|
@ -45,7 +45,7 @@ inline const NegativeReturnType
|
|||||||
operator-() const { return NegativeReturnType(derived()); }
|
operator-() const { return NegativeReturnType(derived()); }
|
||||||
|
|
||||||
|
|
||||||
template<class NewType> struct CastXpr { typedef typename internal::cast_return_type<Derived,const CwiseUnaryOp<internal::scalar_cast_op<Scalar, NewType>, const Derived> >::type Type; };
|
template<class NewType> struct CastXpr { typedef typename internal::cast_return_type<Derived,const CwiseUnaryOp<internal::core_cast_op<Scalar, NewType>, const Derived> >::type Type; };
|
||||||
|
|
||||||
/// \returns an expression of \c *this with the \a Scalar type casted to
|
/// \returns an expression of \c *this with the \a Scalar type casted to
|
||||||
/// \a NewScalar.
|
/// \a NewScalar.
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "main.h"
|
#include "main.h"
|
||||||
|
#include "random_without_cast_overflow.h"
|
||||||
|
|
||||||
// suppress annoying unsigned integer warnings
|
// suppress annoying unsigned integer warnings
|
||||||
template <typename Scalar, bool IsSigned = NumTraits<Scalar>::IsSigned>
|
template <typename Scalar, bool IsSigned = NumTraits<Scalar>::IsSigned>
|
||||||
@ -1213,6 +1214,109 @@ void typed_logicals_test(const ArrayType& m) {
|
|||||||
typed_logicals_test_impl<ArrayType>::run(m);
|
typed_logicals_test_impl<ArrayType>::run(m);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename SrcType, typename DstType, int RowsAtCompileTime, int ColsAtCompileTime>
|
||||||
|
struct cast_test_impl {
|
||||||
|
using SrcArray = Array<SrcType, RowsAtCompileTime, ColsAtCompileTime>;
|
||||||
|
using DstArray = Array<DstType, RowsAtCompileTime, ColsAtCompileTime>;
|
||||||
|
struct RandomOp {
|
||||||
|
inline SrcType operator()(const SrcType&) const {
|
||||||
|
return internal::random_without_cast_overflow<SrcType, DstType>::value();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr int SrcPacketSize = internal::packet_traits<SrcType>::size;
|
||||||
|
static constexpr int DstPacketSize = internal::packet_traits<DstType>::size;
|
||||||
|
static constexpr int MaxPacketSize = internal::plain_enum_max(SrcPacketSize, DstPacketSize);
|
||||||
|
|
||||||
|
// print non-mangled typenames
|
||||||
|
template <typename T>
|
||||||
|
static std::string printTypeInfo(const T&) {
|
||||||
|
if (internal::is_same<bool, T>::value)
|
||||||
|
return "bool";
|
||||||
|
else if (internal::is_same<int8_t, T>::value)
|
||||||
|
return "int8_t";
|
||||||
|
else if (internal::is_same<int16_t, T>::value)
|
||||||
|
return "int16_t";
|
||||||
|
else if (internal::is_same<int32_t, T>::value)
|
||||||
|
return "int32_t";
|
||||||
|
else if (internal::is_same<int64_t, T>::value)
|
||||||
|
return "int64_t";
|
||||||
|
else if (internal::is_same<uint8_t, T>::value)
|
||||||
|
return "uint8_t";
|
||||||
|
else if (internal::is_same<uint16_t, T>::value)
|
||||||
|
return "uint16_t";
|
||||||
|
else if (internal::is_same<uint32_t, T>::value)
|
||||||
|
return "uint32_t";
|
||||||
|
else if (internal::is_same<uint64_t, T>::value)
|
||||||
|
return "uint64_t";
|
||||||
|
else if (internal::is_same<float, T>::value)
|
||||||
|
return "float";
|
||||||
|
else if (internal::is_same<double, T>::value)
|
||||||
|
return "double";
|
||||||
|
//else if (internal::is_same<long double, T>::value)
|
||||||
|
// return "long double";
|
||||||
|
else if (internal::is_same<half, T>::value)
|
||||||
|
return "half";
|
||||||
|
else if (internal::is_same<bfloat16, T>::value)
|
||||||
|
return "bfloat16";
|
||||||
|
else
|
||||||
|
return typeid(T).name();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void run() {
|
||||||
|
const Index testRows = RowsAtCompileTime == Dynamic ? ((10 * MaxPacketSize) + 1) : RowsAtCompileTime;
|
||||||
|
const Index testCols = ColsAtCompileTime == Dynamic ? ((10 * MaxPacketSize) + 1) : ColsAtCompileTime;
|
||||||
|
const Index testSize = testRows * testCols;
|
||||||
|
const Index minTestSize = 100;
|
||||||
|
const Index repeats = numext::div_ceil(minTestSize, testSize);
|
||||||
|
SrcArray src(testRows, testCols);
|
||||||
|
DstArray dst(testRows, testCols);
|
||||||
|
for (Index repeat = 0; repeat < repeats; repeat++) {
|
||||||
|
src = src.unaryExpr(RandomOp());
|
||||||
|
dst = src.template cast<DstType>();
|
||||||
|
for (Index i = 0; i < testRows; i++)
|
||||||
|
for (Index j = 0; j < testCols; j++) {
|
||||||
|
DstType ref = internal::cast_impl<SrcType, DstType>::run(src(i, j));
|
||||||
|
bool all_nan = ((numext::isnan)(src(i, j)) && (numext::isnan)(ref) && (numext::isnan)(dst(i, j)));
|
||||||
|
bool is_equal = ref == dst(i, j);
|
||||||
|
bool pass = all_nan || is_equal;
|
||||||
|
if (!pass) {
|
||||||
|
std::cout << printTypeInfo(SrcType()) << ": [" << +src(i, j) << "] to " << printTypeInfo(DstType()) << ": ["
|
||||||
|
<< +dst(i, j) << "] != [" << +ref << "]\n";
|
||||||
|
}
|
||||||
|
VERIFY(pass);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int RowsAtCompileTime, int ColsAtCompileTime, typename... ScalarTypes>
|
||||||
|
struct cast_tests_impl {
|
||||||
|
using ScalarTuple = std::tuple<ScalarTypes...>;
|
||||||
|
static constexpr size_t ScalarTupleSize = std::tuple_size<ScalarTuple>::value;
|
||||||
|
|
||||||
|
template <size_t i = 0, size_t j = i + 1, bool Done = (i >= ScalarTupleSize - 1) || (j >= ScalarTupleSize)>
|
||||||
|
static std::enable_if_t<Done> run() {}
|
||||||
|
|
||||||
|
template <size_t i = 0, size_t j = i + 1, bool Done = (i >= ScalarTupleSize - 1) || (j >= ScalarTupleSize)>
|
||||||
|
static std::enable_if_t<!Done> run() {
|
||||||
|
using Type1 = typename std::tuple_element<i, ScalarTuple>::type;
|
||||||
|
using Type2 = typename std::tuple_element<j, ScalarTuple>::type;
|
||||||
|
cast_test_impl<Type1, Type2, RowsAtCompileTime, ColsAtCompileTime>::run();
|
||||||
|
cast_test_impl<Type2, Type1, RowsAtCompileTime, ColsAtCompileTime>::run();
|
||||||
|
static constexpr size_t next_i = (j == ScalarTupleSize - 1) ? (i + 1) : (i + 0);
|
||||||
|
static constexpr size_t next_j = (j == ScalarTupleSize - 1) ? (i + 2) : (j + 1);
|
||||||
|
run<next_i, next_j>();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// for now, remove all references to 'long double' until test passes on all platforms
|
||||||
|
template <int RowsAtCompileTime, int ColsAtCompileTime>
|
||||||
|
void cast_test() {
|
||||||
|
cast_tests_impl<RowsAtCompileTime, ColsAtCompileTime, bool, int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t,
|
||||||
|
uint32_t, uint64_t, float, double, /*long double, */half, bfloat16>::run();
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DECLARE_TEST(array_cwise)
|
EIGEN_DECLARE_TEST(array_cwise)
|
||||||
{
|
{
|
||||||
for(int i = 0; i < g_repeat; i++) {
|
for(int i = 0; i < g_repeat; i++) {
|
||||||
@ -1269,6 +1373,20 @@ EIGEN_DECLARE_TEST(array_cwise)
|
|||||||
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<double>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
|
CALL_SUBTEST_3( typed_logicals_test(ArrayX<std::complex<double>>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < g_repeat; i++) {
|
||||||
|
CALL_SUBTEST_1((cast_test<1, 1>()));
|
||||||
|
CALL_SUBTEST_2((cast_test<3, 1>()));
|
||||||
|
CALL_SUBTEST_2((cast_test<3, 3>()));
|
||||||
|
CALL_SUBTEST_3((cast_test<5, 1>()));
|
||||||
|
CALL_SUBTEST_3((cast_test<5, 5>()));
|
||||||
|
CALL_SUBTEST_4((cast_test<9, 1>()));
|
||||||
|
CALL_SUBTEST_4((cast_test<9, 9>()));
|
||||||
|
CALL_SUBTEST_5((cast_test<17, 1>()));
|
||||||
|
CALL_SUBTEST_5((cast_test<17, 17>()));
|
||||||
|
CALL_SUBTEST_6((cast_test<Dynamic, 1>()));
|
||||||
|
CALL_SUBTEST_6((cast_test<Dynamic, Dynamic>()));
|
||||||
|
}
|
||||||
|
|
||||||
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));
|
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));
|
||||||
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<float>::type, float >::value));
|
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<float>::type, float >::value));
|
||||||
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<Array2i>::type, ArrayBase<Array2i> >::value));
|
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<Array2i>::type, ArrayBase<Array2i> >::value));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user