Implement expr+scalar, scalar+expr, expr-scalar, and scalar-expr as binary expressions, and generalize supported scalar types.

The following functors are now deprecated: scalar_add_op, scalar_sub_op, and scalar_rsub_op.
This commit is contained in:
Gael Guennebaud 2016-06-14 12:06:10 +02:00
parent 756ac4a93d
commit a8c08e8b8e
5 changed files with 70 additions and 47 deletions

View File

@ -32,7 +32,13 @@ template<typename LhsScalar,typename RhsScalar>
struct scalar_sum_op : binary_op_base<LhsScalar,RhsScalar> struct scalar_sum_op : binary_op_base<LhsScalar,RhsScalar>
{ {
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_sum_op>::ReturnType result_type; typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_sum_op>::ReturnType result_type;
#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN
EIGEN_EMPTY_STRUCT_CTOR(scalar_sum_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_sum_op)
#else
scalar_sum_op() {
EIGEN_SCALAR_BINARY_OP_PLUGIN
}
#endif
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return a + b; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return a + b; }
template<typename Packet> template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
@ -315,7 +321,13 @@ template<typename LhsScalar,typename RhsScalar>
struct scalar_difference_op : binary_op_base<LhsScalar,RhsScalar> struct scalar_difference_op : binary_op_base<LhsScalar,RhsScalar>
{ {
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_difference_op>::ReturnType result_type; typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_difference_op>::ReturnType result_type;
#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN
EIGEN_EMPTY_STRUCT_CTOR(scalar_difference_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_difference_op)
#else
scalar_difference_op() {
EIGEN_SCALAR_BINARY_OP_PLUGIN
}
#endif
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return a - b; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return a - b; }
template<typename Packet> template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
@ -584,7 +596,7 @@ struct functor_traits<scalar_add_op<Scalar> >
/** \internal /** \internal
* \brief Template functor to subtract a fixed scalar to another one * \brief Template functor to subtract a fixed scalar to another one
* \sa class CwiseUnaryOp, Array::operator-, struct scalar_add_op, struct scalar_rsub_op * \sa class CwiseUnaryOp, Array::operator-, struct scalar_add_op
*/ */
template<typename Scalar> template<typename Scalar>
struct scalar_sub_op { struct scalar_sub_op {
@ -600,23 +612,6 @@ template<typename Scalar>
struct functor_traits<scalar_sub_op<Scalar> > struct functor_traits<scalar_sub_op<Scalar> >
{ enum { Cost = NumTraits<Scalar>::AddCost, PacketAccess = packet_traits<Scalar>::HasAdd }; }; { enum { Cost = NumTraits<Scalar>::AddCost, PacketAccess = packet_traits<Scalar>::HasAdd }; };
/** \internal
* \brief Template functor to subtract a scalar to fixed another one
* \sa class CwiseUnaryOp, Array::operator-, struct scalar_add_op, struct scalar_sub_op
*/
template<typename Scalar>
struct scalar_rsub_op {
EIGEN_DEVICE_FUNC inline scalar_rsub_op(const scalar_rsub_op& other) : m_other(other.m_other) { }
EIGEN_DEVICE_FUNC inline scalar_rsub_op(const Scalar& other) : m_other(other) { }
EIGEN_DEVICE_FUNC inline Scalar operator() (const Scalar& a) const { return m_other - a; }
template <typename Packet>
EIGEN_DEVICE_FUNC inline const Packet packetOp(const Packet& a) const
{ return internal::psub(pset1<Packet>(m_other), a); }
const Scalar m_other;
};
template<typename Scalar>
struct functor_traits<scalar_rsub_op<Scalar> >
{ enum { Cost = NumTraits<Scalar>::AddCost, PacketAccess = packet_traits<Scalar>::HasAdd }; };
/** \internal /** \internal
* \brief Template functor to raise a scalar to a power * \brief Template functor to raise a scalar to a power

View File

@ -202,7 +202,6 @@ template<typename Scalar> struct scalar_square_op;
template<typename Scalar> struct scalar_cube_op; template<typename Scalar> struct scalar_cube_op;
template<typename Scalar, typename NewType> struct scalar_cast_op; template<typename Scalar, typename NewType> struct scalar_cast_op;
template<typename Scalar> struct scalar_random_op; template<typename Scalar> struct scalar_random_op;
template<typename Scalar> struct scalar_add_op;
template<typename Scalar> struct scalar_constant_op; template<typename Scalar> struct scalar_constant_op;
template<typename Scalar> struct scalar_identity_op; template<typename Scalar> struct scalar_identity_op;
template<typename Scalar,bool iscpx> struct scalar_sign_op; template<typename Scalar,bool iscpx> struct scalar_sign_op;

View File

@ -192,48 +192,49 @@ EIGEN_MAKE_CWISE_COMP_OP(operator!=, NEQ)
#undef EIGEN_MAKE_CWISE_COMP_R_OP #undef EIGEN_MAKE_CWISE_COMP_R_OP
// scalar addition // scalar addition
#ifndef EIGEN_PARSED_BY_DOXYGEN
EIGEN_MAKE_SCALAR_BINARY_OP(operator+,sum);
#else
/** \returns an expression of \c *this with each coeff incremented by the constant \a scalar /** \returns an expression of \c *this with each coeff incremented by the constant \a scalar
*
* \tparam T is the scalar type of \a scalar. It must be compatible with the scalar type of the given expression.
* *
* Example: \include Cwise_plus.cpp * Example: \include Cwise_plus.cpp
* Output: \verbinclude Cwise_plus.out * Output: \verbinclude Cwise_plus.out
* *
* \sa operator+=(), operator-() * \sa operator+=(), operator-()
*/ */
EIGEN_DEVICE_FUNC template<typename T>
inline const CwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived> const CwiseBinaryOp<internal::scalar_sum_op<Scalar,T>,Derived,Constant<Scalar> > operator+(const T& scalar) const;
operator+(const Scalar& scalar) const /** \returns an expression of \a expr with each coeff incremented by the constant \a scalar
{ *
return CwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived>(derived(), internal::scalar_add_op<Scalar>(scalar)); * \tparam T is the scalar type of \a scalar. It must be compatible with the scalar type of the given expression.
} */
template<typename T> friend
EIGEN_DEVICE_FUNC const CwiseBinaryOp<internal::scalar_sum_op<T,Scalar>,Constant<Scalar>,Derived> operator+(const T& scalar, const StorageBaseType& expr);
friend inline const CwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived> #endif
operator+(const Scalar& scalar,const EIGEN_CURRENT_STORAGE_BASE_CLASS<Derived>& other)
{
return other + scalar;
}
#ifndef EIGEN_PARSED_BY_DOXYGEN
EIGEN_MAKE_SCALAR_BINARY_OP(operator-,difference);
#else
/** \returns an expression of \c *this with each coeff decremented by the constant \a scalar /** \returns an expression of \c *this with each coeff decremented by the constant \a scalar
*
* \tparam T is the scalar type of \a scalar. It must be compatible with the scalar type of the given expression.
* *
* Example: \include Cwise_minus.cpp * Example: \include Cwise_minus.cpp
* Output: \verbinclude Cwise_minus.out * Output: \verbinclude Cwise_minus.out
* *
* \sa operator+(), operator-=() * \sa operator+=(), operator-()
*/ */
EIGEN_DEVICE_FUNC template<typename T>
inline const CwiseUnaryOp<internal::scalar_sub_op<Scalar>, const Derived> const CwiseBinaryOp<internal::scalar_difference_op<Scalar,T>,Derived,Constant<Scalar> > operator-(const T& scalar) const;
operator-(const Scalar& scalar) const /** \returns an expression of the constant matrix of value \a scalar decremented by the coefficients of \a expr
{ *
return CwiseUnaryOp<internal::scalar_sub_op<Scalar>, const Derived>(derived(), internal::scalar_sub_op<Scalar>(scalar));; * \tparam T is the scalar type of \a scalar. It must be compatible with the scalar type of the given expression.
} */
template<typename T> friend
EIGEN_DEVICE_FUNC const CwiseBinaryOp<internal::scalar_difference_op<T,Scalar>,Constant<Scalar>,Derived> operator-(const T& scalar, const StorageBaseType& expr);
friend inline const CwiseUnaryOp<internal::scalar_rsub_op<Scalar>, const Derived> #endif
operator-(const Scalar& scalar,const EIGEN_CURRENT_STORAGE_BASE_CLASS<Derived>& other)
{
return CwiseUnaryOp<internal::scalar_rsub_op<Scalar>, const Derived>(other.derived(), internal::scalar_rsub_op<Scalar>(scalar));;
}
/** \returns an expression of the coefficient-wise && operator of *this and \a other /** \returns an expression of the coefficient-wise && operator of *this and \a other
* *

View File

@ -93,6 +93,22 @@ template<typename MatrixType> void real_complex(DenseIndex rows = MatrixType::Ro
g_called = false; g_called = false;
VERIFY_IS_APPROX(m1/s, m1/Scalar(s)); VERIFY_IS_APPROX(m1/s, m1/Scalar(s));
VERIFY(g_called && "matrix<complex> / real not properly optimized"); VERIFY(g_called && "matrix<complex> / real not properly optimized");
g_called = false;
VERIFY_IS_APPROX(s+m1.array(), Scalar(s)+m1.array());
VERIFY(g_called && "real + matrix<complex> not properly optimized");
g_called = false;
VERIFY_IS_APPROX(m1.array()+s, m1.array()+Scalar(s));
VERIFY(g_called && "matrix<complex> + real not properly optimized");
g_called = false;
VERIFY_IS_APPROX(s-m1.array(), Scalar(s)-m1.array());
VERIFY(g_called && "real - matrix<complex> not properly optimized");
g_called = false;
VERIFY_IS_APPROX(m1.array()-s, m1.array()-Scalar(s));
VERIFY(g_called && "matrix<complex> - real not properly optimized");
} }
void test_linearstructure() void test_linearstructure()

View File

@ -75,6 +75,18 @@ template<int SizeAtCompileType> void mixingtypes(int size = SizeAtCompileType)
VERIFY_IS_APPROX(vcf / sf , vcf / complex<float>(sf)); VERIFY_IS_APPROX(vcf / sf , vcf / complex<float>(sf));
VERIFY_IS_APPROX(vf / scf , vf.template cast<complex<float> >() / scf); VERIFY_IS_APPROX(vf / scf , vf.template cast<complex<float> >() / scf);
// check scalar increment
VERIFY_IS_APPROX(vcf.array() + sf , vcf.array() + complex<float>(sf));
VERIFY_IS_APPROX(sd + vcd.array(), complex<double>(sd) + vcd.array());
VERIFY_IS_APPROX(vf.array() + scf, vf.template cast<complex<float> >().array() + scf);
VERIFY_IS_APPROX(scd + vd.array() , scd + vd.template cast<complex<double> >().array());
// check scalar subtractions
VERIFY_IS_APPROX(vcf.array() - sf , vcf.array() - complex<float>(sf));
VERIFY_IS_APPROX(sd - vcd.array(), complex<double>(sd) - vcd.array());
VERIFY_IS_APPROX(vf.array() - scf, vf.template cast<complex<float> >().array() - scf);
VERIFY_IS_APPROX(scd - vd.array() , scd - vd.template cast<complex<double> >().array());
// check dot product // check dot product
vf.dot(vf); vf.dot(vf);
#if 0 // we get other compilation errors here than just static asserts #if 0 // we get other compilation errors here than just static asserts