mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Add evaluator support for diagonal products
This commit is contained in:
parent
94acccc126
commit
bffa15142c
@ -66,6 +66,7 @@ class DiagonalBase : public EigenBase<Derived>
|
||||
EIGEN_DEVICE_FUNC
|
||||
inline Index cols() const { return diagonal().size(); }
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
/** \returns the diagonal matrix product of \c *this by the matrix \a matrix.
|
||||
*/
|
||||
template<typename MatrixDerived>
|
||||
@ -75,6 +76,15 @@ class DiagonalBase : public EigenBase<Derived>
|
||||
{
|
||||
return DiagonalProduct<MatrixDerived, Derived, OnTheLeft>(matrix.derived(), derived());
|
||||
}
|
||||
#else
|
||||
template<typename MatrixDerived>
|
||||
EIGEN_DEVICE_FUNC
|
||||
const Product<Derived,MatrixDerived,LazyProduct>
|
||||
operator*(const MatrixBase<MatrixDerived> &matrix) const
|
||||
{
|
||||
return Product<Derived, MatrixDerived, LazyProduct>(derived(),matrix.derived());
|
||||
}
|
||||
#endif // EIGEN_TEST_EVALUATORS
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
inline const DiagonalWrapper<const CwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const DiagonalVectorType> >
|
||||
@ -270,7 +280,8 @@ struct traits<DiagonalWrapper<_DiagonalVectorType> >
|
||||
ColsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
|
||||
MaxRowsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
|
||||
MaxColsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
|
||||
Flags = traits<DiagonalVectorType>::Flags & LvalueBit
|
||||
Flags = traits<DiagonalVectorType>::Flags & LvalueBit,
|
||||
CoeffReadCost = traits<_DiagonalVectorType>::CoeffReadCost
|
||||
};
|
||||
};
|
||||
}
|
||||
@ -341,6 +352,29 @@ bool MatrixBase<Derived>::isDiagonal(const RealScalar& prec) const
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef EIGEN_ENABLE_EVALUATORS
|
||||
namespace internal {
|
||||
|
||||
// TODO currently a diagonal expression has the form DiagonalMatrix<> or DiagonalWrapper
|
||||
// in the future diagonal-ness should be defined by the expression traits
|
||||
template<typename _Scalar, int SizeAtCompileTime, int MaxSizeAtCompileTime>
|
||||
struct evaluator_traits<DiagonalMatrix<_Scalar,SizeAtCompileTime,MaxSizeAtCompileTime> >
|
||||
{
|
||||
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
|
||||
typedef DiagonalShape Shape;
|
||||
static const int AssumeAliasing = 0;
|
||||
};
|
||||
template<typename Derived>
|
||||
struct evaluator_traits<DiagonalWrapper<Derived> >
|
||||
{
|
||||
typedef typename storage_kind_to_evaluator_kind<typename Derived::StorageKind>::Kind Kind;
|
||||
typedef DiagonalShape Shape;
|
||||
static const int AssumeAliasing = 0;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
#endif // EIGEN_ENABLE_EVALUATORS
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_DIAGONALMATRIX_H
|
||||
|
@ -309,9 +309,9 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode>
|
||||
};
|
||||
|
||||
// This specialization enforces the use of a coefficient-based evaluation strategy
|
||||
// template<typename Lhs, typename Rhs>
|
||||
// struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,LazyCoeffBasedProductMode>
|
||||
// : generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> {};
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,LazyCoeffBasedProductMode>
|
||||
: generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> {};
|
||||
|
||||
// Case 2: Evaluate coeff by coeff
|
||||
//
|
||||
@ -764,6 +764,146 @@ protected:
|
||||
PlainObject m_result;
|
||||
};
|
||||
|
||||
/***************************************************************************
|
||||
* Diagonal products
|
||||
***************************************************************************/
|
||||
|
||||
template<typename MatrixType, typename DiagonalType, typename Derived>
|
||||
struct diagonal_product_evaluator_base
|
||||
: evaluator_base<Derived>
|
||||
{
|
||||
typedef typename MatrixType::Index Index;
|
||||
typedef typename MatrixType::Scalar Scalar;
|
||||
typedef typename MatrixType::PacketScalar PacketScalar;
|
||||
public:
|
||||
diagonal_product_evaluator_base(const MatrixType &mat, const DiagonalType &diag)
|
||||
: m_diagImpl(diag), m_matImpl(mat)
|
||||
{
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE const Scalar coeff(Index idx) const
|
||||
{
|
||||
return m_diagImpl.coeff(idx) * m_matImpl.coeff(idx);
|
||||
}
|
||||
|
||||
protected:
|
||||
template<int LoadMode>
|
||||
EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::true_type) const
|
||||
{
|
||||
return internal::pmul(m_matImpl.template packet<LoadMode>(row, col),
|
||||
internal::pset1<PacketScalar>(m_diagImpl.coeff(id)));
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::false_type) const
|
||||
{
|
||||
enum {
|
||||
InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime,
|
||||
DiagonalPacketLoadMode = (LoadMode == Aligned && (((InnerSize%16) == 0) || (int(DiagonalType::Flags)&AlignedBit)==AlignedBit) ? Aligned : Unaligned)
|
||||
};
|
||||
return internal::pmul(m_matImpl.template packet<LoadMode>(row, col),
|
||||
m_diagImpl.template packet<DiagonalPacketLoadMode>(id));
|
||||
}
|
||||
|
||||
typename evaluator<DiagonalType>::nestedType m_diagImpl;
|
||||
typename evaluator<MatrixType>::nestedType m_matImpl;
|
||||
};
|
||||
|
||||
// diagonal * dense
|
||||
template<typename Lhs, typename Rhs, int ProductKind, int ProductTag>
|
||||
struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DiagonalShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar>
|
||||
: diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> >
|
||||
{
|
||||
typedef diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> > Base;
|
||||
using Base::m_diagImpl;
|
||||
using Base::m_matImpl;
|
||||
using Base::coeff;
|
||||
using Base::packet_impl;
|
||||
typedef typename Base::Scalar Scalar;
|
||||
typedef typename Base::Index Index;
|
||||
typedef typename Base::PacketScalar PacketScalar;
|
||||
|
||||
typedef Product<Lhs, Rhs, ProductKind> XprType;
|
||||
typedef typename XprType::PlainObject PlainObject;
|
||||
|
||||
product_evaluator(const XprType& xpr)
|
||||
: Base(xpr.rhs(), xpr.lhs().diagonal())
|
||||
{
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const
|
||||
{
|
||||
return m_diagImpl.coeff(row) * m_matImpl.coeff(row, col);
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
|
||||
{
|
||||
enum {
|
||||
StorageOrder = Rhs::Flags & RowMajorBit ? RowMajor : ColMajor
|
||||
};
|
||||
return this->template packet_impl<LoadMode>(row,col, row,
|
||||
typename internal::conditional<int(StorageOrder)==RowMajor, internal::true_type, internal::false_type>::type());
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
EIGEN_STRONG_INLINE PacketScalar packet(Index idx) const
|
||||
{
|
||||
enum {
|
||||
StorageOrder = int(Rhs::Flags) & RowMajorBit ? RowMajor : ColMajor
|
||||
};
|
||||
return packet<LoadMode>(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// dense * diagonal
|
||||
template<typename Lhs, typename Rhs, int ProductKind, int ProductTag>
|
||||
struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DenseShape, DiagonalShape, typename Lhs::Scalar, typename Rhs::Scalar>
|
||||
: diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> >
|
||||
{
|
||||
typedef diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> > Base;
|
||||
using Base::m_diagImpl;
|
||||
using Base::m_matImpl;
|
||||
using Base::coeff;
|
||||
using Base::packet_impl;
|
||||
typedef typename Base::Scalar Scalar;
|
||||
typedef typename Base::Index Index;
|
||||
typedef typename Base::PacketScalar PacketScalar;
|
||||
|
||||
typedef Product<Lhs, Rhs, ProductKind> XprType;
|
||||
typedef typename XprType::PlainObject PlainObject;
|
||||
|
||||
product_evaluator(const XprType& xpr)
|
||||
: Base(xpr.lhs(), xpr.rhs().diagonal())
|
||||
{
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const
|
||||
{
|
||||
return m_matImpl.coeff(row, col) * m_diagImpl.coeff(col);
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
|
||||
{
|
||||
enum {
|
||||
StorageOrder = Rhs::Flags & RowMajorBit ? RowMajor : ColMajor
|
||||
};
|
||||
return this->template packet_impl<LoadMode>(row,col, col,
|
||||
typename internal::conditional<int(StorageOrder)==ColMajor, internal::true_type, internal::false_type>::type());
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
EIGEN_STRONG_INLINE PacketScalar packet(Index idx) const
|
||||
{
|
||||
enum {
|
||||
StorageOrder = int(Rhs::Flags) & RowMajorBit ? RowMajor : ColMajor
|
||||
};
|
||||
return packet<LoadMode>(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
@ -151,19 +151,19 @@ void test_evaluators()
|
||||
c = a*a;
|
||||
copy_using_evaluator(a, prod(a,a));
|
||||
VERIFY_IS_APPROX(a,c);
|
||||
|
||||
|
||||
// check compound assignment of products
|
||||
d = c;
|
||||
add_assign_using_evaluator(c.noalias(), prod(a,b));
|
||||
d.noalias() += a*b;
|
||||
VERIFY_IS_APPROX(c, d);
|
||||
|
||||
|
||||
d = c;
|
||||
subtract_assign_using_evaluator(c.noalias(), prod(a,b));
|
||||
d.noalias() -= a*b;
|
||||
VERIFY_IS_APPROX(c, d);
|
||||
}
|
||||
|
||||
|
||||
{
|
||||
// test product with all possible sizes
|
||||
int s = internal::random<int>(1,100);
|
||||
@ -458,4 +458,15 @@ void test_evaluators()
|
||||
VERIFY_IS_APPROX_EVALUATOR2(B, prod(A.selfadjointView<Upper>(),A), MatrixXd(A.selfadjointView<Upper>()*A));
|
||||
|
||||
}
|
||||
|
||||
{
|
||||
// test diagonal shapes
|
||||
VectorXd d = VectorXd::Random(6);
|
||||
MatrixXd A = MatrixXd::Random(6,6), B(6,6);
|
||||
A.setRandom();B.setRandom();
|
||||
|
||||
VERIFY_IS_APPROX_EVALUATOR2(B, lazyprod(d.asDiagonal(),A), MatrixXd(d.asDiagonal()*A));
|
||||
VERIFY_IS_APPROX_EVALUATOR2(B, lazyprod(A,d.asDiagonal()), MatrixXd(A*d.asDiagonal()));
|
||||
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user