Added support for component-wise pow (equivalent to Matlab's operator .^).

This commit is contained in:
Hauke Heibel 2012-03-07 08:58:42 +01:00
parent aee0db2e2c
commit 81c1336ab8
3 changed files with 56 additions and 28 deletions

View File

@ -178,6 +178,18 @@ struct functor_traits<scalar_hypot_op<Scalar> > {
enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess=0 }; enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess=0 };
}; };
/** \internal
* \brief Template functor to compute the pow of two scalars
*/
template<typename Scalar, typename OtherScalar> struct scalar_binary_pow_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_binary_pow_op)
inline Scalar operator() (const Scalar& a, const OtherScalar& b) const { return internal::pow(a, b); }
};
template<typename Scalar, typename OtherScalar>
struct functor_traits<scalar_binary_pow_op<Scalar,OtherScalar> > {
enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
};
// other binary functors: // other binary functors:
/** \internal /** \internal
@ -813,18 +825,18 @@ template<typename Scalar>
struct functor_traits<scalar_pow_op<Scalar> > struct functor_traits<scalar_pow_op<Scalar> >
{ enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; }; { enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; };
/** \internal /** \internal
* \brief Template functor to compute the quotient between a scalar and array entries. * \brief Template functor to compute the quotient between a scalar and array entries.
* \sa class CwiseUnaryOp, Cwise::inverse() * \sa class CwiseUnaryOp, Cwise::inverse()
*/ */
template<typename Scalar> template<typename Scalar>
struct scalar_inverse_mult_op { struct scalar_inverse_mult_op {
scalar_inverse_mult_op(const Scalar& other) : m_other(other) {} scalar_inverse_mult_op(const Scalar& other) : m_other(other) {}
inline Scalar operator() (const Scalar& a) const { return m_other / a; } inline Scalar operator() (const Scalar& a) const { return m_other / a; }
template<typename Packet> template<typename Packet>
inline const Packet packetOp(const Packet& a) const inline const Packet packetOp(const Packet& a) const
{ return internal::pdiv(pset1<Packet>(m_other),a); } { return internal::pdiv(pset1<Packet>(m_other),a); }
Scalar m_other; Scalar m_other;
}; };
/** \internal /** \internal

View File

@ -66,24 +66,34 @@ namespace std
template<typename Derived> template<typename Derived>
inline const Eigen::CwiseUnaryOp<Eigen::internal::scalar_pow_op<typename Derived::Scalar>, const Derived> inline const Eigen::CwiseUnaryOp<Eigen::internal::scalar_pow_op<typename Derived::Scalar>, const Derived>
pow(const Eigen::ArrayBase<Derived>& x, const typename Derived::Scalar& exponent) { \ pow(const Eigen::ArrayBase<Derived>& x, const typename Derived::Scalar& exponent) {
return x.derived().pow(exponent); \ return x.derived().pow(exponent);
}
template<typename Derived>
inline const Eigen::CwiseBinaryOp<Eigen::internal::scalar_binary_pow_op<typename Derived::Scalar, typename Derived::Scalar>, const Derived, const Derived>
pow(const Eigen::ArrayBase<Derived>& x, const Eigen::ArrayBase<Derived>& exponents)
{
return Eigen::CwiseBinaryOp<Eigen::internal::scalar_binary_pow_op<typename Derived::Scalar, typename Derived::Scalar>, const Derived, const Derived>(
x.derived(),
exponents.derived()
);
} }
} }
namespace Eigen namespace Eigen
{ {
/** /**
* \brief Component-wise division of a scalar by array elements. * \brief Component-wise division of a scalar by array elements.
**/ **/
template <typename Derived> template <typename Derived>
inline const Eigen::CwiseUnaryOp<Eigen::internal::scalar_inverse_mult_op<typename Derived::Scalar>, const Derived> inline const Eigen::CwiseUnaryOp<Eigen::internal::scalar_inverse_mult_op<typename Derived::Scalar>, const Derived>
operator/(typename Derived::Scalar s, const Eigen::ArrayBase<Derived>& a) operator/(typename Derived::Scalar s, const Eigen::ArrayBase<Derived>& a)
{ {
return Eigen::CwiseUnaryOp<Eigen::internal::scalar_inverse_mult_op<typename Derived::Scalar>, const Derived>( return Eigen::CwiseUnaryOp<Eigen::internal::scalar_inverse_mult_op<typename Derived::Scalar>, const Derived>(
a.derived(), a.derived(),
Eigen::internal::scalar_inverse_mult_op<typename Derived::Scalar>(s) Eigen::internal::scalar_inverse_mult_op<typename Derived::Scalar>(s)
); );
} }
namespace internal namespace internal

View File

@ -45,9 +45,6 @@ template<typename ArrayType> void array(const ArrayType& m)
Scalar s1 = internal::random<Scalar>(), Scalar s1 = internal::random<Scalar>(),
s2 = internal::random<Scalar>(); s2 = internal::random<Scalar>();
// scalar by array division
VERIFY_IS_APPROX(s1/m1, s1 * m1.inverse());
// scalar addition // scalar addition
VERIFY_IS_APPROX(m1 + s1, s1 + m1); VERIFY_IS_APPROX(m1 + s1, s1 + m1);
VERIFY_IS_APPROX(m1 + s1, ArrayType::Constant(rows,cols,s1) + m1); VERIFY_IS_APPROX(m1 + s1, ArrayType::Constant(rows,cols,s1) + m1);
@ -212,9 +209,18 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX(m1.pow(2), m1.square()); VERIFY_IS_APPROX(m1.pow(2), m1.square());
VERIFY_IS_APPROX(std::pow(m1,2), m1.square()); VERIFY_IS_APPROX(std::pow(m1,2), m1.square());
ArrayType exponents = ArrayType::Constant(rows, cols, RealScalar(2));
VERIFY_IS_APPROX(std::pow(m1,exponents), m1.square());
m3 = m1.abs(); m3 = m1.abs();
VERIFY_IS_APPROX(m3.pow(RealScalar(0.5)), m3.sqrt()); VERIFY_IS_APPROX(m3.pow(RealScalar(0.5)), m3.sqrt());
VERIFY_IS_APPROX(std::pow(m3,RealScalar(0.5)), m3.sqrt()); VERIFY_IS_APPROX(std::pow(m3,RealScalar(0.5)), m3.sqrt());
// scalar by array division
const auto t1 = (s1/m1).eval();
const auto t2 = (s1 * m1.inverse()).eval();
VERIFY_IS_APPROX(s1/m1, s1 * m1.inverse());
} }
template<typename ArrayType> void array_complex(const ArrayType& m) template<typename ArrayType> void array_complex(const ArrayType& m)