From 76faf4a9657efeed089aeedc98a769410c32d3d7 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Thu, 23 Jun 2016 14:27:20 +0200 Subject: [PATCH] Introduce a NumTraits::Literal type to be used for literals, and improve mixing type support in operations between arrays and scalars: - 2 * ArrayXcf is now optimized in the sense that the integer 2 is properly promoted to a float instead of a complex (fix a regression) - 2.1 * ArrayXi is now forbiden (previously, 2.1 was converted to 2) - This mechanism should be applicable to any custom scalar type, assuming NumTraits::Literal is properly defined (it defaults to T) --- Eigen/src/Core/NumTraits.h | 11 ++++++++--- Eigen/src/Core/util/Macros.h | 30 ++++++++---------------------- Eigen/src/Core/util/XprHelper.h | 28 ++++++++++++++++++++++++++++ test/mixingtypes.cpp | 5 +++++ test/nesting_ops.cpp | 4 ++-- 5 files changed, 51 insertions(+), 27 deletions(-) diff --git a/Eigen/src/Core/NumTraits.h b/Eigen/src/Core/NumTraits.h index e065fa714..03f64a8e9 100644 --- a/Eigen/src/Core/NumTraits.h +++ b/Eigen/src/Core/NumTraits.h @@ -22,14 +22,16 @@ namespace Eigen { * This class stores enums, typedefs and static methods giving information about a numeric type. * * The provided data consists of: - * \li A typedef \a Real, giving the "real part" type of \a T. If \a T is already real, - * then \a Real is just a typedef to \a T. If \a T is \c std::complex then \a Real + * \li A typedef \c Real, giving the "real part" type of \a T. If \a T is already real, + * then \c Real is just a typedef to \a T. If \a T is \c std::complex then \c Real * is a typedef to \a U. - * \li A typedef \a NonInteger, giving the type that should be used for operations producing non-integral values, + * \li A typedef \c NonInteger, giving the type that should be used for operations producing non-integral values, * such as quotients, square roots, etc. If \a T is a floating-point type, then this typedef just gives * \a T again. Note however that many Eigen functions such as internal::sqrt simply refuse to * take integers. Outside of a few cases, Eigen doesn't do automatic type promotion. Thus, this typedef is * only intended as a helper for code that needs to explicitly promote types. + * \li A typedef \c Literal giving the type to use for numeric literals such as "2" or "0.5". For instance, for \c std::complex, Literal is defined as \c U. + * Of course, this type must be fully compatible with \a T. In doubt, just use \a T here. * \li A typedef \a Nested giving the type to use to nest a value inside of the expression tree. If you don't know what * this means, just use \a T here. * \li An enum value \a IsComplex. It is equal to 1 if \a T is a \c std::complex @@ -84,6 +86,7 @@ template struct GenericNumTraits T >::type NonInteger; typedef T Nested; + typedef T Literal; EIGEN_DEVICE_FUNC static inline Real epsilon() @@ -145,6 +148,7 @@ template struct NumTraits > : GenericNumTraits > { typedef _Real Real; + typedef typename NumTraits<_Real>::Literal Literal; enum { IsComplex = 1, RequireInitialization = NumTraits<_Real>::RequireInitialization, @@ -168,6 +172,7 @@ struct NumTraits > typedef typename NumTraits::NonInteger NonIntegerScalar; typedef Array NonInteger; typedef ArrayType & Nested; + typedef typename NumTraits::Literal Literal; enum { IsComplex = NumTraits::IsComplex, diff --git a/Eigen/src/Core/util/Macros.h b/Eigen/src/Core/util/Macros.h index 87cc44657..6de21d2bb 100644 --- a/Eigen/src/Core/util/Macros.h +++ b/Eigen/src/Core/util/Macros.h @@ -906,35 +906,21 @@ namespace Eigen { const typename internal::plain_constant_type::type, const EXPR> #define EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(METHOD,OPNAME) \ - EIGEN_DEVICE_FUNC inline \ - const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,Scalar,OPNAME) \ - (METHOD)(const Scalar& scalar) const { \ - return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,Scalar,OPNAME)(derived(), \ - typename internal::plain_constant_type::type(derived().rows(), derived().cols(), scalar)); \ - } \ - \ template EIGEN_DEVICE_FUNC inline \ - typename internal::enable_if >::Defined, \ - const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,T,OPNAME) >::type \ + const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,typename internal::promote_scalar_arg >::Defined>::type,OPNAME) \ (METHOD)(const T& scalar) const { \ - return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,T,OPNAME)(derived(), \ - typename internal::plain_constant_type::type(derived().rows(), derived().cols(), scalar)); \ + typedef typename internal::promote_scalar_arg >::Defined>::type PromotedT; \ + return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,PromotedT,OPNAME)(derived(), \ + typename internal::plain_constant_type::type(derived().rows(), derived().cols(), internal::scalar_constant_op(scalar))); \ } #define EIGEN_MAKE_SCALAR_BINARY_OP_ONTHELEFT(METHOD,OPNAME) \ - EIGEN_DEVICE_FUNC inline friend \ - const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,Derived,OPNAME) \ - (METHOD)(const Scalar& scalar, const StorageBaseType& matrix) { \ - return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,Derived,OPNAME)( \ - typename internal::plain_constant_type::type(matrix.derived().rows(), matrix.derived().cols(), scalar), matrix.derived()); \ - } \ - \ template EIGEN_DEVICE_FUNC inline friend \ - typename internal::enable_if >::Defined, \ - const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(T,Derived,OPNAME) >::type \ + const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(typename internal::promote_scalar_arg >::Defined>::type,Derived,OPNAME) \ (METHOD)(const T& scalar, const StorageBaseType& matrix) { \ - return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(T,Derived,OPNAME)( \ - typename internal::plain_constant_type::type(matrix.derived().rows(), matrix.derived().cols(), scalar), matrix.derived()); \ + typedef typename internal::promote_scalar_arg >::Defined>::type PromotedT; \ + return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(PromotedT,Derived,OPNAME)( \ + typename internal::plain_constant_type::type(matrix.derived().rows(), matrix.derived().cols(), internal::scalar_constant_op(scalar)), matrix.derived()); \ } #define EIGEN_MAKE_SCALAR_BINARY_OP(METHOD,OPNAME) \ diff --git a/Eigen/src/Core/util/XprHelper.h b/Eigen/src/Core/util/XprHelper.h index c41c408b0..b372ac1ad 100644 --- a/Eigen/src/Core/util/XprHelper.h +++ b/Eigen/src/Core/util/XprHelper.h @@ -45,6 +45,34 @@ inline IndexDest convert_index(const IndexSrc& idx) { } +// promote_scalar_arg is an helper used in operation between an expression and a scalar, like: +// expression * scalar +// Its role is to determine how the type T of the scalar operand should be promoted given the scalar type ExprScalar of the given expression. +// The IsSupported template parameter must be provided by the caller as: ScalarBinaryOpTraits::Defined using the proper order for ExprScalar and T. +// Then the logic is as follows: +// - if the operation is natively supported as defined by IsSupported, then the scalar type is not promoted, and T is returned. +// - otherwise, NumTraits::Literal is returned if T is implicitly convertible to NumTraits::Literal AND that this does not imply a float to integer conversion. +// - In all other cases, the promoted type is not defined, and the respective operation is thus invalid and not available (SFINAE). +template::Literal>::value, + bool IsSafe = NumTraits::IsInteger || !NumTraits::Literal>::IsInteger> +struct promote_scalar_arg +{ +}; + +template +struct promote_scalar_arg +{ + typedef T type; +}; + +template +struct promote_scalar_arg +{ + typedef typename NumTraits::Literal type; +}; + //classes inheriting no_assignment_operator don't generate a default operator=. class no_assignment_operator { diff --git a/test/mixingtypes.cpp b/test/mixingtypes.cpp index fe8c16470..57ef85c32 100644 --- a/test/mixingtypes.cpp +++ b/test/mixingtypes.cpp @@ -79,6 +79,11 @@ template void mixingtypes(int size = SizeAtCompileType) VERIFY_MIX_SCALAR(vf * scf , vf.template cast >() * scf); VERIFY_MIX_SCALAR(scd * vd , scd * vd.template cast >()); + VERIFY_MIX_SCALAR(vcf * 2 , vcf * complex(2)); + VERIFY_MIX_SCALAR(vcf * 2.1 , vcf * complex(2.1)); + VERIFY_MIX_SCALAR(2 * vcf, vcf * complex(2)); + VERIFY_MIX_SCALAR(2.1 * vcf , vcf * complex(2.1)); + // check scalar quotients VERIFY_MIX_SCALAR(vcf / sf , vcf / complex(sf)); VERIFY_MIX_SCALAR(vf / scf , vf.template cast >() / scf); diff --git a/test/nesting_ops.cpp b/test/nesting_ops.cpp index 2f5025305..a419b0e44 100644 --- a/test/nesting_ops.cpp +++ b/test/nesting_ops.cpp @@ -75,8 +75,8 @@ template void run_nesting_ops_2(const MatrixType& _m) } else { - VERIFY( verify_eval_type<1>(2*m1, 2*m1) ); - VERIFY( verify_eval_type<2>(2*m1, m1) ); + VERIFY( verify_eval_type<2>(2*m1, 2*m1) ); + VERIFY( verify_eval_type<3>(2*m1, m1) ); } VERIFY( verify_eval_type<2>(m1+m1, m1+m1) ); VERIFY( verify_eval_type<3>(m1+m1, m1) );