Add evaluator support for diagonal products

This commit is contained in:
Gael Guennebaud 2014-02-17 16:10:55 +01:00
parent 94acccc126
commit bffa15142c
3 changed files with 192 additions and 7 deletions

View File

@ -66,6 +66,7 @@ class DiagonalBase : public EigenBase<Derived>
EIGEN_DEVICE_FUNC
inline Index cols() const { return diagonal().size(); }
#ifndef EIGEN_TEST_EVALUATORS
/** \returns the diagonal matrix product of \c *this by the matrix \a matrix.
*/
template<typename MatrixDerived>
@ -75,6 +76,15 @@ class DiagonalBase : public EigenBase<Derived>
{
return DiagonalProduct<MatrixDerived, Derived, OnTheLeft>(matrix.derived(), derived());
}
#else
template<typename MatrixDerived>
EIGEN_DEVICE_FUNC
const Product<Derived,MatrixDerived,LazyProduct>
operator*(const MatrixBase<MatrixDerived> &matrix) const
{
return Product<Derived, MatrixDerived, LazyProduct>(derived(),matrix.derived());
}
#endif // EIGEN_TEST_EVALUATORS
EIGEN_DEVICE_FUNC
inline const DiagonalWrapper<const CwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const DiagonalVectorType> >
@ -270,7 +280,8 @@ struct traits<DiagonalWrapper<_DiagonalVectorType> >
ColsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
MaxRowsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
MaxColsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
Flags = traits<DiagonalVectorType>::Flags & LvalueBit
Flags = traits<DiagonalVectorType>::Flags & LvalueBit,
CoeffReadCost = traits<_DiagonalVectorType>::CoeffReadCost
};
};
}
@ -341,6 +352,29 @@ bool MatrixBase<Derived>::isDiagonal(const RealScalar& prec) const
return true;
}
#ifdef EIGEN_ENABLE_EVALUATORS
namespace internal {
// TODO currently a diagonal expression has the form DiagonalMatrix<> or DiagonalWrapper
// in the future diagonal-ness should be defined by the expression traits
template<typename _Scalar, int SizeAtCompileTime, int MaxSizeAtCompileTime>
struct evaluator_traits<DiagonalMatrix<_Scalar,SizeAtCompileTime,MaxSizeAtCompileTime> >
{
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
typedef DiagonalShape Shape;
static const int AssumeAliasing = 0;
};
template<typename Derived>
struct evaluator_traits<DiagonalWrapper<Derived> >
{
typedef typename storage_kind_to_evaluator_kind<typename Derived::StorageKind>::Kind Kind;
typedef DiagonalShape Shape;
static const int AssumeAliasing = 0;
};
} // namespace internal
#endif // EIGEN_ENABLE_EVALUATORS
} // end namespace Eigen
#endif // EIGEN_DIAGONALMATRIX_H

View File

@ -309,9 +309,9 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode>
};
// This specialization enforces the use of a coefficient-based evaluation strategy
// template<typename Lhs, typename Rhs>
// struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,LazyCoeffBasedProductMode>
// : generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> {};
template<typename Lhs, typename Rhs>
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,LazyCoeffBasedProductMode>
: generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> {};
// Case 2: Evaluate coeff by coeff
//
@ -764,6 +764,146 @@ protected:
PlainObject m_result;
};
/***************************************************************************
* Diagonal products
***************************************************************************/
template<typename MatrixType, typename DiagonalType, typename Derived>
struct diagonal_product_evaluator_base
: evaluator_base<Derived>
{
typedef typename MatrixType::Index Index;
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::PacketScalar PacketScalar;
public:
diagonal_product_evaluator_base(const MatrixType &mat, const DiagonalType &diag)
: m_diagImpl(diag), m_matImpl(mat)
{
}
EIGEN_STRONG_INLINE const Scalar coeff(Index idx) const
{
return m_diagImpl.coeff(idx) * m_matImpl.coeff(idx);
}
protected:
template<int LoadMode>
EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::true_type) const
{
return internal::pmul(m_matImpl.template packet<LoadMode>(row, col),
internal::pset1<PacketScalar>(m_diagImpl.coeff(id)));
}
template<int LoadMode>
EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::false_type) const
{
enum {
InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime,
DiagonalPacketLoadMode = (LoadMode == Aligned && (((InnerSize%16) == 0) || (int(DiagonalType::Flags)&AlignedBit)==AlignedBit) ? Aligned : Unaligned)
};
return internal::pmul(m_matImpl.template packet<LoadMode>(row, col),
m_diagImpl.template packet<DiagonalPacketLoadMode>(id));
}
typename evaluator<DiagonalType>::nestedType m_diagImpl;
typename evaluator<MatrixType>::nestedType m_matImpl;
};
// diagonal * dense
template<typename Lhs, typename Rhs, int ProductKind, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DiagonalShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar>
: diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> >
{
typedef diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> > Base;
using Base::m_diagImpl;
using Base::m_matImpl;
using Base::coeff;
using Base::packet_impl;
typedef typename Base::Scalar Scalar;
typedef typename Base::Index Index;
typedef typename Base::PacketScalar PacketScalar;
typedef Product<Lhs, Rhs, ProductKind> XprType;
typedef typename XprType::PlainObject PlainObject;
product_evaluator(const XprType& xpr)
: Base(xpr.rhs(), xpr.lhs().diagonal())
{
}
EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const
{
return m_diagImpl.coeff(row) * m_matImpl.coeff(row, col);
}
template<int LoadMode>
EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
{
enum {
StorageOrder = Rhs::Flags & RowMajorBit ? RowMajor : ColMajor
};
return this->template packet_impl<LoadMode>(row,col, row,
typename internal::conditional<int(StorageOrder)==RowMajor, internal::true_type, internal::false_type>::type());
}
template<int LoadMode>
EIGEN_STRONG_INLINE PacketScalar packet(Index idx) const
{
enum {
StorageOrder = int(Rhs::Flags) & RowMajorBit ? RowMajor : ColMajor
};
return packet<LoadMode>(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx);
}
};
// dense * diagonal
template<typename Lhs, typename Rhs, int ProductKind, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DenseShape, DiagonalShape, typename Lhs::Scalar, typename Rhs::Scalar>
: diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> >
{
typedef diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> > Base;
using Base::m_diagImpl;
using Base::m_matImpl;
using Base::coeff;
using Base::packet_impl;
typedef typename Base::Scalar Scalar;
typedef typename Base::Index Index;
typedef typename Base::PacketScalar PacketScalar;
typedef Product<Lhs, Rhs, ProductKind> XprType;
typedef typename XprType::PlainObject PlainObject;
product_evaluator(const XprType& xpr)
: Base(xpr.lhs(), xpr.rhs().diagonal())
{
}
EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const
{
return m_matImpl.coeff(row, col) * m_diagImpl.coeff(col);
}
template<int LoadMode>
EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
{
enum {
StorageOrder = Rhs::Flags & RowMajorBit ? RowMajor : ColMajor
};
return this->template packet_impl<LoadMode>(row,col, col,
typename internal::conditional<int(StorageOrder)==ColMajor, internal::true_type, internal::false_type>::type());
}
template<int LoadMode>
EIGEN_STRONG_INLINE PacketScalar packet(Index idx) const
{
enum {
StorageOrder = int(Rhs::Flags) & RowMajorBit ? RowMajor : ColMajor
};
return packet<LoadMode>(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx);
}
};
} // end namespace internal

View File

@ -151,19 +151,19 @@ void test_evaluators()
c = a*a;
copy_using_evaluator(a, prod(a,a));
VERIFY_IS_APPROX(a,c);
// check compound assignment of products
d = c;
add_assign_using_evaluator(c.noalias(), prod(a,b));
d.noalias() += a*b;
VERIFY_IS_APPROX(c, d);
d = c;
subtract_assign_using_evaluator(c.noalias(), prod(a,b));
d.noalias() -= a*b;
VERIFY_IS_APPROX(c, d);
}
{
// test product with all possible sizes
int s = internal::random<int>(1,100);
@ -458,4 +458,15 @@ void test_evaluators()
VERIFY_IS_APPROX_EVALUATOR2(B, prod(A.selfadjointView<Upper>(),A), MatrixXd(A.selfadjointView<Upper>()*A));
}
{
// test diagonal shapes
VectorXd d = VectorXd::Random(6);
MatrixXd A = MatrixXd::Random(6,6), B(6,6);
A.setRandom();B.setRandom();
VERIFY_IS_APPROX_EVALUATOR2(B, lazyprod(d.asDiagonal(),A), MatrixXd(d.asDiagonal()*A));
VERIFY_IS_APPROX_EVALUATOR2(B, lazyprod(A,d.asDiagonal()), MatrixXd(A*d.asDiagonal()));
}
}