mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-20 00:29:38 +08:00
Fix test basic stuff
- Guard fundamental types that are not available pre C++11 - Separate subsequent angle brackets >> by spaces - Allow casting of Eigen::half and Eigen::bfloat16 to complex types
This commit is contained in:
parent
8889a2c1c6
commit
ee4715ff48
@ -27,6 +27,20 @@ namespace Eigen {
|
|||||||
|
|
||||||
struct bfloat16;
|
struct bfloat16;
|
||||||
|
|
||||||
|
// explicit conversion operators are no available before C++11 so we first cast
|
||||||
|
// bfloat16 to RealScalar rather than to std::complex<RealScalar> directly
|
||||||
|
#if !EIGEN_HAS_CXX11
|
||||||
|
namespace internal {
|
||||||
|
template <typename RealScalar>
|
||||||
|
struct cast_impl<bfloat16, std::complex<RealScalar> > {
|
||||||
|
EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(const bfloat16 &x)
|
||||||
|
{
|
||||||
|
return static_cast<std::complex<RealScalar> >(static_cast<RealScalar>(x));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace internal
|
||||||
|
#endif // EIGEN_HAS_CXX11
|
||||||
|
|
||||||
namespace bfloat16_impl {
|
namespace bfloat16_impl {
|
||||||
|
|
||||||
// Make our own __bfloat16_raw definition.
|
// Make our own __bfloat16_raw definition.
|
||||||
|
@ -36,7 +36,7 @@
|
|||||||
#ifndef EIGEN_HALF_H
|
#ifndef EIGEN_HALF_H
|
||||||
#define EIGEN_HALF_H
|
#define EIGEN_HALF_H
|
||||||
|
|
||||||
#if __cplusplus > 199711L
|
#if EIGEN_HAS_CXX11
|
||||||
#define EIGEN_EXPLICIT_CAST(tgt_type) explicit operator tgt_type()
|
#define EIGEN_EXPLICIT_CAST(tgt_type) explicit operator tgt_type()
|
||||||
#else
|
#else
|
||||||
#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type()
|
#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type()
|
||||||
@ -48,6 +48,20 @@ namespace Eigen {
|
|||||||
|
|
||||||
struct half;
|
struct half;
|
||||||
|
|
||||||
|
// explicit conversion operators are no available before C++11 so we first cast
|
||||||
|
// half to RealScalar rather than to std::complex<RealScalar> directly
|
||||||
|
#if !EIGEN_HAS_CXX11
|
||||||
|
namespace internal {
|
||||||
|
template <typename RealScalar>
|
||||||
|
struct cast_impl<half, std::complex<RealScalar> > {
|
||||||
|
EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(const half &x)
|
||||||
|
{
|
||||||
|
return static_cast<std::complex<RealScalar> >(static_cast<RealScalar>(x));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace internal
|
||||||
|
#endif // EIGEN_HAS_CXX11
|
||||||
|
|
||||||
namespace half_impl {
|
namespace half_impl {
|
||||||
|
|
||||||
#if !defined(EIGEN_HAS_GPU_FP16)
|
#if !defined(EIGEN_HAS_GPU_FP16)
|
||||||
@ -737,7 +751,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half __ldg(const Eigen::half* ptr)
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#if defined(EIGEN_GPU_COMPILE_PHASE)
|
#if defined(EIGEN_GPU_COMPILE_PHASE)
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
namespace numext {
|
namespace numext {
|
||||||
|
@ -195,41 +195,73 @@ template<typename MatrixType> void basicStuffComplex(const MatrixType& m)
|
|||||||
VERIFY(!static_cast<const MatrixType&>(cm).imag().isZero());
|
VERIFY(!static_cast<const MatrixType&>(cm).imag().isZero());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename SrcScalar, typename TgtScalar, bool SrcIsHalfOrBF16 = (internal::is_same<SrcScalar, half>::value || internal::is_same<SrcScalar, bfloat16>::value)> struct casting_test;
|
||||||
|
|
||||||
|
|
||||||
template<typename SrcScalar, typename TgtScalar>
|
template<typename SrcScalar, typename TgtScalar>
|
||||||
void casting_test()
|
struct casting_test<SrcScalar, TgtScalar, false> {
|
||||||
{
|
static void run() {
|
||||||
Matrix<SrcScalar,4,4> m;
|
Matrix<SrcScalar,4,4> m;
|
||||||
for (int i=0; i<m.rows(); ++i) {
|
for (int i=0; i<m.rows(); ++i) {
|
||||||
for (int j=0; j<m.cols(); ++j) {
|
for (int j=0; j<m.cols(); ++j) {
|
||||||
m(i, j) = internal::random_without_cast_overflow<SrcScalar,TgtScalar>::value();
|
m(i, j) = internal::random_without_cast_overflow<SrcScalar,TgtScalar>::value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>();
|
||||||
|
for (int i=0; i<m.rows(); ++i) {
|
||||||
|
for (int j=0; j<m.cols(); ++j) {
|
||||||
|
VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(m(i, j)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>();
|
};
|
||||||
for (int i=0; i<m.rows(); ++i) {
|
|
||||||
for (int j=0; j<m.cols(); ++j) {
|
template<typename SrcScalar, typename TgtScalar>
|
||||||
VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(m(i, j)));
|
struct casting_test<SrcScalar, TgtScalar, true> {
|
||||||
|
static void run() {
|
||||||
|
casting_test<SrcScalar, TgtScalar, false>::run();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename SrcScalar, typename RealScalar>
|
||||||
|
struct casting_test<SrcScalar, std::complex<RealScalar>, true> {
|
||||||
|
static void run() {
|
||||||
|
typedef std::complex<RealScalar> TgtScalar;
|
||||||
|
Matrix<SrcScalar,4,4> m;
|
||||||
|
for (int i=0; i<m.rows(); ++i) {
|
||||||
|
for (int j=0; j<m.cols(); ++j) {
|
||||||
|
m(i, j) = internal::random_without_cast_overflow<SrcScalar, TgtScalar>::value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>();
|
||||||
|
for (int i=0; i<m.rows(); ++i) {
|
||||||
|
for (int j=0; j<m.cols(); ++j) {
|
||||||
|
VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(static_cast<RealScalar>(m(i, j))));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
template<typename SrcScalar, typename EnableIf = void>
|
template<typename SrcScalar, typename EnableIf = void>
|
||||||
struct casting_test_runner {
|
struct casting_test_runner {
|
||||||
static void run() {
|
static void run() {
|
||||||
casting_test<SrcScalar, bool>();
|
casting_test<SrcScalar, bool>::run();
|
||||||
casting_test<SrcScalar, int8_t>();
|
casting_test<SrcScalar, int8_t>::run();
|
||||||
casting_test<SrcScalar, uint8_t>();
|
casting_test<SrcScalar, uint8_t>::run();
|
||||||
casting_test<SrcScalar, int16_t>();
|
casting_test<SrcScalar, int16_t>::run();
|
||||||
casting_test<SrcScalar, uint16_t>();
|
casting_test<SrcScalar, uint16_t>::run();
|
||||||
casting_test<SrcScalar, int32_t>();
|
casting_test<SrcScalar, int32_t>::run();
|
||||||
casting_test<SrcScalar, uint32_t>();
|
casting_test<SrcScalar, uint32_t>::run();
|
||||||
casting_test<SrcScalar, int64_t>();
|
#if EIGEN_HAS_CXX11
|
||||||
casting_test<SrcScalar, uint64_t>();
|
casting_test<SrcScalar, int64_t>::run();
|
||||||
casting_test<SrcScalar, half>();
|
casting_test<SrcScalar, uint64_t>::run();
|
||||||
casting_test<SrcScalar, bfloat16>();
|
#endif
|
||||||
casting_test<SrcScalar, float>();
|
casting_test<SrcScalar, half>::run();
|
||||||
casting_test<SrcScalar, double>();
|
casting_test<SrcScalar, bfloat16>::run();
|
||||||
casting_test<SrcScalar, std::complex<float>>();
|
casting_test<SrcScalar, float>::run();
|
||||||
casting_test<SrcScalar, std::complex<double>>();
|
casting_test<SrcScalar, double>::run();
|
||||||
|
casting_test<SrcScalar, std::complex<float> >::run();
|
||||||
|
casting_test<SrcScalar, std::complex<double> >::run();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -238,10 +270,10 @@ struct casting_test_runner<SrcScalar, typename internal::enable_if<(NumTraits<Sr
|
|||||||
{
|
{
|
||||||
static void run() {
|
static void run() {
|
||||||
// Only a few casts from std::complex<T> are defined.
|
// Only a few casts from std::complex<T> are defined.
|
||||||
casting_test<SrcScalar, half>();
|
casting_test<SrcScalar, half>::run();
|
||||||
casting_test<SrcScalar, bfloat16>();
|
casting_test<SrcScalar, bfloat16>::run();
|
||||||
casting_test<SrcScalar, std::complex<float>>();
|
casting_test<SrcScalar, std::complex<float> >::run();
|
||||||
casting_test<SrcScalar, std::complex<double>>();
|
casting_test<SrcScalar, std::complex<double> >::run();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -253,14 +285,16 @@ void casting_all() {
|
|||||||
casting_test_runner<uint16_t>::run();
|
casting_test_runner<uint16_t>::run();
|
||||||
casting_test_runner<int32_t>::run();
|
casting_test_runner<int32_t>::run();
|
||||||
casting_test_runner<uint32_t>::run();
|
casting_test_runner<uint32_t>::run();
|
||||||
|
#if EIGEN_HAS_CXX11
|
||||||
casting_test_runner<int64_t>::run();
|
casting_test_runner<int64_t>::run();
|
||||||
casting_test_runner<uint64_t>::run();
|
casting_test_runner<uint64_t>::run();
|
||||||
|
#endif
|
||||||
casting_test_runner<half>::run();
|
casting_test_runner<half>::run();
|
||||||
casting_test_runner<bfloat16>::run();
|
casting_test_runner<bfloat16>::run();
|
||||||
casting_test_runner<float>::run();
|
casting_test_runner<float>::run();
|
||||||
casting_test_runner<double>::run();
|
casting_test_runner<double>::run();
|
||||||
casting_test_runner<std::complex<float>>::run();
|
casting_test_runner<std::complex<float> >::run();
|
||||||
casting_test_runner<std::complex<double>>::run();
|
casting_test_runner<std::complex<double> >::run();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user