mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Add evaluator support for diagonal products
This commit is contained in:
parent
94acccc126
commit
bffa15142c
@ -66,6 +66,7 @@ class DiagonalBase : public EigenBase<Derived>
|
|||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
inline Index cols() const { return diagonal().size(); }
|
inline Index cols() const { return diagonal().size(); }
|
||||||
|
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
/** \returns the diagonal matrix product of \c *this by the matrix \a matrix.
|
/** \returns the diagonal matrix product of \c *this by the matrix \a matrix.
|
||||||
*/
|
*/
|
||||||
template<typename MatrixDerived>
|
template<typename MatrixDerived>
|
||||||
@ -75,6 +76,15 @@ class DiagonalBase : public EigenBase<Derived>
|
|||||||
{
|
{
|
||||||
return DiagonalProduct<MatrixDerived, Derived, OnTheLeft>(matrix.derived(), derived());
|
return DiagonalProduct<MatrixDerived, Derived, OnTheLeft>(matrix.derived(), derived());
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
template<typename MatrixDerived>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
const Product<Derived,MatrixDerived,LazyProduct>
|
||||||
|
operator*(const MatrixBase<MatrixDerived> &matrix) const
|
||||||
|
{
|
||||||
|
return Product<Derived, MatrixDerived, LazyProduct>(derived(),matrix.derived());
|
||||||
|
}
|
||||||
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
inline const DiagonalWrapper<const CwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const DiagonalVectorType> >
|
inline const DiagonalWrapper<const CwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const DiagonalVectorType> >
|
||||||
@ -270,7 +280,8 @@ struct traits<DiagonalWrapper<_DiagonalVectorType> >
|
|||||||
ColsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
|
ColsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
|
||||||
MaxRowsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
|
MaxRowsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
|
||||||
MaxColsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
|
MaxColsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
|
||||||
Flags = traits<DiagonalVectorType>::Flags & LvalueBit
|
Flags = traits<DiagonalVectorType>::Flags & LvalueBit,
|
||||||
|
CoeffReadCost = traits<_DiagonalVectorType>::CoeffReadCost
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -341,6 +352,29 @@ bool MatrixBase<Derived>::isDiagonal(const RealScalar& prec) const
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef EIGEN_ENABLE_EVALUATORS
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// TODO currently a diagonal expression has the form DiagonalMatrix<> or DiagonalWrapper
|
||||||
|
// in the future diagonal-ness should be defined by the expression traits
|
||||||
|
template<typename _Scalar, int SizeAtCompileTime, int MaxSizeAtCompileTime>
|
||||||
|
struct evaluator_traits<DiagonalMatrix<_Scalar,SizeAtCompileTime,MaxSizeAtCompileTime> >
|
||||||
|
{
|
||||||
|
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
|
||||||
|
typedef DiagonalShape Shape;
|
||||||
|
static const int AssumeAliasing = 0;
|
||||||
|
};
|
||||||
|
template<typename Derived>
|
||||||
|
struct evaluator_traits<DiagonalWrapper<Derived> >
|
||||||
|
{
|
||||||
|
typedef typename storage_kind_to_evaluator_kind<typename Derived::StorageKind>::Kind Kind;
|
||||||
|
typedef DiagonalShape Shape;
|
||||||
|
static const int AssumeAliasing = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace internal
|
||||||
|
#endif // EIGEN_ENABLE_EVALUATORS
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
#endif // EIGEN_DIAGONALMATRIX_H
|
#endif // EIGEN_DIAGONALMATRIX_H
|
||||||
|
@ -309,9 +309,9 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode>
|
|||||||
};
|
};
|
||||||
|
|
||||||
// This specialization enforces the use of a coefficient-based evaluation strategy
|
// This specialization enforces the use of a coefficient-based evaluation strategy
|
||||||
// template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
// struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,LazyCoeffBasedProductMode>
|
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,LazyCoeffBasedProductMode>
|
||||||
// : generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> {};
|
: generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> {};
|
||||||
|
|
||||||
// Case 2: Evaluate coeff by coeff
|
// Case 2: Evaluate coeff by coeff
|
||||||
//
|
//
|
||||||
@ -764,6 +764,146 @@ protected:
|
|||||||
PlainObject m_result;
|
PlainObject m_result;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/***************************************************************************
|
||||||
|
* Diagonal products
|
||||||
|
***************************************************************************/
|
||||||
|
|
||||||
|
template<typename MatrixType, typename DiagonalType, typename Derived>
|
||||||
|
struct diagonal_product_evaluator_base
|
||||||
|
: evaluator_base<Derived>
|
||||||
|
{
|
||||||
|
typedef typename MatrixType::Index Index;
|
||||||
|
typedef typename MatrixType::Scalar Scalar;
|
||||||
|
typedef typename MatrixType::PacketScalar PacketScalar;
|
||||||
|
public:
|
||||||
|
diagonal_product_evaluator_base(const MatrixType &mat, const DiagonalType &diag)
|
||||||
|
: m_diagImpl(diag), m_matImpl(mat)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE const Scalar coeff(Index idx) const
|
||||||
|
{
|
||||||
|
return m_diagImpl.coeff(idx) * m_matImpl.coeff(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
template<int LoadMode>
|
||||||
|
EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::true_type) const
|
||||||
|
{
|
||||||
|
return internal::pmul(m_matImpl.template packet<LoadMode>(row, col),
|
||||||
|
internal::pset1<PacketScalar>(m_diagImpl.coeff(id)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int LoadMode>
|
||||||
|
EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::false_type) const
|
||||||
|
{
|
||||||
|
enum {
|
||||||
|
InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime,
|
||||||
|
DiagonalPacketLoadMode = (LoadMode == Aligned && (((InnerSize%16) == 0) || (int(DiagonalType::Flags)&AlignedBit)==AlignedBit) ? Aligned : Unaligned)
|
||||||
|
};
|
||||||
|
return internal::pmul(m_matImpl.template packet<LoadMode>(row, col),
|
||||||
|
m_diagImpl.template packet<DiagonalPacketLoadMode>(id));
|
||||||
|
}
|
||||||
|
|
||||||
|
typename evaluator<DiagonalType>::nestedType m_diagImpl;
|
||||||
|
typename evaluator<MatrixType>::nestedType m_matImpl;
|
||||||
|
};
|
||||||
|
|
||||||
|
// diagonal * dense
|
||||||
|
template<typename Lhs, typename Rhs, int ProductKind, int ProductTag>
|
||||||
|
struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DiagonalShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar>
|
||||||
|
: diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> >
|
||||||
|
{
|
||||||
|
typedef diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> > Base;
|
||||||
|
using Base::m_diagImpl;
|
||||||
|
using Base::m_matImpl;
|
||||||
|
using Base::coeff;
|
||||||
|
using Base::packet_impl;
|
||||||
|
typedef typename Base::Scalar Scalar;
|
||||||
|
typedef typename Base::Index Index;
|
||||||
|
typedef typename Base::PacketScalar PacketScalar;
|
||||||
|
|
||||||
|
typedef Product<Lhs, Rhs, ProductKind> XprType;
|
||||||
|
typedef typename XprType::PlainObject PlainObject;
|
||||||
|
|
||||||
|
product_evaluator(const XprType& xpr)
|
||||||
|
: Base(xpr.rhs(), xpr.lhs().diagonal())
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const
|
||||||
|
{
|
||||||
|
return m_diagImpl.coeff(row) * m_matImpl.coeff(row, col);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int LoadMode>
|
||||||
|
EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
|
||||||
|
{
|
||||||
|
enum {
|
||||||
|
StorageOrder = Rhs::Flags & RowMajorBit ? RowMajor : ColMajor
|
||||||
|
};
|
||||||
|
return this->template packet_impl<LoadMode>(row,col, row,
|
||||||
|
typename internal::conditional<int(StorageOrder)==RowMajor, internal::true_type, internal::false_type>::type());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int LoadMode>
|
||||||
|
EIGEN_STRONG_INLINE PacketScalar packet(Index idx) const
|
||||||
|
{
|
||||||
|
enum {
|
||||||
|
StorageOrder = int(Rhs::Flags) & RowMajorBit ? RowMajor : ColMajor
|
||||||
|
};
|
||||||
|
return packet<LoadMode>(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
// dense * diagonal
|
||||||
|
template<typename Lhs, typename Rhs, int ProductKind, int ProductTag>
|
||||||
|
struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DenseShape, DiagonalShape, typename Lhs::Scalar, typename Rhs::Scalar>
|
||||||
|
: diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> >
|
||||||
|
{
|
||||||
|
typedef diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> > Base;
|
||||||
|
using Base::m_diagImpl;
|
||||||
|
using Base::m_matImpl;
|
||||||
|
using Base::coeff;
|
||||||
|
using Base::packet_impl;
|
||||||
|
typedef typename Base::Scalar Scalar;
|
||||||
|
typedef typename Base::Index Index;
|
||||||
|
typedef typename Base::PacketScalar PacketScalar;
|
||||||
|
|
||||||
|
typedef Product<Lhs, Rhs, ProductKind> XprType;
|
||||||
|
typedef typename XprType::PlainObject PlainObject;
|
||||||
|
|
||||||
|
product_evaluator(const XprType& xpr)
|
||||||
|
: Base(xpr.lhs(), xpr.rhs().diagonal())
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const
|
||||||
|
{
|
||||||
|
return m_matImpl.coeff(row, col) * m_diagImpl.coeff(col);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int LoadMode>
|
||||||
|
EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
|
||||||
|
{
|
||||||
|
enum {
|
||||||
|
StorageOrder = Rhs::Flags & RowMajorBit ? RowMajor : ColMajor
|
||||||
|
};
|
||||||
|
return this->template packet_impl<LoadMode>(row,col, col,
|
||||||
|
typename internal::conditional<int(StorageOrder)==ColMajor, internal::true_type, internal::false_type>::type());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int LoadMode>
|
||||||
|
EIGEN_STRONG_INLINE PacketScalar packet(Index idx) const
|
||||||
|
{
|
||||||
|
enum {
|
||||||
|
StorageOrder = int(Rhs::Flags) & RowMajorBit ? RowMajor : ColMajor
|
||||||
|
};
|
||||||
|
return packet<LoadMode>(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
|
@ -151,19 +151,19 @@ void test_evaluators()
|
|||||||
c = a*a;
|
c = a*a;
|
||||||
copy_using_evaluator(a, prod(a,a));
|
copy_using_evaluator(a, prod(a,a));
|
||||||
VERIFY_IS_APPROX(a,c);
|
VERIFY_IS_APPROX(a,c);
|
||||||
|
|
||||||
// check compound assignment of products
|
// check compound assignment of products
|
||||||
d = c;
|
d = c;
|
||||||
add_assign_using_evaluator(c.noalias(), prod(a,b));
|
add_assign_using_evaluator(c.noalias(), prod(a,b));
|
||||||
d.noalias() += a*b;
|
d.noalias() += a*b;
|
||||||
VERIFY_IS_APPROX(c, d);
|
VERIFY_IS_APPROX(c, d);
|
||||||
|
|
||||||
d = c;
|
d = c;
|
||||||
subtract_assign_using_evaluator(c.noalias(), prod(a,b));
|
subtract_assign_using_evaluator(c.noalias(), prod(a,b));
|
||||||
d.noalias() -= a*b;
|
d.noalias() -= a*b;
|
||||||
VERIFY_IS_APPROX(c, d);
|
VERIFY_IS_APPROX(c, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// test product with all possible sizes
|
// test product with all possible sizes
|
||||||
int s = internal::random<int>(1,100);
|
int s = internal::random<int>(1,100);
|
||||||
@ -458,4 +458,15 @@ void test_evaluators()
|
|||||||
VERIFY_IS_APPROX_EVALUATOR2(B, prod(A.selfadjointView<Upper>(),A), MatrixXd(A.selfadjointView<Upper>()*A));
|
VERIFY_IS_APPROX_EVALUATOR2(B, prod(A.selfadjointView<Upper>(),A), MatrixXd(A.selfadjointView<Upper>()*A));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// test diagonal shapes
|
||||||
|
VectorXd d = VectorXd::Random(6);
|
||||||
|
MatrixXd A = MatrixXd::Random(6,6), B(6,6);
|
||||||
|
A.setRandom();B.setRandom();
|
||||||
|
|
||||||
|
VERIFY_IS_APPROX_EVALUATOR2(B, lazyprod(d.asDiagonal(),A), MatrixXd(d.asDiagonal()*A));
|
||||||
|
VERIFY_IS_APPROX_EVALUATOR2(B, lazyprod(A,d.asDiagonal()), MatrixXd(A*d.asDiagonal()));
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user