From 81c1336ab8135d81b261728c3560fc3bfffbb5c5 Mon Sep 17 00:00:00 2001 From: Hauke Heibel Date: Wed, 7 Mar 2012 08:58:42 +0100 Subject: [PATCH] Added support for component-wise pow (equivalent to Matlab's operator .^). --- Eigen/src/Core/Functors.h | 36 +++++++++++++++++++++----------- Eigen/src/Core/GlobalFunctions.h | 36 ++++++++++++++++++++------------ test/array.cpp | 12 ++++++++--- 3 files changed, 56 insertions(+), 28 deletions(-) diff --git a/Eigen/src/Core/Functors.h b/Eigen/src/Core/Functors.h index 38e1b57dc..cbda77c2d 100644 --- a/Eigen/src/Core/Functors.h +++ b/Eigen/src/Core/Functors.h @@ -178,6 +178,18 @@ struct functor_traits > { enum { Cost = 5 * NumTraits::MulCost, PacketAccess=0 }; }; +/** \internal + * \brief Template functor to compute the pow of two scalars + */ +template struct scalar_binary_pow_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_binary_pow_op) + inline Scalar operator() (const Scalar& a, const OtherScalar& b) const { return internal::pow(a, b); } +}; +template +struct functor_traits > { + enum { Cost = 5 * NumTraits::MulCost, PacketAccess = false }; +}; + // other binary functors: /** \internal @@ -813,18 +825,18 @@ template struct functor_traits > { enum { Cost = 5 * NumTraits::MulCost, PacketAccess = false }; }; -/** \internal - * \brief Template functor to compute the quotient between a scalar and array entries. - * \sa class CwiseUnaryOp, Cwise::inverse() - */ -template -struct scalar_inverse_mult_op { - scalar_inverse_mult_op(const Scalar& other) : m_other(other) {} - inline Scalar operator() (const Scalar& a) const { return m_other / a; } - template - inline const Packet packetOp(const Packet& a) const - { return internal::pdiv(pset1(m_other),a); } - Scalar m_other; +/** \internal + * \brief Template functor to compute the quotient between a scalar and array entries. + * \sa class CwiseUnaryOp, Cwise::inverse() + */ +template +struct scalar_inverse_mult_op { + scalar_inverse_mult_op(const Scalar& other) : m_other(other) {} + inline Scalar operator() (const Scalar& a) const { return m_other / a; } + template + inline const Packet packetOp(const Packet& a) const + { return internal::pdiv(pset1(m_other),a); } + Scalar m_other; }; /** \internal diff --git a/Eigen/src/Core/GlobalFunctions.h b/Eigen/src/Core/GlobalFunctions.h index 472e101a5..94605252d 100644 --- a/Eigen/src/Core/GlobalFunctions.h +++ b/Eigen/src/Core/GlobalFunctions.h @@ -66,24 +66,34 @@ namespace std template inline const Eigen::CwiseUnaryOp, const Derived> - pow(const Eigen::ArrayBase& x, const typename Derived::Scalar& exponent) { \ - return x.derived().pow(exponent); \ + pow(const Eigen::ArrayBase& x, const typename Derived::Scalar& exponent) { + return x.derived().pow(exponent); + } + + template + inline const Eigen::CwiseBinaryOp, const Derived, const Derived> + pow(const Eigen::ArrayBase& x, const Eigen::ArrayBase& exponents) + { + return Eigen::CwiseBinaryOp, const Derived, const Derived>( + x.derived(), + exponents.derived() + ); } } namespace Eigen { - /** - * \brief Component-wise division of a scalar by array elements. - **/ - template - inline const Eigen::CwiseUnaryOp, const Derived> - operator/(typename Derived::Scalar s, const Eigen::ArrayBase& a) - { - return Eigen::CwiseUnaryOp, const Derived>( - a.derived(), - Eigen::internal::scalar_inverse_mult_op(s) - ); + /** + * \brief Component-wise division of a scalar by array elements. + **/ + template + inline const Eigen::CwiseUnaryOp, const Derived> + operator/(typename Derived::Scalar s, const Eigen::ArrayBase& a) + { + return Eigen::CwiseUnaryOp, const Derived>( + a.derived(), + Eigen::internal::scalar_inverse_mult_op(s) + ); } namespace internal diff --git a/test/array.cpp b/test/array.cpp index 481862266..0f38aff65 100644 --- a/test/array.cpp +++ b/test/array.cpp @@ -45,9 +45,6 @@ template void array(const ArrayType& m) Scalar s1 = internal::random(), s2 = internal::random(); - // scalar by array division - VERIFY_IS_APPROX(s1/m1, s1 * m1.inverse()); - // scalar addition VERIFY_IS_APPROX(m1 + s1, s1 + m1); VERIFY_IS_APPROX(m1 + s1, ArrayType::Constant(rows,cols,s1) + m1); @@ -212,9 +209,18 @@ template void array_real(const ArrayType& m) VERIFY_IS_APPROX(m1.pow(2), m1.square()); VERIFY_IS_APPROX(std::pow(m1,2), m1.square()); + + ArrayType exponents = ArrayType::Constant(rows, cols, RealScalar(2)); + VERIFY_IS_APPROX(std::pow(m1,exponents), m1.square()); + m3 = m1.abs(); VERIFY_IS_APPROX(m3.pow(RealScalar(0.5)), m3.sqrt()); VERIFY_IS_APPROX(std::pow(m3,RealScalar(0.5)), m3.sqrt()); + + // scalar by array division + const auto t1 = (s1/m1).eval(); + const auto t2 = (s1 * m1.inverse()).eval(); + VERIFY_IS_APPROX(s1/m1, s1 * m1.inverse()); } template void array_complex(const ArrayType& m)