mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-24 02:29:33 +08:00
bugfix in DiagonalProduct: a "DiagonalProduct<SomeXpr>" expression
is now evaluated as a "DiagonalProduct<Matrix<SomeXpr::Eval> >". Note that currently this only happens in DiagonalProduct.
This commit is contained in:
parent
ba9a53f9c6
commit
beabf008b0
@ -62,10 +62,19 @@ class DiagonalMatrix : ei_no_assignment_operator,
|
||||
|
||||
EIGEN_GENERIC_PUBLIC_INTERFACE(DiagonalMatrix)
|
||||
|
||||
// needed to evaluate a DiagonalMatrix<Xpr> to a DiagonalMatrix<NestByValue<Vector> >
|
||||
template<typename OtherCoeffsVectorType>
|
||||
inline DiagonalMatrix(const DiagonalMatrix<OtherCoeffsVectorType>& other) : m_coeffs(other.diagonal())
|
||||
{
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(CoeffsVectorType);
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherCoeffsVectorType);
|
||||
ei_assert(m_coeffs.size() > 0);
|
||||
}
|
||||
|
||||
inline DiagonalMatrix(const CoeffsVectorType& coeffs) : m_coeffs(coeffs)
|
||||
{
|
||||
ei_assert(CoeffsVectorType::IsVectorAtCompileTime
|
||||
&& coeffs.size() > 0);
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(CoeffsVectorType);
|
||||
ei_assert(coeffs.size() > 0);
|
||||
}
|
||||
|
||||
inline int rows() const { return m_coeffs.size(); }
|
||||
@ -76,6 +85,8 @@ class DiagonalMatrix : ei_no_assignment_operator,
|
||||
return row == col ? m_coeffs.coeff(row) : static_cast<Scalar>(0);
|
||||
}
|
||||
|
||||
inline const CoeffsVectorType& diagonal() const { return m_coeffs; }
|
||||
|
||||
protected:
|
||||
const typename CoeffsVectorType::Nested m_coeffs;
|
||||
};
|
||||
|
@ -26,12 +26,31 @@
|
||||
#ifndef EIGEN_DIAGONALPRODUCT_H
|
||||
#define EIGEN_DIAGONALPRODUCT_H
|
||||
|
||||
/** \internal Specialization of ei_nested for DiagonalMatrix.
|
||||
* Unlike ei_nested, if the argument is a DiagonalMatrix and if it must be evaluated,
|
||||
* then it evaluated to a DiagonalMatrix having its own argument evaluated.
|
||||
*/
|
||||
template<typename T, int N> struct ei_nested_diagonal : ei_nested<T,N> {};
|
||||
template<typename T, int N> struct ei_nested_diagonal<DiagonalMatrix<T>,N >
|
||||
: ei_nested<DiagonalMatrix<T>, N, DiagonalMatrix<NestByValue<typename ei_eval<T>::type> > >
|
||||
{};
|
||||
|
||||
// specialization of ProductReturnType
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct ProductReturnType<Lhs,Rhs,DiagonalProduct>
|
||||
{
|
||||
typedef typename ei_nested_diagonal<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
|
||||
typedef typename ei_nested_diagonal<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
|
||||
|
||||
typedef Product<LhsNested, RhsNested, DiagonalProduct> Type;
|
||||
};
|
||||
|
||||
template<typename LhsNested, typename RhsNested>
|
||||
struct ei_traits<Product<LhsNested, RhsNested, DiagonalProduct> >
|
||||
{
|
||||
// clean the nested types:
|
||||
typedef typename ei_unconst<typename ei_unref<LhsNested>::type>::type _LhsNested;
|
||||
typedef typename ei_unconst<typename ei_unref<RhsNested>::type>::type _RhsNested;
|
||||
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
|
||||
typedef typename ei_cleantype<RhsNested>::type _RhsNested;
|
||||
typedef typename _LhsNested::Scalar Scalar;
|
||||
|
||||
enum {
|
||||
|
@ -62,6 +62,7 @@ struct ProductReturnType
|
||||
};
|
||||
|
||||
// cache friendly specialization
|
||||
// note that there is a DiagonalProduct specialization in DiagonalProduct.h
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
|
||||
{
|
||||
@ -77,7 +78,8 @@ struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
|
||||
/* Helper class to determine the type of the product, can be either:
|
||||
* - NormalProduct
|
||||
* - CacheFriendlyProduct
|
||||
* - NormalProduct
|
||||
* - DiagonalProduct
|
||||
* - SparseProduct
|
||||
*/
|
||||
template<typename Lhs, typename Rhs> struct ei_product_mode
|
||||
{
|
||||
|
@ -33,4 +33,13 @@ void test_product_large()
|
||||
CALL_SUBTEST( product(MatrixXcf(ei_random<int>(1,50), ei_random<int>(1,50))) );
|
||||
CALL_SUBTEST( product(Matrix<float,Dynamic,Dynamic,RowMajor>(ei_random<int>(1,320), ei_random<int>(1,320))) );
|
||||
}
|
||||
|
||||
{
|
||||
// test a specific issue in DiagonalProduct
|
||||
int N = 1000000;
|
||||
VectorXf v = VectorXf::Ones(N);
|
||||
MatrixXf m = MatrixXf::Ones(N,3);
|
||||
m = (v+v).asDiagonal() * m;
|
||||
VERIFY_IS_APPROX(m, MatrixXf::Constant(N,3,2));
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user