fix a couple of issues related to recent products

This commit is contained in:
Gael Guennebaud 2009-07-28 18:11:30 +02:00
parent 1ba35248e9
commit 864171df5c
9 changed files with 35 additions and 18 deletions

View File

@ -173,14 +173,16 @@ template<typename Derived> class MapBase
using Base::operator=;
using Base::operator*=;
using Base::operator+=;
using Base::operator-=;
template<typename Lhs,typename Rhs>
Derived& operator+=(const Flagged<Product<Lhs,Rhs,CacheFriendlyProduct>, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
{ return Base::operator+=(other); }
template<typename Lhs,typename Rhs>
Derived& operator-=(const Flagged<Product<Lhs,Rhs,CacheFriendlyProduct>, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
{ return Base::operator-=(other); }
// template<typename Lhs,typename Rhs>
// Derived& operator+=(const Flagged<Product<Lhs,Rhs,CacheFriendlyProduct>, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
// { return Base::operator+=(other); }
//
// template<typename Lhs,typename Rhs>
// Derived& operator-=(const Flagged<Product<Lhs,Rhs,CacheFriendlyProduct>, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
// { return Base::operator-=(other); }
template<typename OtherDerived>
Derived& operator+=(const MatrixBase<OtherDerived>& other)

View File

@ -334,7 +334,10 @@ class Matrix
template<typename OtherDerived,typename OtherEvalType>
EIGEN_STRONG_INLINE Matrix& operator=(const ReturnByValue<OtherDerived,OtherEvalType>& func)
{ return Base::operator=(func); }
{
resize(func.rows(), func.cols());
return Base::operator=(func);
}
using Base::operator +=;
using Base::operator -=;
@ -438,6 +441,7 @@ class Matrix
EIGEN_STRONG_INLINE Matrix(const ReturnByValue<OtherDerived,OtherEvalType>& other)
{
_check_template_params();
resize(other.rows(), other.cols());
other.evalTo(*this);
}
/** Destructor */

View File

@ -409,15 +409,8 @@ template<typename Derived> class MatrixBase
const typename ProductReturnType<Derived,OtherDerived>::Type
operator*(const MatrixBase<OtherDerived> &other) const;
/** replaces \c *this by \c *this * \a other.
*
* \returns a reference to \c *this
*/
template<typename OtherDerived>
Derived& operator*=(const MultiplierBase<OtherDerived>& other)
{
return *this = *this * other.derived();
}
Derived& operator*=(const MultiplierBase<OtherDerived>& other);
template<typename DiagonalDerived>
const DiagonalProduct<Derived, DiagonalDerived, DiagonalOnTheRight>

View File

@ -294,7 +294,7 @@ MatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
template<typename Derived>
template<typename OtherDerived>
inline Derived &
MatrixBase<Derived>::operator*=(const MatrixBase<OtherDerived> &other)
MatrixBase<Derived>::operator*=(const MultiplierBase<OtherDerived> &other)
{
return derived() = derived() * other.derived();
}

View File

@ -76,6 +76,8 @@ template<typename Functor, typename _Scalar,int _Rows,int _Cols,int _Options,int
{ EvalType res; evalTo(res); dst += res; }
template<typename Dest> inline void _subTo(Dest& dst) const
{ EvalType res; evalTo(res); dst -= res; }
inline int rows() const { return static_cast<const Functor* const>(this)->rows(); }
inline int cols() const { return static_cast<const Functor* const>(this)->cols(); }
};
template<typename Derived>

View File

@ -228,6 +228,9 @@ struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,0,true>
: m_lhs(lhs), m_rhs(rhs)
{}
inline int rows() const { return m_lhs.rows(); }
inline int cols() const { return m_lhs.cols(); }
template<typename Dest> inline void _addTo(Dest& dst) const
{ evalTo(dst,1); }
template<typename Dest> inline void _subTo(Dest& dst) const
@ -278,6 +281,9 @@ struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,false>
: m_lhs(lhs), m_rhs(rhs)
{}
inline int rows() const { return m_lhs.rows(); }
inline int cols() const { return m_lhs.cols(); }
typedef typename Lhs::Scalar Scalar;
typedef typename Lhs::Nested LhsNested;

View File

@ -330,6 +330,9 @@ struct ei_triangular_product_returntype<Mode,LhsIsTriangular,Lhs,false,Rhs,false
: m_lhs(lhs), m_rhs(rhs)
{}
inline int rows() const { return m_lhs.rows(); }
inline int cols() const { return m_lhs.cols(); }
typedef typename Lhs::Scalar Scalar;
typedef typename Lhs::Nested LhsNested;

View File

@ -141,6 +141,9 @@ struct ei_triangular_product_returntype<Mode,true,Lhs,false,Rhs,true>
: m_lhs(lhs), m_rhs(rhs)
{}
inline int rows() const { return m_lhs.rows(); }
inline int cols() const { return m_lhs.cols(); }
template<typename Dest> inline void _addTo(Dest& dst) const
{ evalTo(dst,1); }
template<typename Dest> inline void _subTo(Dest& dst) const

View File

@ -102,9 +102,13 @@ template<typename MatrixType> void product_notemporary(const MatrixType& m)
VERIFY_EVALUATION_COUNT(( m3.block(r0,r0,r1,r1) += m1.block(r0,r0,r1,r1).template selfadjointView<UpperTriangular>() * (s1*m2.block(c0,r0,c1,r1)) ), 0);
VERIFY_EVALUATION_COUNT(( m3.block(r0,r0,r1,r1) = m1.block(r0,r0,r1,r1).template selfadjointView<UpperTriangular>() * m2.block(c0,r0,c1,r1) ), 0);
VERIFY_EVALUATION_COUNT(( m3 = m1.block(r0,r0,r1,r1).template selfadjointView<LowerTriangular>() * m2.block(c0,r0,c1,r1) ), 0);
VERIFY_EVALUATION_COUNT( m3.template selfadjointView<LowerTriangular>().rankUpdate(m2.adjoint()), 0);
m3.resize(1,1);
VERIFY_EVALUATION_COUNT(( m3 = m1.block(r0,r0,r1,r1).template selfadjointView<LowerTriangular>() * m2.block(c0,r0,c1,r1) ), 0);
m3.resize(1,1);
VERIFY_EVALUATION_COUNT(( m3 = m1.block(r0,r0,r1,r1).template triangularView<UnitUpperTriangular>() * m2.block(c0,r0,c1,r1) ), 0);
}
void test_product_notemporary()